From 2b59d8a4a372866c6e19baac031e94ddb849829b Mon Sep 17 00:00:00 2001 From: jyelon Date: Tue, 11 Jan 2022 13:59:13 -0500 Subject: [PATCH] refactor driver-linux to separate machine dependent and independent code. --- luprex/core/cpp/driver-common.cpp | 487 ++++++++++++++++++++++++ luprex/core/cpp/driver-linux.cpp | 602 +++++------------------------- 2 files changed, 577 insertions(+), 512 deletions(-) create mode 100644 luprex/core/cpp/driver-common.cpp diff --git a/luprex/core/cpp/driver-common.cpp b/luprex/core/cpp/driver-common.cpp new file mode 100644 index 00000000..24fe6ab5 --- /dev/null +++ b/luprex/core/cpp/driver-common.cpp @@ -0,0 +1,487 @@ + +static MonoClock monoclock; + +namespace util { + double profiling_clock() { + return monoclock.get(); + } +} + +static void if_error_print_and_exit(const std::string &str) { + if (!str.empty()) { + std::cerr << std::endl << "error: " << str << std::endl; + exit(1); + } +} + +static SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &require_cert) { + SSL_CTX *ctx = SSL_CTX_new(TLS_method()); + SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); + SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr); + // server_cert is not implemented yet. + if (root_certs) { + SSL_CTX_set_default_verify_paths(ctx); + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL); + } + // require_cert is not implemented yet. + return ctx; +} + +static std::string err_print_errors_str() { + BIO *bio = BIO_new(BIO_s_mem()); + ERR_print_errors(bio); + char *buf; + size_t len = BIO_get_mem_data(bio, &buf); + std::string ret(buf, len); + BIO_free(bio); + return ret; +} + +static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) { + BIO *bio = BIO_new(BIO_s_mem()); + BIO_puts(bio, str); + X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL); + BIO_free(bio); + int status = SSL_CTX_use_certificate(ctx, certificate); + X509_free(certificate); + return status; +} + +static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) { + BIO *bio = BIO_new(BIO_s_mem()); + BIO_puts(bio, str); + EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); + BIO_free(bio); + int status = SSL_CTX_use_PrivateKey(ctx, pkey); + EVP_PKEY_free(pkey); + return status; +} + +class Driver { +public: + enum ChanState { + CHAN_INACTIVE, + CHAN_PLAINTEXT, + CHAN_SSL_CONNECTING, + CHAN_SSL_ACCEPTING, + CHAN_SSL_READWRITE, + }; + struct ChanInfo { + int chid; + SOCKET socket; + SSL *ssl; + + ChanState state; + int nbytes; + const char *bytes; + bool released; + bool just_released; + bool ready_now; + bool ready_on_pollin; + bool ready_on_pollout; + bool ready_on_outgoing; + int last_write_nbytes; + }; + + DrivenEngine *driven_; + std::vector chans_; + std::map listen_sockets_; + std::unique_ptr chbuf; + bool read_console_recently_; + + SSL_CTX *ssl_ctx_with_root_certs_; + SSL_CTX *ssl_ctx_with_server_certs_; + SSL_CTX *ssl_ctx_with_no_certs_; + + void handle_listen_ports() { + std::set listenports; + driven_->drv_get_listen_ports(listenports); + for (int port : listenports) { + if (listen_sockets_.find(port) == listen_sockets_.end()) { + std::string err; + SOCKET sock = listen_on_port(port, err); + if_error_print_and_exit(err); + assert(sock != INVALID_SOCKET); + listen_sockets_[port] = sock; + } + } + } + + void handle_lua_source() { + if (driven_->drv_get_rescan_lua_source()) { + driven_->drv_set_lua_source(util::read_lua_source("lua")); + } + } + + void close_channel(ChanInfo &chan, const std::string &err) { + // std::cerr << "Closing channel " << chan.chid << std::endl; + assert(chan.state != CHAN_INACTIVE); + // Close and release the SSL channel. + if (chan.ssl != nullptr) { + SSL_free(chan.ssl); + chan.ssl = nullptr; + } + // Close and release the socket. + assert(chan.socket != INVALID_SOCKET); + assert(socket_close(chan.socket) == 0); + chan.socket = INVALID_SOCKET; + // Close everything else. + driven_->drv_notify_close(chan.chid, err); + chan.state = CHAN_INACTIVE; + chan.chid = -1; + chan.nbytes = 0; + chan.bytes = 0; + chan.released = false; + chan.just_released = false; + chan.ready_now = false; + chan.ready_on_pollin = false; + chan.ready_on_pollout = false; + chan.ready_on_outgoing = false; + chan.last_write_nbytes = 0; + } + + void cleanup_channels() { + for (int i = 0; i < int(chans_.size()); ) { + if (chans_[i].state == CHAN_INACTIVE) { + chans_[i] = chans_.back(); + chans_.pop_back(); + } else { + i += 1; + } + } + } + + void handle_console_output() { + while (true) { + int nbytes; const char *bytes; + driven_->drv_peek_outgoing(0, &nbytes, &bytes); + if (nbytes == 0) break; + int nwrote = console_write(bytes, nbytes); + assert(nwrote > 0); + driven_->drv_sent_outgoing(0, nwrote); + } + } + + void handle_console_input() { + char buffer[256]; + read_console_recently_ = false; + while (true) { + int nread = console_read(buffer, 256); + if (nread == 0) break; + assert(nread > 0); + read_console_recently_ = true; + driven_->drv_recv_incoming(0, nread, buffer); + } + } + + void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { + ChanInfo newchan; + newchan.chid = chid; + newchan.socket = sock; + newchan.ssl = SSL_new(ctx); + newchan.state = state; + newchan.nbytes = 0; + newchan.bytes = 0; + newchan.released = false; + newchan.just_released = false; + newchan.ready_now = false; + newchan.ready_on_pollin = false; + newchan.ready_on_pollout = true; + newchan.ready_on_outgoing = false; + newchan.last_write_nbytes = 0; + SSL_set_fd(newchan.ssl, newchan.socket); + // SSL_set_msg_callback(newchan.ssl, SSL_trace); + // SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0)); + chans_.push_back(newchan); + } + + void handle_new_outgoing_sockets() { + std::set chans; + driven_->drv_get_new_outgoing(chans); + for (int chid : chans) { + std::string err; + SOCKET sock = open_connection(driven_->drv_get_target(chid), err); + if (sock == INVALID_SOCKET) { + driven_->drv_notify_close(chid, err); + } else { + //std::cerr << "Opening channel " << chid << std::endl; + make_channel(sock, chid, ssl_ctx_with_no_certs_, CHAN_SSL_CONNECTING); + } + } + } + + void accept_connection(int port, SOCKET sock) { + std::string err; + SOCKET socket = accept_on_socket(sock, err); + if_error_print_and_exit(err); + if (socket != INVALID_SOCKET) { + int chid = driven_->drv_notify_accept(port); + // std::cerr << "Accepted channel " << chid << std::endl; + make_channel(socket, chid, ssl_ctx_with_server_certs_, CHAN_SSL_ACCEPTING); + } + } + + void advance_plaintext(ChanInfo &chan) { + std::string err; + + // If the channel has no outgoing bytes and has been released, + // just close it. + if (chan.released) { + close_channel(chan, ""); + return; + } + + // Try to write plaintext to the channel. + int nbytes; const char *bytes; + driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes); + if (nbytes > 0) { + int sbytes = nbytes; + if (sbytes > 65536) sbytes = 65536; + int wbytes = socket_send(chan.socket, bytes, sbytes, err); + // std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " "; + if (wbytes < 0) { + close_channel(chan, err); + } else { + driven_->drv_sent_outgoing(chan.chid, wbytes); + } + } + + // Try to read plaintext from the channel. + // Someday, find a way to avoid this copy. + int nrecv = socket_recv(chan.socket, chbuf.get(), 65536, err); + // std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " "; + if (nrecv < 0) { + close_channel(chan, err); + } else { + driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get()); + } + + // Update the ready-flags for next time. + chan.ready_on_outgoing = true; + chan.ready_on_pollin = true; + } + + void process_ssl_error(ChanInfo &chan, int retval) { + int error = SSL_get_error(chan.ssl, retval); + // std::cerr << "SSL error code = " << error << " "; + if (error == SSL_ERROR_WANT_READ) { + chan.ready_on_pollin = true; + } else if (error == SSL_ERROR_WANT_WRITE) { + chan.ready_on_pollout = true; + } else { + close_channel(chan, err_print_errors_str()); + } + } + + void advance_ssl_connecting(ChanInfo &chan) { + // std::cerr << "In advance_ssl_connecting" << std::endl; + int retval = SSL_connect(chan.ssl); + if (retval == 1) { + // Connection successful. + chan.state = CHAN_SSL_READWRITE; + chan.ready_now = true; + } else { + // std::cerr << "ssl_connect_error"; + process_ssl_error(chan, retval); + } + } + + void advance_ssl_accepting(ChanInfo &chan) { + // std::cerr << "In advance_ssl_accepting" << std::endl; + int retval = SSL_accept(chan.ssl); + if (retval == 1) { + // Connection successful. + chan.state = CHAN_SSL_READWRITE; + chan.ready_now = true; + } else { + process_ssl_error(chan, retval); + } + } + + void advance_ssl_readwrite(ChanInfo &chan) { + // std::cerr << "In advance_ssl_readwrite" << std::endl; + // Try to read data. + int read_result = SSL_read(chan.ssl, chbuf.get(), 65536); + if (read_result > 0) { + driven_->drv_recv_incoming(chan.chid, read_result, chbuf.get()); + chan.ready_now = true; + } else { + process_ssl_error(chan, read_result); + if (chan.state == CHAN_INACTIVE) return; + } + + // Try to write data. + int wbytes; + if (chan.last_write_nbytes > 0) { + wbytes = chan.last_write_nbytes; + assert(wbytes < chan.nbytes); + } else { + wbytes = chan.nbytes; + if (wbytes > 65536) wbytes = 65536; + } + if (wbytes > 0) { + int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); + if (write_result > 0) { + driven_->drv_sent_outgoing(chan.chid, write_result); + chan.last_write_nbytes = 0; + chan.ready_on_outgoing = true; + } else { + chan.last_write_nbytes = wbytes; + process_ssl_error(chan, write_result); + if (chan.state == CHAN_INACTIVE) return; + } + } else { + chan.ready_on_outgoing = true; + } + + // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " "; + } + + void advance_channel(ChanInfo &chan) { + switch(chan.state) { + case CHAN_PLAINTEXT: + advance_plaintext(chan); + break; + case CHAN_SSL_CONNECTING: + advance_ssl_connecting(chan); + break; + case CHAN_SSL_ACCEPTING: + advance_ssl_accepting(chan); + break; + case CHAN_SSL_READWRITE: + advance_ssl_readwrite(chan); + break; + default: + assert(false); + break; + } + } + + + void handle_socket_input_output() { + int mstimeout = 1000; + + // Peek output buffers and determine channel release flags. + for (ChanInfo &chan : chans_) { + driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes); + chan.just_released = false; + if ((chan.nbytes == 0)&&(!chan.released)) { + chan.released = driven_->drv_get_channel_released(chan.chid); + chan.just_released = chan.released; + } + } + + // Construct the pollfd vector. + int pollsize = listen_sockets_.size() + chans_.size() + 1; + std::vector pollvec(pollsize); + int index = 0; + for (const auto &p : listen_sockets_) { + struct pollfd &pfd = pollvec[index++]; + pfd.fd = p.second; + pfd.events = POLLIN; + } + for (const ChanInfo &chan : chans_) { + struct pollfd &pfd = pollvec[index++]; + assert(chan.socket != INVALID_SOCKET); + pfd.fd = chan.socket; + pfd.events = POLLERR; + if (chan.ready_now) mstimeout = 0; + if (chan.just_released) mstimeout = 0; + if (chan.ready_on_pollin) pfd.events |= POLLIN; + if (chan.ready_on_pollout) pfd.events |= POLLOUT; + if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT; + // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " "; + } + fill_stdio_pollfd(pollvec, mstimeout, read_console_recently_); + + // Do the poll. + int status = poll(&pollvec[0], pollvec.size(), mstimeout); + assert(status >= 0); + + // Check listening sockets. + index = 0; + for (auto &p : listen_sockets_) { + struct pollfd &pfd = pollvec[index++]; + if (pfd.revents & (POLLIN | POLLERR)) { + accept_connection(p.first, p.second); + } + } + + // Advance channels where possible. + for (ChanInfo &chan : chans_) { + struct pollfd &pfd = pollvec[index++]; + bool pollin = ((pfd.revents & POLLIN) != 0); + bool pollout = ((pfd.revents & POLLOUT) != 0); + bool pollerr = ((pfd.revents & POLLERR) != 0); + if (chan.ready_now || pollerr || chan.just_released || + (chan.ready_on_pollin && pollin) || + (chan.ready_on_pollout && pollout) || + (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { + chan.ready_now = false; + chan.ready_on_pollin = false; + chan.ready_on_pollout = false; + chan.ready_on_outgoing = false; + advance_channel(chan); + } + chan.nbytes = 0; + chan.bytes = 0; + } + + // Delete any newly-inactive channels + cleanup_channels(); + } + + void drive(DrivenEngine *de, int argc, char *argv[]) { + SSL_load_error_strings(); + ERR_load_crypto_strings(); + enable_tty_raw(); + driven_ = de; + read_console_recently_ = false; + chbuf.reset(new char[65536]); + + ssl_ctx_with_root_certs_ = new_ssl_context(false, true, ""); + ssl_ctx_with_server_certs_ = new_ssl_context(true, false, ""); + ssl_ctx_with_no_certs_ = new_ssl_context(false, false, ""); + + if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_, dummycert::certificate) <= 0) { + ERR_print_errors_fp(stderr); + exit(1); + } + + if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_, dummycert::privatekey) <= 0 ) { + ERR_print_errors_fp(stderr); + exit(1); + } + + DrivenEngine::set(de); + driven_->drv_set_lua_source(util::read_lua_source("lua")); + driven_->drv_invoke_event_init(argc, argv); + handle_listen_ports(); + + while (!de->drv_get_stop_driver()) { + handle_lua_source(); + handle_console_output(); + handle_new_outgoing_sockets(); + handle_socket_input_output(); + handle_console_input(); + handle_console_output(); + de->drv_invoke_event_update(monoclock.get()); + } + + for (ChanInfo &chan : chans_) { + close_channel(chan, ""); + } + SSL_CTX_free(ssl_ctx_with_no_certs_); + SSL_CTX_free(ssl_ctx_with_root_certs_); + SSL_CTX_free(ssl_ctx_with_server_certs_); + DrivenEngine::set(nullptr); + } +}; + + +void driver_drive(DrivenEngine *de, int argc, char *argv[]) { + Driver driver; + driver.drive(de, argc, argv); +} + diff --git a/luprex/core/cpp/driver-linux.cpp b/luprex/core/cpp/driver-linux.cpp index cab66669..2828fb2c 100644 --- a/luprex/core/cpp/driver-linux.cpp +++ b/luprex/core/cpp/driver-linux.cpp @@ -25,6 +25,8 @@ #include #include #include +#include +#include using SOCKET=int; const int INVALID_SOCKET = -1; @@ -32,6 +34,11 @@ using PollVector = std::vector; struct termios orig_termios; +static std::string strerror_str(int err) { + char errbuf[256]; + return strerror_r(errno, errbuf, 256); +} + void set_nonblocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); assert(flags != -1); @@ -39,14 +46,14 @@ void set_nonblocking(int fd) { assert(status != -1); } -static void disableRawMode() { +static void disable_tty_raw() { tcsetattr(0, TCSAFLUSH, &orig_termios); } -static void enableRawMode() { +static void enable_tty_raw() { int status = tcgetattr(0, &orig_termios); assert(status >= 0); - atexit(disableRawMode); + atexit(disable_tty_raw); struct termios raw = orig_termios; raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); raw.c_lflag &= ~(ECHO | ICANON); @@ -57,12 +64,7 @@ static void enableRawMode() { assert(status >= 0); } -static std::string strerror_str(int err) { - char errbuf[256]; - return strerror_r(errno, errbuf, 256); -} - -SOCKET open_connection(const std::string &target, std::string &err) { +static SOCKET open_connection(const std::string &target, std::string &err) { struct addrinfo *addrs = nullptr; struct addrinfo *goodaddr = nullptr; struct addrinfo hints; @@ -80,58 +82,66 @@ SOCKET open_connection(const std::string &target, std::string &err) { int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs); if (status != 0) { err = gai_strerror(status); - goto error; + goto error_general; } if (addrs == nullptr) { err = "no such host found"; - goto error; + goto error_general; } goodaddr = addrs; assert(goodaddr->ai_family == AF_INET); assert(goodaddr->ai_socktype == SOCK_STREAM); assert(goodaddr->ai_protocol == IPPROTO_TCP); sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol); - assert(sock > 0); + if (sock <= 0) goto error_errno; + set_nonblocking(sock); status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); - if ((status != 0) && (errno != EINPROGRESS)) { - goto error_errno; - } + if ((status != 0) && (errno != EINPROGRESS)) goto error_errno; + freeaddrinfo(addrs); return sock; error_errno: err = strerror_str(errno); -error: +error_general: if (sock != INVALID_SOCKET) close(sock); if (addrs != nullptr) freeaddrinfo(addrs); return INVALID_SOCKET; } -SOCKET listen_on_port(int port, std::string &err) { - int status; +static SOCKET listen_on_port(int port, std::string &err) { + int status, enable; err.clear(); + SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); - assert(sock > 0); + if (sock <= 0) goto error_errno; - int enable = 1; + enable = 1; status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); - assert(status == 0); - + if (status != 0) goto error_errno; + struct sockaddr_in server; server.sin_family = AF_INET; server.sin_addr.s_addr = INADDR_ANY; server.sin_port = htons(port); status = bind(sock, (struct sockaddr *)&server, sizeof(server)); - assert(status == 0); + if (status != 0) goto error_errno; + status = listen(sock, 10); - assert(status == 0); + if (status != 0) goto error_errno; + set_nonblocking(sock); return sock; + +error_errno: + err = strerror_str(errno); + if (sock >= 0) close(sock); + return INVALID_SOCKET; } -SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { +static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { err.clear(); SOCKET chsock = accept(listen_socket, nullptr, nullptr); if (chsock >= 0) { @@ -145,60 +155,66 @@ SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { } } -SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &require_cert) { - SSL_CTX *ctx = SSL_CTX_new(TLS_method()); - SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); - SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr); - // server_cert is not implemented yet. - if (root_certs) { - SSL_CTX_set_default_verify_paths(ctx); - SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL); +// the return values for socket_send and socket_recv are: +// +// positive: sent or received bytes successfully +// zero: would block +// negative: channel closed, possibly cleanly or possibly with error +// +static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) { + err.clear(); + int wbytes = send(socket, bytes, nbytes, 0); + if (wbytes < 0) { + if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { + return 0; + } else { + err = strerror_str(errno); + return -1; + } + } else { + return wbytes; } - // require_cert is not implemented yet. - return ctx; } -std::string err_print_errors_str() { - BIO *bio = BIO_new(BIO_s_mem()); - ERR_print_errors(bio); - char *buf; - size_t len = BIO_get_mem_data(bio, &buf); - std::string ret(buf, len); - BIO_free(bio); - return ret; -} -#include -#include -#include - -int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) { - BIO *bio = BIO_new(BIO_s_mem()); - BIO_puts(bio, str); - X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL); - BIO_free(bio); - int status = SSL_CTX_use_certificate(ctx, certificate); - X509_free(certificate); - return status; -} - -int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) { - BIO *bio = BIO_new(BIO_s_mem()); - BIO_puts(bio, str); - EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); - BIO_free(bio); - int status = SSL_CTX_use_PrivateKey(ctx, pkey); - EVP_PKEY_free(pkey); - return status; -} - -static void print_error_and_exit(const std::string &str) { - if (!str.empty()) { - std::cerr << "error: " << str << std::endl; - exit(1); +static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) { + err.clear(); + int nrecv = recv(socket, bytes, nbytes, 0); + if (nrecv < 0) { + if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { + err = strerror_str(errno); + return -1; + } else { + return 0; + } + } else if (nrecv == 0) { + return -1; + } else { + return nrecv; } } +static int socket_close(SOCKET socket) { + return close(socket); +} + +static int console_write(const char *bytes, int nbytes) { + return write(1, bytes, nbytes); +} + +static int console_read(char *bytes, int nbytes) { + return read(0, bytes, nbytes); +} + +// The last element in the vector is supposed to be +// for polling stdio. But on windows, you can't poll +// stdio, so on windows, we remove the last element from +// the vector and we reduce mstimeout instead. +static void fill_stdio_pollfd(PollVector &pollvec, int &mstimeout, bool read_console_recently) { + struct pollfd &stdiopoll = pollvec.back(); + stdiopoll.fd = 0; + stdiopoll.events = POLLIN; +} + class MonoClock { private: struct timespec base_; @@ -217,442 +233,4 @@ public: } }; -static MonoClock monoclock; - -namespace util { - double profiling_clock() { - return monoclock.get(); - } -} - -class Driver { -public: - enum ChanState { - CHAN_INACTIVE, - CHAN_PLAINTEXT, - CHAN_SSL_CONNECTING, - CHAN_SSL_ACCEPTING, - CHAN_SSL_READWRITE, - }; - struct ChanInfo { - int chid; - SOCKET socket; - SSL *ssl; - - ChanState state; - int nbytes; - const char *bytes; - bool released; - bool just_released; - bool ready_now; - bool ready_on_pollin; - bool ready_on_pollout; - bool ready_on_outgoing; - int last_write_nbytes; - }; - - DrivenEngine *driven_; - std::vector chans_; - std::map listen_sockets_; - std::unique_ptr chbuf; - - SSL_CTX *ssl_ctx_with_root_certs_; - SSL_CTX *ssl_ctx_with_server_certs_; - SSL_CTX *ssl_ctx_with_no_certs_; - - void handle_listen_ports() { - std::set listenports; - driven_->drv_get_listen_ports(listenports); - for (int port : listenports) { - if (listen_sockets_.find(port) == listen_sockets_.end()) { - std::string err; - SOCKET sock = listen_on_port(port, err); - print_error_and_exit(err); - assert(sock != INVALID_SOCKET); - listen_sockets_[port] = sock; - } - } - } - - void handle_lua_source() { - if (driven_->drv_get_rescan_lua_source()) { - driven_->drv_set_lua_source(util::read_lua_source("lua")); - } - } - - void close_channel(ChanInfo &chan, const std::string &err) { - // std::cerr << "Closing channel " << chan.chid << std::endl; - assert(chan.state != CHAN_INACTIVE); - // Close the SSL channel. - if (chan.ssl != nullptr) { - SSL_free(chan.ssl); - chan.ssl = nullptr; - } - // Close the socket. - assert(chan.socket != INVALID_SOCKET); - assert(close(chan.socket) == 0); - chan.socket = INVALID_SOCKET; - // Close everything else. - driven_->drv_notify_close(chan.chid, err); - chan.state = CHAN_INACTIVE; - chan.chid = -1; - chan.nbytes = 0; - chan.bytes = 0; - chan.released = false; - chan.just_released = false; - chan.ready_now = false; - chan.ready_on_pollin = false; - chan.ready_on_pollout = false; - chan.ready_on_outgoing = false; - chan.last_write_nbytes = 0; - } - - void cleanup_channels() { - for (int i = 0; i < int(chans_.size()); ) { - if (chans_[i].state == CHAN_INACTIVE) { - chans_[i] = chans_.back(); - chans_.pop_back(); - } else { - i += 1; - } - } - } - - void handle_console_output() { - while (true) { - int nbytes; const char *bytes; - driven_->drv_peek_outgoing(0, &nbytes, &bytes); - if (nbytes == 0) break; - int nwrote = write(1, bytes, nbytes); - assert(nwrote > 0); - driven_->drv_sent_outgoing(0, nwrote); - } - } - - void handle_console_input() { - char buffer[256]; - while (true) { - int nread = read(0, buffer, 256); - if (nread == 0) break; - assert(nread > 0); - driven_->drv_recv_incoming(0, nread, buffer); - } - } - - void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { - ChanInfo newchan; - newchan.chid = chid; - newchan.socket = sock; - newchan.ssl = SSL_new(ctx); - newchan.state = state; - newchan.nbytes = 0; - newchan.bytes = 0; - newchan.released = false; - newchan.just_released = false; - newchan.ready_now = false; - newchan.ready_on_pollin = false; - newchan.ready_on_pollout = true; - newchan.ready_on_outgoing = false; - newchan.last_write_nbytes = 0; - SSL_set_fd(newchan.ssl, newchan.socket); - // SSL_set_msg_callback(newchan.ssl, SSL_trace); - // SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0)); - chans_.push_back(newchan); - } - - void handle_new_outgoing_sockets() { - std::set chans; - driven_->drv_get_new_outgoing(chans); - for (int chid : chans) { - std::string err; - SOCKET sock = open_connection(driven_->drv_get_target(chid), err); - if (sock == INVALID_SOCKET) { - driven_->drv_notify_close(chid, err); - } else { - //std::cerr << "Opening channel " << chid << std::endl; - make_channel(sock, chid, ssl_ctx_with_no_certs_, CHAN_SSL_CONNECTING); - } - } - } - - void accept_connections(int port, SOCKET sock) { - std::string err; - SOCKET socket = accept_on_socket(sock, err); - print_error_and_exit(err); - if (socket != INVALID_SOCKET) { - int chid = driven_->drv_notify_accept(port); - // std::cerr << "Accepted channel " << chid << std::endl; - make_channel(socket, chid, ssl_ctx_with_server_certs_, CHAN_SSL_ACCEPTING); - } - } - - void advance_plaintext(ChanInfo &chan) { - // If the channel has no outgoing bytes and has been released, - // just close it. - if (chan.released) { - close_channel(chan, ""); - return; - } - - // Try to write plaintext to the channel. - int nbytes; const char *bytes; - driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes); - if (nbytes > 0) { - int sbytes = nbytes; - if (sbytes > 65536) sbytes = 65536; - int wbytes = send(chan.socket, bytes, sbytes, 0); - // std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " "; - if (wbytes < 0) { - if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) { - close_channel(chan, "send failure"); - return; - } - } else { - driven_->drv_sent_outgoing(chan.chid, wbytes); - } - } - - // Try to read plaintext from the channel. - // Someday, find a way to avoid this copy. - int nrecv = recv(chan.socket, chbuf.get(), 65536, 0); - // std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " "; - if (nrecv < 0) { - if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) { - close_channel(chan, "recv failure"); - return; - } - } else if (nrecv == 0) { - close_channel(chan, ""); - return; - } else { - driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get()); - } - - // Update the ready-flags for next time. - chan.ready_on_outgoing = true; - chan.ready_on_pollin = true; - } - - void process_ssl_error(ChanInfo &chan, int retval) { - int error = SSL_get_error(chan.ssl, retval); - // std::cerr << "SSL error code = " << error << " "; - if (error == SSL_ERROR_WANT_READ) { - chan.ready_on_pollin = true; - } else if (error == SSL_ERROR_WANT_WRITE) { - chan.ready_on_pollout = true; - } else { - close_channel(chan, err_print_errors_str()); - } - } - - void advance_ssl_connecting(ChanInfo &chan) { - // std::cerr << "In advance_ssl_connecting" << std::endl; - int retval = SSL_connect(chan.ssl); - if (retval == 1) { - // Connection successful. - chan.state = CHAN_SSL_READWRITE; - chan.ready_now = true; - } else { - // std::cerr << "ssl_connect_error"; - process_ssl_error(chan, retval); - } - } - - void advance_ssl_accepting(ChanInfo &chan) { - // std::cerr << "In advance_ssl_accepting" << std::endl; - int retval = SSL_accept(chan.ssl); - if (retval == 1) { - // Connection successful. - chan.state = CHAN_SSL_READWRITE; - chan.ready_now = true; - } else { - process_ssl_error(chan, retval); - } - } - - void advance_ssl_readwrite(ChanInfo &chan) { - // std::cerr << "In advance_ssl_readwrite" << std::endl; - // Try to read data. - int read_result = SSL_read(chan.ssl, chbuf.get(), 65536); - if (read_result > 0) { - driven_->drv_recv_incoming(chan.chid, read_result, chbuf.get()); - chan.ready_now = true; - } else { - process_ssl_error(chan, read_result); - if (chan.state == CHAN_INACTIVE) return; - } - - // Try to write data. - int wbytes; - if (chan.last_write_nbytes > 0) { - wbytes = chan.last_write_nbytes; - assert(wbytes < chan.nbytes); - } else { - wbytes = chan.nbytes; - if (wbytes > 65536) wbytes = 65536; - } - if (wbytes > 0) { - int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); - if (write_result > 0) { - driven_->drv_sent_outgoing(chan.chid, write_result); - chan.last_write_nbytes = 0; - chan.ready_on_outgoing = true; - } else { - chan.last_write_nbytes = wbytes; - process_ssl_error(chan, write_result); - if (chan.state == CHAN_INACTIVE) return; - } - } else { - chan.ready_on_outgoing = true; - } - - // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " "; - } - - void advance_channel(ChanInfo &chan) { - switch(chan.state) { - case CHAN_PLAINTEXT: - advance_plaintext(chan); - break; - case CHAN_SSL_CONNECTING: - advance_ssl_connecting(chan); - break; - case CHAN_SSL_ACCEPTING: - advance_ssl_accepting(chan); - break; - case CHAN_SSL_READWRITE: - advance_ssl_readwrite(chan); - break; - default: - assert(false); - break; - } - } - - - void handle_socket_input_output() { - int mstimeout = 1000; - - // Peek output buffers and determine channel release flags. - for (ChanInfo &chan : chans_) { - driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes); - chan.just_released = false; - if ((chan.nbytes == 0)&&(!chan.released)) { - chan.released = driven_->drv_get_channel_released(chan.chid); - chan.just_released = chan.released; - } - } - - // Construct the pollfd vector. - std::vector pollvec; - pollvec.resize(listen_sockets_.size() + chans_.size() + 1); - int index = 0; - for (const auto &p : listen_sockets_) { - struct pollfd &pfd = pollvec[index++]; - pfd.fd = p.second; - pfd.events = POLLIN; - } - for (const ChanInfo &chan : chans_) { - struct pollfd &pfd = pollvec[index++]; - assert(chan.socket != INVALID_SOCKET); - pfd.fd = chan.socket; - pfd.events = POLLERR; - if (chan.ready_now) mstimeout = 0; - if (chan.just_released) mstimeout = 0; - if (chan.ready_on_pollin) pfd.events |= POLLIN; - if (chan.ready_on_pollout) pfd.events |= POLLOUT; - if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT; - // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " "; - } - struct pollfd &stdiopoll = pollvec[index++]; - stdiopoll.fd = 0; - stdiopoll.events = POLLIN; - - // Do the poll. - int status = poll(&pollvec[0], pollvec.size(), mstimeout); - assert(status >= 0); - - // Check listening sockets. - index = 0; - for (auto &p : listen_sockets_) { - struct pollfd &pfd = pollvec[index++]; - if (pfd.revents & (POLLIN | POLLERR)) { - accept_connections(p.first, p.second); - } - } - - // Advance channels where possible. - for (ChanInfo &chan : chans_) { - struct pollfd &pfd = pollvec[index++]; - bool pollin = ((pfd.revents & POLLIN) != 0); - bool pollout = ((pfd.revents & POLLOUT) != 0); - bool pollerr = ((pfd.revents & POLLERR) != 0); - if (chan.ready_now || pollerr || chan.just_released || - (chan.ready_on_pollin && pollin) || - (chan.ready_on_pollout && pollout) || - (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { - chan.ready_now = false; - chan.ready_on_pollin = false; - chan.ready_on_pollout = false; - chan.ready_on_outgoing = false; - advance_channel(chan); - chan.nbytes = false; - chan.bytes = 0; - } - } - - // Delete any newly-inactive channels - cleanup_channels(); - } - - void drive(DrivenEngine *de, int argc, char *argv[]) { - SSL_load_error_strings(); - ERR_load_crypto_strings(); - enableRawMode(); - driven_ = de; - chbuf.reset(new char[65536]); - ssl_ctx_with_root_certs_ = new_ssl_context(false, true, ""); - ssl_ctx_with_server_certs_ = new_ssl_context(true, false, ""); - ssl_ctx_with_no_certs_ = new_ssl_context(false, false, ""); - - if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_, dummycert::certificate) <= 0) { - ERR_print_errors_fp(stderr); - exit(1); - } - - if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_, dummycert::privatekey) <= 0 ) { - ERR_print_errors_fp(stderr); - exit(1); - } - - DrivenEngine::set(de); - driven_->drv_set_lua_source(util::read_lua_source("lua")); - driven_->drv_invoke_event_init(argc, argv); - handle_listen_ports(); - - while (!de->drv_get_stop_driver()) { - handle_lua_source(); - handle_console_output(); - handle_new_outgoing_sockets(); - handle_socket_input_output(); - handle_console_input(); - handle_console_output(); - de->drv_invoke_event_update(monoclock.get()); - } - - for (ChanInfo &chan : chans_) { - close_channel(chan, ""); - } - SSL_CTX_free(ssl_ctx_with_no_certs_); - SSL_CTX_free(ssl_ctx_with_root_certs_); - SSL_CTX_free(ssl_ctx_with_server_certs_); - DrivenEngine::set(nullptr); - } -}; - - -void driver_drive(DrivenEngine *de, int argc, char *argv[]) { - Driver driver; - driver.drive(de, argc, argv); -} - +#include "driver-common.cpp"