fix accidental reformatting of driver-common.cpp
This commit is contained in:
@@ -4,42 +4,33 @@
|
|||||||
|
|
||||||
static MonoClock monoclock;
|
static MonoClock monoclock;
|
||||||
|
|
||||||
namespace util
|
namespace util {
|
||||||
{
|
|
||||||
double profiling_clock()
|
|
||||||
{
|
|
||||||
return monoclock.get();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void if_error_print_and_exit(const std::string &str)
|
double profiling_clock() { return monoclock.get(); }
|
||||||
{
|
|
||||||
if (!str.empty())
|
} // namespace util
|
||||||
{
|
|
||||||
std::cerr << std::endl
|
static void if_error_print_and_exit(const std::string &str) {
|
||||||
<< "error: " << str << std::endl;
|
if (!str.empty()) {
|
||||||
|
std::cerr << std::endl << "error: " << str << std::endl;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err)
|
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) {
|
||||||
{
|
|
||||||
FILE *f = fopen(fn, "r");
|
FILE *f = fopen(fn, "r");
|
||||||
if (f == 0)
|
if (f == 0) {
|
||||||
{
|
|
||||||
err = std::string("cannot read file") + fn;
|
err = std::string("cannot read file") + fn;
|
||||||
buf[0] = 0;
|
buf[0] = 0;
|
||||||
return std::string_view(buf, 0);
|
return std::string_view(buf, 0);
|
||||||
}
|
}
|
||||||
int nread = fread(buf, 1, bufsize, f);
|
int nread = fread(buf, 1, bufsize, f);
|
||||||
if (nread < 0)
|
if (nread < 0) {
|
||||||
{
|
|
||||||
err = std::string("cannot read file: ") + fn;
|
err = std::string("cannot read file: ") + fn;
|
||||||
buf[0] = 0;
|
buf[0] = 0;
|
||||||
return std::string_view(buf, 0);
|
return std::string_view(buf, 0);
|
||||||
}
|
}
|
||||||
if (nread == bufsize)
|
if (nread == bufsize) {
|
||||||
{
|
|
||||||
err = std::string("file too large: ") + fn;
|
err = std::string("file too large: ") + fn;
|
||||||
buf[0] = 0;
|
buf[0] = 0;
|
||||||
return std::string_view(buf, 0);
|
return std::string_view(buf, 0);
|
||||||
@@ -48,66 +39,50 @@ static std::string_view read_file(const char *fn, char *buf, int bufsize, std::s
|
|||||||
return std::string_view(buf, nread);
|
return std::string_view(buf, nread);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SSL_CTX_Deleter
|
struct SSL_CTX_Deleter {
|
||||||
{
|
void operator()(SSL_CTX *ctx) { SSL_CTX_free(ctx); }
|
||||||
void operator()(SSL_CTX *ctx)
|
|
||||||
{
|
|
||||||
SSL_CTX_free(ctx);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using UniqueSSLCTX = std::unique_ptr<SSL_CTX, SSL_CTX_Deleter>;
|
using UniqueSSLCTX = std::unique_ptr<SSL_CTX, SSL_CTX_Deleter>;
|
||||||
|
|
||||||
static std::string ssl_errors_string(bool lastonly = true)
|
static std::string ssl_errors_string(bool lastonly = true) {
|
||||||
{
|
|
||||||
std::string err;
|
std::string err;
|
||||||
const char *file, *data, *func;
|
const char *file, *data, *func;
|
||||||
int line, flags;
|
int line, flags;
|
||||||
|
|
||||||
while (true)
|
while (true) {
|
||||||
{
|
unsigned long code =
|
||||||
unsigned long code = ERR_get_error_all(&file, &line, &func, &data, &flags);
|
ERR_get_error_all(&file, &line, &func, &data, &flags);
|
||||||
if (code == 0)
|
if (code == 0) break;
|
||||||
break;
|
|
||||||
std::string reason;
|
std::string reason;
|
||||||
if (ERR_SYSTEM_ERROR(code))
|
if (ERR_SYSTEM_ERROR(code)) {
|
||||||
{
|
|
||||||
reason = strerror_str(ERR_GET_REASON(code));
|
reason = strerror_str(ERR_GET_REASON(code));
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
const char *rc = ERR_reason_error_string(code);
|
const char *rc = ERR_reason_error_string(code);
|
||||||
reason = (rc == nullptr) ? "unknown" : rc;
|
reason = (rc == nullptr) ? "unknown" : rc;
|
||||||
}
|
}
|
||||||
if (err.empty() || lastonly)
|
if (err.empty() || lastonly) {
|
||||||
{
|
|
||||||
err = reason;
|
err = reason;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
err = err + ", " + reason;
|
err = err + ", " + reason;
|
||||||
}
|
}
|
||||||
if (data != nullptr)
|
if (data != nullptr) {
|
||||||
{
|
|
||||||
err = err + " " + data;
|
err = err + " " + data;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
void assert_ssl_errors_empty()
|
void assert_ssl_errors_empty() {
|
||||||
{
|
|
||||||
int code = ERR_peek_error();
|
int code = ERR_peek_error();
|
||||||
if (code != 0)
|
if (code != 0) {
|
||||||
{
|
|
||||||
std::cerr << "SSL should not have errors at this point." << std::endl;
|
std::cerr << "SSL should not have errors at this point." << std::endl;
|
||||||
ERR_print_errors_fp(stderr);
|
ERR_print_errors_fp(stderr);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str)
|
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
|
||||||
{
|
|
||||||
BIO *bio = BIO_new(BIO_s_mem());
|
BIO *bio = BIO_new(BIO_s_mem());
|
||||||
BIO_puts(bio, str);
|
BIO_puts(bio, str);
|
||||||
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
|
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
|
||||||
@@ -117,8 +92,7 @@ static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str)
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str)
|
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
|
||||||
{
|
|
||||||
BIO *bio = BIO_new(BIO_s_mem());
|
BIO *bio = BIO_new(BIO_s_mem());
|
||||||
BIO_puts(bio, str);
|
BIO_puts(bio, str);
|
||||||
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
|
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
|
||||||
@@ -128,33 +102,27 @@ static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str)
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ssl_ctx_use_dummycert(SSL_CTX *ctx)
|
static void ssl_ctx_use_dummycert(SSL_CTX *ctx) {
|
||||||
{
|
if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0) {
|
||||||
if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0)
|
|
||||||
{
|
|
||||||
ERR_print_errors_fp(stderr);
|
ERR_print_errors_fp(stderr);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0)
|
if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0) {
|
||||||
{
|
|
||||||
ERR_print_errors_fp(stderr);
|
ERR_print_errors_fp(stderr);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class Driver
|
class Driver {
|
||||||
{
|
public:
|
||||||
public:
|
enum ChanState {
|
||||||
enum ChanState
|
|
||||||
{
|
|
||||||
CHAN_INACTIVE,
|
CHAN_INACTIVE,
|
||||||
CHAN_PLAINTEXT,
|
CHAN_PLAINTEXT,
|
||||||
CHAN_SSL_CONNECTING,
|
CHAN_SSL_CONNECTING,
|
||||||
CHAN_SSL_ACCEPTING,
|
CHAN_SSL_ACCEPTING,
|
||||||
CHAN_SSL_READWRITE,
|
CHAN_SSL_READWRITE,
|
||||||
};
|
};
|
||||||
struct ChanInfo
|
struct ChanInfo {
|
||||||
{
|
|
||||||
int chid;
|
int chid;
|
||||||
SOCKET socket;
|
SOCKET socket;
|
||||||
SSL *ssl;
|
SSL *ssl;
|
||||||
@@ -182,13 +150,10 @@ public:
|
|||||||
UniqueSSLCTX ssl_client_secure_ctx_;
|
UniqueSSLCTX ssl_client_secure_ctx_;
|
||||||
UniqueSSLCTX ssl_client_insecure_ctx_;
|
UniqueSSLCTX ssl_client_insecure_ctx_;
|
||||||
|
|
||||||
void handle_listen_ports()
|
void handle_listen_ports() {
|
||||||
{
|
|
||||||
const auto &listenports = recorder_.drv_get_listen_ports();
|
const auto &listenports = recorder_.drv_get_listen_ports();
|
||||||
for (int port : listenports)
|
for (int port : listenports) {
|
||||||
{
|
if (listen_sockets_.find(port) == listen_sockets_.end()) {
|
||||||
if (listen_sockets_.find(port) == listen_sockets_.end())
|
|
||||||
{
|
|
||||||
std::string err;
|
std::string err;
|
||||||
SOCKET sock = listen_on_port(port, err);
|
SOCKET sock = listen_on_port(port, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
@@ -198,32 +163,29 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_lua_source()
|
void handle_lua_source() {
|
||||||
{
|
if (recorder_.drv_get_rescan_lua_source()) {
|
||||||
if (recorder_.drv_get_rescan_lua_source())
|
|
||||||
{
|
|
||||||
std::string err;
|
std::string err;
|
||||||
std::string_view ctrl = read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err);
|
std::string_view ctrl =
|
||||||
|
read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
std::vector<std::string> names = drv::parse_control_lst(ctrl);
|
std::vector<std::string> names = drv::parse_control_lst(ctrl);
|
||||||
recorder_.drv_clear_lua_source();
|
recorder_.drv_clear_lua_source();
|
||||||
for (const std::string &str : names)
|
for (const std::string &str : names) {
|
||||||
{
|
|
||||||
std::string lfn = std::string("lua/") + str;
|
std::string lfn = std::string("lua/") + str;
|
||||||
std::string_view data = read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err);
|
std::string_view data =
|
||||||
|
read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
recorder_.drv_add_lua_source(str, data);
|
recorder_.drv_add_lua_source(str, data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void close_channel(ChanInfo &chan, std::string_view err)
|
void close_channel(ChanInfo &chan, std::string_view err) {
|
||||||
{
|
|
||||||
// std::cerr << "Closing channel " << chan.chid << std::endl;
|
// std::cerr << "Closing channel " << chan.chid << std::endl;
|
||||||
assert(chan.state != CHAN_INACTIVE);
|
assert(chan.state != CHAN_INACTIVE);
|
||||||
// Close and release the SSL channel.
|
// Close and release the SSL channel.
|
||||||
if (chan.ssl != nullptr)
|
if (chan.ssl != nullptr) {
|
||||||
{
|
|
||||||
SSL_free(chan.ssl);
|
SSL_free(chan.ssl);
|
||||||
chan.ssl = nullptr;
|
chan.ssl = nullptr;
|
||||||
}
|
}
|
||||||
@@ -246,52 +208,39 @@ public:
|
|||||||
chan.last_write_nbytes = 0;
|
chan.last_write_nbytes = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void cleanup_channels()
|
void cleanup_channels() {
|
||||||
{
|
for (int i = 0; i < int(chans_.size());) {
|
||||||
for (int i = 0; i < int(chans_.size());)
|
if (chans_[i].state == CHAN_INACTIVE) {
|
||||||
{
|
|
||||||
if (chans_[i].state == CHAN_INACTIVE)
|
|
||||||
{
|
|
||||||
chans_[i] = chans_.back();
|
chans_[i] = chans_.back();
|
||||||
chans_.pop_back();
|
chans_.pop_back();
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_console_output()
|
void handle_console_output() {
|
||||||
{
|
while (true) {
|
||||||
while (true)
|
|
||||||
{
|
|
||||||
std::string_view s = recorder_.drv_peek_outgoing(0);
|
std::string_view s = recorder_.drv_peek_outgoing(0);
|
||||||
if (s.size() == 0)
|
if (s.size() == 0) break;
|
||||||
break;
|
|
||||||
int nwrote = console_write(s.data(), s.size());
|
int nwrote = console_write(s.data(), s.size());
|
||||||
if (nwrote <= 0)
|
if (nwrote <= 0) break;
|
||||||
break;
|
|
||||||
recorder_.drv_sent_outgoing(0, nwrote);
|
recorder_.drv_sent_outgoing(0, nwrote);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_console_input()
|
void handle_console_input() {
|
||||||
{
|
|
||||||
char buffer[256];
|
char buffer[256];
|
||||||
read_console_recently_ = false;
|
read_console_recently_ = false;
|
||||||
while (true)
|
while (true) {
|
||||||
{
|
|
||||||
int nread = console_read(buffer, 256);
|
int nread = console_read(buffer, 256);
|
||||||
if (nread <= 0)
|
if (nread <= 0) break;
|
||||||
break;
|
|
||||||
read_console_recently_ = true;
|
read_console_recently_ = true;
|
||||||
recorder_.drv_recv_incoming(0, std::string_view(buffer, nread));
|
recorder_.drv_recv_incoming(0, std::string_view(buffer, nread));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state)
|
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
|
||||||
{
|
|
||||||
ChanInfo newchan;
|
ChanInfo newchan;
|
||||||
newchan.chid = chid;
|
newchan.chid = chid;
|
||||||
newchan.socket = sock;
|
newchan.socket = sock;
|
||||||
@@ -312,16 +261,15 @@ public:
|
|||||||
chans_.push_back(newchan);
|
chans_.push_back(newchan);
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_new_outgoing_sockets()
|
void handle_new_outgoing_sockets() {
|
||||||
{
|
|
||||||
const auto &chans = recorder_.drv_get_new_outgoing();
|
const auto &chans = recorder_.drv_get_new_outgoing();
|
||||||
for (int chid : chans)
|
for (int chid : chans) {
|
||||||
{
|
|
||||||
std::string err, cert, host, port;
|
std::string err, cert, host, port;
|
||||||
std::string target(recorder_.drv_get_target(chid));
|
std::string target(recorder_.drv_get_target(chid));
|
||||||
drv::split_target(target, cert, host, port);
|
drv::split_target(target, cert, host, port);
|
||||||
if (cert.empty() || host.empty() || port.empty()) {
|
if (cert.empty() || host.empty() || port.empty()) {
|
||||||
recorder_.drv_notify_close(chid, std::string("invalid target: ") + target);
|
recorder_.drv_notify_close(
|
||||||
|
chid, std::string("invalid target: ") + target);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SSL_CTX *ctx = nullptr;
|
SSL_CTX *ctx = nullptr;
|
||||||
@@ -330,63 +278,54 @@ public:
|
|||||||
} else if (cert == "nocert") {
|
} else if (cert == "nocert") {
|
||||||
ctx = ssl_client_insecure_ctx_.get();
|
ctx = ssl_client_insecure_ctx_.get();
|
||||||
} else {
|
} else {
|
||||||
recorder_.drv_notify_close(chid, std::string("invalid cert rule: ") + target);
|
recorder_.drv_notify_close(
|
||||||
|
chid, std::string("invalid cert rule: ") + target);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SOCKET sock = open_connection(host.c_str(), port.c_str(), err);
|
SOCKET sock = open_connection(host.c_str(), port.c_str(), err);
|
||||||
if (sock == INVALID_SOCKET)
|
if (sock == INVALID_SOCKET) {
|
||||||
{
|
|
||||||
recorder_.drv_notify_close(chid, err);
|
recorder_.drv_notify_close(chid, err);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// std::cerr << "Opening channel " << chid << std::endl;
|
// std::cerr << "Opening channel " << chid << std::endl;
|
||||||
make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING);
|
make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING);
|
||||||
}
|
}
|
||||||
if (!chans.empty())
|
if (!chans.empty()) {
|
||||||
{
|
|
||||||
recorder_.drv_clear_new_outgoing();
|
recorder_.drv_clear_new_outgoing();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void accept_connection(int port, SOCKET sock)
|
void accept_connection(int port, SOCKET sock) {
|
||||||
{
|
|
||||||
std::string err;
|
std::string err;
|
||||||
SOCKET socket = accept_on_socket(sock, err);
|
SOCKET socket = accept_on_socket(sock, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
if (socket != INVALID_SOCKET)
|
if (socket != INVALID_SOCKET) {
|
||||||
{
|
|
||||||
int chid = recorder_.drv_notify_accept(port);
|
int chid = recorder_.drv_notify_accept(port);
|
||||||
// std::cerr << "Accepted channel " << chid << std::endl;
|
// std::cerr << "Accepted channel " << chid << std::endl;
|
||||||
make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING);
|
make_channel(socket, chid, ssl_server_ctx_.get(),
|
||||||
|
CHAN_SSL_ACCEPTING);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_plaintext(ChanInfo &chan)
|
void advance_plaintext(ChanInfo &chan) {
|
||||||
{
|
|
||||||
std::string err;
|
std::string err;
|
||||||
|
|
||||||
// If the channel has no outgoing bytes and has been released,
|
// If the channel has no outgoing bytes and has been released,
|
||||||
// just close it.
|
// just close it.
|
||||||
if (chan.released)
|
if (chan.released) {
|
||||||
{
|
|
||||||
close_channel(chan, "");
|
close_channel(chan, "");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to write plaintext to the channel.
|
// Try to write plaintext to the channel.
|
||||||
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
||||||
if (s.size() > 0)
|
if (s.size() > 0) {
|
||||||
{
|
|
||||||
int sbytes = s.size();
|
int sbytes = s.size();
|
||||||
if (sbytes > 65536)
|
if (sbytes > 65536) sbytes = 65536;
|
||||||
sbytes = 65536;
|
|
||||||
int wbytes = socket_send(chan.socket, s.data(), sbytes, err);
|
int wbytes = socket_send(chan.socket, s.data(), sbytes, err);
|
||||||
if (wbytes < 0)
|
if (wbytes < 0) {
|
||||||
{
|
|
||||||
close_channel(chan, err);
|
close_channel(chan, err);
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
recorder_.drv_sent_outgoing(chan.chid, wbytes);
|
recorder_.drv_sent_outgoing(chan.chid, wbytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -394,13 +333,11 @@ public:
|
|||||||
// Try to read plaintext from the channel.
|
// Try to read plaintext from the channel.
|
||||||
// Someday, find a way to avoid this copy.
|
// Someday, find a way to avoid this copy.
|
||||||
int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err);
|
int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err);
|
||||||
if (nrecv < 0)
|
if (nrecv < 0) {
|
||||||
{
|
|
||||||
close_channel(chan, err);
|
close_channel(chan, err);
|
||||||
}
|
} else {
|
||||||
else
|
recorder_.drv_recv_incoming(chan.chid,
|
||||||
{
|
std::string_view(chbuf_.get(), nrecv));
|
||||||
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), nrecv));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the ready-flags for next time.
|
// Update the ready-flags for next time.
|
||||||
@@ -408,149 +345,117 @@ public:
|
|||||||
chan.ready_on_pollin = true;
|
chan.ready_on_pollin = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void process_ssl_error(ChanInfo &chan, int retval)
|
void process_ssl_error(ChanInfo &chan, int retval) {
|
||||||
{
|
|
||||||
int error = SSL_get_error(chan.ssl, retval);
|
int error = SSL_get_error(chan.ssl, retval);
|
||||||
// std::cerr << "SSL error code = " << error << " ";
|
// std::cerr << "SSL error code = " << error << " ";
|
||||||
if (error == SSL_ERROR_WANT_READ)
|
if (error == SSL_ERROR_WANT_READ) {
|
||||||
{
|
|
||||||
chan.ready_on_pollin = true;
|
chan.ready_on_pollin = true;
|
||||||
}
|
} else if (error == SSL_ERROR_WANT_WRITE) {
|
||||||
else if (error == SSL_ERROR_WANT_WRITE)
|
|
||||||
{
|
|
||||||
chan.ready_on_pollout = true;
|
chan.ready_on_pollout = true;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
close_channel(chan, ssl_errors_string());
|
close_channel(chan, ssl_errors_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_ssl_connecting(ChanInfo &chan)
|
void advance_ssl_connecting(ChanInfo &chan) {
|
||||||
{
|
|
||||||
// std::cerr << "In advance_ssl_connecting" << std::endl;
|
// std::cerr << "In advance_ssl_connecting" << std::endl;
|
||||||
int retval = SSL_connect(chan.ssl);
|
int retval = SSL_connect(chan.ssl);
|
||||||
if (retval == 1)
|
if (retval == 1) {
|
||||||
{
|
|
||||||
// Connection successful.
|
// Connection successful.
|
||||||
chan.state = CHAN_SSL_READWRITE;
|
chan.state = CHAN_SSL_READWRITE;
|
||||||
chan.ready_now = true;
|
chan.ready_now = true;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
// std::cerr << "ssl_connect_error";
|
// std::cerr << "ssl_connect_error";
|
||||||
process_ssl_error(chan, retval);
|
process_ssl_error(chan, retval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_ssl_accepting(ChanInfo &chan)
|
void advance_ssl_accepting(ChanInfo &chan) {
|
||||||
{
|
|
||||||
// std::cerr << "In advance_ssl_accepting" << std::endl;
|
// std::cerr << "In advance_ssl_accepting" << std::endl;
|
||||||
int retval = SSL_accept(chan.ssl);
|
int retval = SSL_accept(chan.ssl);
|
||||||
if (retval == 1)
|
if (retval == 1) {
|
||||||
{
|
|
||||||
// Connection successful.
|
// Connection successful.
|
||||||
chan.state = CHAN_SSL_READWRITE;
|
chan.state = CHAN_SSL_READWRITE;
|
||||||
chan.ready_now = true;
|
chan.ready_now = true;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
process_ssl_error(chan, retval);
|
process_ssl_error(chan, retval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_ssl_readwrite(ChanInfo &chan)
|
void advance_ssl_readwrite(ChanInfo &chan) {
|
||||||
{
|
|
||||||
// std::cerr << "In advance_ssl_readwrite" << std::endl;
|
// std::cerr << "In advance_ssl_readwrite" << std::endl;
|
||||||
// Try to read data.
|
// Try to read data.
|
||||||
int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536);
|
int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536);
|
||||||
if (read_result > 0)
|
if (read_result > 0) {
|
||||||
{
|
recorder_.drv_recv_incoming(
|
||||||
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), read_result));
|
chan.chid, std::string_view(chbuf_.get(), read_result));
|
||||||
chan.ready_now = true;
|
chan.ready_now = true;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
process_ssl_error(chan, read_result);
|
process_ssl_error(chan, read_result);
|
||||||
if (chan.state == CHAN_INACTIVE)
|
if (chan.state == CHAN_INACTIVE) return;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to write data.
|
// Try to write data.
|
||||||
int wbytes;
|
int wbytes;
|
||||||
if (chan.last_write_nbytes > 0)
|
if (chan.last_write_nbytes > 0) {
|
||||||
{
|
|
||||||
wbytes = chan.last_write_nbytes;
|
wbytes = chan.last_write_nbytes;
|
||||||
assert(wbytes < chan.nbytes);
|
assert(wbytes < chan.nbytes);
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
wbytes = chan.nbytes;
|
wbytes = chan.nbytes;
|
||||||
if (wbytes > 65536)
|
if (wbytes > 65536) wbytes = 65536;
|
||||||
wbytes = 65536;
|
|
||||||
}
|
}
|
||||||
if (wbytes > 0)
|
if (wbytes > 0) {
|
||||||
{
|
|
||||||
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
|
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
|
||||||
if (write_result > 0)
|
if (write_result > 0) {
|
||||||
{
|
|
||||||
recorder_.drv_sent_outgoing(chan.chid, write_result);
|
recorder_.drv_sent_outgoing(chan.chid, write_result);
|
||||||
chan.last_write_nbytes = 0;
|
chan.last_write_nbytes = 0;
|
||||||
chan.ready_on_outgoing = true;
|
chan.ready_on_outgoing = true;
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
chan.last_write_nbytes = wbytes;
|
chan.last_write_nbytes = wbytes;
|
||||||
process_ssl_error(chan, write_result);
|
process_ssl_error(chan, write_result);
|
||||||
if (chan.state == CHAN_INACTIVE)
|
if (chan.state == CHAN_INACTIVE) return;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
chan.ready_on_outgoing = true;
|
chan.ready_on_outgoing = true;
|
||||||
}
|
}
|
||||||
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " ";
|
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" <<
|
||||||
|
// chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" <<
|
||||||
|
// chan.ready_on_outgoing << " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_channel(ChanInfo &chan)
|
void advance_channel(ChanInfo &chan) {
|
||||||
{
|
|
||||||
assert_ssl_errors_empty();
|
assert_ssl_errors_empty();
|
||||||
switch (chan.state)
|
switch (chan.state) {
|
||||||
{
|
case CHAN_PLAINTEXT:
|
||||||
case CHAN_PLAINTEXT:
|
advance_plaintext(chan);
|
||||||
advance_plaintext(chan);
|
break;
|
||||||
break;
|
case CHAN_SSL_CONNECTING:
|
||||||
case CHAN_SSL_CONNECTING:
|
advance_ssl_connecting(chan);
|
||||||
advance_ssl_connecting(chan);
|
break;
|
||||||
break;
|
case CHAN_SSL_ACCEPTING:
|
||||||
case CHAN_SSL_ACCEPTING:
|
advance_ssl_accepting(chan);
|
||||||
advance_ssl_accepting(chan);
|
break;
|
||||||
break;
|
case CHAN_SSL_READWRITE:
|
||||||
case CHAN_SSL_READWRITE:
|
advance_ssl_readwrite(chan);
|
||||||
advance_ssl_readwrite(chan);
|
break;
|
||||||
break;
|
default:
|
||||||
default:
|
assert(false);
|
||||||
assert(false);
|
break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
assert_ssl_errors_empty();
|
assert_ssl_errors_empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_socket_input_output()
|
void handle_socket_input_output() {
|
||||||
{
|
|
||||||
std::string err;
|
std::string err;
|
||||||
int mstimeout = read_console_recently_ ? 100 : 1000;
|
int mstimeout = read_console_recently_ ? 100 : 1000;
|
||||||
|
|
||||||
// Peek output buffers and determine channel release flags.
|
// Peek output buffers and determine channel release flags.
|
||||||
for (ChanInfo &chan : chans_)
|
for (ChanInfo &chan : chans_) {
|
||||||
{
|
|
||||||
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
||||||
chan.nbytes = s.size();
|
chan.nbytes = s.size();
|
||||||
chan.bytes = s.data();
|
chan.bytes = s.data();
|
||||||
chan.just_released = false;
|
chan.just_released = false;
|
||||||
if ((chan.nbytes == 0) && (!chan.released))
|
if ((chan.nbytes == 0) && (!chan.released)) {
|
||||||
{
|
|
||||||
chan.released = recorder_.drv_get_channel_released(chan.chid);
|
chan.released = recorder_.drv_get_channel_released(chan.chid);
|
||||||
chan.just_released = chan.released;
|
chan.just_released = chan.released;
|
||||||
}
|
}
|
||||||
@@ -558,31 +463,26 @@ public:
|
|||||||
|
|
||||||
// Construct the struct pollfd vector.
|
// Construct the struct pollfd vector.
|
||||||
int pollsize = 0;
|
int pollsize = 0;
|
||||||
for (const auto &p : listen_sockets_)
|
for (const auto &p : listen_sockets_) {
|
||||||
{
|
|
||||||
struct pollfd &pfd = pollvec_[pollsize++];
|
struct pollfd &pfd = pollvec_[pollsize++];
|
||||||
pfd.fd = p.second;
|
pfd.fd = p.second;
|
||||||
pfd.events = POLLIN;
|
pfd.events = POLLIN;
|
||||||
pfd.revents = 0;
|
pfd.revents = 0;
|
||||||
}
|
}
|
||||||
for (const ChanInfo &chan : chans_)
|
for (const ChanInfo &chan : chans_) {
|
||||||
{
|
|
||||||
struct pollfd &pfd = pollvec_[pollsize++];
|
struct pollfd &pfd = pollvec_[pollsize++];
|
||||||
assert(chan.socket != INVALID_SOCKET);
|
assert(chan.socket != INVALID_SOCKET);
|
||||||
pfd.fd = chan.socket;
|
pfd.fd = chan.socket;
|
||||||
pfd.events = 0;
|
pfd.events = 0;
|
||||||
pfd.revents = 0;
|
pfd.revents = 0;
|
||||||
if (chan.ready_now)
|
if (chan.ready_now) mstimeout = 0;
|
||||||
mstimeout = 0;
|
if (chan.just_released) mstimeout = 0;
|
||||||
if (chan.just_released)
|
if (chan.ready_on_pollin) pfd.events |= POLLIN;
|
||||||
mstimeout = 0;
|
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
|
||||||
if (chan.ready_on_pollin)
|
|
||||||
pfd.events |= POLLIN;
|
|
||||||
if (chan.ready_on_pollout)
|
|
||||||
pfd.events |= POLLOUT;
|
|
||||||
if (chan.ready_on_outgoing && (chan.nbytes > 0))
|
if (chan.ready_on_outgoing && (chan.nbytes > 0))
|
||||||
pfd.events |= POLLOUT;
|
pfd.events |= POLLOUT;
|
||||||
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " ";
|
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << "
|
||||||
|
// ";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the poll.
|
// Do the poll.
|
||||||
@@ -591,18 +491,15 @@ public:
|
|||||||
|
|
||||||
// Check listening sockets.
|
// Check listening sockets.
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (auto &p : listen_sockets_)
|
for (auto &p : listen_sockets_) {
|
||||||
{
|
|
||||||
struct pollfd &pfd = pollvec_[index++];
|
struct pollfd &pfd = pollvec_[index++];
|
||||||
if (pfd.revents & (POLLIN | POLLERR))
|
if (pfd.revents & (POLLIN | POLLERR)) {
|
||||||
{
|
|
||||||
accept_connection(p.first, p.second);
|
accept_connection(p.first, p.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Advance channels where possible.
|
// Advance channels where possible.
|
||||||
for (ChanInfo &chan : chans_)
|
for (ChanInfo &chan : chans_) {
|
||||||
{
|
|
||||||
struct pollfd &pfd = pollvec_[index++];
|
struct pollfd &pfd = pollvec_[index++];
|
||||||
bool pollin = ((pfd.revents & POLLIN) != 0);
|
bool pollin = ((pfd.revents & POLLIN) != 0);
|
||||||
bool pollout = ((pfd.revents & POLLOUT) != 0);
|
bool pollout = ((pfd.revents & POLLOUT) != 0);
|
||||||
@@ -610,8 +507,7 @@ public:
|
|||||||
if (chan.ready_now || pollerr || chan.just_released ||
|
if (chan.ready_now || pollerr || chan.just_released ||
|
||||||
(chan.ready_on_pollin && pollin) ||
|
(chan.ready_on_pollin && pollin) ||
|
||||||
(chan.ready_on_pollout && pollout) ||
|
(chan.ready_on_pollout && pollout) ||
|
||||||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout))
|
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
|
||||||
{
|
|
||||||
chan.ready_now = false;
|
chan.ready_now = false;
|
||||||
chan.ready_on_pollin = false;
|
chan.ready_on_pollin = false;
|
||||||
chan.ready_on_pollout = false;
|
chan.ready_on_pollout = false;
|
||||||
@@ -626,30 +522,24 @@ public:
|
|||||||
cleanup_channels();
|
cleanup_channels();
|
||||||
}
|
}
|
||||||
|
|
||||||
int replay_logfile(const char *fn, bool verbose)
|
int replay_logfile(const char *fn, bool verbose) {
|
||||||
{
|
|
||||||
drv::ReplayPlayer player;
|
drv::ReplayPlayer player;
|
||||||
player.open_logfile(fn);
|
player.open_logfile(fn);
|
||||||
if (verbose)
|
if (verbose) {
|
||||||
{
|
|
||||||
player.enable_stdout();
|
player.enable_stdout();
|
||||||
}
|
}
|
||||||
while (true)
|
while (true) {
|
||||||
{
|
|
||||||
drv::ReplayPlayer::Status st = player.step();
|
drv::ReplayPlayer::Status st = player.step();
|
||||||
if (st != drv::ReplayPlayer::ST_REPLAYING)
|
if (st != drv::ReplayPlayer::ST_REPLAYING) {
|
||||||
{
|
|
||||||
player.print_status(std::cerr);
|
player.print_status(std::cerr);
|
||||||
return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1;
|
return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int drive(int argc, char *argv[])
|
int drive(int argc, char *argv[]) {
|
||||||
{
|
|
||||||
// Remove the program name from argv.
|
// Remove the program name from argv.
|
||||||
if (argc < 1)
|
if (argc < 1) {
|
||||||
{
|
|
||||||
DrivenEngine::print_usage(std::cerr, "<unknown>");
|
DrivenEngine::print_usage(std::cerr, "<unknown>");
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
@@ -659,14 +549,12 @@ public:
|
|||||||
|
|
||||||
// If argv contains "replay <filename>", do a replay,
|
// If argv contains "replay <filename>", do a replay,
|
||||||
// and then skip everything else.
|
// and then skip everything else.
|
||||||
if (argc >= 1)
|
if (argc >= 1) {
|
||||||
{
|
|
||||||
std::string cmd(argv[0]);
|
std::string cmd(argv[0]);
|
||||||
if ((cmd == "replay") || (cmd == "vreplay"))
|
if ((cmd == "replay") || (cmd == "vreplay")) {
|
||||||
{
|
if (argc != 2) {
|
||||||
if (argc != 2)
|
std::cerr << "usage: " << program << " replay <filename>"
|
||||||
{
|
<< std::endl;
|
||||||
std::cerr << "usage: " << program << " replay <filename>" << std::endl;
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
return replay_logfile(argv[1], cmd == "vreplay");
|
return replay_logfile(argv[1], cmd == "vreplay");
|
||||||
@@ -675,20 +563,17 @@ public:
|
|||||||
|
|
||||||
// If argv contains "record <filename>", start recording,
|
// If argv contains "record <filename>", start recording,
|
||||||
// and remove the "record <filename>" from argv.
|
// and remove the "record <filename>" from argv.
|
||||||
if (argc >= 1)
|
if (argc >= 1) {
|
||||||
{
|
|
||||||
std::string cmd = argv[0];
|
std::string cmd = argv[0];
|
||||||
if (cmd == "record")
|
if (cmd == "record") {
|
||||||
{
|
if (argc < 2) {
|
||||||
if (argc < 2)
|
|
||||||
{
|
|
||||||
DrivenEngine::print_usage(std::cerr, program);
|
DrivenEngine::print_usage(std::cerr, program);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
bool ok = recorder_.open_logfile(argv[1]);
|
bool ok = recorder_.open_logfile(argv[1]);
|
||||||
if (!ok)
|
if (!ok) {
|
||||||
{
|
std::cerr << "Could not open logfile: " << argv[1]
|
||||||
std::cerr << "Could not open logfile: " << argv[1] << std::endl;
|
<< std::endl;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
argc -= 2;
|
argc -= 2;
|
||||||
@@ -697,14 +582,12 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the engine.
|
// Create the engine.
|
||||||
if (argc < 1)
|
if (argc < 1) {
|
||||||
{
|
|
||||||
DrivenEngine::print_usage(std::cerr, program);
|
DrivenEngine::print_usage(std::cerr, program);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
bool engine_made = recorder_.create_engine(argv[0]);
|
bool engine_made = recorder_.create_engine(argv[0]);
|
||||||
if (!engine_made)
|
if (!engine_made) {
|
||||||
{
|
|
||||||
DrivenEngine::print_usage(std::cerr, program);
|
DrivenEngine::print_usage(std::cerr, program);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@@ -723,8 +606,7 @@ public:
|
|||||||
recorder_.drv_invoke_event_init(argc, argv);
|
recorder_.drv_invoke_event_init(argc, argv);
|
||||||
handle_listen_ports();
|
handle_listen_ports();
|
||||||
|
|
||||||
while (!recorder_.drv_get_stop_driver())
|
while (!recorder_.drv_get_stop_driver()) {
|
||||||
{
|
|
||||||
handle_lua_source();
|
handle_lua_source();
|
||||||
handle_console_output();
|
handle_console_output();
|
||||||
handle_new_outgoing_sockets();
|
handle_new_outgoing_sockets();
|
||||||
@@ -734,8 +616,7 @@ public:
|
|||||||
recorder_.drv_invoke_event_update(monoclock.get());
|
recorder_.drv_invoke_event_update(monoclock.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ChanInfo &chan : chans_)
|
for (ChanInfo &chan : chans_) {
|
||||||
{
|
|
||||||
close_channel(chan, "");
|
close_channel(chan, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user