From 1e45aa425be2697d04bf40680ffd8ae82ea583ca Mon Sep 17 00:00:00 2001 From: jyelon Date: Fri, 18 Mar 2022 18:16:21 -0400 Subject: [PATCH] fix accidental reformatting of driver-common.cpp --- luprex/core/cpp/driver-common.cpp | 441 +++++++++++------------------- 1 file changed, 161 insertions(+), 280 deletions(-) diff --git a/luprex/core/cpp/driver-common.cpp b/luprex/core/cpp/driver-common.cpp index 5c888f41..ed45e865 100644 --- a/luprex/core/cpp/driver-common.cpp +++ b/luprex/core/cpp/driver-common.cpp @@ -4,42 +4,33 @@ static MonoClock monoclock; -namespace util -{ - double profiling_clock() - { - return monoclock.get(); - } -} +namespace util { -static void if_error_print_and_exit(const std::string &str) -{ - if (!str.empty()) - { - std::cerr << std::endl - << "error: " << str << std::endl; +double profiling_clock() { return monoclock.get(); } + +} // namespace util + +static void if_error_print_and_exit(const std::string &str) { + if (!str.empty()) { + std::cerr << std::endl << "error: " << str << std::endl; exit(1); } } -static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) -{ +static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) { FILE *f = fopen(fn, "r"); - if (f == 0) - { + if (f == 0) { err = std::string("cannot read file") + fn; buf[0] = 0; return std::string_view(buf, 0); } int nread = fread(buf, 1, bufsize, f); - if (nread < 0) - { + if (nread < 0) { err = std::string("cannot read file: ") + fn; buf[0] = 0; return std::string_view(buf, 0); } - if (nread == bufsize) - { + if (nread == bufsize) { err = std::string("file too large: ") + fn; buf[0] = 0; return std::string_view(buf, 0); @@ -48,66 +39,50 @@ static std::string_view read_file(const char *fn, char *buf, int bufsize, std::s return std::string_view(buf, nread); } -struct SSL_CTX_Deleter -{ - void operator()(SSL_CTX *ctx) - { - SSL_CTX_free(ctx); - } +struct SSL_CTX_Deleter { + void operator()(SSL_CTX *ctx) { SSL_CTX_free(ctx); } }; using UniqueSSLCTX = std::unique_ptr; -static std::string ssl_errors_string(bool lastonly = true) -{ +static std::string ssl_errors_string(bool lastonly = true) { std::string err; const char *file, *data, *func; int line, flags; - while (true) - { - unsigned long code = ERR_get_error_all(&file, &line, &func, &data, &flags); - if (code == 0) - break; + while (true) { + unsigned long code = + ERR_get_error_all(&file, &line, &func, &data, &flags); + if (code == 0) break; std::string reason; - if (ERR_SYSTEM_ERROR(code)) - { + if (ERR_SYSTEM_ERROR(code)) { reason = strerror_str(ERR_GET_REASON(code)); - } - else - { + } else { const char *rc = ERR_reason_error_string(code); reason = (rc == nullptr) ? "unknown" : rc; } - if (err.empty() || lastonly) - { + if (err.empty() || lastonly) { err = reason; - } - else - { + } else { err = err + ", " + reason; } - if (data != nullptr) - { + if (data != nullptr) { err = err + " " + data; } } return err; } -void assert_ssl_errors_empty() -{ +void assert_ssl_errors_empty() { int code = ERR_peek_error(); - if (code != 0) - { + if (code != 0) { std::cerr << "SSL should not have errors at this point." << std::endl; ERR_print_errors_fp(stderr); exit(1); } } -static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) -{ +static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) { BIO *bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL); @@ -117,8 +92,7 @@ static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) return status; } -static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) -{ +static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) { BIO *bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); @@ -128,33 +102,27 @@ static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) return status; } -static void ssl_ctx_use_dummycert(SSL_CTX *ctx) -{ - if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0) - { +static void ssl_ctx_use_dummycert(SSL_CTX *ctx) { + if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0) { ERR_print_errors_fp(stderr); exit(1); } - if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0) - { + if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0) { ERR_print_errors_fp(stderr); exit(1); } } -class Driver -{ -public: - enum ChanState - { +class Driver { + public: + enum ChanState { CHAN_INACTIVE, CHAN_PLAINTEXT, CHAN_SSL_CONNECTING, CHAN_SSL_ACCEPTING, CHAN_SSL_READWRITE, }; - struct ChanInfo - { + struct ChanInfo { int chid; SOCKET socket; SSL *ssl; @@ -182,13 +150,10 @@ public: UniqueSSLCTX ssl_client_secure_ctx_; UniqueSSLCTX ssl_client_insecure_ctx_; - void handle_listen_ports() - { + void handle_listen_ports() { const auto &listenports = recorder_.drv_get_listen_ports(); - for (int port : listenports) - { - if (listen_sockets_.find(port) == listen_sockets_.end()) - { + for (int port : listenports) { + if (listen_sockets_.find(port) == listen_sockets_.end()) { std::string err; SOCKET sock = listen_on_port(port, err); if_error_print_and_exit(err); @@ -198,32 +163,29 @@ public: } } - void handle_lua_source() - { - if (recorder_.drv_get_rescan_lua_source()) - { + void handle_lua_source() { + if (recorder_.drv_get_rescan_lua_source()) { std::string err; - std::string_view ctrl = read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err); + std::string_view ctrl = + read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err); if_error_print_and_exit(err); std::vector names = drv::parse_control_lst(ctrl); recorder_.drv_clear_lua_source(); - for (const std::string &str : names) - { + for (const std::string &str : names) { std::string lfn = std::string("lua/") + str; - std::string_view data = read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err); + std::string_view data = + read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err); if_error_print_and_exit(err); recorder_.drv_add_lua_source(str, data); } } } - void close_channel(ChanInfo &chan, std::string_view err) - { + void close_channel(ChanInfo &chan, std::string_view err) { // std::cerr << "Closing channel " << chan.chid << std::endl; assert(chan.state != CHAN_INACTIVE); // Close and release the SSL channel. - if (chan.ssl != nullptr) - { + if (chan.ssl != nullptr) { SSL_free(chan.ssl); chan.ssl = nullptr; } @@ -246,52 +208,39 @@ public: chan.last_write_nbytes = 0; } - void cleanup_channels() - { - for (int i = 0; i < int(chans_.size());) - { - if (chans_[i].state == CHAN_INACTIVE) - { + void cleanup_channels() { + for (int i = 0; i < int(chans_.size());) { + if (chans_[i].state == CHAN_INACTIVE) { chans_[i] = chans_.back(); chans_.pop_back(); - } - else - { + } else { i += 1; } } } - void handle_console_output() - { - while (true) - { + void handle_console_output() { + while (true) { std::string_view s = recorder_.drv_peek_outgoing(0); - if (s.size() == 0) - break; + if (s.size() == 0) break; int nwrote = console_write(s.data(), s.size()); - if (nwrote <= 0) - break; + if (nwrote <= 0) break; recorder_.drv_sent_outgoing(0, nwrote); } } - void handle_console_input() - { + void handle_console_input() { char buffer[256]; read_console_recently_ = false; - while (true) - { + while (true) { int nread = console_read(buffer, 256); - if (nread <= 0) - break; + if (nread <= 0) break; read_console_recently_ = true; recorder_.drv_recv_incoming(0, std::string_view(buffer, nread)); } } - void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) - { + void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { ChanInfo newchan; newchan.chid = chid; newchan.socket = sock; @@ -312,16 +261,15 @@ public: chans_.push_back(newchan); } - void handle_new_outgoing_sockets() - { + void handle_new_outgoing_sockets() { const auto &chans = recorder_.drv_get_new_outgoing(); - for (int chid : chans) - { + for (int chid : chans) { std::string err, cert, host, port; std::string target(recorder_.drv_get_target(chid)); drv::split_target(target, cert, host, port); if (cert.empty() || host.empty() || port.empty()) { - recorder_.drv_notify_close(chid, std::string("invalid target: ") + target); + recorder_.drv_notify_close( + chid, std::string("invalid target: ") + target); continue; } SSL_CTX *ctx = nullptr; @@ -330,63 +278,54 @@ public: } else if (cert == "nocert") { ctx = ssl_client_insecure_ctx_.get(); } else { - recorder_.drv_notify_close(chid, std::string("invalid cert rule: ") + target); + recorder_.drv_notify_close( + chid, std::string("invalid cert rule: ") + target); continue; } SOCKET sock = open_connection(host.c_str(), port.c_str(), err); - if (sock == INVALID_SOCKET) - { + if (sock == INVALID_SOCKET) { recorder_.drv_notify_close(chid, err); continue; } // std::cerr << "Opening channel " << chid << std::endl; make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING); } - if (!chans.empty()) - { + if (!chans.empty()) { recorder_.drv_clear_new_outgoing(); } } - void accept_connection(int port, SOCKET sock) - { + void accept_connection(int port, SOCKET sock) { std::string err; SOCKET socket = accept_on_socket(sock, err); if_error_print_and_exit(err); - if (socket != INVALID_SOCKET) - { + if (socket != INVALID_SOCKET) { int chid = recorder_.drv_notify_accept(port); // std::cerr << "Accepted channel " << chid << std::endl; - make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING); + make_channel(socket, chid, ssl_server_ctx_.get(), + CHAN_SSL_ACCEPTING); } } - void advance_plaintext(ChanInfo &chan) - { + void advance_plaintext(ChanInfo &chan) { std::string err; // If the channel has no outgoing bytes and has been released, // just close it. - if (chan.released) - { + if (chan.released) { close_channel(chan, ""); return; } // Try to write plaintext to the channel. std::string_view s = recorder_.drv_peek_outgoing(chan.chid); - if (s.size() > 0) - { + if (s.size() > 0) { int sbytes = s.size(); - if (sbytes > 65536) - sbytes = 65536; + if (sbytes > 65536) sbytes = 65536; int wbytes = socket_send(chan.socket, s.data(), sbytes, err); - if (wbytes < 0) - { + if (wbytes < 0) { close_channel(chan, err); - } - else - { + } else { recorder_.drv_sent_outgoing(chan.chid, wbytes); } } @@ -394,13 +333,11 @@ public: // Try to read plaintext from the channel. // Someday, find a way to avoid this copy. int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err); - if (nrecv < 0) - { + if (nrecv < 0) { close_channel(chan, err); - } - else - { - recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), nrecv)); + } else { + recorder_.drv_recv_incoming(chan.chid, + std::string_view(chbuf_.get(), nrecv)); } // Update the ready-flags for next time. @@ -408,149 +345,117 @@ public: chan.ready_on_pollin = true; } - void process_ssl_error(ChanInfo &chan, int retval) - { + void process_ssl_error(ChanInfo &chan, int retval) { int error = SSL_get_error(chan.ssl, retval); // std::cerr << "SSL error code = " << error << " "; - if (error == SSL_ERROR_WANT_READ) - { + if (error == SSL_ERROR_WANT_READ) { chan.ready_on_pollin = true; - } - else if (error == SSL_ERROR_WANT_WRITE) - { + } else if (error == SSL_ERROR_WANT_WRITE) { chan.ready_on_pollout = true; - } - else - { + } else { close_channel(chan, ssl_errors_string()); } } - void advance_ssl_connecting(ChanInfo &chan) - { + void advance_ssl_connecting(ChanInfo &chan) { // std::cerr << "In advance_ssl_connecting" << std::endl; int retval = SSL_connect(chan.ssl); - if (retval == 1) - { + if (retval == 1) { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; - } - else - { + } else { // std::cerr << "ssl_connect_error"; process_ssl_error(chan, retval); } } - void advance_ssl_accepting(ChanInfo &chan) - { + void advance_ssl_accepting(ChanInfo &chan) { // std::cerr << "In advance_ssl_accepting" << std::endl; int retval = SSL_accept(chan.ssl); - if (retval == 1) - { + if (retval == 1) { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; - } - else - { + } else { process_ssl_error(chan, retval); } } - void advance_ssl_readwrite(ChanInfo &chan) - { + void advance_ssl_readwrite(ChanInfo &chan) { // std::cerr << "In advance_ssl_readwrite" << std::endl; // Try to read data. int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536); - if (read_result > 0) - { - recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), read_result)); + if (read_result > 0) { + recorder_.drv_recv_incoming( + chan.chid, std::string_view(chbuf_.get(), read_result)); chan.ready_now = true; - } - else - { + } else { process_ssl_error(chan, read_result); - if (chan.state == CHAN_INACTIVE) - return; + if (chan.state == CHAN_INACTIVE) return; } // Try to write data. int wbytes; - if (chan.last_write_nbytes > 0) - { + if (chan.last_write_nbytes > 0) { wbytes = chan.last_write_nbytes; assert(wbytes < chan.nbytes); - } - else - { + } else { wbytes = chan.nbytes; - if (wbytes > 65536) - wbytes = 65536; + if (wbytes > 65536) wbytes = 65536; } - if (wbytes > 0) - { + if (wbytes > 0) { int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); - if (write_result > 0) - { + if (write_result > 0) { recorder_.drv_sent_outgoing(chan.chid, write_result); chan.last_write_nbytes = 0; chan.ready_on_outgoing = true; - } - else - { + } else { chan.last_write_nbytes = wbytes; process_ssl_error(chan, write_result); - if (chan.state == CHAN_INACTIVE) - return; + if (chan.state == CHAN_INACTIVE) return; } - } - else - { + } else { chan.ready_on_outgoing = true; } - // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " "; + // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << + // chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << + // chan.ready_on_outgoing << " "; } - void advance_channel(ChanInfo &chan) - { + void advance_channel(ChanInfo &chan) { assert_ssl_errors_empty(); - switch (chan.state) - { - case CHAN_PLAINTEXT: - advance_plaintext(chan); - break; - case CHAN_SSL_CONNECTING: - advance_ssl_connecting(chan); - break; - case CHAN_SSL_ACCEPTING: - advance_ssl_accepting(chan); - break; - case CHAN_SSL_READWRITE: - advance_ssl_readwrite(chan); - break; - default: - assert(false); - break; + switch (chan.state) { + case CHAN_PLAINTEXT: + advance_plaintext(chan); + break; + case CHAN_SSL_CONNECTING: + advance_ssl_connecting(chan); + break; + case CHAN_SSL_ACCEPTING: + advance_ssl_accepting(chan); + break; + case CHAN_SSL_READWRITE: + advance_ssl_readwrite(chan); + break; + default: + assert(false); + break; } assert_ssl_errors_empty(); } - void handle_socket_input_output() - { + void handle_socket_input_output() { std::string err; int mstimeout = read_console_recently_ ? 100 : 1000; // Peek output buffers and determine channel release flags. - for (ChanInfo &chan : chans_) - { + for (ChanInfo &chan : chans_) { std::string_view s = recorder_.drv_peek_outgoing(chan.chid); chan.nbytes = s.size(); chan.bytes = s.data(); chan.just_released = false; - if ((chan.nbytes == 0) && (!chan.released)) - { + if ((chan.nbytes == 0) && (!chan.released)) { chan.released = recorder_.drv_get_channel_released(chan.chid); chan.just_released = chan.released; } @@ -558,31 +463,26 @@ public: // Construct the struct pollfd vector. int pollsize = 0; - for (const auto &p : listen_sockets_) - { + for (const auto &p : listen_sockets_) { struct pollfd &pfd = pollvec_[pollsize++]; pfd.fd = p.second; pfd.events = POLLIN; pfd.revents = 0; } - for (const ChanInfo &chan : chans_) - { + for (const ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec_[pollsize++]; assert(chan.socket != INVALID_SOCKET); pfd.fd = chan.socket; pfd.events = 0; pfd.revents = 0; - if (chan.ready_now) - mstimeout = 0; - if (chan.just_released) - mstimeout = 0; - if (chan.ready_on_pollin) - pfd.events |= POLLIN; - if (chan.ready_on_pollout) - pfd.events |= POLLOUT; + if (chan.ready_now) mstimeout = 0; + if (chan.just_released) mstimeout = 0; + if (chan.ready_on_pollin) pfd.events |= POLLIN; + if (chan.ready_on_pollout) pfd.events |= POLLOUT; if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT; - // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " "; + // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " + // "; } // Do the poll. @@ -591,18 +491,15 @@ public: // Check listening sockets. int index = 0; - for (auto &p : listen_sockets_) - { + for (auto &p : listen_sockets_) { struct pollfd &pfd = pollvec_[index++]; - if (pfd.revents & (POLLIN | POLLERR)) - { + if (pfd.revents & (POLLIN | POLLERR)) { accept_connection(p.first, p.second); } } // Advance channels where possible. - for (ChanInfo &chan : chans_) - { + for (ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec_[index++]; bool pollin = ((pfd.revents & POLLIN) != 0); bool pollout = ((pfd.revents & POLLOUT) != 0); @@ -610,8 +507,7 @@ public: if (chan.ready_now || pollerr || chan.just_released || (chan.ready_on_pollin && pollin) || (chan.ready_on_pollout && pollout) || - (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) - { + (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; @@ -626,30 +522,24 @@ public: cleanup_channels(); } - int replay_logfile(const char *fn, bool verbose) - { + int replay_logfile(const char *fn, bool verbose) { drv::ReplayPlayer player; player.open_logfile(fn); - if (verbose) - { + if (verbose) { player.enable_stdout(); } - while (true) - { + while (true) { drv::ReplayPlayer::Status st = player.step(); - if (st != drv::ReplayPlayer::ST_REPLAYING) - { + if (st != drv::ReplayPlayer::ST_REPLAYING) { player.print_status(std::cerr); return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1; } } } - int drive(int argc, char *argv[]) - { + int drive(int argc, char *argv[]) { // Remove the program name from argv. - if (argc < 1) - { + if (argc < 1) { DrivenEngine::print_usage(std::cerr, ""); exit(1); } @@ -659,14 +549,12 @@ public: // If argv contains "replay ", do a replay, // and then skip everything else. - if (argc >= 1) - { + if (argc >= 1) { std::string cmd(argv[0]); - if ((cmd == "replay") || (cmd == "vreplay")) - { - if (argc != 2) - { - std::cerr << "usage: " << program << " replay " << std::endl; + if ((cmd == "replay") || (cmd == "vreplay")) { + if (argc != 2) { + std::cerr << "usage: " << program << " replay " + << std::endl; return 1; } return replay_logfile(argv[1], cmd == "vreplay"); @@ -675,20 +563,17 @@ public: // If argv contains "record ", start recording, // and remove the "record " from argv. - if (argc >= 1) - { + if (argc >= 1) { std::string cmd = argv[0]; - if (cmd == "record") - { - if (argc < 2) - { + if (cmd == "record") { + if (argc < 2) { DrivenEngine::print_usage(std::cerr, program); return 1; } bool ok = recorder_.open_logfile(argv[1]); - if (!ok) - { - std::cerr << "Could not open logfile: " << argv[1] << std::endl; + if (!ok) { + std::cerr << "Could not open logfile: " << argv[1] + << std::endl; return 1; } argc -= 2; @@ -697,14 +582,12 @@ public: } // Create the engine. - if (argc < 1) - { + if (argc < 1) { DrivenEngine::print_usage(std::cerr, program); return 1; } bool engine_made = recorder_.create_engine(argv[0]); - if (!engine_made) - { + if (!engine_made) { DrivenEngine::print_usage(std::cerr, program); return 1; } @@ -723,8 +606,7 @@ public: recorder_.drv_invoke_event_init(argc, argv); handle_listen_ports(); - while (!recorder_.drv_get_stop_driver()) - { + while (!recorder_.drv_get_stop_driver()) { handle_lua_source(); handle_console_output(); handle_new_outgoing_sockets(); @@ -734,8 +616,7 @@ public: recorder_.drv_invoke_event_update(monoclock.get()); } - for (ChanInfo &chan : chans_) - { + for (ChanInfo &chan : chans_) { close_channel(chan, ""); }