fix accidental reformatting of driver-common.cpp

This commit is contained in:
2022-03-18 18:16:21 -04:00
parent 8e900e37be
commit 1e45aa425b

View File

@@ -4,42 +4,33 @@
static MonoClock monoclock; static MonoClock monoclock;
namespace util namespace util {
{
double profiling_clock()
{
return monoclock.get();
}
}
static void if_error_print_and_exit(const std::string &str) double profiling_clock() { return monoclock.get(); }
{
if (!str.empty()) } // namespace util
{
std::cerr << std::endl static void if_error_print_and_exit(const std::string &str) {
<< "error: " << str << std::endl; if (!str.empty()) {
std::cerr << std::endl << "error: " << str << std::endl;
exit(1); exit(1);
} }
} }
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) {
{
FILE *f = fopen(fn, "r"); FILE *f = fopen(fn, "r");
if (f == 0) if (f == 0) {
{
err = std::string("cannot read file") + fn; err = std::string("cannot read file") + fn;
buf[0] = 0; buf[0] = 0;
return std::string_view(buf, 0); return std::string_view(buf, 0);
} }
int nread = fread(buf, 1, bufsize, f); int nread = fread(buf, 1, bufsize, f);
if (nread < 0) if (nread < 0) {
{
err = std::string("cannot read file: ") + fn; err = std::string("cannot read file: ") + fn;
buf[0] = 0; buf[0] = 0;
return std::string_view(buf, 0); return std::string_view(buf, 0);
} }
if (nread == bufsize) if (nread == bufsize) {
{
err = std::string("file too large: ") + fn; err = std::string("file too large: ") + fn;
buf[0] = 0; buf[0] = 0;
return std::string_view(buf, 0); return std::string_view(buf, 0);
@@ -48,66 +39,50 @@ static std::string_view read_file(const char *fn, char *buf, int bufsize, std::s
return std::string_view(buf, nread); return std::string_view(buf, nread);
} }
struct SSL_CTX_Deleter struct SSL_CTX_Deleter {
{ void operator()(SSL_CTX *ctx) { SSL_CTX_free(ctx); }
void operator()(SSL_CTX *ctx)
{
SSL_CTX_free(ctx);
}
}; };
using UniqueSSLCTX = std::unique_ptr<SSL_CTX, SSL_CTX_Deleter>; using UniqueSSLCTX = std::unique_ptr<SSL_CTX, SSL_CTX_Deleter>;
static std::string ssl_errors_string(bool lastonly = true) static std::string ssl_errors_string(bool lastonly = true) {
{
std::string err; std::string err;
const char *file, *data, *func; const char *file, *data, *func;
int line, flags; int line, flags;
while (true) while (true) {
{ unsigned long code =
unsigned long code = ERR_get_error_all(&file, &line, &func, &data, &flags); ERR_get_error_all(&file, &line, &func, &data, &flags);
if (code == 0) if (code == 0) break;
break;
std::string reason; std::string reason;
if (ERR_SYSTEM_ERROR(code)) if (ERR_SYSTEM_ERROR(code)) {
{
reason = strerror_str(ERR_GET_REASON(code)); reason = strerror_str(ERR_GET_REASON(code));
} } else {
else
{
const char *rc = ERR_reason_error_string(code); const char *rc = ERR_reason_error_string(code);
reason = (rc == nullptr) ? "unknown" : rc; reason = (rc == nullptr) ? "unknown" : rc;
} }
if (err.empty() || lastonly) if (err.empty() || lastonly) {
{
err = reason; err = reason;
} } else {
else
{
err = err + ", " + reason; err = err + ", " + reason;
} }
if (data != nullptr) if (data != nullptr) {
{
err = err + " " + data; err = err + " " + data;
} }
} }
return err; return err;
} }
void assert_ssl_errors_empty() void assert_ssl_errors_empty() {
{
int code = ERR_peek_error(); int code = ERR_peek_error();
if (code != 0) if (code != 0) {
{
std::cerr << "SSL should not have errors at this point." << std::endl; std::cerr << "SSL should not have errors at this point." << std::endl;
ERR_print_errors_fp(stderr); ERR_print_errors_fp(stderr);
exit(1); exit(1);
} }
} }
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
{
BIO *bio = BIO_new(BIO_s_mem()); BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str); BIO_puts(bio, str);
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL); X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
@@ -117,8 +92,7 @@ static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str)
return status; return status;
} }
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
{
BIO *bio = BIO_new(BIO_s_mem()); BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str); BIO_puts(bio, str);
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
@@ -128,33 +102,27 @@ static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str)
return status; return status;
} }
static void ssl_ctx_use_dummycert(SSL_CTX *ctx) static void ssl_ctx_use_dummycert(SSL_CTX *ctx) {
{ if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0) {
if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0)
{
ERR_print_errors_fp(stderr); ERR_print_errors_fp(stderr);
exit(1); exit(1);
} }
if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0) if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0) {
{
ERR_print_errors_fp(stderr); ERR_print_errors_fp(stderr);
exit(1); exit(1);
} }
} }
class Driver class Driver {
{ public:
public: enum ChanState {
enum ChanState
{
CHAN_INACTIVE, CHAN_INACTIVE,
CHAN_PLAINTEXT, CHAN_PLAINTEXT,
CHAN_SSL_CONNECTING, CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING, CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE, CHAN_SSL_READWRITE,
}; };
struct ChanInfo struct ChanInfo {
{
int chid; int chid;
SOCKET socket; SOCKET socket;
SSL *ssl; SSL *ssl;
@@ -182,13 +150,10 @@ public:
UniqueSSLCTX ssl_client_secure_ctx_; UniqueSSLCTX ssl_client_secure_ctx_;
UniqueSSLCTX ssl_client_insecure_ctx_; UniqueSSLCTX ssl_client_insecure_ctx_;
void handle_listen_ports() void handle_listen_ports() {
{
const auto &listenports = recorder_.drv_get_listen_ports(); const auto &listenports = recorder_.drv_get_listen_ports();
for (int port : listenports) for (int port : listenports) {
{ if (listen_sockets_.find(port) == listen_sockets_.end()) {
if (listen_sockets_.find(port) == listen_sockets_.end())
{
std::string err; std::string err;
SOCKET sock = listen_on_port(port, err); SOCKET sock = listen_on_port(port, err);
if_error_print_and_exit(err); if_error_print_and_exit(err);
@@ -198,32 +163,29 @@ public:
} }
} }
void handle_lua_source() void handle_lua_source() {
{ if (recorder_.drv_get_rescan_lua_source()) {
if (recorder_.drv_get_rescan_lua_source())
{
std::string err; std::string err;
std::string_view ctrl = read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err); std::string_view ctrl =
read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err); if_error_print_and_exit(err);
std::vector<std::string> names = drv::parse_control_lst(ctrl); std::vector<std::string> names = drv::parse_control_lst(ctrl);
recorder_.drv_clear_lua_source(); recorder_.drv_clear_lua_source();
for (const std::string &str : names) for (const std::string &str : names) {
{
std::string lfn = std::string("lua/") + str; std::string lfn = std::string("lua/") + str;
std::string_view data = read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err); std::string_view data =
read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err); if_error_print_and_exit(err);
recorder_.drv_add_lua_source(str, data); recorder_.drv_add_lua_source(str, data);
} }
} }
} }
void close_channel(ChanInfo &chan, std::string_view err) void close_channel(ChanInfo &chan, std::string_view err) {
{
// std::cerr << "Closing channel " << chan.chid << std::endl; // std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE); assert(chan.state != CHAN_INACTIVE);
// Close and release the SSL channel. // Close and release the SSL channel.
if (chan.ssl != nullptr) if (chan.ssl != nullptr) {
{
SSL_free(chan.ssl); SSL_free(chan.ssl);
chan.ssl = nullptr; chan.ssl = nullptr;
} }
@@ -246,52 +208,39 @@ public:
chan.last_write_nbytes = 0; chan.last_write_nbytes = 0;
} }
void cleanup_channels() void cleanup_channels() {
{ for (int i = 0; i < int(chans_.size());) {
for (int i = 0; i < int(chans_.size());) if (chans_[i].state == CHAN_INACTIVE) {
{
if (chans_[i].state == CHAN_INACTIVE)
{
chans_[i] = chans_.back(); chans_[i] = chans_.back();
chans_.pop_back(); chans_.pop_back();
} } else {
else
{
i += 1; i += 1;
} }
} }
} }
void handle_console_output() void handle_console_output() {
{ while (true) {
while (true)
{
std::string_view s = recorder_.drv_peek_outgoing(0); std::string_view s = recorder_.drv_peek_outgoing(0);
if (s.size() == 0) if (s.size() == 0) break;
break;
int nwrote = console_write(s.data(), s.size()); int nwrote = console_write(s.data(), s.size());
if (nwrote <= 0) if (nwrote <= 0) break;
break;
recorder_.drv_sent_outgoing(0, nwrote); recorder_.drv_sent_outgoing(0, nwrote);
} }
} }
void handle_console_input() void handle_console_input() {
{
char buffer[256]; char buffer[256];
read_console_recently_ = false; read_console_recently_ = false;
while (true) while (true) {
{
int nread = console_read(buffer, 256); int nread = console_read(buffer, 256);
if (nread <= 0) if (nread <= 0) break;
break;
read_console_recently_ = true; read_console_recently_ = true;
recorder_.drv_recv_incoming(0, std::string_view(buffer, nread)); recorder_.drv_recv_incoming(0, std::string_view(buffer, nread));
} }
} }
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
{
ChanInfo newchan; ChanInfo newchan;
newchan.chid = chid; newchan.chid = chid;
newchan.socket = sock; newchan.socket = sock;
@@ -312,16 +261,15 @@ public:
chans_.push_back(newchan); chans_.push_back(newchan);
} }
void handle_new_outgoing_sockets() void handle_new_outgoing_sockets() {
{
const auto &chans = recorder_.drv_get_new_outgoing(); const auto &chans = recorder_.drv_get_new_outgoing();
for (int chid : chans) for (int chid : chans) {
{
std::string err, cert, host, port; std::string err, cert, host, port;
std::string target(recorder_.drv_get_target(chid)); std::string target(recorder_.drv_get_target(chid));
drv::split_target(target, cert, host, port); drv::split_target(target, cert, host, port);
if (cert.empty() || host.empty() || port.empty()) { if (cert.empty() || host.empty() || port.empty()) {
recorder_.drv_notify_close(chid, std::string("invalid target: ") + target); recorder_.drv_notify_close(
chid, std::string("invalid target: ") + target);
continue; continue;
} }
SSL_CTX *ctx = nullptr; SSL_CTX *ctx = nullptr;
@@ -330,63 +278,54 @@ public:
} else if (cert == "nocert") { } else if (cert == "nocert") {
ctx = ssl_client_insecure_ctx_.get(); ctx = ssl_client_insecure_ctx_.get();
} else { } else {
recorder_.drv_notify_close(chid, std::string("invalid cert rule: ") + target); recorder_.drv_notify_close(
chid, std::string("invalid cert rule: ") + target);
continue; continue;
} }
SOCKET sock = open_connection(host.c_str(), port.c_str(), err); SOCKET sock = open_connection(host.c_str(), port.c_str(), err);
if (sock == INVALID_SOCKET) if (sock == INVALID_SOCKET) {
{
recorder_.drv_notify_close(chid, err); recorder_.drv_notify_close(chid, err);
continue; continue;
} }
// std::cerr << "Opening channel " << chid << std::endl; // std::cerr << "Opening channel " << chid << std::endl;
make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING); make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING);
} }
if (!chans.empty()) if (!chans.empty()) {
{
recorder_.drv_clear_new_outgoing(); recorder_.drv_clear_new_outgoing();
} }
} }
void accept_connection(int port, SOCKET sock) void accept_connection(int port, SOCKET sock) {
{
std::string err; std::string err;
SOCKET socket = accept_on_socket(sock, err); SOCKET socket = accept_on_socket(sock, err);
if_error_print_and_exit(err); if_error_print_and_exit(err);
if (socket != INVALID_SOCKET) if (socket != INVALID_SOCKET) {
{
int chid = recorder_.drv_notify_accept(port); int chid = recorder_.drv_notify_accept(port);
// std::cerr << "Accepted channel " << chid << std::endl; // std::cerr << "Accepted channel " << chid << std::endl;
make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING); make_channel(socket, chid, ssl_server_ctx_.get(),
CHAN_SSL_ACCEPTING);
} }
} }
void advance_plaintext(ChanInfo &chan) void advance_plaintext(ChanInfo &chan) {
{
std::string err; std::string err;
// If the channel has no outgoing bytes and has been released, // If the channel has no outgoing bytes and has been released,
// just close it. // just close it.
if (chan.released) if (chan.released) {
{
close_channel(chan, ""); close_channel(chan, "");
return; return;
} }
// Try to write plaintext to the channel. // Try to write plaintext to the channel.
std::string_view s = recorder_.drv_peek_outgoing(chan.chid); std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
if (s.size() > 0) if (s.size() > 0) {
{
int sbytes = s.size(); int sbytes = s.size();
if (sbytes > 65536) if (sbytes > 65536) sbytes = 65536;
sbytes = 65536;
int wbytes = socket_send(chan.socket, s.data(), sbytes, err); int wbytes = socket_send(chan.socket, s.data(), sbytes, err);
if (wbytes < 0) if (wbytes < 0) {
{
close_channel(chan, err); close_channel(chan, err);
} } else {
else
{
recorder_.drv_sent_outgoing(chan.chid, wbytes); recorder_.drv_sent_outgoing(chan.chid, wbytes);
} }
} }
@@ -394,13 +333,11 @@ public:
// Try to read plaintext from the channel. // Try to read plaintext from the channel.
// Someday, find a way to avoid this copy. // Someday, find a way to avoid this copy.
int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err); int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err);
if (nrecv < 0) if (nrecv < 0) {
{
close_channel(chan, err); close_channel(chan, err);
} } else {
else recorder_.drv_recv_incoming(chan.chid,
{ std::string_view(chbuf_.get(), nrecv));
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), nrecv));
} }
// Update the ready-flags for next time. // Update the ready-flags for next time.
@@ -408,149 +345,117 @@ public:
chan.ready_on_pollin = true; chan.ready_on_pollin = true;
} }
void process_ssl_error(ChanInfo &chan, int retval) void process_ssl_error(ChanInfo &chan, int retval) {
{
int error = SSL_get_error(chan.ssl, retval); int error = SSL_get_error(chan.ssl, retval);
// std::cerr << "SSL error code = " << error << " "; // std::cerr << "SSL error code = " << error << " ";
if (error == SSL_ERROR_WANT_READ) if (error == SSL_ERROR_WANT_READ) {
{
chan.ready_on_pollin = true; chan.ready_on_pollin = true;
} } else if (error == SSL_ERROR_WANT_WRITE) {
else if (error == SSL_ERROR_WANT_WRITE)
{
chan.ready_on_pollout = true; chan.ready_on_pollout = true;
} } else {
else
{
close_channel(chan, ssl_errors_string()); close_channel(chan, ssl_errors_string());
} }
} }
void advance_ssl_connecting(ChanInfo &chan) void advance_ssl_connecting(ChanInfo &chan) {
{
// std::cerr << "In advance_ssl_connecting" << std::endl; // std::cerr << "In advance_ssl_connecting" << std::endl;
int retval = SSL_connect(chan.ssl); int retval = SSL_connect(chan.ssl);
if (retval == 1) if (retval == 1) {
{
// Connection successful. // Connection successful.
chan.state = CHAN_SSL_READWRITE; chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true; chan.ready_now = true;
} } else {
else
{
// std::cerr << "ssl_connect_error"; // std::cerr << "ssl_connect_error";
process_ssl_error(chan, retval); process_ssl_error(chan, retval);
} }
} }
void advance_ssl_accepting(ChanInfo &chan) void advance_ssl_accepting(ChanInfo &chan) {
{
// std::cerr << "In advance_ssl_accepting" << std::endl; // std::cerr << "In advance_ssl_accepting" << std::endl;
int retval = SSL_accept(chan.ssl); int retval = SSL_accept(chan.ssl);
if (retval == 1) if (retval == 1) {
{
// Connection successful. // Connection successful.
chan.state = CHAN_SSL_READWRITE; chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true; chan.ready_now = true;
} } else {
else
{
process_ssl_error(chan, retval); process_ssl_error(chan, retval);
} }
} }
void advance_ssl_readwrite(ChanInfo &chan) void advance_ssl_readwrite(ChanInfo &chan) {
{
// std::cerr << "In advance_ssl_readwrite" << std::endl; // std::cerr << "In advance_ssl_readwrite" << std::endl;
// Try to read data. // Try to read data.
int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536); int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536);
if (read_result > 0) if (read_result > 0) {
{ recorder_.drv_recv_incoming(
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), read_result)); chan.chid, std::string_view(chbuf_.get(), read_result));
chan.ready_now = true; chan.ready_now = true;
} } else {
else
{
process_ssl_error(chan, read_result); process_ssl_error(chan, read_result);
if (chan.state == CHAN_INACTIVE) if (chan.state == CHAN_INACTIVE) return;
return;
} }
// Try to write data. // Try to write data.
int wbytes; int wbytes;
if (chan.last_write_nbytes > 0) if (chan.last_write_nbytes > 0) {
{
wbytes = chan.last_write_nbytes; wbytes = chan.last_write_nbytes;
assert(wbytes < chan.nbytes); assert(wbytes < chan.nbytes);
} } else {
else
{
wbytes = chan.nbytes; wbytes = chan.nbytes;
if (wbytes > 65536) if (wbytes > 65536) wbytes = 65536;
wbytes = 65536;
} }
if (wbytes > 0) if (wbytes > 0) {
{
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
if (write_result > 0) if (write_result > 0) {
{
recorder_.drv_sent_outgoing(chan.chid, write_result); recorder_.drv_sent_outgoing(chan.chid, write_result);
chan.last_write_nbytes = 0; chan.last_write_nbytes = 0;
chan.ready_on_outgoing = true; chan.ready_on_outgoing = true;
} } else {
else
{
chan.last_write_nbytes = wbytes; chan.last_write_nbytes = wbytes;
process_ssl_error(chan, write_result); process_ssl_error(chan, write_result);
if (chan.state == CHAN_INACTIVE) if (chan.state == CHAN_INACTIVE) return;
return;
} }
} } else {
else
{
chan.ready_on_outgoing = true; chan.ready_on_outgoing = true;
} }
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " "; // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" <<
// chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" <<
// chan.ready_on_outgoing << " ";
} }
void advance_channel(ChanInfo &chan) void advance_channel(ChanInfo &chan) {
{
assert_ssl_errors_empty(); assert_ssl_errors_empty();
switch (chan.state) switch (chan.state) {
{ case CHAN_PLAINTEXT:
case CHAN_PLAINTEXT: advance_plaintext(chan);
advance_plaintext(chan); break;
break; case CHAN_SSL_CONNECTING:
case CHAN_SSL_CONNECTING: advance_ssl_connecting(chan);
advance_ssl_connecting(chan); break;
break; case CHAN_SSL_ACCEPTING:
case CHAN_SSL_ACCEPTING: advance_ssl_accepting(chan);
advance_ssl_accepting(chan); break;
break; case CHAN_SSL_READWRITE:
case CHAN_SSL_READWRITE: advance_ssl_readwrite(chan);
advance_ssl_readwrite(chan); break;
break; default:
default: assert(false);
assert(false); break;
break;
} }
assert_ssl_errors_empty(); assert_ssl_errors_empty();
} }
void handle_socket_input_output() void handle_socket_input_output() {
{
std::string err; std::string err;
int mstimeout = read_console_recently_ ? 100 : 1000; int mstimeout = read_console_recently_ ? 100 : 1000;
// Peek output buffers and determine channel release flags. // Peek output buffers and determine channel release flags.
for (ChanInfo &chan : chans_) for (ChanInfo &chan : chans_) {
{
std::string_view s = recorder_.drv_peek_outgoing(chan.chid); std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
chan.nbytes = s.size(); chan.nbytes = s.size();
chan.bytes = s.data(); chan.bytes = s.data();
chan.just_released = false; chan.just_released = false;
if ((chan.nbytes == 0) && (!chan.released)) if ((chan.nbytes == 0) && (!chan.released)) {
{
chan.released = recorder_.drv_get_channel_released(chan.chid); chan.released = recorder_.drv_get_channel_released(chan.chid);
chan.just_released = chan.released; chan.just_released = chan.released;
} }
@@ -558,31 +463,26 @@ public:
// Construct the struct pollfd vector. // Construct the struct pollfd vector.
int pollsize = 0; int pollsize = 0;
for (const auto &p : listen_sockets_) for (const auto &p : listen_sockets_) {
{
struct pollfd &pfd = pollvec_[pollsize++]; struct pollfd &pfd = pollvec_[pollsize++];
pfd.fd = p.second; pfd.fd = p.second;
pfd.events = POLLIN; pfd.events = POLLIN;
pfd.revents = 0; pfd.revents = 0;
} }
for (const ChanInfo &chan : chans_) for (const ChanInfo &chan : chans_) {
{
struct pollfd &pfd = pollvec_[pollsize++]; struct pollfd &pfd = pollvec_[pollsize++];
assert(chan.socket != INVALID_SOCKET); assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket; pfd.fd = chan.socket;
pfd.events = 0; pfd.events = 0;
pfd.revents = 0; pfd.revents = 0;
if (chan.ready_now) if (chan.ready_now) mstimeout = 0;
mstimeout = 0; if (chan.just_released) mstimeout = 0;
if (chan.just_released) if (chan.ready_on_pollin) pfd.events |= POLLIN;
mstimeout = 0; if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_pollin)
pfd.events |= POLLIN;
if (chan.ready_on_pollout)
pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0)) if (chan.ready_on_outgoing && (chan.nbytes > 0))
pfd.events |= POLLOUT; pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " "; // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << "
// ";
} }
// Do the poll. // Do the poll.
@@ -591,18 +491,15 @@ public:
// Check listening sockets. // Check listening sockets.
int index = 0; int index = 0;
for (auto &p : listen_sockets_) for (auto &p : listen_sockets_) {
{
struct pollfd &pfd = pollvec_[index++]; struct pollfd &pfd = pollvec_[index++];
if (pfd.revents & (POLLIN | POLLERR)) if (pfd.revents & (POLLIN | POLLERR)) {
{
accept_connection(p.first, p.second); accept_connection(p.first, p.second);
} }
} }
// Advance channels where possible. // Advance channels where possible.
for (ChanInfo &chan : chans_) for (ChanInfo &chan : chans_) {
{
struct pollfd &pfd = pollvec_[index++]; struct pollfd &pfd = pollvec_[index++];
bool pollin = ((pfd.revents & POLLIN) != 0); bool pollin = ((pfd.revents & POLLIN) != 0);
bool pollout = ((pfd.revents & POLLOUT) != 0); bool pollout = ((pfd.revents & POLLOUT) != 0);
@@ -610,8 +507,7 @@ public:
if (chan.ready_now || pollerr || chan.just_released || if (chan.ready_now || pollerr || chan.just_released ||
(chan.ready_on_pollin && pollin) || (chan.ready_on_pollin && pollin) ||
(chan.ready_on_pollout && pollout) || (chan.ready_on_pollout && pollout) ||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
{
chan.ready_now = false; chan.ready_now = false;
chan.ready_on_pollin = false; chan.ready_on_pollin = false;
chan.ready_on_pollout = false; chan.ready_on_pollout = false;
@@ -626,30 +522,24 @@ public:
cleanup_channels(); cleanup_channels();
} }
int replay_logfile(const char *fn, bool verbose) int replay_logfile(const char *fn, bool verbose) {
{
drv::ReplayPlayer player; drv::ReplayPlayer player;
player.open_logfile(fn); player.open_logfile(fn);
if (verbose) if (verbose) {
{
player.enable_stdout(); player.enable_stdout();
} }
while (true) while (true) {
{
drv::ReplayPlayer::Status st = player.step(); drv::ReplayPlayer::Status st = player.step();
if (st != drv::ReplayPlayer::ST_REPLAYING) if (st != drv::ReplayPlayer::ST_REPLAYING) {
{
player.print_status(std::cerr); player.print_status(std::cerr);
return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1; return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1;
} }
} }
} }
int drive(int argc, char *argv[]) int drive(int argc, char *argv[]) {
{
// Remove the program name from argv. // Remove the program name from argv.
if (argc < 1) if (argc < 1) {
{
DrivenEngine::print_usage(std::cerr, "<unknown>"); DrivenEngine::print_usage(std::cerr, "<unknown>");
exit(1); exit(1);
} }
@@ -659,14 +549,12 @@ public:
// If argv contains "replay <filename>", do a replay, // If argv contains "replay <filename>", do a replay,
// and then skip everything else. // and then skip everything else.
if (argc >= 1) if (argc >= 1) {
{
std::string cmd(argv[0]); std::string cmd(argv[0]);
if ((cmd == "replay") || (cmd == "vreplay")) if ((cmd == "replay") || (cmd == "vreplay")) {
{ if (argc != 2) {
if (argc != 2) std::cerr << "usage: " << program << " replay <filename>"
{ << std::endl;
std::cerr << "usage: " << program << " replay <filename>" << std::endl;
return 1; return 1;
} }
return replay_logfile(argv[1], cmd == "vreplay"); return replay_logfile(argv[1], cmd == "vreplay");
@@ -675,20 +563,17 @@ public:
// If argv contains "record <filename>", start recording, // If argv contains "record <filename>", start recording,
// and remove the "record <filename>" from argv. // and remove the "record <filename>" from argv.
if (argc >= 1) if (argc >= 1) {
{
std::string cmd = argv[0]; std::string cmd = argv[0];
if (cmd == "record") if (cmd == "record") {
{ if (argc < 2) {
if (argc < 2)
{
DrivenEngine::print_usage(std::cerr, program); DrivenEngine::print_usage(std::cerr, program);
return 1; return 1;
} }
bool ok = recorder_.open_logfile(argv[1]); bool ok = recorder_.open_logfile(argv[1]);
if (!ok) if (!ok) {
{ std::cerr << "Could not open logfile: " << argv[1]
std::cerr << "Could not open logfile: " << argv[1] << std::endl; << std::endl;
return 1; return 1;
} }
argc -= 2; argc -= 2;
@@ -697,14 +582,12 @@ public:
} }
// Create the engine. // Create the engine.
if (argc < 1) if (argc < 1) {
{
DrivenEngine::print_usage(std::cerr, program); DrivenEngine::print_usage(std::cerr, program);
return 1; return 1;
} }
bool engine_made = recorder_.create_engine(argv[0]); bool engine_made = recorder_.create_engine(argv[0]);
if (!engine_made) if (!engine_made) {
{
DrivenEngine::print_usage(std::cerr, program); DrivenEngine::print_usage(std::cerr, program);
return 1; return 1;
} }
@@ -723,8 +606,7 @@ public:
recorder_.drv_invoke_event_init(argc, argv); recorder_.drv_invoke_event_init(argc, argv);
handle_listen_ports(); handle_listen_ports();
while (!recorder_.drv_get_stop_driver()) while (!recorder_.drv_get_stop_driver()) {
{
handle_lua_source(); handle_lua_source();
handle_console_output(); handle_console_output();
handle_new_outgoing_sockets(); handle_new_outgoing_sockets();
@@ -734,8 +616,7 @@ public:
recorder_.drv_invoke_event_update(monoclock.get()); recorder_.drv_invoke_event_update(monoclock.get());
} }
for (ChanInfo &chan : chans_) for (ChanInfo &chan : chans_) {
{
close_channel(chan, ""); close_channel(chan, "");
} }