refactor driver-linux to separate machine dependent and independent code.

This commit is contained in:
2022-01-11 13:59:13 -05:00
parent c47e13691d
commit 2b59d8a4a3
2 changed files with 577 additions and 512 deletions

View File

@@ -0,0 +1,487 @@
static MonoClock monoclock;
namespace util {
double profiling_clock() {
return monoclock.get();
}
}
static void if_error_print_and_exit(const std::string &str) {
if (!str.empty()) {
std::cerr << std::endl << "error: " << str << std::endl;
exit(1);
}
}
static SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &require_cert) {
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
// server_cert is not implemented yet.
if (root_certs) {
SSL_CTX_set_default_verify_paths(ctx);
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
}
// require_cert is not implemented yet.
return ctx;
}
static std::string err_print_errors_str() {
BIO *bio = BIO_new(BIO_s_mem());
ERR_print_errors(bio);
char *buf;
size_t len = BIO_get_mem_data(bio, &buf);
std::string ret(buf, len);
BIO_free(bio);
return ret;
}
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
BIO_free(bio);
int status = SSL_CTX_use_certificate(ctx, certificate);
X509_free(certificate);
return status;
}
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
BIO_free(bio);
int status = SSL_CTX_use_PrivateKey(ctx, pkey);
EVP_PKEY_free(pkey);
return status;
}
class Driver {
public:
enum ChanState {
CHAN_INACTIVE,
CHAN_PLAINTEXT,
CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE,
};
struct ChanInfo {
int chid;
SOCKET socket;
SSL *ssl;
ChanState state;
int nbytes;
const char *bytes;
bool released;
bool just_released;
bool ready_now;
bool ready_on_pollin;
bool ready_on_pollout;
bool ready_on_outgoing;
int last_write_nbytes;
};
DrivenEngine *driven_;
std::vector<ChanInfo> chans_;
std::map<int, SOCKET> listen_sockets_;
std::unique_ptr<char[]> chbuf;
bool read_console_recently_;
SSL_CTX *ssl_ctx_with_root_certs_;
SSL_CTX *ssl_ctx_with_server_certs_;
SSL_CTX *ssl_ctx_with_no_certs_;
void handle_listen_ports() {
std::set<int> listenports;
driven_->drv_get_listen_ports(listenports);
for (int port : listenports) {
if (listen_sockets_.find(port) == listen_sockets_.end()) {
std::string err;
SOCKET sock = listen_on_port(port, err);
if_error_print_and_exit(err);
assert(sock != INVALID_SOCKET);
listen_sockets_[port] = sock;
}
}
}
void handle_lua_source() {
if (driven_->drv_get_rescan_lua_source()) {
driven_->drv_set_lua_source(util::read_lua_source("lua"));
}
}
void close_channel(ChanInfo &chan, const std::string &err) {
// std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE);
// Close and release the SSL channel.
if (chan.ssl != nullptr) {
SSL_free(chan.ssl);
chan.ssl = nullptr;
}
// Close and release the socket.
assert(chan.socket != INVALID_SOCKET);
assert(socket_close(chan.socket) == 0);
chan.socket = INVALID_SOCKET;
// Close everything else.
driven_->drv_notify_close(chan.chid, err);
chan.state = CHAN_INACTIVE;
chan.chid = -1;
chan.nbytes = 0;
chan.bytes = 0;
chan.released = false;
chan.just_released = false;
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
chan.last_write_nbytes = 0;
}
void cleanup_channels() {
for (int i = 0; i < int(chans_.size()); ) {
if (chans_[i].state == CHAN_INACTIVE) {
chans_[i] = chans_.back();
chans_.pop_back();
} else {
i += 1;
}
}
}
void handle_console_output() {
while (true) {
int nbytes; const char *bytes;
driven_->drv_peek_outgoing(0, &nbytes, &bytes);
if (nbytes == 0) break;
int nwrote = console_write(bytes, nbytes);
assert(nwrote > 0);
driven_->drv_sent_outgoing(0, nwrote);
}
}
void handle_console_input() {
char buffer[256];
read_console_recently_ = false;
while (true) {
int nread = console_read(buffer, 256);
if (nread == 0) break;
assert(nread > 0);
read_console_recently_ = true;
driven_->drv_recv_incoming(0, nread, buffer);
}
}
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
ChanInfo newchan;
newchan.chid = chid;
newchan.socket = sock;
newchan.ssl = SSL_new(ctx);
newchan.state = state;
newchan.nbytes = 0;
newchan.bytes = 0;
newchan.released = false;
newchan.just_released = false;
newchan.ready_now = false;
newchan.ready_on_pollin = false;
newchan.ready_on_pollout = true;
newchan.ready_on_outgoing = false;
newchan.last_write_nbytes = 0;
SSL_set_fd(newchan.ssl, newchan.socket);
// SSL_set_msg_callback(newchan.ssl, SSL_trace);
// SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0));
chans_.push_back(newchan);
}
void handle_new_outgoing_sockets() {
std::set<int> chans;
driven_->drv_get_new_outgoing(chans);
for (int chid : chans) {
std::string err;
SOCKET sock = open_connection(driven_->drv_get_target(chid), err);
if (sock == INVALID_SOCKET) {
driven_->drv_notify_close(chid, err);
} else {
//std::cerr << "Opening channel " << chid << std::endl;
make_channel(sock, chid, ssl_ctx_with_no_certs_, CHAN_SSL_CONNECTING);
}
}
}
void accept_connection(int port, SOCKET sock) {
std::string err;
SOCKET socket = accept_on_socket(sock, err);
if_error_print_and_exit(err);
if (socket != INVALID_SOCKET) {
int chid = driven_->drv_notify_accept(port);
// std::cerr << "Accepted channel " << chid << std::endl;
make_channel(socket, chid, ssl_ctx_with_server_certs_, CHAN_SSL_ACCEPTING);
}
}
void advance_plaintext(ChanInfo &chan) {
std::string err;
// If the channel has no outgoing bytes and has been released,
// just close it.
if (chan.released) {
close_channel(chan, "");
return;
}
// Try to write plaintext to the channel.
int nbytes; const char *bytes;
driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes);
if (nbytes > 0) {
int sbytes = nbytes;
if (sbytes > 65536) sbytes = 65536;
int wbytes = socket_send(chan.socket, bytes, sbytes, err);
// std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " ";
if (wbytes < 0) {
close_channel(chan, err);
} else {
driven_->drv_sent_outgoing(chan.chid, wbytes);
}
}
// Try to read plaintext from the channel.
// Someday, find a way to avoid this copy.
int nrecv = socket_recv(chan.socket, chbuf.get(), 65536, err);
// std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " ";
if (nrecv < 0) {
close_channel(chan, err);
} else {
driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get());
}
// Update the ready-flags for next time.
chan.ready_on_outgoing = true;
chan.ready_on_pollin = true;
}
void process_ssl_error(ChanInfo &chan, int retval) {
int error = SSL_get_error(chan.ssl, retval);
// std::cerr << "SSL error code = " << error << " ";
if (error == SSL_ERROR_WANT_READ) {
chan.ready_on_pollin = true;
} else if (error == SSL_ERROR_WANT_WRITE) {
chan.ready_on_pollout = true;
} else {
close_channel(chan, err_print_errors_str());
}
}
void advance_ssl_connecting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_connecting" << std::endl;
int retval = SSL_connect(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
// std::cerr << "ssl_connect_error";
process_ssl_error(chan, retval);
}
}
void advance_ssl_accepting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_accepting" << std::endl;
int retval = SSL_accept(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
process_ssl_error(chan, retval);
}
}
void advance_ssl_readwrite(ChanInfo &chan) {
// std::cerr << "In advance_ssl_readwrite" << std::endl;
// Try to read data.
int read_result = SSL_read(chan.ssl, chbuf.get(), 65536);
if (read_result > 0) {
driven_->drv_recv_incoming(chan.chid, read_result, chbuf.get());
chan.ready_now = true;
} else {
process_ssl_error(chan, read_result);
if (chan.state == CHAN_INACTIVE) return;
}
// Try to write data.
int wbytes;
if (chan.last_write_nbytes > 0) {
wbytes = chan.last_write_nbytes;
assert(wbytes < chan.nbytes);
} else {
wbytes = chan.nbytes;
if (wbytes > 65536) wbytes = 65536;
}
if (wbytes > 0) {
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
if (write_result > 0) {
driven_->drv_sent_outgoing(chan.chid, write_result);
chan.last_write_nbytes = 0;
chan.ready_on_outgoing = true;
} else {
chan.last_write_nbytes = wbytes;
process_ssl_error(chan, write_result);
if (chan.state == CHAN_INACTIVE) return;
}
} else {
chan.ready_on_outgoing = true;
}
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " ";
}
void advance_channel(ChanInfo &chan) {
switch(chan.state) {
case CHAN_PLAINTEXT:
advance_plaintext(chan);
break;
case CHAN_SSL_CONNECTING:
advance_ssl_connecting(chan);
break;
case CHAN_SSL_ACCEPTING:
advance_ssl_accepting(chan);
break;
case CHAN_SSL_READWRITE:
advance_ssl_readwrite(chan);
break;
default:
assert(false);
break;
}
}
void handle_socket_input_output() {
int mstimeout = 1000;
// Peek output buffers and determine channel release flags.
for (ChanInfo &chan : chans_) {
driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes);
chan.just_released = false;
if ((chan.nbytes == 0)&&(!chan.released)) {
chan.released = driven_->drv_get_channel_released(chan.chid);
chan.just_released = chan.released;
}
}
// Construct the pollfd vector.
int pollsize = listen_sockets_.size() + chans_.size() + 1;
std::vector<struct pollfd> pollvec(pollsize);
int index = 0;
for (const auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec[index++];
pfd.fd = p.second;
pfd.events = POLLIN;
}
for (const ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++];
assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket;
pfd.events = POLLERR;
if (chan.ready_now) mstimeout = 0;
if (chan.just_released) mstimeout = 0;
if (chan.ready_on_pollin) pfd.events |= POLLIN;
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " ";
}
fill_stdio_pollfd(pollvec, mstimeout, read_console_recently_);
// Do the poll.
int status = poll(&pollvec[0], pollvec.size(), mstimeout);
assert(status >= 0);
// Check listening sockets.
index = 0;
for (auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec[index++];
if (pfd.revents & (POLLIN | POLLERR)) {
accept_connection(p.first, p.second);
}
}
// Advance channels where possible.
for (ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++];
bool pollin = ((pfd.revents & POLLIN) != 0);
bool pollout = ((pfd.revents & POLLOUT) != 0);
bool pollerr = ((pfd.revents & POLLERR) != 0);
if (chan.ready_now || pollerr || chan.just_released ||
(chan.ready_on_pollin && pollin) ||
(chan.ready_on_pollout && pollout) ||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
advance_channel(chan);
}
chan.nbytes = 0;
chan.bytes = 0;
}
// Delete any newly-inactive channels
cleanup_channels();
}
void drive(DrivenEngine *de, int argc, char *argv[]) {
SSL_load_error_strings();
ERR_load_crypto_strings();
enable_tty_raw();
driven_ = de;
read_console_recently_ = false;
chbuf.reset(new char[65536]);
ssl_ctx_with_root_certs_ = new_ssl_context(false, true, "");
ssl_ctx_with_server_certs_ = new_ssl_context(true, false, "");
ssl_ctx_with_no_certs_ = new_ssl_context(false, false, "");
if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_, dummycert::certificate) <= 0) {
ERR_print_errors_fp(stderr);
exit(1);
}
if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_, dummycert::privatekey) <= 0 ) {
ERR_print_errors_fp(stderr);
exit(1);
}
DrivenEngine::set(de);
driven_->drv_set_lua_source(util::read_lua_source("lua"));
driven_->drv_invoke_event_init(argc, argv);
handle_listen_ports();
while (!de->drv_get_stop_driver()) {
handle_lua_source();
handle_console_output();
handle_new_outgoing_sockets();
handle_socket_input_output();
handle_console_input();
handle_console_output();
de->drv_invoke_event_update(monoclock.get());
}
for (ChanInfo &chan : chans_) {
close_channel(chan, "");
}
SSL_CTX_free(ssl_ctx_with_no_certs_);
SSL_CTX_free(ssl_ctx_with_root_certs_);
SSL_CTX_free(ssl_ctx_with_server_certs_);
DrivenEngine::set(nullptr);
}
};
void driver_drive(DrivenEngine *de, int argc, char *argv[]) {
Driver driver;
driver.drive(de, argc, argv);
}

View File

@@ -25,6 +25,8 @@
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#include <openssl/pem.h>
using SOCKET=int;
const int INVALID_SOCKET = -1;
@@ -32,6 +34,11 @@ using PollVector = std::vector<struct pollfd>;
struct termios orig_termios;
static std::string strerror_str(int err) {
char errbuf[256];
return strerror_r(errno, errbuf, 256);
}
void set_nonblocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
assert(flags != -1);
@@ -39,14 +46,14 @@ void set_nonblocking(int fd) {
assert(status != -1);
}
static void disableRawMode() {
static void disable_tty_raw() {
tcsetattr(0, TCSAFLUSH, &orig_termios);
}
static void enableRawMode() {
static void enable_tty_raw() {
int status = tcgetattr(0, &orig_termios);
assert(status >= 0);
atexit(disableRawMode);
atexit(disable_tty_raw);
struct termios raw = orig_termios;
raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON);
raw.c_lflag &= ~(ECHO | ICANON);
@@ -57,12 +64,7 @@ static void enableRawMode() {
assert(status >= 0);
}
static std::string strerror_str(int err) {
char errbuf[256];
return strerror_r(errno, errbuf, 256);
}
SOCKET open_connection(const std::string &target, std::string &err) {
static SOCKET open_connection(const std::string &target, std::string &err) {
struct addrinfo *addrs = nullptr;
struct addrinfo *goodaddr = nullptr;
struct addrinfo hints;
@@ -80,43 +82,44 @@ SOCKET open_connection(const std::string &target, std::string &err) {
int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs);
if (status != 0) {
err = gai_strerror(status);
goto error;
goto error_general;
}
if (addrs == nullptr) {
err = "no such host found";
goto error;
goto error_general;
}
goodaddr = addrs;
assert(goodaddr->ai_family == AF_INET);
assert(goodaddr->ai_socktype == SOCK_STREAM);
assert(goodaddr->ai_protocol == IPPROTO_TCP);
sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol);
assert(sock > 0);
if (sock <= 0) goto error_errno;
set_nonblocking(sock);
status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen);
if ((status != 0) && (errno != EINPROGRESS)) {
goto error_errno;
}
if ((status != 0) && (errno != EINPROGRESS)) goto error_errno;
freeaddrinfo(addrs);
return sock;
error_errno:
err = strerror_str(errno);
error:
error_general:
if (sock != INVALID_SOCKET) close(sock);
if (addrs != nullptr) freeaddrinfo(addrs);
return INVALID_SOCKET;
}
SOCKET listen_on_port(int port, std::string &err) {
int status;
static SOCKET listen_on_port(int port, std::string &err) {
int status, enable;
err.clear();
SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
assert(sock > 0);
int enable = 1;
SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock <= 0) goto error_errno;
enable = 1;
status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
assert(status == 0);
if (status != 0) goto error_errno;
struct sockaddr_in server;
server.sin_family = AF_INET;
@@ -124,14 +127,21 @@ SOCKET listen_on_port(int port, std::string &err) {
server.sin_port = htons(port);
status = bind(sock, (struct sockaddr *)&server, sizeof(server));
assert(status == 0);
if (status != 0) goto error_errno;
status = listen(sock, 10);
assert(status == 0);
if (status != 0) goto error_errno;
set_nonblocking(sock);
return sock;
error_errno:
err = strerror_str(errno);
if (sock >= 0) close(sock);
return INVALID_SOCKET;
}
SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
err.clear();
SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock >= 0) {
@@ -145,60 +155,66 @@ SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
}
}
SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &require_cert) {
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
// server_cert is not implemented yet.
if (root_certs) {
SSL_CTX_set_default_verify_paths(ctx);
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
// the return values for socket_send and socket_recv are:
//
// positive: sent or received bytes successfully
// zero: would block
// negative: channel closed, possibly cleanly or possibly with error
//
static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) {
err.clear();
int wbytes = send(socket, bytes, nbytes, 0);
if (wbytes < 0) {
if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
return 0;
} else {
err = strerror_str(errno);
return -1;
}
// require_cert is not implemented yet.
return ctx;
}
std::string err_print_errors_str() {
BIO *bio = BIO_new(BIO_s_mem());
ERR_print_errors(bio);
char *buf;
size_t len = BIO_get_mem_data(bio, &buf);
std::string ret(buf, len);
BIO_free(bio);
return ret;
}
#include <openssl/bio.h>
#include <openssl/x509.h>
#include <openssl/pem.h>
int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
BIO_free(bio);
int status = SSL_CTX_use_certificate(ctx, certificate);
X509_free(certificate);
return status;
}
int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
BIO_free(bio);
int status = SSL_CTX_use_PrivateKey(ctx, pkey);
EVP_PKEY_free(pkey);
return status;
}
static void print_error_and_exit(const std::string &str) {
if (!str.empty()) {
std::cerr << "error: " << str << std::endl;
exit(1);
} else {
return wbytes;
}
}
static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) {
err.clear();
int nrecv = recv(socket, bytes, nbytes, 0);
if (nrecv < 0) {
if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) {
err = strerror_str(errno);
return -1;
} else {
return 0;
}
} else if (nrecv == 0) {
return -1;
} else {
return nrecv;
}
}
static int socket_close(SOCKET socket) {
return close(socket);
}
static int console_write(const char *bytes, int nbytes) {
return write(1, bytes, nbytes);
}
static int console_read(char *bytes, int nbytes) {
return read(0, bytes, nbytes);
}
// The last element in the vector is supposed to be
// for polling stdio. But on windows, you can't poll
// stdio, so on windows, we remove the last element from
// the vector and we reduce mstimeout instead.
static void fill_stdio_pollfd(PollVector &pollvec, int &mstimeout, bool read_console_recently) {
struct pollfd &stdiopoll = pollvec.back();
stdiopoll.fd = 0;
stdiopoll.events = POLLIN;
}
class MonoClock {
private:
struct timespec base_;
@@ -217,442 +233,4 @@ public:
}
};
static MonoClock monoclock;
namespace util {
double profiling_clock() {
return monoclock.get();
}
}
class Driver {
public:
enum ChanState {
CHAN_INACTIVE,
CHAN_PLAINTEXT,
CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE,
};
struct ChanInfo {
int chid;
SOCKET socket;
SSL *ssl;
ChanState state;
int nbytes;
const char *bytes;
bool released;
bool just_released;
bool ready_now;
bool ready_on_pollin;
bool ready_on_pollout;
bool ready_on_outgoing;
int last_write_nbytes;
};
DrivenEngine *driven_;
std::vector<ChanInfo> chans_;
std::map<int, SOCKET> listen_sockets_;
std::unique_ptr<char[]> chbuf;
SSL_CTX *ssl_ctx_with_root_certs_;
SSL_CTX *ssl_ctx_with_server_certs_;
SSL_CTX *ssl_ctx_with_no_certs_;
void handle_listen_ports() {
std::set<int> listenports;
driven_->drv_get_listen_ports(listenports);
for (int port : listenports) {
if (listen_sockets_.find(port) == listen_sockets_.end()) {
std::string err;
SOCKET sock = listen_on_port(port, err);
print_error_and_exit(err);
assert(sock != INVALID_SOCKET);
listen_sockets_[port] = sock;
}
}
}
void handle_lua_source() {
if (driven_->drv_get_rescan_lua_source()) {
driven_->drv_set_lua_source(util::read_lua_source("lua"));
}
}
void close_channel(ChanInfo &chan, const std::string &err) {
// std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE);
// Close the SSL channel.
if (chan.ssl != nullptr) {
SSL_free(chan.ssl);
chan.ssl = nullptr;
}
// Close the socket.
assert(chan.socket != INVALID_SOCKET);
assert(close(chan.socket) == 0);
chan.socket = INVALID_SOCKET;
// Close everything else.
driven_->drv_notify_close(chan.chid, err);
chan.state = CHAN_INACTIVE;
chan.chid = -1;
chan.nbytes = 0;
chan.bytes = 0;
chan.released = false;
chan.just_released = false;
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
chan.last_write_nbytes = 0;
}
void cleanup_channels() {
for (int i = 0; i < int(chans_.size()); ) {
if (chans_[i].state == CHAN_INACTIVE) {
chans_[i] = chans_.back();
chans_.pop_back();
} else {
i += 1;
}
}
}
void handle_console_output() {
while (true) {
int nbytes; const char *bytes;
driven_->drv_peek_outgoing(0, &nbytes, &bytes);
if (nbytes == 0) break;
int nwrote = write(1, bytes, nbytes);
assert(nwrote > 0);
driven_->drv_sent_outgoing(0, nwrote);
}
}
void handle_console_input() {
char buffer[256];
while (true) {
int nread = read(0, buffer, 256);
if (nread == 0) break;
assert(nread > 0);
driven_->drv_recv_incoming(0, nread, buffer);
}
}
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
ChanInfo newchan;
newchan.chid = chid;
newchan.socket = sock;
newchan.ssl = SSL_new(ctx);
newchan.state = state;
newchan.nbytes = 0;
newchan.bytes = 0;
newchan.released = false;
newchan.just_released = false;
newchan.ready_now = false;
newchan.ready_on_pollin = false;
newchan.ready_on_pollout = true;
newchan.ready_on_outgoing = false;
newchan.last_write_nbytes = 0;
SSL_set_fd(newchan.ssl, newchan.socket);
// SSL_set_msg_callback(newchan.ssl, SSL_trace);
// SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0));
chans_.push_back(newchan);
}
void handle_new_outgoing_sockets() {
std::set<int> chans;
driven_->drv_get_new_outgoing(chans);
for (int chid : chans) {
std::string err;
SOCKET sock = open_connection(driven_->drv_get_target(chid), err);
if (sock == INVALID_SOCKET) {
driven_->drv_notify_close(chid, err);
} else {
//std::cerr << "Opening channel " << chid << std::endl;
make_channel(sock, chid, ssl_ctx_with_no_certs_, CHAN_SSL_CONNECTING);
}
}
}
void accept_connections(int port, SOCKET sock) {
std::string err;
SOCKET socket = accept_on_socket(sock, err);
print_error_and_exit(err);
if (socket != INVALID_SOCKET) {
int chid = driven_->drv_notify_accept(port);
// std::cerr << "Accepted channel " << chid << std::endl;
make_channel(socket, chid, ssl_ctx_with_server_certs_, CHAN_SSL_ACCEPTING);
}
}
void advance_plaintext(ChanInfo &chan) {
// If the channel has no outgoing bytes and has been released,
// just close it.
if (chan.released) {
close_channel(chan, "");
return;
}
// Try to write plaintext to the channel.
int nbytes; const char *bytes;
driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes);
if (nbytes > 0) {
int sbytes = nbytes;
if (sbytes > 65536) sbytes = 65536;
int wbytes = send(chan.socket, bytes, sbytes, 0);
// std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " ";
if (wbytes < 0) {
if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) {
close_channel(chan, "send failure");
return;
}
} else {
driven_->drv_sent_outgoing(chan.chid, wbytes);
}
}
// Try to read plaintext from the channel.
// Someday, find a way to avoid this copy.
int nrecv = recv(chan.socket, chbuf.get(), 65536, 0);
// std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " ";
if (nrecv < 0) {
if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) {
close_channel(chan, "recv failure");
return;
}
} else if (nrecv == 0) {
close_channel(chan, "");
return;
} else {
driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get());
}
// Update the ready-flags for next time.
chan.ready_on_outgoing = true;
chan.ready_on_pollin = true;
}
void process_ssl_error(ChanInfo &chan, int retval) {
int error = SSL_get_error(chan.ssl, retval);
// std::cerr << "SSL error code = " << error << " ";
if (error == SSL_ERROR_WANT_READ) {
chan.ready_on_pollin = true;
} else if (error == SSL_ERROR_WANT_WRITE) {
chan.ready_on_pollout = true;
} else {
close_channel(chan, err_print_errors_str());
}
}
void advance_ssl_connecting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_connecting" << std::endl;
int retval = SSL_connect(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
// std::cerr << "ssl_connect_error";
process_ssl_error(chan, retval);
}
}
void advance_ssl_accepting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_accepting" << std::endl;
int retval = SSL_accept(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
process_ssl_error(chan, retval);
}
}
void advance_ssl_readwrite(ChanInfo &chan) {
// std::cerr << "In advance_ssl_readwrite" << std::endl;
// Try to read data.
int read_result = SSL_read(chan.ssl, chbuf.get(), 65536);
if (read_result > 0) {
driven_->drv_recv_incoming(chan.chid, read_result, chbuf.get());
chan.ready_now = true;
} else {
process_ssl_error(chan, read_result);
if (chan.state == CHAN_INACTIVE) return;
}
// Try to write data.
int wbytes;
if (chan.last_write_nbytes > 0) {
wbytes = chan.last_write_nbytes;
assert(wbytes < chan.nbytes);
} else {
wbytes = chan.nbytes;
if (wbytes > 65536) wbytes = 65536;
}
if (wbytes > 0) {
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
if (write_result > 0) {
driven_->drv_sent_outgoing(chan.chid, write_result);
chan.last_write_nbytes = 0;
chan.ready_on_outgoing = true;
} else {
chan.last_write_nbytes = wbytes;
process_ssl_error(chan, write_result);
if (chan.state == CHAN_INACTIVE) return;
}
} else {
chan.ready_on_outgoing = true;
}
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " ";
}
void advance_channel(ChanInfo &chan) {
switch(chan.state) {
case CHAN_PLAINTEXT:
advance_plaintext(chan);
break;
case CHAN_SSL_CONNECTING:
advance_ssl_connecting(chan);
break;
case CHAN_SSL_ACCEPTING:
advance_ssl_accepting(chan);
break;
case CHAN_SSL_READWRITE:
advance_ssl_readwrite(chan);
break;
default:
assert(false);
break;
}
}
void handle_socket_input_output() {
int mstimeout = 1000;
// Peek output buffers and determine channel release flags.
for (ChanInfo &chan : chans_) {
driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes);
chan.just_released = false;
if ((chan.nbytes == 0)&&(!chan.released)) {
chan.released = driven_->drv_get_channel_released(chan.chid);
chan.just_released = chan.released;
}
}
// Construct the pollfd vector.
std::vector<struct pollfd> pollvec;
pollvec.resize(listen_sockets_.size() + chans_.size() + 1);
int index = 0;
for (const auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec[index++];
pfd.fd = p.second;
pfd.events = POLLIN;
}
for (const ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++];
assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket;
pfd.events = POLLERR;
if (chan.ready_now) mstimeout = 0;
if (chan.just_released) mstimeout = 0;
if (chan.ready_on_pollin) pfd.events |= POLLIN;
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " ";
}
struct pollfd &stdiopoll = pollvec[index++];
stdiopoll.fd = 0;
stdiopoll.events = POLLIN;
// Do the poll.
int status = poll(&pollvec[0], pollvec.size(), mstimeout);
assert(status >= 0);
// Check listening sockets.
index = 0;
for (auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec[index++];
if (pfd.revents & (POLLIN | POLLERR)) {
accept_connections(p.first, p.second);
}
}
// Advance channels where possible.
for (ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec[index++];
bool pollin = ((pfd.revents & POLLIN) != 0);
bool pollout = ((pfd.revents & POLLOUT) != 0);
bool pollerr = ((pfd.revents & POLLERR) != 0);
if (chan.ready_now || pollerr || chan.just_released ||
(chan.ready_on_pollin && pollin) ||
(chan.ready_on_pollout && pollout) ||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
advance_channel(chan);
chan.nbytes = false;
chan.bytes = 0;
}
}
// Delete any newly-inactive channels
cleanup_channels();
}
void drive(DrivenEngine *de, int argc, char *argv[]) {
SSL_load_error_strings();
ERR_load_crypto_strings();
enableRawMode();
driven_ = de;
chbuf.reset(new char[65536]);
ssl_ctx_with_root_certs_ = new_ssl_context(false, true, "");
ssl_ctx_with_server_certs_ = new_ssl_context(true, false, "");
ssl_ctx_with_no_certs_ = new_ssl_context(false, false, "");
if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_, dummycert::certificate) <= 0) {
ERR_print_errors_fp(stderr);
exit(1);
}
if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_, dummycert::privatekey) <= 0 ) {
ERR_print_errors_fp(stderr);
exit(1);
}
DrivenEngine::set(de);
driven_->drv_set_lua_source(util::read_lua_source("lua"));
driven_->drv_invoke_event_init(argc, argv);
handle_listen_ports();
while (!de->drv_get_stop_driver()) {
handle_lua_source();
handle_console_output();
handle_new_outgoing_sockets();
handle_socket_input_output();
handle_console_input();
handle_console_output();
de->drv_invoke_event_update(monoclock.get());
}
for (ChanInfo &chan : chans_) {
close_channel(chan, "");
}
SSL_CTX_free(ssl_ctx_with_no_certs_);
SSL_CTX_free(ssl_ctx_with_root_certs_);
SSL_CTX_free(ssl_ctx_with_server_certs_);
DrivenEngine::set(nullptr);
}
};
void driver_drive(DrivenEngine *de, int argc, char *argv[]) {
Driver driver;
driver.drive(de, argc, argv);
}
#include "driver-common.cpp"