Partway through the SSL refactor, code is operational again (in plaintext)

This commit is contained in:
2022-01-09 21:41:28 -05:00
parent c5bf032db0
commit 2e6f924737
3 changed files with 190 additions and 110 deletions

View File

@@ -236,10 +236,6 @@ int DrivenEngine::drv_notify_accept(int port) {
return chid; return chid;
} }
void DrivenEngine::drv_set_clock(double t) {
clock_ = t;
}
void DrivenEngine::drv_set_lua_source(util::LuaSourcePtr source) { void DrivenEngine::drv_set_lua_source(util::LuaSourcePtr source) {
lua_source_ = std::move(source); lua_source_ = std::move(source);
rescan_lua_source_ = false; rescan_lua_source_ = false;
@@ -249,7 +245,8 @@ void DrivenEngine::drv_invoke_event_init(int argc, char *argv[]) {
event_init(argc, argv); event_init(argc, argv);
} }
void DrivenEngine::drv_invoke_event_update() { void DrivenEngine::drv_invoke_event_update(double clock) {
clock_ = clock;
event_update(); event_update();
} }

View File

@@ -359,11 +359,6 @@ public:
// //
int drv_notify_accept(int port); int drv_notify_accept(int port);
// Set the clock. The driver is expected to periodically check the system
// clock and feed the value into the engine.
//
void drv_set_clock(double t);
// Set the lua source code. The driver is expected to read the lua source // Set the lua source code. The driver is expected to read the lua source
// code and store it (using this function) once before invoking // code and store it (using this function) once before invoking
// //
@@ -372,7 +367,7 @@ public:
// Invoke the init or update event. // Invoke the init or update event.
// //
void drv_invoke_event_init(int argc, char *argv[]); void drv_invoke_event_init(int argc, char *argv[]);
void drv_invoke_event_update(); void drv_invoke_event_update(double clock);
// Check the 'rescan_lua_source' flag. If this flag is set, it means // Check the 'rescan_lua_source' flag. If this flag is set, it means
// that the engine wants the driver to rescan the lua source code. // that the engine wants the driver to rescan the lua source code.

View File

@@ -130,6 +130,7 @@ SocketVector accept_on_socket(SOCKET listen_socket) {
while (true) { while (true) {
SOCKET chsock = accept(listen_socket, nullptr, nullptr); SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock >= 0) { if (chsock >= 0) {
set_nonblocking(chsock);
result.push_back(chsock); result.push_back(chsock);
} else { } else {
if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
@@ -150,6 +151,7 @@ SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &r
SSL_CTX *ctx = SSL_CTX_new(TLS_method()); SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
// server_cert is not implemented yet. // server_cert is not implemented yet.
if (root_certs) { if (root_certs) {
SSL_CTX_set_default_verify_paths(ctx); SSL_CTX_set_default_verify_paths(ctx);
@@ -189,21 +191,32 @@ class Driver {
public: public:
enum ChanState { enum ChanState {
CHAN_INACTIVE, CHAN_INACTIVE,
CHAN_CONNECTING, CHAN_PLAINTEXT,
CHAN_OPEN, CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE,
CHAN_SSL_SHUTDOWN
}; };
struct ChanInfo { struct ChanInfo {
int chid; int chid;
ChanState state;
SOCKET socket; SOCKET socket;
SSL_CTX *ssl_ctx; SSL_CTX *ssl_ctx;
SSL *ssl; SSL *ssl;
ChanState state;
int nbytes;
const char *bytes;
bool released;
bool just_released;
bool ready_now;
bool ready_on_pollin;
bool ready_on_pollout;
bool ready_on_outgoing;
}; };
DrivenEngine *driven_; DrivenEngine *driven_;
std::vector<ChanInfo> chans_; std::vector<ChanInfo> chans_;
bool any_inactive_; bool any_inactive_;
bool short_sleep_;
std::map<int, SOCKET> listen_sockets_; std::map<int, SOCKET> listen_sockets_;
std::unique_ptr<char[]> chbuf; std::unique_ptr<char[]> chbuf;
@@ -228,11 +241,11 @@ public:
void handle_lua_source() { void handle_lua_source() {
if (driven_->drv_get_rescan_lua_source()) { if (driven_->drv_get_rescan_lua_source()) {
driven_->drv_set_lua_source(util::read_lua_source("lua")); driven_->drv_set_lua_source(util::read_lua_source("lua"));
short_sleep_ = true;
} }
} }
void close_channel(ChanInfo &chan, const std::string &err) { void close_channel(ChanInfo &chan, const std::string &err) {
// std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE); assert(chan.state != CHAN_INACTIVE);
// Close the SSL channel. // Close the SSL channel.
if (chan.ssl != nullptr) { if (chan.ssl != nullptr) {
@@ -252,9 +265,16 @@ public:
driven_->drv_notify_close(chan.chid, err); driven_->drv_notify_close(chan.chid, err);
chan.state = CHAN_INACTIVE; chan.state = CHAN_INACTIVE;
chan.chid = -1; chan.chid = -1;
chan.nbytes = 0;
chan.bytes = 0;
chan.released = false;
chan.just_released = false;
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
// Set global variables. // Set global variables.
any_inactive_ = true; any_inactive_ = true;
short_sleep_ = true;
} }
void cleanup_channels() { void cleanup_channels() {
@@ -271,34 +291,6 @@ public:
} }
} }
void handle_released_channels() {
for (ChanInfo &chan : chans_) {
if (driven_->drv_get_channel_released(chan.chid)) {
close_channel(chan, "");
}
}
cleanup_channels();
}
void handle_new_outgoing_sockets() {
std::set<int> chans;
driven_->drv_get_new_outgoing(chans);
for (int chid : chans) {
std::string err;
SOCKET sock = open_connection(driven_->drv_get_target(chid), err);
if (sock == INVALID_SOCKET) {
driven_->drv_notify_close(chid, err);
short_sleep_ = true;
} else {
ChanInfo newchan;
newchan.chid = chid;
newchan.state = CHAN_CONNECTING;
newchan.socket = sock;
chans_.push_back(newchan);
}
}
}
void handle_console_output() { void handle_console_output() {
while (true) { while (true) {
int nbytes; const char *bytes; int nbytes; const char *bytes;
@@ -320,43 +312,152 @@ public:
} }
} }
ChanInfo make_channel(SOCKET sock, int chid, SSL_CTX *ctx, SSL *ssl, ChanState state) {
ChanInfo newchan;
newchan.chid = chid;
newchan.socket = sock;
newchan.ssl_ctx = ctx;
newchan.ssl = ssl;
newchan.state = state;
newchan.nbytes = 0;
newchan.bytes = 0;
newchan.released = false;
newchan.just_released = false;
newchan.ready_now = true;
newchan.ready_on_pollin = false;
newchan.ready_on_pollout = false;
newchan.ready_on_outgoing = false;
return newchan;
}
void handle_new_outgoing_sockets() {
std::set<int> chans;
driven_->drv_get_new_outgoing(chans);
for (int chid : chans) {
std::string err;
SOCKET sock = open_connection(driven_->drv_get_target(chid), err);
if (sock == INVALID_SOCKET) {
driven_->drv_notify_close(chid, err);
} else {
//std::cerr << "Opening channel " << chid << std::endl;
SSL_CTX *ctx = nullptr;
SSL *ssl = SSL_new(ssl_ctx_with_no_certs_);
chans_.push_back(make_channel(sock, chid, ctx, ssl, CHAN_PLAINTEXT));
}
}
}
void accept_connections(int port, SOCKET sock) { void accept_connections(int port, SOCKET sock) {
SocketVector sockets = accept_on_socket(sock); SocketVector sockets = accept_on_socket(sock);
for (SOCKET sock : sockets) { for (SOCKET sock : sockets) {
int chid = driven_->drv_notify_accept(port); int chid = driven_->drv_notify_accept(port);
ChanInfo newchan; // std::cerr << "Accepted channel " << chid << std::endl;
newchan.chid = chid; SSL_CTX *ctx = nullptr;
newchan.state = CHAN_OPEN; SSL *ssl = SSL_new(ssl_ctx_with_server_certs_);
newchan.socket = sock; chans_.push_back(make_channel(sock, chid, ctx, ssl, CHAN_PLAINTEXT));
chans_.push_back(newchan);
short_sleep_ = true;
} }
} }
int calc_select_sets(fd_set &rfds, fd_set &wfds, fd_set &efds) const { void advance_plaintext(ChanInfo &chan) {
FD_ZERO(&rfds); // If the channel has no outgoing bytes and has been released,
FD_ZERO(&wfds); // just close it.
FD_ZERO(&efds); if (chan.released) {
int largest = -1; close_channel(chan, "");
for (const auto &p : listen_sockets_) { return;
FD_SET(p.second, &rfds); }
FD_SET(p.second, &efds);
if (p.second > largest) largest = p.second; // Try to write plaintext to the channel.
} int nbytes; const char *bytes;
for (const ChanInfo &chan : chans_) { driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes);
SOCKET sock = chan.socket; if (nbytes > 0) {
if (sock == INVALID_SOCKET) continue; int sbytes = nbytes;
FD_SET(sock, &rfds); if (sbytes > 65536) sbytes = 65536;
FD_SET(sock, &efds); int wbytes = send(chan.socket, bytes, sbytes, 0);
if (!driven_->drv_outgoing_empty(chan.chid)) { // std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " ";
FD_SET(sock, &wfds); if (wbytes < 0) {
} if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) {
if (sock > largest) largest = sock; close_channel(chan, "send failure");
} return;
return largest + 1; }
} else {
driven_->drv_sent_outgoing(chan.chid, wbytes);
}
}
// Try to read plaintext from the channel.
// Someday, find a way to avoid this copy.
int nrecv = recv(chan.socket, chbuf.get(), 65536, 0);
// std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " ";
if (nrecv < 0) {
if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) {
close_channel(chan, "recv failure");
return;
}
} else if (nrecv == 0) {
close_channel(chan, "");
return;
} else {
driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get());
}
// Update the ready-flags for next time.
chan.ready_on_outgoing = true;
chan.ready_on_pollin = true;
}
void advance_ssl_connecting(ChanInfo &chan) {
assert(false);
}
void advance_ssl_accepting(ChanInfo &chan) {
assert(false);
}
void advance_ssl_readwrite(ChanInfo &chan) {
assert(false);
}
void advance_ssl_shutdown(ChanInfo &chan) {
assert(false);
}
void advance_channel(ChanInfo &chan) {
switch(chan.state) {
case CHAN_PLAINTEXT:
advance_plaintext(chan);
break;
case CHAN_SSL_CONNECTING:
advance_ssl_connecting(chan);
break;
case CHAN_SSL_ACCEPTING:
advance_ssl_accepting(chan);
break;
case CHAN_SSL_READWRITE:
advance_ssl_readwrite(chan);
break;
case CHAN_SSL_SHUTDOWN:
advance_ssl_shutdown(chan);
break;
default:
assert(false);
break;
}
}
void handle_socket_input_output() {
int mstimeout = 1000;
// Peek output buffers and determine channel release flags.
for (ChanInfo &chan : chans_) {
driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes);
chan.just_released = false;
if ((chan.nbytes == 0)&&(!chan.released)) {
chan.released = driven_->drv_get_channel_released(chan.chid);
chan.just_released = chan.released;
}
} }
void handle_socket_input_output(int mstimeout) {
// Construct the pollfd vector. // Construct the pollfd vector.
std::vector<struct pollfd> pollvec; std::vector<struct pollfd> pollvec;
pollvec.resize(listen_sockets_.size() + chans_.size() + 1); pollvec.resize(listen_sockets_.size() + chans_.size() + 1);
@@ -368,11 +469,15 @@ public:
} }
for (const ChanInfo &chan : chans_) { for (const ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++]; struct pollfd &pfd = pollvec[index++];
assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket; pfd.fd = chan.socket;
pfd.events = POLLIN; pfd.events = POLLERR;
if (!driven_->drv_outgoing_empty(chan.chid)) { if (chan.ready_now) mstimeout = 0;
pfd.events |= POLLOUT; if (chan.just_released) mstimeout = 0;
} if (chan.ready_on_pollin) pfd.events |= POLLIN;
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << " ";
} }
struct pollfd &stdiopoll = pollvec[index++]; struct pollfd &stdiopoll = pollvec[index++];
stdiopoll.fd = 0; stdiopoll.fd = 0;
@@ -391,35 +496,23 @@ public:
} }
} }
// Transfer bytes wherever possible. // Advance channels where possible.
for (ChanInfo &chan : chans_) { for (ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++]; struct pollfd &pfd = pollvec[index++];
SOCKET sock = chan.socket; bool pollin = ((pfd.revents & POLLIN) != 0);
if (sock == INVALID_SOCKET) continue; bool pollout = ((pfd.revents & POLLOUT) != 0);
if (pfd.revents & POLLOUT) { bool pollerr = ((pfd.revents & POLLERR) != 0);
int nbytes; const char *bytes; if (chan.ready_now || pollerr || chan.just_released ||
chan.state = CHAN_OPEN; (chan.ready_on_pollin && pollin) ||
driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes); (chan.ready_on_pollout && pollout) ||
if (nbytes > 0) { (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
int wbytes = send(sock, bytes, nbytes, 0); chan.ready_now = false;
if (wbytes < 0) { chan.ready_on_pollin = false;
close_channel(chan, "send failure"); chan.ready_on_pollout = false;
continue; chan.ready_on_outgoing = false;
} else { advance_channel(chan);
driven_->drv_sent_outgoing(chan.chid, wbytes); chan.nbytes = false;
} chan.bytes = 0;
}
}
if (pfd.revents & (POLLIN | POLLERR)) {
// Someday, find a way to avoid this copy.
int nrecv = recv(sock, chbuf.get(), 65536, 0);
if (nrecv <= 0) {
close_channel(chan, "recv failure");
continue;
} else {
driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get());
short_sleep_ = true;
}
} }
} }
@@ -431,7 +524,6 @@ public:
enableRawMode(); enableRawMode();
driven_ = de; driven_ = de;
any_inactive_ = false; any_inactive_ = false;
short_sleep_ = false;
chbuf.reset(new char[65536]); chbuf.reset(new char[65536]);
ssl_ctx_with_root_certs_ = new_ssl_context(false, true, ""); ssl_ctx_with_root_certs_ = new_ssl_context(false, true, "");
ssl_ctx_with_server_certs_ = new_ssl_context(true, false, ""); ssl_ctx_with_server_certs_ = new_ssl_context(true, false, "");
@@ -442,17 +534,13 @@ public:
handle_listen_ports(); handle_listen_ports();
while (!de->drv_get_stop_driver()) { while (!de->drv_get_stop_driver()) {
short_sleep_ = false;
handle_lua_source(); handle_lua_source();
handle_console_output(); handle_console_output();
handle_new_outgoing_sockets();
handle_socket_input_output();
handle_console_input(); handle_console_input();
handle_console_output(); handle_console_output();
handle_released_channels(); de->drv_invoke_event_update(monoclock.get());
handle_new_outgoing_sockets();
int mstimeout = short_sleep_ ? 0 : 100;
handle_socket_input_output(mstimeout);
driven_->drv_set_clock(monoclock.get());
de->drv_invoke_event_update();
} }
for (ChanInfo &chan : chans_) { for (ChanInfo &chan : chans_) {