Files
integration/luprex/core/cpp/driver-common.cpp
2022-05-06 18:25:15 -04:00

555 lines
18 KiB
C++

#define CHBUF_SIZE (256 * 1024)
#define POLLVEC_SIZE (DrivenEngine::MAX_CHAN + 1)
static MonoClock monoclock;
namespace util {
double profiling_clock() { return monoclock.get(); }
} // namespace util
static void if_error_print_and_exit(const std::string &str) {
if (!str.empty()) {
std::cerr << std::endl << "error: " << str << std::endl;
exit(1);
}
}
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) {
FILE *f = fopen(fn, "r");
if (f == 0) {
err = std::string("cannot read file") + fn;
buf[0] = 0;
return std::string_view(buf, 0);
}
int nread = fread(buf, 1, bufsize, f);
if (nread < 0) {
err = std::string("cannot read file: ") + fn;
buf[0] = 0;
return std::string_view(buf, 0);
}
if (nread == bufsize) {
err = std::string("file too large: ") + fn;
buf[0] = 0;
return std::string_view(buf, 0);
}
err = "";
return std::string_view(buf, nread);
}
class Driver {
public:
enum ChanState {
CHAN_INACTIVE,
CHAN_PLAINTEXT,
CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE,
};
struct ChanInfo {
int chid;
SOCKET socket;
SSL *ssl;
ChanState state;
int nbytes;
const char *bytes;
bool released;
bool just_released;
bool ready_now;
bool ready_on_pollin;
bool ready_on_pollout;
bool ready_on_outgoing;
int last_write_nbytes;
};
std::vector<ChanInfo> chans_;
std::map<int, SOCKET> listen_sockets_;
bool read_console_recently_;
std::unique_ptr<char[]> chbuf_;
std::unique_ptr<struct pollfd[]> pollvec_;
drv::ReplayRecorder recorder_;
drvssl::UniqueCTX ssl_server_ctx_;
drvssl::UniqueCTX ssl_client_secure_ctx_;
drvssl::UniqueCTX ssl_client_insecure_ctx_;
void handle_listen_ports() {
const auto &listenports = recorder_.drv_get_listen_ports();
for (int port : listenports) {
if (listen_sockets_.find(port) == listen_sockets_.end()) {
std::string err;
SOCKET sock = listen_on_port(port, err);
if_error_print_and_exit(err);
assert(sock != INVALID_SOCKET);
listen_sockets_[port] = sock;
}
}
}
void handle_lua_source() {
if (recorder_.drv_get_rescan_lua_source()) {
std::string err;
std::string_view ctrl =
read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err);
std::vector<std::string> names = drv::parse_control_lst(ctrl);
recorder_.drv_clear_lua_source();
for (const std::string &str : names) {
std::string lfn = std::string("lua/") + str;
std::string_view data =
read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err);
recorder_.drv_add_lua_source(str, data);
}
}
}
void close_channel(ChanInfo &chan, std::string_view err) {
// std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE);
// Close and release the SSL channel.
if (chan.ssl != nullptr) {
SSL_free(chan.ssl);
chan.ssl = nullptr;
}
// Close and release the socket.
assert(chan.socket != INVALID_SOCKET);
assert(socket_close(chan.socket) == 0);
chan.socket = INVALID_SOCKET;
// Close everything else.
recorder_.drv_notify_close(chan.chid, err);
chan.state = CHAN_INACTIVE;
chan.chid = -1;
chan.nbytes = 0;
chan.bytes = 0;
chan.released = false;
chan.just_released = false;
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
chan.last_write_nbytes = 0;
}
void cleanup_channels() {
for (int i = 0; i < int(chans_.size());) {
if (chans_[i].state == CHAN_INACTIVE) {
chans_[i] = chans_.back();
chans_.pop_back();
} else {
i += 1;
}
}
}
void handle_console_output() {
while (true) {
std::string_view s = recorder_.drv_peek_outgoing(0);
if (s.size() == 0) break;
int nwrote = console_write(s.data(), s.size());
if (nwrote <= 0) break;
recorder_.drv_sent_outgoing(0, nwrote);
}
}
void handle_console_input() {
char buffer[256];
read_console_recently_ = false;
while (true) {
int nread = console_read(buffer, 256);
if (nread <= 0) break;
read_console_recently_ = true;
recorder_.drv_recv_incoming(0, std::string_view(buffer, nread));
}
}
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
ChanInfo newchan;
newchan.chid = chid;
newchan.socket = sock;
newchan.ssl = SSL_new(ctx);
newchan.state = state;
newchan.nbytes = 0;
newchan.bytes = 0;
newchan.released = false;
newchan.just_released = false;
newchan.ready_now = false;
newchan.ready_on_pollin = false;
newchan.ready_on_pollout = true;
newchan.ready_on_outgoing = false;
newchan.last_write_nbytes = 0;
SSL_set_fd(newchan.ssl, newchan.socket);
// SSL_set_msg_callback(newchan.ssl, SSL_trace);
// SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0));
chans_.push_back(newchan);
}
void handle_new_outgoing_sockets() {
const auto &chans = recorder_.drv_get_new_outgoing();
for (int chid : chans) {
std::string err, cert, host, port;
std::string target(recorder_.drv_get_target(chid));
drv::split_target(target, cert, host, port);
if (cert.empty() || host.empty() || port.empty()) {
recorder_.drv_notify_close(
chid, std::string("invalid target: ") + target);
continue;
}
SSL_CTX *ctx = nullptr;
if (cert == "cert") {
ctx = ssl_client_secure_ctx_.get();
} else if (cert == "nocert") {
ctx = ssl_client_insecure_ctx_.get();
} else {
recorder_.drv_notify_close(
chid, std::string("invalid cert rule: ") + target);
continue;
}
SOCKET sock = open_connection(host.c_str(), port.c_str(), err);
if (sock == INVALID_SOCKET) {
recorder_.drv_notify_close(chid, err);
continue;
}
// std::cerr << "Opening channel " << chid << std::endl;
make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING);
}
if (!chans.empty()) {
recorder_.drv_clear_new_outgoing();
}
}
void accept_connection(int port, SOCKET sock) {
std::string err;
SOCKET socket = accept_on_socket(sock, err);
if_error_print_and_exit(err);
if (socket != INVALID_SOCKET) {
int chid = recorder_.drv_notify_accept(port);
// std::cerr << "Accepted channel " << chid << std::endl;
make_channel(socket, chid, ssl_server_ctx_.get(),
CHAN_SSL_ACCEPTING);
}
}
void advance_plaintext(ChanInfo &chan) {
std::string err;
// If the channel has no outgoing bytes and has been released,
// just close it.
if (chan.released) {
close_channel(chan, "");
return;
}
// Try to write plaintext to the channel.
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
if (s.size() > 0) {
int sbytes = s.size();
if (sbytes > 65536) sbytes = 65536;
int wbytes = socket_send(chan.socket, s.data(), sbytes, err);
if (wbytes < 0) {
close_channel(chan, err);
} else {
recorder_.drv_sent_outgoing(chan.chid, wbytes);
}
}
// Try to read plaintext from the channel.
// Someday, find a way to avoid this copy.
int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err);
if (nrecv < 0) {
close_channel(chan, err);
} else {
recorder_.drv_recv_incoming(chan.chid,
std::string_view(chbuf_.get(), nrecv));
}
// Update the ready-flags for next time.
chan.ready_on_outgoing = true;
chan.ready_on_pollin = true;
}
void process_ssl_error(ChanInfo &chan, int retval) {
int error = SSL_get_error(chan.ssl, retval);
// std::cerr << "SSL error code = " << error << " ";
if (error == SSL_ERROR_WANT_READ) {
chan.ready_on_pollin = true;
} else if (error == SSL_ERROR_WANT_WRITE) {
chan.ready_on_pollout = true;
} else {
std::string error = drvssl::error_string();
if (error == "") error = "unknown error";
close_channel(chan, error);
}
}
void advance_ssl_connecting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_connecting" << std::endl;
int retval = SSL_connect(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
// std::cerr << "ssl_connect_error";
process_ssl_error(chan, retval);
}
}
void advance_ssl_accepting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_accepting" << std::endl;
int retval = SSL_accept(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
process_ssl_error(chan, retval);
}
}
void advance_ssl_readwrite(ChanInfo &chan) {
// std::cerr << "In advance_ssl_readwrite" << std::endl;
// Try to read data.
int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536);
if (read_result > 0) {
recorder_.drv_recv_incoming(
chan.chid, std::string_view(chbuf_.get(), read_result));
chan.ready_now = true;
} else {
process_ssl_error(chan, read_result);
if (chan.state == CHAN_INACTIVE) return;
}
// Try to write data.
int wbytes;
if (chan.last_write_nbytes > 0) {
wbytes = chan.last_write_nbytes;
assert(wbytes < chan.nbytes);
} else {
wbytes = chan.nbytes;
if (wbytes > 65536) wbytes = 65536;
}
if (wbytes > 0) {
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
if (write_result > 0) {
recorder_.drv_sent_outgoing(chan.chid, write_result);
chan.last_write_nbytes = 0;
chan.ready_on_outgoing = true;
} else {
chan.last_write_nbytes = wbytes;
process_ssl_error(chan, write_result);
if (chan.state == CHAN_INACTIVE) return;
}
} else {
chan.ready_on_outgoing = true;
}
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" <<
// chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" <<
// chan.ready_on_outgoing << " ";
}
void advance_channel(ChanInfo &chan) {
drvssl::clear_all_errors();
switch (chan.state) {
case CHAN_PLAINTEXT:
advance_plaintext(chan);
break;
case CHAN_SSL_CONNECTING:
advance_ssl_connecting(chan);
break;
case CHAN_SSL_ACCEPTING:
advance_ssl_accepting(chan);
break;
case CHAN_SSL_READWRITE:
advance_ssl_readwrite(chan);
break;
default:
assert(false);
break;
}
}
void handle_socket_input_output() {
std::string err;
int mstimeout = read_console_recently_ ? 100 : 1000;
// Peek output buffers and determine channel release flags.
for (ChanInfo &chan : chans_) {
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
chan.nbytes = s.size();
chan.bytes = s.data();
chan.just_released = false;
if ((chan.nbytes == 0) && (!chan.released)) {
chan.released = recorder_.drv_get_channel_released(chan.chid);
chan.just_released = chan.released;
}
}
// Construct the struct pollfd vector.
int pollsize = 0;
for (const auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec_[pollsize++];
pfd.fd = p.second;
pfd.events = POLLIN;
pfd.revents = 0;
}
for (const ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec_[pollsize++];
assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket;
pfd.events = 0;
pfd.revents = 0;
if (chan.ready_now) mstimeout = 0;
if (chan.just_released) mstimeout = 0;
if (chan.ready_on_pollin) pfd.events |= POLLIN;
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0))
pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << "
// ";
}
// Do the poll.
socket_poll(pollvec_.get(), pollsize, mstimeout, err);
if_error_print_and_exit(err);
// Check listening sockets.
int index = 0;
for (auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec_[index++];
if (pfd.revents & (POLLIN | POLLERR)) {
accept_connection(p.first, p.second);
}
}
// Advance channels where possible.
for (ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec_[index++];
bool pollin = ((pfd.revents & POLLIN) != 0);
bool pollout = ((pfd.revents & POLLOUT) != 0);
bool pollerr = ((pfd.revents & (POLLERR | POLLHUP)) != 0);
if (chan.ready_now || pollerr || chan.just_released ||
(chan.ready_on_pollin && pollin) ||
(chan.ready_on_pollout && pollout) ||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
advance_channel(chan);
}
chan.nbytes = 0;
chan.bytes = 0;
}
// Delete any newly-inactive channels
cleanup_channels();
}
int replay_logfile(const char *fn, bool verbose) {
drv::ReplayPlayer player;
player.open_logfile(fn);
if (verbose) {
player.enable_stdout();
}
while (true) {
drv::ReplayPlayer::Status st = player.step();
if (st != drv::ReplayPlayer::ST_REPLAYING) {
player.print_status(std::cerr);
return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1;
}
}
}
int drive(int argc, char *argv[]) {
// Remove the program name from argv.
if (argc < 1) {
DrivenEngine::print_usage(std::cerr, "<unknown>");
exit(1);
}
std::string program = argv[0];
argc -= 1;
argv += 1;
// If argv contains "replay <filename>", do a replay,
// and then skip everything else.
if (argc >= 1) {
std::string cmd(argv[0]);
if ((cmd == "replay") || (cmd == "vreplay")) {
if (argc != 2) {
std::cerr << "usage: " << program << " replay <filename>"
<< std::endl;
return 1;
}
return replay_logfile(argv[1], cmd == "vreplay");
}
}
// If argv contains "record <filename>", start recording,
// and remove the "record <filename>" from argv.
if (argc >= 1) {
std::string cmd = argv[0];
if (cmd == "record") {
if (argc < 2) {
DrivenEngine::print_usage(std::cerr, program);
return 1;
}
bool ok = recorder_.open_logfile(argv[1]);
if (!ok) {
std::cerr << "Could not open logfile: " << argv[1]
<< std::endl;
return 1;
}
argc -= 2;
argv += 2;
}
}
// Create the engine.
if (argc < 1) {
DrivenEngine::print_usage(std::cerr, program);
return 1;
}
bool engine_made = recorder_.create_engine(argv[0]);
if (!engine_made) {
DrivenEngine::print_usage(std::cerr, program);
return 1;
}
read_console_recently_ = false;
chbuf_.reset(new char[CHBUF_SIZE]);
pollvec_.reset(new struct pollfd[POLLVEC_SIZE]);
ssl_server_ctx_.reset(drvssl::new_context(SSL_VERIFY_NONE));
ssl_client_secure_ctx_.reset(drvssl::new_context(SSL_VERIFY_PEER));
ssl_client_insecure_ctx_.reset(drvssl::new_context(SSL_VERIFY_NONE));
ssl_load_certificate_authorities(ssl_client_secure_ctx_.get());
drvssl::ctx_load_dummy_cert(ssl_server_ctx_.get());
handle_lua_source();
recorder_.drv_invoke_event_init(argc, argv);
handle_listen_ports();
while (!recorder_.drv_get_stop_driver()) {
handle_lua_source();
handle_console_output();
handle_new_outgoing_sockets();
handle_socket_input_output();
handle_console_input();
handle_console_output();
recorder_.drv_invoke_event_update(monoclock.get());
}
for (ChanInfo &chan : chans_) {
close_channel(chan, "");
}
DrivenEngine::set(nullptr);
recorder_.clean_exit();
return 0;
}
};