diff --git a/luprex/TODO b/luprex/TODO index d8f0f169..d9181688 100644 --- a/luprex/TODO +++ b/luprex/TODO @@ -8,4 +8,4 @@ Do something about std::cerr && std::cout once and for all. Fix math.random (?) - +Do a better job handling 'close' in the driver (need some equivalent of SSL_shutdown) \ No newline at end of file diff --git a/luprex/cpp/core/drivenengine.cpp b/luprex/cpp/core/drivenengine.cpp index 623061b6..10d57634 100644 --- a/luprex/cpp/core/drivenengine.cpp +++ b/luprex/cpp/core/drivenengine.cpp @@ -505,12 +505,13 @@ void DrivenEngine::drv_sent_outgoing(uint32_t chid, uint32_t nbytes) { } void DrivenEngine::drv_recv_incoming(uint32_t chid, uint32_t nbytes, const char *bytes) { + std::string_view sbytes(bytes, nbytes); if (nbytes > 0) { Channel *ch = get_chid(chid); if (ch->sb_drvout_ != ch->sb_out_) { - ch->feed_readline(bytes); + ch->feed_readline(sbytes); } else { - ch->sb_in_->write_bytes(bytes); + ch->sb_in_->write_bytes(sbytes); } } } diff --git a/luprex/cpp/core/drivenengine.hpp b/luprex/cpp/core/drivenengine.hpp index a24ec9c6..42f2cc8b 100644 --- a/luprex/cpp/core/drivenengine.hpp +++ b/luprex/cpp/core/drivenengine.hpp @@ -77,7 +77,7 @@ public: // const eng::string &target() const { return target_; } - // True if the remote closed the connection, or a failure occurred. + // True if the remote has closed the connection. // bool closed() const { return closed_; } diff --git a/luprex/cpp/drv/driver-common.cpp b/luprex/cpp/drv/driver-common.cpp index 669ff5cb..2228ca9f 100644 --- a/luprex/cpp/drv/driver-common.cpp +++ b/luprex/cpp/drv/driver-common.cpp @@ -1,4 +1,5 @@ #define POLLVEC_SIZE (DRV_MAX_CHAN + 1) +#define MAX_BIO_BUFFER (128 * 1024) static void if_error_print_and_exit(const std::string_view str) { @@ -21,15 +22,26 @@ class Driver { int chid; SOCKET socket; SSL *ssl; + BIO *recv_bio; + BIO *send_bio; + + // If recent_error is set, that means that a recent IO operation generated + // an error. As a special case, EOF on read is considered an error, we use + // the string "EOF" for this case. + std::string recent_error; + + // OpenSSL has a rule: if you try to SSL_write and it returns + // SSL_ERROR_WANT_READ, then you have to retry the write with the same + // number of bytes. In this event, we record how many bytes we + // attempted to write, which will enable us to retry. + int retry_write_nbytes; + + // True if the channel needs to be advanced. + bool need_advance; ChanState state; uint32_t nbytes; const char *bytes; - bool ready_now; - bool ready_on_pollin; - bool ready_on_pollout; - bool ready_on_outgoing; - uint32_t last_write_nbytes; bool marked_for_deletion() const { return state == CHAN_INACTIVE; } }; @@ -45,6 +57,82 @@ class Driver { sslutil::UniqueCTX ssl_client_secure_ctx_; sslutil::UniqueCTX ssl_client_insecure_ctx_; + // Return the amount of 'space left' in a BIO. This is a fiction, + // because MEM BIOs technically have unlimited capacity. We're + // artificially limiting them to a certain size because there's no + // reason to buffer huge amounts of data. + // + int bio_space(BIO *bio) { + int space = (MAX_BIO_BUFFER) - BIO_pending(bio); + if (space < 0) space = 0; + return space; + } + + // This is a terribly inefficient way to discard data that has + // already been processed. There has to be something better. + // + void bio_discard(BIO *b, int nbytes) { + while (nbytes > 0) { + int nread = nbytes; + if (nread > DRV_SHORTSTRING_SIZE) nread = DRV_SHORTSTRING_SIZE; + int ndropped = BIO_read(b, chbuf_.get(), nread); + assert(ndropped == nread); + nbytes -= ndropped; + } + } + + void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { + ChanInfo newchan; + newchan.chid = chid; + newchan.socket = sock; + newchan.recv_bio = BIO_new(BIO_s_mem()); + newchan.send_bio = BIO_new(BIO_s_mem()); + newchan.recent_error.clear(); + newchan.retry_write_nbytes = 0; + newchan.need_advance = true; + + if (state == CHAN_PLAINTEXT) { + newchan.ssl = nullptr; + } else { + newchan.ssl = SSL_new(ctx); + SSL_set_bio(newchan.ssl, newchan.recv_bio, newchan.send_bio); + } + + newchan.state = state; + newchan.nbytes = 0; + newchan.bytes = 0; + chans_.push_back(newchan); + } + + void close_channel(ChanInfo &chan, std::string_view err) { + // std::cerr << "Closing channel " << chan.chid << " with " << err << std::endl; + assert(chan.state != CHAN_INACTIVE); + + // Close and release the SSL channel. + // This frees the BIO objects as well. + if (chan.ssl != nullptr) { + SSL_free(chan.ssl); + chan.ssl = nullptr; + } + chan.recv_bio = nullptr; + chan.send_bio = nullptr; + chan.recent_error.clear(); + chan.retry_write_nbytes = 0; + chan.need_advance = false; + + // Close and release the socket. + assert(chan.socket != INVALID_SOCKET); + assert(socket_close(chan.socket) == 0); + chan.socket = INVALID_SOCKET; + + // Close everything else. + engw.play_notify_close(&engw, chan.chid, err.size(), err.data()); + chan.state = CHAN_INACTIVE; + chan.chid = -1; + chan.nbytes = 0; + chan.bytes = 0; + } + void handle_listen_ports() { uint32_t nports; const uint32_t *ports; engw.get_listen_ports(&engw, &nports, &ports); @@ -69,30 +157,6 @@ class Driver { } } - void close_channel(ChanInfo &chan, std::string_view err) { - // std::cerr << "Closing channel " << chan.chid << std::endl; - assert(chan.state != CHAN_INACTIVE); - // Close and release the SSL channel. - if (chan.ssl != nullptr) { - SSL_free(chan.ssl); - chan.ssl = nullptr; - } - // Close and release the socket. - assert(chan.socket != INVALID_SOCKET); - assert(socket_close(chan.socket) == 0); - chan.socket = INVALID_SOCKET; - // Close everything else. - engw.play_notify_close(&engw, chan.chid, err.size(), err.data()); - chan.state = CHAN_INACTIVE; - chan.chid = -1; - chan.nbytes = 0; - chan.bytes = 0; - chan.ready_now = false; - chan.ready_on_pollin = false; - chan.ready_on_pollout = false; - chan.ready_on_outgoing = false; - chan.last_write_nbytes = 0; - } void handle_console_output() { while (true) { @@ -117,25 +181,6 @@ class Driver { } } - void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { - ChanInfo newchan; - newchan.chid = chid; - newchan.socket = sock; - newchan.ssl = SSL_new(ctx); - newchan.state = state; - newchan.nbytes = 0; - newchan.bytes = 0; - newchan.ready_now = false; - newchan.ready_on_pollin = false; - newchan.ready_on_pollout = true; - newchan.ready_on_outgoing = false; - newchan.last_write_nbytes = 0; - SSL_set_fd(newchan.ssl, newchan.socket); - // SSL_set_msg_callback(newchan.ssl, SSL_trace); - // SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0)); - chans_.push_back(newchan); - } - void handle_new_outgoing_sockets() { uint32_t nchids; const uint32_t *chids; engw.get_new_outgoing(&engw, &nchids, &chids); @@ -166,7 +211,6 @@ class Driver { engw.play_notify_close(&engw, chid, err.size(), err.c_str()); continue; } - // std::cerr << "Opening channel " << chid << std::endl; make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING); } engw.play_clear_new_outgoing(&engw); @@ -178,123 +222,188 @@ class Driver { if_error_print_and_exit(err); if (socket != INVALID_SOCKET) { uint32_t chid = engw.play_notify_accept(&engw, port); - // std::cerr << "Accepted channel " << chid << std::endl; make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING); } } - void advance_plaintext(ChanInfo &chan) { - std::string err; + // Copy data from the socket into the recv bio. + // + // If it detects an error or EOF, sets the recent_errno flag. + // + void transfer_socket_to_recv_bio(ChanInfo &chan) { + if ((chan.state == CHAN_INACTIVE) || (!chan.recent_error.empty())) { + return; + } - // Try to write plaintext to the channel. - uint32_t ndata; const char *data; - engw.get_outgoing(&engw, chan.chid, &ndata, &data); - if (ndata > 0) { - int sbytes = ndata; - if (sbytes > DRV_SHORTSTRING_SIZE) sbytes = DRV_SHORTSTRING_SIZE; - int wbytes = socket_send(chan.socket, data, sbytes, err); - if (wbytes < 0) { - close_channel(chan, err.c_str()); + std::string err; + int nread = socket_recv(chan.socket, chbuf_.get(), DRV_SHORTSTRING_SIZE, err); + // std::cerr << "chan " << chan.chid << " recv " << nread << " err=" << err << std::endl; + if (nread < 0) { + chan.recent_error = err; + } else { + if (nread == 0) { + chan.recent_error = "EOF"; } else { - engw.play_sent_outgoing(&engw, chan.chid, wbytes); + int nstored = BIO_write(chan.recv_bio, chbuf_.get(), nread); + assert(nstored == nread); + chan.need_advance = true; + // std::cerr << "chan " << chan.chid << " stored " << nread << " bytes" << std::endl; } } - - // Try to read plaintext from the channel. - // Someday, find a way to avoid this copy. - int nrecv = socket_recv(chan.socket, chbuf_.get(), DRV_SHORTSTRING_SIZE, err); - if (nrecv < 0) { - close_channel(chan, err.c_str()); - } else { - engw.play_recv_incoming(&engw, chan.chid, nrecv, chbuf_.get()); - } - - // Update the ready-flags for next time. - chan.ready_on_outgoing = true; - chan.ready_on_pollin = true; } - void process_ssl_error(ChanInfo &chan, int retval) { - int error = SSL_get_error(chan.ssl, retval); - // std::cerr << "SSL error code = " << error << " "; - if (error == SSL_ERROR_WANT_READ) { - chan.ready_on_pollin = true; - } else if (error == SSL_ERROR_WANT_WRITE) { - chan.ready_on_pollout = true; + // Copy data from the send BIO into the socket. + // + // If it detects an error, sets the recent_errno flag. + // + void transfer_send_bio_to_socket(ChanInfo &chan) { + if ((chan.state == CHAN_INACTIVE) || (!chan.recent_error.empty())) { + return; + } + + char *data; + int ndata = BIO_get_mem_data(chan.send_bio, &data); + if (ndata > DRV_SHORTSTRING_SIZE) ndata = DRV_SHORTSTRING_SIZE; + std::string err; + int nwrote = socket_send(chan.socket, data, ndata, err); + // std::cerr << "chan " << chan.chid << " send " << nwrote << " err=" << err << std::endl; + if (nwrote < 0) { + chan.recent_error = err; } else { - std::string error = sslutil::error_string(); - if (error == "") error = "unknown error"; - close_channel(chan, error); + assert(nwrote != 0); + bio_discard(chan.send_bio, nwrote); + chan.need_advance = true; + } + } + + // Close the channel if there's a serious OpenSSL error. + // + // The 'retval' is the return value of the SSL function that returned an + // error. + // + // All errors are considered serious except for SSL_ERROR_WANT_READ, which + // is not serious because it is transient. However, if you get an + // SSL_ERROR_WANT_READ when there's tons of data available in the read + // buffer, that's inexplicable and therefore serious. + // + void if_error_is_serious_close_channel(ChanInfo &chan, int retval) { + int error = SSL_get_error(chan.ssl, retval); + //std::cerr << "chan " << chan.chid << " ssl error = " << error << std::endl; + + // Should never have write errors, because we're + // using a memory BIO with unlimited capacity. + assert(error != SSL_ERROR_WANT_WRITE); + + // If we get a read error, make sure it's plausible: + // if the recv bio is full, that makes no sense. + if (error == SSL_ERROR_WANT_READ) { + if (bio_space(chan.recv_bio) == 0) { + close_channel(chan, "ssl waiting for data, but there's tons of data"); + } + return; + } + + // Any other error is an actual error. Close + // the channel. + std::string errstr = sslutil::error_string(); + if (errstr == "") errstr = "unknown error"; + close_channel(chan, errstr); + } + + void advance_plaintext(ChanInfo &chan) { + uint32_t ndata; const char *data; + + // Transfer all data from the recv BIO into the channel. + ndata = BIO_get_mem_data(chan.recv_bio, &data); + if (ndata > 0) { + engw.play_recv_incoming(&engw, chan.chid, ndata, data); + bio_discard(chan.recv_bio, ndata); + } + + // Transfer all data from the channel to the send BIO. + engw.get_outgoing(&engw, chan.chid, &ndata, &data); + if (ndata > 0) { + int nwrote = BIO_write(chan.send_bio, data, ndata); + assert(nwrote == int(ndata)); + engw.play_sent_outgoing(&engw, chan.chid, ndata); } } void advance_ssl_connecting(ChanInfo &chan) { - // std::cerr << "In advance_ssl_connecting" << std::endl; int retval = SSL_connect(chan.ssl); + //std::cerr << "chan " << chan.chid << " ssl_connect returns " << retval << std::endl; if (retval == 1) { - // Connection successful. chan.state = CHAN_SSL_READWRITE; - chan.ready_now = true; + chan.need_advance = true; } else { - // std::cerr << "ssl_connect_error"; - process_ssl_error(chan, retval); + if_error_is_serious_close_channel(chan, retval); } } void advance_ssl_accepting(ChanInfo &chan) { - // std::cerr << "In advance_ssl_accepting" << std::endl; int retval = SSL_accept(chan.ssl); + //std::cerr << "chan " << chan.chid << " ssl_accept returns " << retval << std::endl; if (retval == 1) { - // Connection successful. chan.state = CHAN_SSL_READWRITE; - chan.ready_now = true; + chan.need_advance = true; } else { - process_ssl_error(chan, retval); + if_error_is_serious_close_channel(chan, retval); } } void advance_ssl_readwrite(ChanInfo &chan) { - // std::cerr << "In advance_ssl_readwrite" << std::endl; - // Try to read data. - int read_result = SSL_read(chan.ssl, chbuf_.get(), DRV_SHORTSTRING_SIZE); - if (read_result > 0) { - engw.play_recv_incoming(&engw, chan.chid, read_result, chbuf_.get()); - chan.ready_now = true; - } else { - process_ssl_error(chan, read_result); - if (chan.state == CHAN_INACTIVE) return; + // Read as much as we can, which of course will be limited + // by the fact that the recv_bio contains finite data. + while (true) { + int read_result = SSL_read(chan.ssl, chbuf_.get(), DRV_SHORTSTRING_SIZE); + if (read_result > 0) { + engw.play_recv_incoming(&engw, chan.chid, read_result, chbuf_.get()); + } else { + if_error_is_serious_close_channel(chan, read_result); + break; + } + } + + // The read process could have generated an error which could + // have closed the channel. If so, don't try writing. + if (chan.state == CHAN_INACTIVE) { + return; } // Try to write data. - uint32_t wbytes; - if (chan.last_write_nbytes > 0) { - wbytes = chan.last_write_nbytes; - assert(wbytes < chan.nbytes); - } else { - wbytes = chan.nbytes; - if (wbytes > 65536) wbytes = 65536; - } - if (wbytes > 0) { + while (chan.nbytes) { + uint32_t wbytes; + if (chan.retry_write_nbytes > 0) { + wbytes = chan.retry_write_nbytes; + assert(wbytes < chan.nbytes); + } else { + wbytes = chan.nbytes; + if (wbytes > DRV_SHORTSTRING_SIZE) wbytes = DRV_SHORTSTRING_SIZE; + } + if (wbytes == 0) break; int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); if (write_result > 0) { engw.play_sent_outgoing(&engw, chan.chid, write_result); - chan.last_write_nbytes = 0; - chan.ready_on_outgoing = true; + chan.retry_write_nbytes = 0; + chan.nbytes -= write_result; + chan.bytes += write_result; } else { - chan.last_write_nbytes = wbytes; - process_ssl_error(chan, write_result); - if (chan.state == CHAN_INACTIVE) return; + if_error_is_serious_close_channel(chan, write_result); + chan.retry_write_nbytes = wbytes; + break; } - } else { - chan.ready_on_outgoing = true; } - // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << - // chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << - // chan.ready_on_outgoing << " "; } void advance_channel(ChanInfo &chan) { sslutil::clear_all_errors(); + + // We set the need_advance flag to false here, but + // the rest of the advance routine is allowed to set + // it back to true in the event that the advance routine + // only processes some of the available data. + chan.need_advance = false; + switch (chan.state) { case CHAN_PLAINTEXT: advance_plaintext(chan); @@ -349,13 +458,17 @@ class Driver { pfd.fd = chan.socket; pfd.events = 0; pfd.revents = 0; - if (chan.ready_now) mstimeout = 0; - if (chan.ready_on_pollin) pfd.events |= POLLIN; - if (chan.ready_on_pollout) pfd.events |= POLLOUT; - if (chan.ready_on_outgoing && (chan.nbytes > 0)) + // If there's room in the receive buffer, set POLLIN + if (bio_space(chan.recv_bio) > 0) { + pfd.events |= POLLIN; + } + // If there's data in the outgoing buffer, set POLLOUT + if (BIO_pending(chan.send_bio) > 0) { pfd.events |= POLLOUT; - // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << - // std::endl; + } + if (chan.need_advance) { + mstimeout = 0; + } } // Do the poll. @@ -370,23 +483,26 @@ class Driver { accept_connection(p.first, p.second); } } - // Advance channels where possible. for (ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec_[index++]; - bool pollin = ((pfd.revents & POLLIN) != 0); - bool pollout = ((pfd.revents & POLLOUT) != 0); - bool pollerr = ((pfd.revents & (POLLERR | POLLHUP)) != 0); - if (chan.ready_now || pollerr || - (chan.ready_on_pollin && pollin) || - (chan.ready_on_pollout && pollout) || - (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { - chan.ready_now = false; - chan.ready_on_pollin = false; - chan.ready_on_pollout = false; - chan.ready_on_outgoing = false; + if ((pfd.revents & POLLIN) != 0) { + transfer_socket_to_recv_bio(chan); + } + if ((pfd.revents & POLLOUT) != 0) { + transfer_send_bio_to_socket(chan); + } + if (chan.need_advance || (!chan.recent_error.empty())) { advance_channel(chan); } + if (!chan.recent_error.empty()) { + if (chan.recent_error == "EOF") { + close_channel(chan, ""); + } else { + close_channel(chan, chan.recent_error); + } + chan.recent_error.clear(); + } chan.nbytes = 0; chan.bytes = 0; } @@ -486,10 +602,10 @@ class Driver { } // Cleanup - engw.release(&engw); for (ChanInfo &chan : chans_) { close_channel(chan, ""); } + engw.release(&engw); return 0; } }; diff --git a/luprex/cpp/drv/driver-linux.cpp b/luprex/cpp/drv/driver-linux.cpp index 55663724..589e3862 100644 --- a/luprex/cpp/drv/driver-linux.cpp +++ b/luprex/cpp/drv/driver-linux.cpp @@ -149,40 +149,42 @@ static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { } } -// the return values for socket_send and socket_recv are: +// the return values for socket_send: // -// positive: sent or received bytes successfully -// zero: would block -// negative: channel closed, possibly cleanly or possibly with error +// positive: sent bytes successfully +// negative: error. +// If the error message is empty, then it's "would block" +// Any other error generates an error message. // static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) { - err.clear(); int wbytes = send(socket, bytes, nbytes, 0); if (wbytes < 0) { if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { - return 0; + err.clear(); } else { err = drvutil::strerror_str(errno); - return -1; } + return -1; } else { + err.clear(); return wbytes; } } static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) { - err.clear(); int nrecv = recv(socket, bytes, nbytes, 0); if (nrecv < 0) { - if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { - err = drvutil::strerror_str(errno); - return -1; + if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { + err.clear(); } else { - return 0; + err = drvutil::strerror_str(errno); } - } else if (nrecv == 0) { return -1; + } else if (nrecv == 0) { + err.clear(); + return 0; } else { + err.clear(); return nrecv; } } diff --git a/luprex/cpp/drv/driver-mingw.cpp b/luprex/cpp/drv/driver-mingw.cpp index a1e4318e..706a404c 100644 --- a/luprex/cpp/drv/driver-mingw.cpp +++ b/luprex/cpp/drv/driver-mingw.cpp @@ -152,37 +152,46 @@ static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { } } +// the return values for socket_send: +// +// positive: sent bytes successfully +// negative: error. +// If the error message is empty, then it's "would block" +// Any other error generates an error message. +// + static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) { - err.clear(); int wbytes = send(socket, bytes, nbytes, 0); if (wbytes == SOCKET_ERROR) { int errcode = WSAGetLastError(); if (errcode == WSAEWOULDBLOCK) { - return 0; + err.clear(); } else { err = "send failure"; - return -1; } + return -1; } else { assert(wbytes > 0); + err.clear(); return wbytes; } } static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) { - err.clear(); int nrecv = recv(socket, bytes, nbytes, 0); if (nrecv < 0) { int errcode = WSAGetLastError(); if (errcode == WSAEWOULDBLOCK) { - return 0; + err = ""; } else { err = "recv failure"; - return -1; } - } else if (nrecv == 0) { return -1; + } else if (nrecv == 0) { + err.clear(); + return 0; } else { + err.clear(); return nrecv; } }