Partway through the SSL refactor, code is operational again (in plaintext)

This commit is contained in:
2022-01-09 21:41:28 -05:00
parent c5bf032db0
commit 2e6f924737
3 changed files with 190 additions and 110 deletions

View File

@@ -236,10 +236,6 @@ int DrivenEngine::drv_notify_accept(int port) {
return chid;
}
void DrivenEngine::drv_set_clock(double t) {
clock_ = t;
}
void DrivenEngine::drv_set_lua_source(util::LuaSourcePtr source) {
lua_source_ = std::move(source);
rescan_lua_source_ = false;
@@ -249,7 +245,8 @@ void DrivenEngine::drv_invoke_event_init(int argc, char *argv[]) {
event_init(argc, argv);
}
void DrivenEngine::drv_invoke_event_update() {
void DrivenEngine::drv_invoke_event_update(double clock) {
clock_ = clock;
event_update();
}

View File

@@ -359,11 +359,6 @@ public:
//
int drv_notify_accept(int port);
// Set the clock. The driver is expected to periodically check the system
// clock and feed the value into the engine.
//
void drv_set_clock(double t);
// Set the lua source code. The driver is expected to read the lua source
// code and store it (using this function) once before invoking
//
@@ -372,7 +367,7 @@ public:
// Invoke the init or update event.
//
void drv_invoke_event_init(int argc, char *argv[]);
void drv_invoke_event_update();
void drv_invoke_event_update(double clock);
// Check the 'rescan_lua_source' flag. If this flag is set, it means
// that the engine wants the driver to rescan the lua source code.

View File

@@ -130,6 +130,7 @@ SocketVector accept_on_socket(SOCKET listen_socket) {
while (true) {
SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock >= 0) {
set_nonblocking(chsock);
result.push_back(chsock);
} else {
if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
@@ -150,6 +151,7 @@ SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &r
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
// server_cert is not implemented yet.
if (root_certs) {
SSL_CTX_set_default_verify_paths(ctx);
@@ -189,21 +191,32 @@ class Driver {
public:
enum ChanState {
CHAN_INACTIVE,
CHAN_CONNECTING,
CHAN_OPEN,
CHAN_PLAINTEXT,
CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE,
CHAN_SSL_SHUTDOWN
};
struct ChanInfo {
int chid;
ChanState state;
SOCKET socket;
SSL_CTX *ssl_ctx;
SSL *ssl;
ChanState state;
int nbytes;
const char *bytes;
bool released;
bool just_released;
bool ready_now;
bool ready_on_pollin;
bool ready_on_pollout;
bool ready_on_outgoing;
};
DrivenEngine *driven_;
std::vector<ChanInfo> chans_;
bool any_inactive_;
bool short_sleep_;
std::map<int, SOCKET> listen_sockets_;
std::unique_ptr<char[]> chbuf;
@@ -228,11 +241,11 @@ public:
void handle_lua_source() {
if (driven_->drv_get_rescan_lua_source()) {
driven_->drv_set_lua_source(util::read_lua_source("lua"));
short_sleep_ = true;
}
}
void close_channel(ChanInfo &chan, const std::string &err) {
// std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE);
// Close the SSL channel.
if (chan.ssl != nullptr) {
@@ -252,9 +265,16 @@ public:
driven_->drv_notify_close(chan.chid, err);
chan.state = CHAN_INACTIVE;
chan.chid = -1;
chan.nbytes = 0;
chan.bytes = 0;
chan.released = false;
chan.just_released = false;
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
// Set global variables.
any_inactive_ = true;
short_sleep_ = true;
}
void cleanup_channels() {
@@ -271,34 +291,6 @@ public:
}
}
void handle_released_channels() {
for (ChanInfo &chan : chans_) {
if (driven_->drv_get_channel_released(chan.chid)) {
close_channel(chan, "");
}
}
cleanup_channels();
}
void handle_new_outgoing_sockets() {
std::set<int> chans;
driven_->drv_get_new_outgoing(chans);
for (int chid : chans) {
std::string err;
SOCKET sock = open_connection(driven_->drv_get_target(chid), err);
if (sock == INVALID_SOCKET) {
driven_->drv_notify_close(chid, err);
short_sleep_ = true;
} else {
ChanInfo newchan;
newchan.chid = chid;
newchan.state = CHAN_CONNECTING;
newchan.socket = sock;
chans_.push_back(newchan);
}
}
}
void handle_console_output() {
while (true) {
int nbytes; const char *bytes;
@@ -320,43 +312,152 @@ public:
}
}
ChanInfo make_channel(SOCKET sock, int chid, SSL_CTX *ctx, SSL *ssl, ChanState state) {
ChanInfo newchan;
newchan.chid = chid;
newchan.socket = sock;
newchan.ssl_ctx = ctx;
newchan.ssl = ssl;
newchan.state = state;
newchan.nbytes = 0;
newchan.bytes = 0;
newchan.released = false;
newchan.just_released = false;
newchan.ready_now = true;
newchan.ready_on_pollin = false;
newchan.ready_on_pollout = false;
newchan.ready_on_outgoing = false;
return newchan;
}
void handle_new_outgoing_sockets() {
std::set<int> chans;
driven_->drv_get_new_outgoing(chans);
for (int chid : chans) {
std::string err;
SOCKET sock = open_connection(driven_->drv_get_target(chid), err);
if (sock == INVALID_SOCKET) {
driven_->drv_notify_close(chid, err);
} else {
//std::cerr << "Opening channel " << chid << std::endl;
SSL_CTX *ctx = nullptr;
SSL *ssl = SSL_new(ssl_ctx_with_no_certs_);
chans_.push_back(make_channel(sock, chid, ctx, ssl, CHAN_PLAINTEXT));
}
}
}
void accept_connections(int port, SOCKET sock) {
SocketVector sockets = accept_on_socket(sock);
for (SOCKET sock : sockets) {
int chid = driven_->drv_notify_accept(port);
ChanInfo newchan;
newchan.chid = chid;
newchan.state = CHAN_OPEN;
newchan.socket = sock;
chans_.push_back(newchan);
short_sleep_ = true;
// std::cerr << "Accepted channel " << chid << std::endl;
SSL_CTX *ctx = nullptr;
SSL *ssl = SSL_new(ssl_ctx_with_server_certs_);
chans_.push_back(make_channel(sock, chid, ctx, ssl, CHAN_PLAINTEXT));
}
}
int calc_select_sets(fd_set &rfds, fd_set &wfds, fd_set &efds) const {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_ZERO(&efds);
int largest = -1;
for (const auto &p : listen_sockets_) {
FD_SET(p.second, &rfds);
FD_SET(p.second, &efds);
if (p.second > largest) largest = p.second;
}
for (const ChanInfo &chan : chans_) {
SOCKET sock = chan.socket;
if (sock == INVALID_SOCKET) continue;
FD_SET(sock, &rfds);
FD_SET(sock, &efds);
if (!driven_->drv_outgoing_empty(chan.chid)) {
FD_SET(sock, &wfds);
}
if (sock > largest) largest = sock;
}
return largest + 1;
void advance_plaintext(ChanInfo &chan) {
// If the channel has no outgoing bytes and has been released,
// just close it.
if (chan.released) {
close_channel(chan, "");
return;
}
// Try to write plaintext to the channel.
int nbytes; const char *bytes;
driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes);
if (nbytes > 0) {
int sbytes = nbytes;
if (sbytes > 65536) sbytes = 65536;
int wbytes = send(chan.socket, bytes, sbytes, 0);
// std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " ";
if (wbytes < 0) {
if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) {
close_channel(chan, "send failure");
return;
}
} else {
driven_->drv_sent_outgoing(chan.chid, wbytes);
}
}
// Try to read plaintext from the channel.
// Someday, find a way to avoid this copy.
int nrecv = recv(chan.socket, chbuf.get(), 65536, 0);
// std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " ";
if (nrecv < 0) {
if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) {
close_channel(chan, "recv failure");
return;
}
} else if (nrecv == 0) {
close_channel(chan, "");
return;
} else {
driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get());
}
// Update the ready-flags for next time.
chan.ready_on_outgoing = true;
chan.ready_on_pollin = true;
}
void advance_ssl_connecting(ChanInfo &chan) {
assert(false);
}
void advance_ssl_accepting(ChanInfo &chan) {
assert(false);
}
void advance_ssl_readwrite(ChanInfo &chan) {
assert(false);
}
void advance_ssl_shutdown(ChanInfo &chan) {
assert(false);
}
void advance_channel(ChanInfo &chan) {
switch(chan.state) {
case CHAN_PLAINTEXT:
advance_plaintext(chan);
break;
case CHAN_SSL_CONNECTING:
advance_ssl_connecting(chan);
break;
case CHAN_SSL_ACCEPTING:
advance_ssl_accepting(chan);
break;
case CHAN_SSL_READWRITE:
advance_ssl_readwrite(chan);
break;
case CHAN_SSL_SHUTDOWN:
advance_ssl_shutdown(chan);
break;
default:
assert(false);
break;
}
}
void handle_socket_input_output() {
int mstimeout = 1000;
// Peek output buffers and determine channel release flags.
for (ChanInfo &chan : chans_) {
driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes);
chan.just_released = false;
if ((chan.nbytes == 0)&&(!chan.released)) {
chan.released = driven_->drv_get_channel_released(chan.chid);
chan.just_released = chan.released;
}
}
void handle_socket_input_output(int mstimeout) {
// Construct the pollfd vector.
std::vector<struct pollfd> pollvec;
pollvec.resize(listen_sockets_.size() + chans_.size() + 1);
@@ -368,11 +469,15 @@ public:
}
for (const ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++];
assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket;
pfd.events = POLLIN;
if (!driven_->drv_outgoing_empty(chan.chid)) {
pfd.events |= POLLOUT;
}
pfd.events = POLLERR;
if (chan.ready_now) mstimeout = 0;
if (chan.just_released) mstimeout = 0;
if (chan.ready_on_pollin) pfd.events |= POLLIN;
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << " ";
}
struct pollfd &stdiopoll = pollvec[index++];
stdiopoll.fd = 0;
@@ -391,35 +496,23 @@ public:
}
}
// Transfer bytes wherever possible.
// Advance channels where possible.
for (ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++];
SOCKET sock = chan.socket;
if (sock == INVALID_SOCKET) continue;
if (pfd.revents & POLLOUT) {
int nbytes; const char *bytes;
chan.state = CHAN_OPEN;
driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes);
if (nbytes > 0) {
int wbytes = send(sock, bytes, nbytes, 0);
if (wbytes < 0) {
close_channel(chan, "send failure");
continue;
} else {
driven_->drv_sent_outgoing(chan.chid, wbytes);
}
}
}
if (pfd.revents & (POLLIN | POLLERR)) {
// Someday, find a way to avoid this copy.
int nrecv = recv(sock, chbuf.get(), 65536, 0);
if (nrecv <= 0) {
close_channel(chan, "recv failure");
continue;
} else {
driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get());
short_sleep_ = true;
}
bool pollin = ((pfd.revents & POLLIN) != 0);
bool pollout = ((pfd.revents & POLLOUT) != 0);
bool pollerr = ((pfd.revents & POLLERR) != 0);
if (chan.ready_now || pollerr || chan.just_released ||
(chan.ready_on_pollin && pollin) ||
(chan.ready_on_pollout && pollout) ||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
advance_channel(chan);
chan.nbytes = false;
chan.bytes = 0;
}
}
@@ -431,7 +524,6 @@ public:
enableRawMode();
driven_ = de;
any_inactive_ = false;
short_sleep_ = false;
chbuf.reset(new char[65536]);
ssl_ctx_with_root_certs_ = new_ssl_context(false, true, "");
ssl_ctx_with_server_certs_ = new_ssl_context(true, false, "");
@@ -442,17 +534,13 @@ public:
handle_listen_ports();
while (!de->drv_get_stop_driver()) {
short_sleep_ = false;
handle_lua_source();
handle_console_output();
handle_new_outgoing_sockets();
handle_socket_input_output();
handle_console_input();
handle_console_output();
handle_released_channels();
handle_new_outgoing_sockets();
int mstimeout = short_sleep_ ? 0 : 100;
handle_socket_input_output(mstimeout);
driven_->drv_set_clock(monoclock.get());
de->drv_invoke_event_update();
de->drv_invoke_event_update(monoclock.get());
}
for (ChanInfo &chan : chans_) {