diff --git a/luprex/core/cpp/drivenengine.cpp b/luprex/core/cpp/drivenengine.cpp index 1aa30f98..a3be997d 100644 --- a/luprex/core/cpp/drivenengine.cpp +++ b/luprex/core/cpp/drivenengine.cpp @@ -212,7 +212,7 @@ void DrivenEngine::drv_clear_new_outgoing() { new_outgoing_.clear(); } -const eng::string &DrivenEngine::drv_get_target(int chid) const { +std::string_view DrivenEngine::drv_get_target(int chid) const { return get_chid(chid)->target_; } diff --git a/luprex/core/cpp/drivenengine.hpp b/luprex/core/cpp/drivenengine.hpp index 5942913c..9d7a58cc 100644 --- a/luprex/core/cpp/drivenengine.hpp +++ b/luprex/core/cpp/drivenengine.hpp @@ -115,14 +115,15 @@ public: // The channel ID. These are reused. // - int chid() { return chid_; } + int chid() const { return chid_; } // If this is a socket connection, the receiver's port number. // - int port() { return port_; } + int port() const { return port_; } // If this is an outgoing socket connection, get the target host. - const eng::string &target() { return target_; } + // + const eng::string &target() const { return target_; } // True if the remote closed the connection, or a failure occurred. // @@ -133,7 +134,7 @@ public: // If this is an empty string, there is no error. If this is set, // then the channel is also closed. // - eng::string error() const { return error_; } + const eng::string &error() const { return error_; } // Set the prompt for readline mode. // @@ -314,12 +315,13 @@ public: void drv_clear_new_outgoing(); // Get the target of a channel. A target is a string like - // "www.whatever.com:80". It indicates the host and port that the channel - // is supposed to be talking to. Non-socket channels and incoming channels - // have empty targets. + // "cert:whatever.com:80" or "nocert:whatever.com:80". + // The first word indicate whether or not a valid SSL certificate + // is required. The second word is the hostname. The third word is + // the port number. // - const eng::string &drv_get_target(int chid) const; - + std::string_view drv_get_target(int chid) const; + // Return true if the outgoing buffer is empty. // bool drv_outgoing_empty(int chid) const; diff --git a/luprex/core/cpp/driver-common.cpp b/luprex/core/cpp/driver-common.cpp index f9f98f67..5c888f41 100644 --- a/luprex/core/cpp/driver-common.cpp +++ b/luprex/core/cpp/driver-common.cpp @@ -1,36 +1,45 @@ -#define CHBUF_SIZE (256*1024) -#define POLLVEC_SIZE (DrivenEngine::MAX_CHAN+1) +#define CHBUF_SIZE (256 * 1024) +#define POLLVEC_SIZE (DrivenEngine::MAX_CHAN + 1) static MonoClock monoclock; -namespace util { - double profiling_clock() { +namespace util +{ + double profiling_clock() + { return monoclock.get(); } } -static void if_error_print_and_exit(const std::string &str) { - if (!str.empty()) { - std::cerr << std::endl << "error: " << str << std::endl; +static void if_error_print_and_exit(const std::string &str) +{ + if (!str.empty()) + { + std::cerr << std::endl + << "error: " << str << std::endl; exit(1); } } -static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) { +static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) +{ FILE *f = fopen(fn, "r"); - if (f == 0) { + if (f == 0) + { err = std::string("cannot read file") + fn; buf[0] = 0; - return std::string_view(buf, 0); + return std::string_view(buf, 0); } int nread = fread(buf, 1, bufsize, f); - if (nread < 0) { + if (nread < 0) + { err = std::string("cannot read file: ") + fn; buf[0] = 0; return std::string_view(buf, 0); } - if (nread == bufsize) { + if (nread == bufsize) + { err = std::string("file too large: ") + fn; buf[0] = 0; return std::string_view(buf, 0); @@ -39,42 +48,66 @@ static std::string_view read_file(const char *fn, char *buf, int bufsize, std::s return std::string_view(buf, nread); } -struct SSL_CTX_Deleter { - void operator()(SSL_CTX *ctx) { +struct SSL_CTX_Deleter +{ + void operator()(SSL_CTX *ctx) + { SSL_CTX_free(ctx); } }; using UniqueSSLCTX = std::unique_ptr; -static UniqueSSLCTX new_ssl_context(bool server_cert, bool root_certs, std::string_view require_cert) { - SSL_CTX *ctx = SSL_CTX_new(TLS_method()); - SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); - SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr); - // server_cert is not implemented yet. - if (root_certs) { - SSL_CTX_set_default_verify_paths(ctx); - SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL); +static std::string ssl_errors_string(bool lastonly = true) +{ + std::string err; + const char *file, *data, *func; + int line, flags; + + while (true) + { + unsigned long code = ERR_get_error_all(&file, &line, &func, &data, &flags); + if (code == 0) + break; + std::string reason; + if (ERR_SYSTEM_ERROR(code)) + { + reason = strerror_str(ERR_GET_REASON(code)); + } + else + { + const char *rc = ERR_reason_error_string(code); + reason = (rc == nullptr) ? "unknown" : rc; + } + if (err.empty() || lastonly) + { + err = reason; + } + else + { + err = err + ", " + reason; + } + if (data != nullptr) + { + err = err + " " + data; + } } - // require_cert is not implemented yet. - return UniqueSSLCTX(ctx); + return err; } - - - -static std::string err_print_errors_str() { - BIO *bio = BIO_new(BIO_s_mem()); - ERR_print_errors(bio); - char *buf; - size_t len = BIO_get_mem_data(bio, &buf); - std::string ret(buf, len); - BIO_free(bio); - return ret; +void assert_ssl_errors_empty() +{ + int code = ERR_peek_error(); + if (code != 0) + { + std::cerr << "SSL should not have errors at this point." << std::endl; + ERR_print_errors_fp(stderr); + exit(1); + } } -static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) { +static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) +{ BIO *bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL); @@ -84,7 +117,8 @@ static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) { return status; } -static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) { +static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) +{ BIO *bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); @@ -94,18 +128,33 @@ static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) { return status; } +static void ssl_ctx_use_dummycert(SSL_CTX *ctx) +{ + if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0) + { + ERR_print_errors_fp(stderr); + exit(1); + } + if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0) + { + ERR_print_errors_fp(stderr); + exit(1); + } +} -class Driver { +class Driver +{ public: - - enum ChanState { + enum ChanState + { CHAN_INACTIVE, CHAN_PLAINTEXT, CHAN_SSL_CONNECTING, CHAN_SSL_ACCEPTING, CHAN_SSL_READWRITE, }; - struct ChanInfo { + struct ChanInfo + { int chid; SOCKET socket; SSL *ssl; @@ -129,15 +178,17 @@ public: std::unique_ptr pollvec_; drv::ReplayRecorder recorder_; - UniqueSSLCTX ssl_ctx_with_root_certs_; - UniqueSSLCTX ssl_ctx_with_server_certs_; - UniqueSSLCTX ssl_ctx_with_no_certs_; + UniqueSSLCTX ssl_server_ctx_; + UniqueSSLCTX ssl_client_secure_ctx_; + UniqueSSLCTX ssl_client_insecure_ctx_; - - void handle_listen_ports() { + void handle_listen_ports() + { const auto &listenports = recorder_.drv_get_listen_ports(); - for (int port : listenports) { - if (listen_sockets_.find(port) == listen_sockets_.end()) { + for (int port : listenports) + { + if (listen_sockets_.find(port) == listen_sockets_.end()) + { std::string err; SOCKET sock = listen_on_port(port, err); if_error_print_and_exit(err); @@ -147,14 +198,17 @@ public: } } - void handle_lua_source() { - if (recorder_.drv_get_rescan_lua_source()) { + void handle_lua_source() + { + if (recorder_.drv_get_rescan_lua_source()) + { std::string err; std::string_view ctrl = read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err); if_error_print_and_exit(err); std::vector names = drv::parse_control_lst(ctrl); recorder_.drv_clear_lua_source(); - for (const std::string &str : names) { + for (const std::string &str : names) + { std::string lfn = std::string("lua/") + str; std::string_view data = read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err); if_error_print_and_exit(err); @@ -162,12 +216,14 @@ public: } } } - - void close_channel(ChanInfo &chan, std::string_view err) { + + void close_channel(ChanInfo &chan, std::string_view err) + { // std::cerr << "Closing channel " << chan.chid << std::endl; assert(chan.state != CHAN_INACTIVE); // Close and release the SSL channel. - if (chan.ssl != nullptr) { + if (chan.ssl != nullptr) + { SSL_free(chan.ssl); chan.ssl = nullptr; } @@ -190,39 +246,52 @@ public: chan.last_write_nbytes = 0; } - void cleanup_channels() { - for (int i = 0; i < int(chans_.size()); ) { - if (chans_[i].state == CHAN_INACTIVE) { + void cleanup_channels() + { + for (int i = 0; i < int(chans_.size());) + { + if (chans_[i].state == CHAN_INACTIVE) + { chans_[i] = chans_.back(); chans_.pop_back(); - } else { + } + else + { i += 1; } } } - void handle_console_output() { - while (true) { + void handle_console_output() + { + while (true) + { std::string_view s = recorder_.drv_peek_outgoing(0); - if (s.size() == 0) break; + if (s.size() == 0) + break; int nwrote = console_write(s.data(), s.size()); - if (nwrote <= 0) break; + if (nwrote <= 0) + break; recorder_.drv_sent_outgoing(0, nwrote); } } - void handle_console_input() { + void handle_console_input() + { char buffer[256]; read_console_recently_ = false; - while (true) { + while (true) + { int nread = console_read(buffer, 256); - if (nread <= 0) break; + if (nread <= 0) + break; read_console_recently_ = true; recorder_.drv_recv_incoming(0, std::string_view(buffer, nread)); } } - void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { + void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) + { ChanInfo newchan; newchan.chid = chid; newchan.socket = sock; @@ -243,53 +312,81 @@ public: chans_.push_back(newchan); } - void handle_new_outgoing_sockets() { + void handle_new_outgoing_sockets() + { const auto &chans = recorder_.drv_get_new_outgoing(); - for (int chid : chans) { - std::string err; - SOCKET sock = open_connection(recorder_.drv_get_target(chid), err); - if (sock == INVALID_SOCKET) { - recorder_.drv_notify_close(chid, err); - } else { - //std::cerr << "Opening channel " << chid << std::endl; - make_channel(sock, chid, ssl_ctx_with_no_certs_.get(), CHAN_SSL_CONNECTING); + for (int chid : chans) + { + std::string err, cert, host, port; + std::string target(recorder_.drv_get_target(chid)); + drv::split_target(target, cert, host, port); + if (cert.empty() || host.empty() || port.empty()) { + recorder_.drv_notify_close(chid, std::string("invalid target: ") + target); + continue; } + SSL_CTX *ctx = nullptr; + if (cert == "cert") { + ctx = ssl_client_secure_ctx_.get(); + } else if (cert == "nocert") { + ctx = ssl_client_insecure_ctx_.get(); + } else { + recorder_.drv_notify_close(chid, std::string("invalid cert rule: ") + target); + continue; + } + SOCKET sock = open_connection(host.c_str(), port.c_str(), err); + if (sock == INVALID_SOCKET) + { + recorder_.drv_notify_close(chid, err); + continue; + } + // std::cerr << "Opening channel " << chid << std::endl; + make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING); } - if (!chans.empty()) { + if (!chans.empty()) + { recorder_.drv_clear_new_outgoing(); } } - void accept_connection(int port, SOCKET sock) { + void accept_connection(int port, SOCKET sock) + { std::string err; SOCKET socket = accept_on_socket(sock, err); if_error_print_and_exit(err); - if (socket != INVALID_SOCKET) { + if (socket != INVALID_SOCKET) + { int chid = recorder_.drv_notify_accept(port); // std::cerr << "Accepted channel " << chid << std::endl; - make_channel(socket, chid, ssl_ctx_with_server_certs_.get(), CHAN_SSL_ACCEPTING); + make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING); } } - void advance_plaintext(ChanInfo &chan) { + void advance_plaintext(ChanInfo &chan) + { std::string err; // If the channel has no outgoing bytes and has been released, // just close it. - if (chan.released) { + if (chan.released) + { close_channel(chan, ""); return; } // Try to write plaintext to the channel. std::string_view s = recorder_.drv_peek_outgoing(chan.chid); - if (s.size() > 0) { + if (s.size() > 0) + { int sbytes = s.size(); - if (sbytes > 65536) sbytes = 65536; + if (sbytes > 65536) + sbytes = 65536; int wbytes = socket_send(chan.socket, s.data(), sbytes, err); - if (wbytes < 0) { + if (wbytes < 0) + { close_channel(chan, err); - } else { + } + else + { recorder_.drv_sent_outgoing(chan.chid, wbytes); } } @@ -297,9 +394,12 @@ public: // Try to read plaintext from the channel. // Someday, find a way to avoid this copy. int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err); - if (nrecv < 0) { + if (nrecv < 0) + { close_channel(chan, err); - } else { + } + else + { recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), nrecv)); } @@ -308,84 +408,116 @@ public: chan.ready_on_pollin = true; } - void process_ssl_error(ChanInfo &chan, int retval) { + void process_ssl_error(ChanInfo &chan, int retval) + { int error = SSL_get_error(chan.ssl, retval); // std::cerr << "SSL error code = " << error << " "; - if (error == SSL_ERROR_WANT_READ) { + if (error == SSL_ERROR_WANT_READ) + { chan.ready_on_pollin = true; - } else if (error == SSL_ERROR_WANT_WRITE) { + } + else if (error == SSL_ERROR_WANT_WRITE) + { chan.ready_on_pollout = true; - } else { - close_channel(chan, err_print_errors_str()); + } + else + { + close_channel(chan, ssl_errors_string()); } } - void advance_ssl_connecting(ChanInfo &chan) { + void advance_ssl_connecting(ChanInfo &chan) + { // std::cerr << "In advance_ssl_connecting" << std::endl; int retval = SSL_connect(chan.ssl); - if (retval == 1) { + if (retval == 1) + { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; - } else { + } + else + { // std::cerr << "ssl_connect_error"; process_ssl_error(chan, retval); } } - - void advance_ssl_accepting(ChanInfo &chan) { + + void advance_ssl_accepting(ChanInfo &chan) + { // std::cerr << "In advance_ssl_accepting" << std::endl; int retval = SSL_accept(chan.ssl); - if (retval == 1) { + if (retval == 1) + { // Connection successful. chan.state = CHAN_SSL_READWRITE; chan.ready_now = true; - } else { + } + else + { process_ssl_error(chan, retval); } } - - void advance_ssl_readwrite(ChanInfo &chan) { + + void advance_ssl_readwrite(ChanInfo &chan) + { // std::cerr << "In advance_ssl_readwrite" << std::endl; // Try to read data. int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536); - if (read_result > 0) { + if (read_result > 0) + { recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), read_result)); chan.ready_now = true; - } else { + } + else + { process_ssl_error(chan, read_result); - if (chan.state == CHAN_INACTIVE) return; + if (chan.state == CHAN_INACTIVE) + return; } // Try to write data. int wbytes; - if (chan.last_write_nbytes > 0) { + if (chan.last_write_nbytes > 0) + { wbytes = chan.last_write_nbytes; assert(wbytes < chan.nbytes); - } else { - wbytes = chan.nbytes; - if (wbytes > 65536) wbytes = 65536; } - if (wbytes > 0) { + else + { + wbytes = chan.nbytes; + if (wbytes > 65536) + wbytes = 65536; + } + if (wbytes > 0) + { int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); - if (write_result > 0) { + if (write_result > 0) + { recorder_.drv_sent_outgoing(chan.chid, write_result); chan.last_write_nbytes = 0; chan.ready_on_outgoing = true; - } else { + } + else + { chan.last_write_nbytes = wbytes; process_ssl_error(chan, write_result); - if (chan.state == CHAN_INACTIVE) return; + if (chan.state == CHAN_INACTIVE) + return; } - } else { + } + else + { chan.ready_on_outgoing = true; } - // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " "; } - - void advance_channel(ChanInfo &chan) { - switch(chan.state) { + + void advance_channel(ChanInfo &chan) + { + assert_ssl_errors_empty(); + switch (chan.state) + { case CHAN_PLAINTEXT: advance_plaintext(chan); break; @@ -402,62 +534,75 @@ public: assert(false); break; } + assert_ssl_errors_empty(); } - - void handle_socket_input_output() { + void handle_socket_input_output() + { std::string err; int mstimeout = read_console_recently_ ? 100 : 1000; // Peek output buffers and determine channel release flags. - for (ChanInfo &chan : chans_) { + for (ChanInfo &chan : chans_) + { std::string_view s = recorder_.drv_peek_outgoing(chan.chid); chan.nbytes = s.size(); chan.bytes = s.data(); chan.just_released = false; - if ((chan.nbytes == 0)&&(!chan.released)) { + if ((chan.nbytes == 0) && (!chan.released)) + { chan.released = recorder_.drv_get_channel_released(chan.chid); chan.just_released = chan.released; } } - + // Construct the struct pollfd vector. int pollsize = 0; - for (const auto &p : listen_sockets_) { + for (const auto &p : listen_sockets_) + { struct pollfd &pfd = pollvec_[pollsize++]; pfd.fd = p.second; pfd.events = POLLIN; pfd.revents = 0; } - for (const ChanInfo &chan : chans_) { + for (const ChanInfo &chan : chans_) + { struct pollfd &pfd = pollvec_[pollsize++]; assert(chan.socket != INVALID_SOCKET); pfd.fd = chan.socket; pfd.events = 0; pfd.revents = 0; - if (chan.ready_now) mstimeout = 0; - if (chan.just_released) mstimeout = 0; - if (chan.ready_on_pollin) pfd.events |= POLLIN; - if (chan.ready_on_pollout) pfd.events |= POLLOUT; - if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT; + if (chan.ready_now) + mstimeout = 0; + if (chan.just_released) + mstimeout = 0; + if (chan.ready_on_pollin) + pfd.events |= POLLIN; + if (chan.ready_on_pollout) + pfd.events |= POLLOUT; + if (chan.ready_on_outgoing && (chan.nbytes > 0)) + pfd.events |= POLLOUT; // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " "; } // Do the poll. socket_poll(pollvec_.get(), pollsize, mstimeout, err); if_error_print_and_exit(err); - + // Check listening sockets. int index = 0; - for (auto &p : listen_sockets_) { + for (auto &p : listen_sockets_) + { struct pollfd &pfd = pollvec_[index++]; - if (pfd.revents & (POLLIN | POLLERR)) { + if (pfd.revents & (POLLIN | POLLERR)) + { accept_connection(p.first, p.second); } } // Advance channels where possible. - for (ChanInfo &chan : chans_) { + for (ChanInfo &chan : chans_) + { struct pollfd &pfd = pollvec_[index++]; bool pollin = ((pfd.revents & POLLIN) != 0); bool pollout = ((pfd.revents & POLLOUT) != 0); @@ -465,7 +610,8 @@ public: if (chan.ready_now || pollerr || chan.just_released || (chan.ready_on_pollin && pollin) || (chan.ready_on_pollout && pollout) || - (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { + (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) + { chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; @@ -480,36 +626,46 @@ public: cleanup_channels(); } - int replay_logfile(const char *fn, bool verbose) { + int replay_logfile(const char *fn, bool verbose) + { drv::ReplayPlayer player; player.open_logfile(fn); - if (verbose) { + if (verbose) + { player.enable_stdout(); } - while (true) { + while (true) + { drv::ReplayPlayer::Status st = player.step(); - if (st != drv::ReplayPlayer::ST_REPLAYING) { + if (st != drv::ReplayPlayer::ST_REPLAYING) + { player.print_status(std::cerr); return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1; } } } - int drive(int argc, char *argv[]) { + int drive(int argc, char *argv[]) + { // Remove the program name from argv. - if (argc < 1) { + if (argc < 1) + { DrivenEngine::print_usage(std::cerr, ""); exit(1); } std::string program = argv[0]; - argc -= 1; argv += 1; + argc -= 1; + argv += 1; // If argv contains "replay ", do a replay, // and then skip everything else. - if (argc >= 1) { + if (argc >= 1) + { std::string cmd(argv[0]); - if ((cmd == "replay") || (cmd == "vreplay")) { - if (argc != 2) { + if ((cmd == "replay") || (cmd == "vreplay")) + { + if (argc != 2) + { std::cerr << "usage: " << program << " replay " << std::endl; return 1; } @@ -519,29 +675,36 @@ public: // If argv contains "record ", start recording, // and remove the "record " from argv. - if (argc >= 1) { + if (argc >= 1) + { std::string cmd = argv[0]; - if (cmd == "record") { - if (argc < 2) { + if (cmd == "record") + { + if (argc < 2) + { DrivenEngine::print_usage(std::cerr, program); return 1; } bool ok = recorder_.open_logfile(argv[1]); - if (!ok) { + if (!ok) + { std::cerr << "Could not open logfile: " << argv[1] << std::endl; return 1; } - argc -= 2; argv += 2; + argc -= 2; + argv += 2; } } // Create the engine. - if (argc < 1) { + if (argc < 1) + { DrivenEngine::print_usage(std::cerr, program); return 1; } bool engine_made = recorder_.create_engine(argv[0]); - if (!engine_made) { + if (!engine_made) + { DrivenEngine::print_usage(std::cerr, program); return 1; } @@ -551,25 +714,17 @@ public: chbuf_.reset(new char[CHBUF_SIZE]); pollvec_.reset(new struct pollfd[POLLVEC_SIZE]); - ssl_ctx_with_root_certs_ = new_ssl_context(false, true, ""); - ssl_ctx_with_server_certs_ = new_ssl_context(true, false, ""); - ssl_ctx_with_no_certs_ = new_ssl_context(false, false, ""); - - if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_.get(), dummycert::certificate) <= 0) { - ERR_print_errors_fp(stderr); - return 1; - } - - if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_.get(), dummycert::privatekey) <= 0 ) { - ERR_print_errors_fp(stderr); - return 1; - } + ssl_server_ctx_.reset(new_ssl_server_context()); + ssl_client_secure_ctx_.reset(new_ssl_client_context(SSL_VERIFY_PEER)); + ssl_client_insecure_ctx_.reset(new_ssl_client_context(SSL_VERIFY_NONE)); + assert_ssl_errors_empty(); handle_lua_source(); recorder_.drv_invoke_event_init(argc, argv); handle_listen_ports(); - while (!recorder_.drv_get_stop_driver()) { + while (!recorder_.drv_get_stop_driver()) + { handle_lua_source(); handle_console_output(); handle_new_outgoing_sockets(); @@ -579,7 +734,8 @@ public: recorder_.drv_invoke_event_update(monoclock.get()); } - for (ChanInfo &chan : chans_) { + for (ChanInfo &chan : chans_) + { close_channel(chan, ""); } @@ -588,5 +744,3 @@ public: return 0; } }; - - diff --git a/luprex/core/cpp/driver-linux.cpp b/luprex/core/cpp/driver-linux.cpp index 152fab5d..eae420bf 100644 --- a/luprex/core/cpp/driver-linux.cpp +++ b/luprex/core/cpp/driver-linux.cpp @@ -41,7 +41,7 @@ struct termios orig_termios; static std::string strerror_str(int err) { char errbuf[256]; - return strerror_r(errno, errbuf, 256); + return strerror_r(err, errbuf, 256); } void set_nonblocking(int fd) { @@ -69,7 +69,7 @@ static void enable_tty_raw() { assert(status >= 0); } -static SOCKET open_connection(std::string_view target, std::string &err) { +static SOCKET open_connection(const char *host, const char *port, std::string &err) { struct addrinfo *addrs = nullptr; struct addrinfo *goodaddr = nullptr; struct addrinfo hints; @@ -82,9 +82,7 @@ static SOCKET open_connection(std::string_view target, std::string &err) { hints.ai_flags = AI_NUMERICSERV; err.clear(); - std::string host, port; - drv::split_host_port(target, host, port); - int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs); + int status = getaddrinfo(host, port, &hints, &addrs); if (status != 0) { err = gai_strerror(status); goto error_general; @@ -228,6 +226,25 @@ static int console_read(char *bytes, int nbytes) { return read(0, bytes, nbytes); } +static void ssl_ctx_use_dummycert(SSL_CTX *ctx); + +static SSL_CTX *new_ssl_server_context() { + SSL_CTX *ctx = SSL_CTX_new(TLS_method()); + SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); + SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr); + ssl_ctx_use_dummycert(ctx); + return ctx; +} + +static SSL_CTX *new_ssl_client_context(int verify) { + SSL_CTX *ctx = SSL_CTX_new(TLS_method()); + SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); + SSL_CTX_set_default_verify_paths(ctx); + SSL_CTX_set_verify(ctx, verify, nullptr); + return ctx; +} static void disable_randomization(int argc, char *argv[]) { const int old_personality = personality(ADDR_NO_RANDOMIZE); diff --git a/luprex/core/cpp/driver-mingw.cpp b/luprex/core/cpp/driver-mingw.cpp index 1fa476b4..6a217ab7 100644 --- a/luprex/core/cpp/driver-mingw.cpp +++ b/luprex/core/cpp/driver-mingw.cpp @@ -55,17 +55,15 @@ static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) { return nullptr; } -static SOCKET open_connection(std::string_view target, std::string &err) { +static SOCKET open_connection(const char *host, const char *port, std::string &err) { PADDRINFOA addrs = nullptr; PADDRINFOA goodaddr = nullptr; SOCKET sock = INVALID_SOCKET; - std::string_view host, port; err.clear(); - util::split_host_port(target, host, port); - int status = getaddrinfo(host.data(), port.data(), nullptr, &addrs); + int status = getaddrinfo(host, port, nullptr, &addrs); while (status == WSATRY_AGAIN) { - status = getaddrinfo(host.data(), port.data(), nullptr, &addrs); + status = getaddrinfo(host, port, nullptr, &addrs); } if (status == WSAHOST_NOT_FOUND) { err = "host not found"; diff --git a/luprex/core/cpp/driver-util.cpp b/luprex/core/cpp/driver-util.cpp index bfcaf30d..ba905f27 100644 --- a/luprex/core/cpp/driver-util.cpp +++ b/luprex/core/cpp/driver-util.cpp @@ -17,16 +17,31 @@ namespace drv { -void split_host_port(std::string_view target, std::string &host, std::string &port) { - size_t lastcolon = target.rfind(':'); - if (lastcolon == std::string_view::npos) { - host = ""; port = ""; return; +std::vector split_view(std::string_view v, char sep) { + std::vector result; + while (true) { + size_t pos = v.find(sep); + if (pos == std::string_view::npos) break; + result.push_back(v.substr(0, pos)); + v = v.substr(pos + 1); } - host = target.substr(0, lastcolon); - port = target.substr(lastcolon + 1); - if ((host == "") || (port == "")) { - host = ""; port = ""; return; + result.push_back(v); + return result; +} + +void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port) { + std::vector split = split_view(target, ':'); + if (split.size() != 3) { + cert.clear(); host.clear(); port.clear(); + return; } + if (split[0].empty() || split[1].empty() || split[2].empty()) { + cert.clear(); host.clear(); port.clear(); + return; + } + cert = std::string(split[0]); + host = std::string(split[1]); + port = std::string(split[2]); } std::vector parse_control_lst(std::string_view ctrl) { @@ -502,9 +517,10 @@ void ReplayPlayer::drv_invoke_event_update() { } // namespace drv LuaDefine(unittests_driverutil, "", "some unit tests") { - // Test split_host_port - std::string host, port; - drv::split_host_port("stanford.edu:80", host, port); + // Test split_target + std::string cert, host, port; + drv::split_target("cert:stanford.edu:80", cert, host, port); + LuaAssertStrEq(L, cert, "cert"); LuaAssertStrEq(L, host, "stanford.edu"); LuaAssertStrEq(L, port, "80"); diff --git a/luprex/core/cpp/driver-util.hpp b/luprex/core/cpp/driver-util.hpp index a70814ef..88666aae 100644 --- a/luprex/core/cpp/driver-util.hpp +++ b/luprex/core/cpp/driver-util.hpp @@ -11,7 +11,8 @@ namespace drv { -void split_host_port(std::string_view target, std::string &host, std::string &port); +void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port); + std::vector parse_control_lst(std::string_view ctrl); @@ -58,7 +59,7 @@ public: // const eng::vector &drv_get_listen_ports() const { return e_->drv_get_listen_ports(); } const eng::vector &drv_get_new_outgoing() const { return e_->drv_get_new_outgoing(); } - const eng::string &drv_get_target(int chid) const { return e_->drv_get_target(chid); } + std::string_view drv_get_target(int chid) const { return e_->drv_get_target(chid); } bool drv_outgoing_empty(int chid) const { return e_->drv_outgoing_empty(chid); } bool drv_get_channel_released(int chid) const { return e_->drv_get_channel_released(chid); } std::string_view drv_peek_outgoing(int chid) const { return e_->drv_peek_outgoing(chid); } diff --git a/luprex/core/cpp/lpxclient.cpp b/luprex/core/cpp/lpxclient.cpp index 4ed6ed89..9b0deb32 100644 --- a/luprex/core/cpp/lpxclient.cpp +++ b/luprex/core/cpp/lpxclient.cpp @@ -74,7 +74,7 @@ public: set_initial_state(); // Establish a connection to the server. - channel_ = new_outgoing_channel("localhost:8085"); + channel_ = new_outgoing_channel("cert:localhost:8085"); // Set the console prompt get_stdio_channel()->set_prompt(console_.get_prompt()); @@ -262,7 +262,7 @@ public: // Check for communication from server.. if (channel_ != nullptr) { if (channel_->closed()) { - stdostream() << "Server closed connection " << channel_->error() << std::endl; + stdostream() << "server closed connection: " << channel_->error() << std::endl; abandon_server(); } else { while (true) {