#define OPENSSL_HEAP_SIZE (4*1024*1024) #define CHBUF_SIZE (256*1024) #define POLLVEC_SIZE (DrivenEngine::MAX_CHAN+1) int mallocstate(int n) { int64_t result = 0; for (int i = 0; i < n; i++) { int64_t n = int64_t(malloc(1)); result = (result * 17) + n; } return result & 0x7fffffff; } static MonoClock monoclock; namespace util { double profiling_clock() { return monoclock.get(); } } static void if_error_print_and_exit(const UmmString &str) { if (!str.empty()) { std::cerr << std::endl << "error: " << str << std::endl; exit(1); } } static SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, std::string_view require_cert) { SSL_CTX *ctx = SSL_CTX_new(TLS_method()); SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr); // server_cert is not implemented yet. if (root_certs) { SSL_CTX_set_default_verify_paths(ctx); SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL); } // require_cert is not implemented yet. return ctx; } static UmmString err_print_errors_str() { BIO *bio = BIO_new(BIO_s_mem()); ERR_print_errors(bio); char *buf; size_t len = BIO_get_mem_data(bio, &buf); UmmString ret(buf, len); BIO_free(bio); return ret; } static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) { BIO *bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL); BIO_free(bio); int status = SSL_CTX_use_certificate(ctx, certificate); X509_free(certificate); return status; } static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) { BIO *bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); BIO_free(bio); int status = SSL_CTX_use_PrivateKey(ctx, pkey); EVP_PKEY_free(pkey); return status; } static std::unique_ptr chbuf; static std::unique_ptr pollvec; class Driver { public: enum ChanState { CHAN_INACTIVE, CHAN_PLAINTEXT, CHAN_SSL_CONNECTING, CHAN_SSL_ACCEPTING, CHAN_SSL_READWRITE, }; struct ChanInfo { int chid; SOCKET socket; SSL *ssl; ChanState state; int nbytes; const char *bytes; bool released; bool just_released; bool ready_now; bool ready_on_pollin; bool ready_on_pollout; bool ready_on_outgoing; int last_write_nbytes; }; DrivenEngine *driven_; UmmVector chans_; UmmMap listen_sockets_; bool read_console_recently_; SSL_CTX *ssl_ctx_with_root_certs_; SSL_CTX *ssl_ctx_with_server_certs_; SSL_CTX *ssl_ctx_with_no_certs_; void handle_listen_ports() { const std::vector &listenports = driven_->drv_get_listen_ports(); for (int port : listenports) { if (listen_sockets_.find(port) == listen_sockets_.end()) { UmmString err; SOCKET sock = listen_on_port(port, err); if_error_print_and_exit(err); assert(sock != INVALID_SOCKET); listen_sockets_[port] = sock; } } } void handle_lua_source() { if (driven_->drv_get_rescan_lua_source()) { UmmString err; std::string_view ctrl = read_file("lua/control.lst", chbuf.get(), CHBUF_SIZE, err); if_error_print_and_exit(err); UmmStringVec names = drv::parse_control_lst(ctrl); driven_->drv_clear_lua_source(); for (const UmmString &str : names) { UmmString lfn = UmmString("lua/") + str; std::string_view data = read_file(lfn.c_str(), chbuf.get(), CHBUF_SIZE, err); if_error_print_and_exit(err); driven_->drv_add_lua_source(str, data); } } } void close_channel(ChanInfo &chan, std::string_view err) { // std::cerr << "Closing channel " << chan.chid << std::endl; assert(chan.state != CHAN_INACTIVE); // Close and release the SSL channel. if (chan.ssl != nullptr) { SSL_free(chan.ssl); chan.ssl = nullptr; } // Close and release the socket. assert(chan.socket != INVALID_SOCKET); assert(socket_close(chan.socket) == 0); chan.socket = INVALID_SOCKET; // Close everything else. driven_->drv_notify_close(chan.chid, err); chan.state = CHAN_INACTIVE; chan.chid = -1; chan.nbytes = 0; chan.bytes = 0; chan.released = false; chan.just_released = false; chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; chan.ready_on_outgoing = false; chan.last_write_nbytes = 0; } void cleanup_channels() { for (int i = 0; i < int(chans_.size()); ) { if (chans_[i].state == CHAN_INACTIVE) { chans_[i] = chans_.back(); chans_.pop_back(); } else { i += 1; } } } void handle_console_output() { while (true) { int nbytes; const char *bytes; driven_->drv_peek_outgoing(0, &nbytes, &bytes); if (nbytes == 0) break; int nwrote = console_write(bytes, nbytes); if (nwrote <= 0) break; driven_->drv_sent_outgoing(0, nwrote); } } void handle_console_input() { char buffer[256]; read_console_recently_ = false; while (true) { int nread = console_read(buffer, 256); if (nread <= 0) break; read_console_recently_ = true; driven_->drv_recv_incoming(0, nread, buffer); } } void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { ChanInfo newchan; newchan.chid = chid; newchan.socket = sock; newchan.ssl = SSL_new(ctx); newchan.state = state; newchan.nbytes = 0; newchan.bytes = 0; newchan.released = false; newchan.just_released = false; newchan.ready_now = false; newchan.ready_on_pollin = false; newchan.ready_on_pollout = true; newchan.ready_on_outgoing = false; newchan.last_write_nbytes = 0; SSL_set_fd(newchan.ssl, newchan.socket); // SSL_set_msg_callback(newchan.ssl, SSL_trace); // SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0)); chans_.push_back(newchan); } void handle_new_outgoing_sockets() { const std::vector &chans = driven_->drv_get_new_outgoing(); for (int chid : chans) { UmmString err; SOCKET sock = open_connection(driven_->drv_get_target(chid), err); if (sock == INVALID_SOCKET) { driven_->drv_notify_close(chid, err); } else { //std::cerr << "Opening channel " << chid << std::endl; make_channel(sock, chid, ssl_ctx_with_no_certs_, CHAN_SSL_CONNECTING); } } if (!chans.empty()) { driven_->drv_clear_new_outgoing(); } } void accept_connection(int port, SOCKET sock) { UmmString err; SOCKET socket = accept_on_socket(sock, err); if_error_print_and_exit(err); if (socket != INVALID_SOCKET) { int chid = driven_->drv_notify_accept(port); // std::cerr << "Accepted channel " << chid << std::endl; make_channel(socket, chid, ssl_ctx_with_server_certs_, CHAN_SSL_ACCEPTING); } } void advance_plaintext(ChanInfo &chan) { UmmString err; // If the channel has no outgoing bytes and has been released, // just close it. if (chan.released) { close_channel(chan, ""); return; } // Try to write plaintext to the channel. int nbytes; const char *bytes; driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes); if (nbytes > 0) { int sbytes = nbytes; if (sbytes > 65536) sbytes = 65536; int wbytes = socket_send(chan.socket, bytes, sbytes, err); // std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " "; if (wbytes < 0) { close_channel(chan, err); } else { driven_->drv_sent_outgoing(chan.chid, wbytes); } } // Try to read plaintext from the channel. // Someday, find a way to avoid this copy. int nrecv = socket_recv(chan.socket, chbuf.get(), 65536, err); // std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " "; if (nrecv < 0) { close_channel(chan, err); } else { driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get()); } // Update the ready-flags for next time. chan.ready_on_outgoing = true; chan.ready_on_pollin = true; } void process_ssl_error(ChanInfo &chan, int retval) { int error = SSL_get_error(chan.ssl, retval); // std::cerr << "SSL error code = " << error << " "; if (error == SSL_ERROR_WANT_READ) { chan.ready_on_pollin = true; } else if (error == SSL_ERROR_WANT_WRITE) { chan.ready_on_pollout = true; } else { close_channel(chan, err_print_errors_str()); } } void advance_ssl_connecting(ChanInfo &chan) { // std::cerr << "In advance_ssl_connecting" << std::endl; int retval = SSL_connect(chan.ssl); if (retval == 1) { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; } else { // std::cerr << "ssl_connect_error"; process_ssl_error(chan, retval); } } void advance_ssl_accepting(ChanInfo &chan) { // std::cerr << "In advance_ssl_accepting" << std::endl; int retval = SSL_accept(chan.ssl); if (retval == 1) { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; } else { process_ssl_error(chan, retval); } } void advance_ssl_readwrite(ChanInfo &chan) { // std::cerr << "In advance_ssl_readwrite" << std::endl; // Try to read data. int read_result = SSL_read(chan.ssl, chbuf.get(), 65536); if (read_result > 0) { driven_->drv_recv_incoming(chan.chid, read_result, chbuf.get()); chan.ready_now = true; } else { process_ssl_error(chan, read_result); if (chan.state == CHAN_INACTIVE) return; } // Try to write data. int wbytes; if (chan.last_write_nbytes > 0) { wbytes = chan.last_write_nbytes; assert(wbytes < chan.nbytes); } else { wbytes = chan.nbytes; if (wbytes > 65536) wbytes = 65536; } if (wbytes > 0) { int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); if (write_result > 0) { driven_->drv_sent_outgoing(chan.chid, write_result); chan.last_write_nbytes = 0; chan.ready_on_outgoing = true; } else { chan.last_write_nbytes = wbytes; process_ssl_error(chan, write_result); if (chan.state == CHAN_INACTIVE) return; } } else { chan.ready_on_outgoing = true; } // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " "; } void advance_channel(ChanInfo &chan) { switch(chan.state) { case CHAN_PLAINTEXT: advance_plaintext(chan); break; case CHAN_SSL_CONNECTING: advance_ssl_connecting(chan); break; case CHAN_SSL_ACCEPTING: advance_ssl_accepting(chan); break; case CHAN_SSL_READWRITE: advance_ssl_readwrite(chan); break; default: assert(false); break; } } void handle_socket_input_output() { UmmString err; int mstimeout = read_console_recently_ ? 100 : 1000; // Peek output buffers and determine channel release flags. for (ChanInfo &chan : chans_) { driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes); chan.just_released = false; if ((chan.nbytes == 0)&&(!chan.released)) { chan.released = driven_->drv_get_channel_released(chan.chid); chan.just_released = chan.released; } } // Construct the struct pollfd vector. int pollsize = 0; for (const auto &p : listen_sockets_) { struct pollfd &pfd = pollvec[pollsize++]; pfd.fd = p.second; pfd.events = POLLIN; pfd.revents = 0; } for (const ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec[pollsize++]; assert(chan.socket != INVALID_SOCKET); pfd.fd = chan.socket; pfd.events = 0; pfd.revents = 0; if (chan.ready_now) mstimeout = 0; if (chan.just_released) mstimeout = 0; if (chan.ready_on_pollin) pfd.events |= POLLIN; if (chan.ready_on_pollout) pfd.events |= POLLOUT; if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT; // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " "; } // Do the poll. socket_poll(pollvec.get(), pollsize, mstimeout, err); if_error_print_and_exit(err); // Check listening sockets. int index = 0; for (auto &p : listen_sockets_) { struct pollfd &pfd = pollvec[index++]; if (pfd.revents & (POLLIN | POLLERR)) { accept_connection(p.first, p.second); } } // Advance channels where possible. for (ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec[index++]; bool pollin = ((pfd.revents & POLLIN) != 0); bool pollout = ((pfd.revents & POLLOUT) != 0); bool pollerr = ((pfd.revents & (POLLERR | POLLHUP)) != 0); if (chan.ready_now || pollerr || chan.just_released || (chan.ready_on_pollin && pollin) || (chan.ready_on_pollout && pollout) || (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; chan.ready_on_outgoing = false; advance_channel(chan); } chan.nbytes = 0; chan.bytes = 0; } // Delete any newly-inactive channels cleanup_channels(); } void drive(DrivenEngine *de, int argc, char *argv[]) { driven_ = de; read_console_recently_ = false; ssl_ctx_with_root_certs_ = new_ssl_context(false, true, ""); ssl_ctx_with_server_certs_ = new_ssl_context(true, false, ""); ssl_ctx_with_no_certs_ = new_ssl_context(false, false, ""); if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_, dummycert::certificate) <= 0) { ERR_print_errors_fp(stderr); exit(1); } if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_, dummycert::privatekey) <= 0 ) { ERR_print_errors_fp(stderr); exit(1); } DrivenEngine::set(de); handle_lua_source(); driven_->drv_invoke_event_init(argc, argv); handle_listen_ports(); while (!de->drv_get_stop_driver()) { handle_lua_source(); handle_console_output(); handle_new_outgoing_sockets(); handle_socket_input_output(); handle_console_input(); handle_console_output(); de->drv_invoke_event_update(monoclock.get()); } for (ChanInfo &chan : chans_) { close_channel(chan, ""); } SSL_CTX_free(ssl_ctx_with_no_certs_); SSL_CTX_free(ssl_ctx_with_root_certs_); SSL_CTX_free(ssl_ctx_with_server_certs_); DrivenEngine::set(nullptr); } }; void driver_drive(int argc, char *argv[]) { // The only place in the driver where we're allowed to use malloc // is here, before even looking at the arguments. That way, the // impact on the malloc heap is always exactly the same, which // doesn't break the determinism of the execution during replay. umm_init_heap(malloc(OPENSSL_HEAP_SIZE), OPENSSL_HEAP_SIZE); CRYPTO_set_mem_functions(umm_malloc_ssl, umm_realloc_ssl, umm_free_ssl); chbuf.reset(new char[CHBUF_SIZE]); pollvec.reset(new struct pollfd[POLLVEC_SIZE]); ERR_load_crypto_strings(); SSL_load_error_strings(); std::cerr << "#2 " << std::hex << mallocstate(1) << std::endl; Driver driver; if (argc < 2) { DrivenEngine::print_usage(std::cerr, argv[0]); exit(1); } UniqueDrivenEngine engine = DrivenEngine::make(argv[1]); if (engine == nullptr) { DrivenEngine::print_usage(std::cerr, argv[0]); exit(1); } driver.drive(engine.get(), argc, argv); }