#define POLLVEC_SIZE (DRV_MAX_CHAN + 1) static void if_error_print_and_exit(const std::string_view str) { if (!str.empty()) { std::cerr << std::endl << "error: " << str << std::endl; exit(1); } } class Driver { public: enum ChanState { CHAN_INACTIVE, CHAN_PLAINTEXT, CHAN_SSL_CONNECTING, CHAN_SSL_ACCEPTING, CHAN_SSL_READWRITE, }; struct ChanInfo { int chid; SOCKET socket; SSL *ssl; ChanState state; uint32_t nbytes; const char *bytes; bool ready_now; bool ready_on_pollin; bool ready_on_pollout; bool ready_on_outgoing; uint32_t last_write_nbytes; bool marked_for_deletion() const { return state == CHAN_INACTIVE; } }; EngineWrapper engw; std::vector chans_; std::map listen_sockets_; bool read_console_recently_; std::unique_ptr pollvec_; std::unique_ptr chbuf_; sslutil::UniqueCTX ssl_server_ctx_; sslutil::UniqueCTX ssl_client_secure_ctx_; sslutil::UniqueCTX ssl_client_insecure_ctx_; void handle_listen_ports() { uint32_t nports; const uint32_t *ports; engw.get_listen_ports(&engw, &nports, &ports); for (uint32_t i = 0; i < nports; i++) { int port = ports[i]; if (listen_sockets_.find(port) == listen_sockets_.end()) { std::string err; SOCKET sock = listen_on_port(port, err); if_error_print_and_exit(err); assert(sock != INVALID_SOCKET); listen_sockets_[port] = sock; } } } void handle_lua_source() { if (engw.get_rescan_lua_source(&engw)) { drvutil::ostringstream oss; std::string err = drvutil::package_lua_source(".", &oss); if_error_print_and_exit(err); engw.play_set_lua_source(&engw, oss.size(), oss.c_str()); } } void close_channel(ChanInfo &chan, std::string_view err) { // std::cerr << "Closing channel " << chan.chid << std::endl; assert(chan.state != CHAN_INACTIVE); // Close and release the SSL channel. if (chan.ssl != nullptr) { SSL_free(chan.ssl); chan.ssl = nullptr; } // Close and release the socket. assert(chan.socket != INVALID_SOCKET); assert(socket_close(chan.socket) == 0); chan.socket = INVALID_SOCKET; // Close everything else. engw.play_notify_close(&engw, chan.chid, err.size(), err.data()); chan.state = CHAN_INACTIVE; chan.chid = -1; chan.nbytes = 0; chan.bytes = 0; chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; chan.ready_on_outgoing = false; chan.last_write_nbytes = 0; } void handle_console_output() { while (true) { uint32_t ndata; const char *data; engw.get_outgoing(&engw, 0, &ndata, &data); if (ndata == 0) break; if (ndata > DRV_SHORTSTRING_SIZE) ndata = DRV_SHORTSTRING_SIZE; int nwrote = console_write(data, ndata); if (nwrote <= 0) break; engw.play_sent_outgoing(&engw, 0, nwrote); } } void handle_console_input() { char buffer[256]; read_console_recently_ = false; while (true) { int nread = console_read(buffer, 256); if (nread <= 0) break; read_console_recently_ = true; engw.play_recv_incoming(&engw, 0, nread, buffer); } } void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { ChanInfo newchan; newchan.chid = chid; newchan.socket = sock; newchan.ssl = SSL_new(ctx); newchan.state = state; newchan.nbytes = 0; newchan.bytes = 0; newchan.ready_now = false; newchan.ready_on_pollin = false; newchan.ready_on_pollout = true; newchan.ready_on_outgoing = false; newchan.last_write_nbytes = 0; SSL_set_fd(newchan.ssl, newchan.socket); // SSL_set_msg_callback(newchan.ssl, SSL_trace); // SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0)); chans_.push_back(newchan); } void handle_new_outgoing_sockets() { uint32_t nchids; const uint32_t *chids; engw.get_new_outgoing(&engw, &nchids, &chids); for (uint32_t i = 0; i < nchids; i++) { uint32_t chid = chids[i]; std::string err, cert, host, port; const char *target = engw.get_target(&engw, chid); drvutil::split_target(target, cert, host, port); if (cert.empty() || host.empty() || port.empty()) { std::string message = "invalid target: "; message += target; engw.play_notify_close(&engw, chid, message.size(), message.c_str()); continue; } SSL_CTX *ctx = nullptr; if (cert == "cert") { ctx = ssl_client_secure_ctx_.get(); } else if (cert == "nocert") { ctx = ssl_client_insecure_ctx_.get(); } else { std::string message = "invalid cert rule: "; message += target; engw.play_notify_close(&engw, chid, message.size(), message.c_str()); continue; } SOCKET sock = open_connection(host.c_str(), port.c_str(), err); if (sock == INVALID_SOCKET) { engw.play_notify_close(&engw, chid, err.size(), err.c_str()); continue; } // std::cerr << "Opening channel " << chid << std::endl; make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING); } engw.play_clear_new_outgoing(&engw); } void accept_connection(int port, SOCKET sock) { std::string err; SOCKET socket = accept_on_socket(sock, err); if_error_print_and_exit(err); if (socket != INVALID_SOCKET) { uint32_t chid = engw.play_notify_accept(&engw, port); // std::cerr << "Accepted channel " << chid << std::endl; make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING); } } void advance_plaintext(ChanInfo &chan) { std::string err; // Try to write plaintext to the channel. uint32_t ndata; const char *data; engw.get_outgoing(&engw, chan.chid, &ndata, &data); if (ndata > 0) { int sbytes = ndata; if (sbytes > DRV_SHORTSTRING_SIZE) sbytes = DRV_SHORTSTRING_SIZE; int wbytes = socket_send(chan.socket, data, sbytes, err); if (wbytes < 0) { close_channel(chan, err.c_str()); } else { engw.play_sent_outgoing(&engw, chan.chid, wbytes); } } // Try to read plaintext from the channel. // Someday, find a way to avoid this copy. int nrecv = socket_recv(chan.socket, chbuf_.get(), DRV_SHORTSTRING_SIZE, err); if (nrecv < 0) { close_channel(chan, err.c_str()); } else { engw.play_recv_incoming(&engw, chan.chid, nrecv, chbuf_.get()); } // Update the ready-flags for next time. chan.ready_on_outgoing = true; chan.ready_on_pollin = true; } void process_ssl_error(ChanInfo &chan, int retval) { int error = SSL_get_error(chan.ssl, retval); // std::cerr << "SSL error code = " << error << " "; if (error == SSL_ERROR_WANT_READ) { chan.ready_on_pollin = true; } else if (error == SSL_ERROR_WANT_WRITE) { chan.ready_on_pollout = true; } else { std::string error = sslutil::error_string(); if (error == "") error = "unknown error"; close_channel(chan, error); } } void advance_ssl_connecting(ChanInfo &chan) { // std::cerr << "In advance_ssl_connecting" << std::endl; int retval = SSL_connect(chan.ssl); if (retval == 1) { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; } else { // std::cerr << "ssl_connect_error"; process_ssl_error(chan, retval); } } void advance_ssl_accepting(ChanInfo &chan) { // std::cerr << "In advance_ssl_accepting" << std::endl; int retval = SSL_accept(chan.ssl); if (retval == 1) { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; } else { process_ssl_error(chan, retval); } } void advance_ssl_readwrite(ChanInfo &chan) { // std::cerr << "In advance_ssl_readwrite" << std::endl; // Try to read data. int read_result = SSL_read(chan.ssl, chbuf_.get(), DRV_SHORTSTRING_SIZE); if (read_result > 0) { engw.play_recv_incoming(&engw, chan.chid, read_result, chbuf_.get()); chan.ready_now = true; } else { process_ssl_error(chan, read_result); if (chan.state == CHAN_INACTIVE) return; } // Try to write data. uint32_t wbytes; if (chan.last_write_nbytes > 0) { wbytes = chan.last_write_nbytes; assert(wbytes < chan.nbytes); } else { wbytes = chan.nbytes; if (wbytes > 65536) wbytes = 65536; } if (wbytes > 0) { int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); if (write_result > 0) { engw.play_sent_outgoing(&engw, chan.chid, write_result); chan.last_write_nbytes = 0; chan.ready_on_outgoing = true; } else { chan.last_write_nbytes = wbytes; process_ssl_error(chan, write_result); if (chan.state == CHAN_INACTIVE) return; } } else { chan.ready_on_outgoing = true; } // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << // chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << // chan.ready_on_outgoing << " "; } void advance_channel(ChanInfo &chan) { sslutil::clear_all_errors(); switch (chan.state) { case CHAN_PLAINTEXT: advance_plaintext(chan); break; case CHAN_SSL_CONNECTING: advance_ssl_connecting(chan); break; case CHAN_SSL_ACCEPTING: advance_ssl_accepting(chan); break; case CHAN_SSL_READWRITE: advance_ssl_readwrite(chan); break; default: assert(false); break; } } void handle_socket_input_output() { std::string err; int mstimeout = read_console_recently_ ? 100 : 1000; // Peek output buffers and determine channel release flags. bool any_released = false; for (ChanInfo &chan : chans_) { engw.get_outgoing(&engw, chan.chid, &chan.nbytes, &chan.bytes); if (chan.nbytes == 0) { if (engw.get_channel_released(&engw, chan.chid)) { close_channel(chan, ""); any_released = true; } } } // Delete any released channels if (any_released) { drvutil::remove_marked_items(chans_); } // Construct the struct pollfd vector. int pollsize = 0; for (const auto &p : listen_sockets_) { struct pollfd &pfd = pollvec_[pollsize++]; pfd.fd = p.second; pfd.events = POLLIN; pfd.revents = 0; } for (const ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec_[pollsize++]; assert(chan.socket != INVALID_SOCKET); pfd.fd = chan.socket; pfd.events = 0; pfd.revents = 0; if (chan.ready_now) mstimeout = 0; if (chan.ready_on_pollin) pfd.events |= POLLIN; if (chan.ready_on_pollout) pfd.events |= POLLOUT; if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT; // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << // std::endl; } // Do the poll. socket_poll(pollvec_.get(), pollsize, mstimeout, err); if_error_print_and_exit(err); // Check listening sockets. int index = 0; for (auto &p : listen_sockets_) { struct pollfd &pfd = pollvec_[index++]; if (pfd.revents & (POLLIN | POLLERR)) { accept_connection(p.first, p.second); } } // Advance channels where possible. for (ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec_[index++]; bool pollin = ((pfd.revents & POLLIN) != 0); bool pollout = ((pfd.revents & POLLOUT) != 0); bool pollerr = ((pfd.revents & (POLLERR | POLLHUP)) != 0); if (chan.ready_now || pollerr || (chan.ready_on_pollin && pollin) || (chan.ready_on_pollout && pollout) || (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; chan.ready_on_outgoing = false; advance_channel(chan); } chan.nbytes = 0; chan.bytes = 0; } // Delete any newly-inactive channels drvutil::remove_marked_items(chans_); } int replay_logfile(const char *fn, bool verbose) { engw.replay_initialize(&engw, fn); if_error_print_and_exit(engw.error); while (engw.rlog) { engw.replay_step(&engw); } if_error_print_and_exit(engw.error); return 0; } static void replay_cb_sent_outgoing(void *vp, int chid, int ndata, const char *data) { if (chid == 0) { std::cerr.write(data, ndata); } } int drive(int argc, char *argv[]) { // Remove the program name from argv. std::string program = argv[0]; argc -= 1; argv += 1; // Load the DLL and gain access to its functions. call_init_engine_wrapper(&engw); engw.replay_cb_sent_outgoing = replay_cb_sent_outgoing; // If argv contains "replay ", do a replay, // and then skip everything else. if (argc >= 1) { std::string cmd(argv[0]); if ((cmd == "replay") || (cmd == "vreplay")) { if (argc != 2) { std::cerr << "usage: " << program << " replay " << std::endl; return 1; } return replay_logfile(argv[1], cmd == "vreplay"); } } // If argv contains "record ", start recording, // and remove the "record " from argv. std::string replaylogfn; if (argc >= 1) { std::string cmd = argv[0]; if (cmd == "record") { if (argc < 2) { std::cerr << "The 'record' command must be followed by a filename" << std::endl; return 1; } replaylogfn = argv[1]; argc -= 2; argv += 2; } } // Initialize state variables. read_console_recently_ = false; chbuf_.reset(new char[DRV_SHORTSTRING_SIZE]); pollvec_.reset(new struct pollfd[POLLVEC_SIZE]); ssl_server_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE)); ssl_client_secure_ctx_.reset(sslutil::new_context(SSL_VERIFY_PEER)); ssl_client_insecure_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE)); ssl_load_certificate_authorities(ssl_client_secure_ctx_.get()); sslutil::ctx_load_dummy_cert(ssl_server_ctx_.get()); // Read the initial lua source code. drvutil::ostringstream srcpak; std::string srcpakerr = drvutil::package_lua_source(".", &srcpak); if_error_print_and_exit(srcpakerr); // Initialize the engine. engw.play_initialize(&engw, argc, argv, srcpak.size(), srcpak.c_str(), replaylogfn.c_str()); if_error_print_and_exit(engw.error); // Set up listening ports. handle_listen_ports(); // Main loop. while (!engw.get_stop_driver(&engw)) { handle_lua_source(); handle_console_output(); handle_new_outgoing_sockets(); handle_socket_input_output(); handle_console_input(); handle_console_output(); engw.play_invoke_event_update(&engw, drvutil::get_monotonic_clock()); } // Cleanup engw.release(&engw); for (ChanInfo &chan : chans_) { close_channel(chan, ""); } return 0; } };