#if defined(__linux__) #include "driver-linux.cpp" #elif defined(_WIN32) #include "driver-windows.cpp" #endif #define POLLVEC_SIZE (DRV_MAX_CHAN + 1) #define MAX_BIO_BUFFER (128 * 1024) static void if_error_print_and_exit(const std::string_view str) { if (!str.empty()) { std::cerr << std::endl << "error: " << str << std::endl; exit(1); } } // DPrints are currently not going through the readline device. // doing so would not currently be thread-safe. Do I care about // that? I'm not sure. static void dprint_callback(const char *oneline, size_t size) { fwrite("**", 1, 2, stderr); fwrite(oneline, 1, size, stderr); fwrite("\n", 1, 1, stderr); fflush(stderr); } inline bool file_exists(const std::filesystem::path &name) { std::ifstream f(name); return f.good(); } std::filesystem::path find_luprex_root(std::filesystem::path exepath) { std::filesystem::path pp = exepath.parent_path(); if (file_exists(pp / "lua/control.lst")) { return pp; } pp = pp.parent_path(); if (file_exists(pp / "lua/control.lst")) { return pp; } pp = pp.parent_path(); if (file_exists(pp / "lua/control.lst")) { return pp; } assert(false && "Could not find lua/control.lst"); return ""; } class Driver { public: enum ChanState { CHAN_INACTIVE, CHAN_PLAINTEXT, CHAN_SSL_CONNECTING, CHAN_SSL_ACCEPTING, CHAN_SSL_READWRITE, }; struct ChanInfo { int chid; SOCKET socket; SSL *ssl; BIO *recv_bio; BIO *send_bio; // If recent_error is set, that means that a recent IO operation generated // an error. As a special case, EOF on read is considered an error, we use // the string "EOF" for this case. std::string recent_error; // OpenSSL has a rule: if you try to SSL_write and it returns // SSL_ERROR_WANT_READ, then you have to retry the write with the same // number of bytes. In this event, we record how many bytes we // attempted to write, which will enable us to retry. int retry_write_nbytes; // True if the channel needs to be advanced. bool need_advance; ChanState state; uint32_t nbytes; const char *bytes; bool marked_for_deletion() const { return state == CHAN_INACTIVE; } }; std::filesystem::path luprexroot; EngineWrapper engw; std::vector chans_; std::map listen_sockets_; bool read_console_recently_; std::unique_ptr pollvec_; std::unique_ptr chbuf_; ReadlineDevice readline_device_; std::string console_command_; sslutil::UniqueCTX ssl_server_ctx_; sslutil::UniqueCTX ssl_client_secure_ctx_; sslutil::UniqueCTX ssl_client_insecure_ctx_; // Return the amount of 'space left' in a BIO. This is a fiction, // because MEM BIOs technically have unlimited capacity. We're // artificially limiting them to a certain size because there's no // reason to buffer huge amounts of data. // int bio_space(BIO *bio) { int space = (MAX_BIO_BUFFER) - BIO_pending(bio); if (space < 0) space = 0; return space; } // This is a terribly inefficient way to discard data that has // already been processed. There has to be something better. // void bio_discard(BIO *b, int nbytes) { while (nbytes > 0) { int nread = nbytes; if (nread > DRV_SHORTSTRING_SIZE) nread = DRV_SHORTSTRING_SIZE; int ndropped = BIO_read(b, chbuf_.get(), nread); assert(ndropped == nread); nbytes -= ndropped; } } void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { ChanInfo newchan; newchan.chid = chid; newchan.socket = sock; newchan.recv_bio = BIO_new(BIO_s_mem()); newchan.send_bio = BIO_new(BIO_s_mem()); newchan.recent_error.clear(); newchan.retry_write_nbytes = 0; newchan.need_advance = true; if (state == CHAN_PLAINTEXT) { newchan.ssl = nullptr; } else { newchan.ssl = SSL_new(ctx); SSL_set_bio(newchan.ssl, newchan.recv_bio, newchan.send_bio); } newchan.state = state; newchan.nbytes = 0; newchan.bytes = 0; chans_.push_back(newchan); } void close_channel(ChanInfo &chan, std::string_view err) { // std::cerr << "Closing channel " << chan.chid << " with " << err << std::endl; assert(chan.state != CHAN_INACTIVE); // Close and release the SSL channel. // This frees the BIO objects as well. if (chan.ssl != nullptr) { SSL_free(chan.ssl); chan.ssl = nullptr; } chan.recv_bio = nullptr; chan.send_bio = nullptr; chan.recent_error.clear(); chan.retry_write_nbytes = 0; chan.need_advance = false; // Close and release the socket. assert(chan.socket != INVALID_SOCKET); assert(socket_close(chan.socket) == 0); chan.socket = INVALID_SOCKET; // Close everything else. engw.play_notify_close(&engw, chan.chid, err.size(), err.data()); chan.state = CHAN_INACTIVE; chan.chid = -1; chan.nbytes = 0; chan.bytes = 0; } void handle_listen_ports() { uint32_t nports; const uint32_t *ports; engw.get_listen_ports(&engw, &nports, &ports); for (uint32_t i = 0; i < nports; i++) { int port = ports[i]; if (listen_sockets_.find(port) == listen_sockets_.end()) { std::string err; SOCKET sock = listen_on_port(port, err); if_error_print_and_exit(err); assert(sock != INVALID_SOCKET); listen_sockets_[port] = sock; } } } void handle_lua_source() { if (engw.get_rescan_lua_source(&engw)) { drvutil::ostringstream oss; std::string err = drvutil::package_lua_source(".", &oss); if_error_print_and_exit(err); std::string_view ossv = oss.view(); engw.play_access(&engw, AccessKind::INVOKE_LUA_SOURCE, 0, ossv.size(), ossv.data(), nullptr, nullptr); } } void channel_printbuffer() { if (engw.get_have_prints(&engw)) { uint32_t ndata; const char *data; engw.play_access(&engw, AccessKind::CHANNEL_PRINTS, 0, 0, "", &ndata, &data); if (ndata > 0) { if (ndata > DRV_SHORTSTRING_SIZE) ndata = DRV_SHORTSTRING_SIZE; readline_device_.printline(std::string_view(data, ndata)); } } } void add_console_command(std::string_view addition) { std::string cmd = console_command_ + std::string(addition); console_command_.clear(); uint32_t ndata; const char *data; engw.play_access(&engw, AccessKind::VALIDATE_LUA_EXPR, 0, cmd.size(), cmd.c_str(), &ndata, &data); std::string_view message(data, ndata); // Handle the command. if (message == "truncated lua") { console_command_ = cmd; } else if (message == "white space") { readline_device_.printline("white space."); } else if (message == "slash command") { readline_device_.printline("slash command."); } else if (message.empty()) { readline_device_.printline("valid lua"); } else { readline_device_.printline(message); } if (console_command_.empty()) { readline_device_.set_prompt(">"); } else { readline_device_.set_prompt(">>"); } } void handle_console_input() { read_console_recently_ = false; while (true) { std::u32string cps = console_read(); if (cps.size() == 0) break; read_console_recently_ = true; for (char32_t c : cps) { std::string line = readline_device_.putcode(c); if (!line.empty()) { add_console_command(line); } } } } void handle_new_outgoing_sockets() { uint32_t nchids; const uint32_t *chids; engw.get_new_outgoing(&engw, &nchids, &chids); for (uint32_t i = 0; i < nchids; i++) { uint32_t chid = chids[i]; std::string err, cert, host, port; const char *target = engw.get_target(&engw, chid); drvutil::split_target(target, cert, host, port); if (cert.empty() || host.empty() || port.empty()) { std::string message = "invalid target: "; message += target; engw.play_notify_close(&engw, chid, message.size(), message.c_str()); continue; } SSL_CTX *ctx = nullptr; if (cert == "cert") { ctx = ssl_client_secure_ctx_.get(); } else if (cert == "nocert") { ctx = ssl_client_insecure_ctx_.get(); } else { std::string message = "invalid cert rule: "; message += target; engw.play_notify_close(&engw, chid, message.size(), message.c_str()); continue; } SOCKET sock = open_connection(host.c_str(), port.c_str(), err); if (sock == INVALID_SOCKET) { engw.play_notify_close(&engw, chid, err.size(), err.c_str()); continue; } make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING); } engw.play_clear_new_outgoing(&engw); } void accept_connection(int port, SOCKET sock) { std::string err; SOCKET socket = accept_on_socket(sock, err); if_error_print_and_exit(err); if (socket != INVALID_SOCKET) { uint32_t chid = engw.play_notify_accept(&engw, port); make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING); } } // Copy data from the socket into the recv bio. // // If it detects an error or EOF, sets the recent_errno flag. // void transfer_socket_to_recv_bio(ChanInfo &chan) { if ((chan.state == CHAN_INACTIVE) || (!chan.recent_error.empty())) { return; } std::string err; int nread = socket_recv(chan.socket, chbuf_.get(), DRV_SHORTSTRING_SIZE, err); // std::cerr << "chan " << chan.chid << " recv " << nread << " err=" << err << std::endl; if (nread < 0) { chan.recent_error = err; } else { if (nread == 0) { chan.recent_error = "EOF"; } else { int nstored = BIO_write(chan.recv_bio, chbuf_.get(), nread); assert(nstored == nread); chan.need_advance = true; // std::cerr << "chan " << chan.chid << " stored " << nread << " bytes" << std::endl; } } } // Copy data from the send BIO into the socket. // // If it detects an error, sets the recent_errno flag. // It is an error to call this when there is nothing in the send BIO. // void transfer_send_bio_to_socket(ChanInfo &chan) { if ((chan.state == CHAN_INACTIVE) || (!chan.recent_error.empty())) { return; } char *data; int ndata = BIO_get_mem_data(chan.send_bio, &data); if (ndata > DRV_SHORTSTRING_SIZE) ndata = DRV_SHORTSTRING_SIZE; // It is an error to call this function when there is nothing in the send BIO. assert(ndata > 0); std::string err; int nwrote = socket_send(chan.socket, data, ndata, err); // std::cerr << "chan " << chan.chid << " send " << nwrote << " err=" << err << std::endl; if (nwrote < 0) { chan.recent_error = err; } else { assert(nwrote != 0); bio_discard(chan.send_bio, nwrote); chan.need_advance = true; } } // Close the channel if there's a serious OpenSSL error. // // The 'retval' is the return value of the SSL function that returned an // error. // // All errors are considered serious except for SSL_ERROR_WANT_READ, which // is not serious because it is transient. However, if you get an // SSL_ERROR_WANT_READ when there's tons of data available in the read // buffer, that's inexplicable and therefore serious. // void if_error_is_serious_close_channel(ChanInfo &chan, int retval) { int error = SSL_get_error(chan.ssl, retval); //std::cerr << "chan " << chan.chid << " ssl error = " << error << std::endl; // Should never have write errors, because we're // using a memory BIO with unlimited capacity. assert(error != SSL_ERROR_WANT_WRITE); // If we get a read error, make sure it's plausible: // if the recv bio is full, that makes no sense. if (error == SSL_ERROR_WANT_READ) { if (bio_space(chan.recv_bio) == 0) { close_channel(chan, "ssl waiting for data, but there's tons of data"); } return; } // Any other error is an actual error. Close // the channel. std::string errstr = sslutil::error_string(); if (errstr == "") errstr = "unknown error"; close_channel(chan, errstr); } void advance_plaintext(ChanInfo &chan) { uint32_t ndata; const char *data; // Transfer all data from the recv BIO into the channel. ndata = BIO_get_mem_data(chan.recv_bio, &data); if (ndata > 0) { engw.play_recv_incoming(&engw, chan.chid, ndata, data); bio_discard(chan.recv_bio, ndata); } // Transfer all data from the channel to the send BIO. engw.get_outgoing(&engw, chan.chid, &ndata, &data); if (ndata > 0) { int nwrote = BIO_write(chan.send_bio, data, ndata); assert(nwrote == int(ndata)); engw.play_sent_outgoing(&engw, chan.chid, ndata); } } void advance_ssl_connecting(ChanInfo &chan) { int retval = SSL_connect(chan.ssl); //std::cerr << "chan " << chan.chid << " ssl_connect returns " << retval << std::endl; if (retval == 1) { chan.state = CHAN_SSL_READWRITE; chan.need_advance = true; } else { if_error_is_serious_close_channel(chan, retval); } } void advance_ssl_accepting(ChanInfo &chan) { int retval = SSL_accept(chan.ssl); //std::cerr << "chan " << chan.chid << " ssl_accept returns " << retval << std::endl; if (retval == 1) { chan.state = CHAN_SSL_READWRITE; chan.need_advance = true; } else { if_error_is_serious_close_channel(chan, retval); } } void advance_ssl_readwrite(ChanInfo &chan) { // Read as much as we can, which of course will be limited // by the fact that the recv_bio contains finite data. while (true) { int read_result = SSL_read(chan.ssl, chbuf_.get(), DRV_SHORTSTRING_SIZE); if (read_result > 0) { engw.play_recv_incoming(&engw, chan.chid, read_result, chbuf_.get()); } else { if_error_is_serious_close_channel(chan, read_result); break; } } // The read process could have generated an error which could // have closed the channel. If so, don't try writing. if (chan.state == CHAN_INACTIVE) { return; } // Try to write data. while (chan.nbytes) { uint32_t wbytes; if (chan.retry_write_nbytes > 0) { wbytes = chan.retry_write_nbytes; assert(wbytes < chan.nbytes); } else { wbytes = chan.nbytes; if (wbytes > DRV_SHORTSTRING_SIZE) wbytes = DRV_SHORTSTRING_SIZE; } if (wbytes == 0) break; int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); if (write_result > 0) { engw.play_sent_outgoing(&engw, chan.chid, write_result); chan.retry_write_nbytes = 0; chan.nbytes -= write_result; chan.bytes += write_result; } else { if_error_is_serious_close_channel(chan, write_result); chan.retry_write_nbytes = wbytes; break; } } } void advance_channel(ChanInfo &chan) { sslutil::clear_all_errors(); // We set the need_advance flag to false here, but // the rest of the advance routine is allowed to set // it back to true in the event that the advance routine // only processes some of the available data. chan.need_advance = false; switch (chan.state) { case CHAN_PLAINTEXT: advance_plaintext(chan); break; case CHAN_SSL_CONNECTING: advance_ssl_connecting(chan); break; case CHAN_SSL_ACCEPTING: advance_ssl_accepting(chan); break; case CHAN_SSL_READWRITE: advance_ssl_readwrite(chan); break; default: assert(false); break; } } void handle_socket_input_output() { std::string err; int mstimeout = read_console_recently_ ? 10 : 100; // Peek output buffers and determine channel release flags. bool any_released = false; for (ChanInfo &chan : chans_) { engw.get_outgoing(&engw, chan.chid, &chan.nbytes, &chan.bytes); if (chan.nbytes > 0) { chan.need_advance = true; } if (chan.nbytes == 0) { if (engw.get_channel_released(&engw, chan.chid)) { if (BIO_pending(chan.send_bio) == 0) { close_channel(chan, ""); any_released = true; } } } } // Delete any released channels if (any_released) { drvutil::remove_marked_items(chans_); } // Construct the struct pollfd vector. int pollsize = 0; for (const ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec_[pollsize++]; assert(chan.socket != INVALID_SOCKET); pfd.fd = chan.socket; pfd.events = 0; pfd.revents = 0; // If there's room in the receive buffer, set POLLIN if (bio_space(chan.recv_bio) > 0) { pfd.events |= POLLIN; } // If there's data in the outgoing buffer, set POLLOUT if (BIO_pending(chan.send_bio) > 0) { pfd.events |= POLLOUT; } if (chan.need_advance) { mstimeout = 0; } } for (const auto &p : listen_sockets_) { struct pollfd &pfd = pollvec_[pollsize++]; pfd.fd = p.second; pfd.events = POLLIN; pfd.revents = 0; } // Do the poll. socket_poll(pollvec_.get(), pollsize, mstimeout, err); if_error_print_and_exit(err); // Advance channels where possible and then check listen sockets. int index = 0; for (ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec_[index++]; if ((pfd.revents & POLLIN) != 0) { transfer_socket_to_recv_bio(chan); } if ((pfd.revents & POLLOUT) != 0) { transfer_send_bio_to_socket(chan); } if (chan.need_advance || (!chan.recent_error.empty())) { advance_channel(chan); } if (!chan.recent_error.empty()) { if (chan.recent_error == "EOF") { close_channel(chan, ""); } else { close_channel(chan, chan.recent_error); } chan.recent_error.clear(); } chan.nbytes = 0; chan.bytes = 0; } for (auto &p : listen_sockets_) { struct pollfd &pfd = pollvec_[index++]; if (pfd.revents & (POLLIN | POLLERR)) { accept_connection(p.first, p.second); } } // Delete any newly-inactive channels drvutil::remove_marked_items(chans_); } int replay_logfile(const char *fn, bool verbose) { engw.replay_initialize(&engw, fn); if_error_print_and_exit(engw.error); while (engw.rlog) { engw.replay_step(&engw); } if_error_print_and_exit(engw.error); return 0; } static void replay_cb_sent_outgoing(void *vp, int chid, int ndata, const char *data) { if (chid == 0) { std::cerr.write(data, ndata); } } int drive(int argc, char *argv[]) { // Set up the console readline device. readline_device_.set_print_callback(console_write); readline_device_.set_prompt(">"); console_command_.clear(); // Remove the program name from argv. std::string program = argv[0]; argc -= 1; argv += 1; // Find the root of the luprex tree. luprexroot = find_luprex_root(get_exe_path()); // Load the DLL and gain access to its functions. call_init_engine_wrapper(luprexroot, &engw); engw.replay_cb_sent_outgoing = replay_cb_sent_outgoing; engw.hook_dprint(dprint_callback); // If argv contains "replay ", do a replay, // and then skip everything else. if (argc >= 1) { std::string cmd(argv[0]); if ((cmd == "replay") || (cmd == "vreplay")) { if (argc != 2) { std::cerr << "usage: " << program << " replay " << std::endl; return 1; } return replay_logfile(argv[1], cmd == "vreplay"); } } // If argv contains "record ", start recording, // and remove the "record " from argv. std::string replaylogfn; if (argc >= 1) { std::string cmd = argv[0]; if (cmd == "record") { if (argc < 2) { std::cerr << "The 'record' command must be followed by a filename" << std::endl; return 1; } replaylogfn = argv[1]; argc -= 2; argv += 2; } } // Make sure there's exactly one argument left for the engine type. if (argc != 1) { std::cerr << "Must specify the engine type" << std::endl; return 1; } // Initialize state variables. read_console_recently_ = false; chbuf_.reset(new char[DRV_SHORTSTRING_SIZE]); pollvec_.reset(new struct pollfd[POLLVEC_SIZE]); ssl_server_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE)); ssl_client_secure_ctx_.reset(sslutil::new_context(SSL_VERIFY_PEER)); ssl_client_insecure_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE)); ssl_load_certificate_authorities(ssl_client_secure_ctx_.get()); sslutil::ctx_load_dummy_cert(ssl_server_ctx_.get()); // Initialize the engine. engw.play_initialize(&engw, argv[0], replaylogfn.c_str()); if_error_print_and_exit(engw.error); // Set up listening ports. handle_listen_ports(); // Main loop. while (!engw.get_stop_driver(&engw)) { handle_lua_source(); handle_new_outgoing_sockets(); handle_socket_input_output(); handle_console_input(); engw.play_update(&engw, drvutil::get_monotonic_clock()); channel_printbuffer(); } // Cleanup for (ChanInfo &chan : chans_) { close_channel(chan, ""); } engw.release(&engw); return 0; } }; int main(int argc, char **argv) { os_initialize(argc, argv); assert(OPENSSL_init_ssl(0, NULL) == 1); sslutil::clear_all_errors(); Driver driver; return driver.drive(argc, argv); }