Can now control SSL cert verfication from in-engine

This commit is contained in:
2022-03-18 16:25:20 -04:00
parent 2e7b793110
commit 2e3bef79b3
8 changed files with 392 additions and 204 deletions

View File

@@ -212,7 +212,7 @@ void DrivenEngine::drv_clear_new_outgoing() {
new_outgoing_.clear();
}
const eng::string &DrivenEngine::drv_get_target(int chid) const {
std::string_view DrivenEngine::drv_get_target(int chid) const {
return get_chid(chid)->target_;
}

View File

@@ -115,14 +115,15 @@ public:
// The channel ID. These are reused.
//
int chid() { return chid_; }
int chid() const { return chid_; }
// If this is a socket connection, the receiver's port number.
//
int port() { return port_; }
int port() const { return port_; }
// If this is an outgoing socket connection, get the target host.
const eng::string &target() { return target_; }
//
const eng::string &target() const { return target_; }
// True if the remote closed the connection, or a failure occurred.
//
@@ -133,7 +134,7 @@ public:
// If this is an empty string, there is no error. If this is set,
// then the channel is also closed.
//
eng::string error() const { return error_; }
const eng::string &error() const { return error_; }
// Set the prompt for readline mode.
//
@@ -314,11 +315,12 @@ public:
void drv_clear_new_outgoing();
// Get the target of a channel. A target is a string like
// "www.whatever.com:80". It indicates the host and port that the channel
// is supposed to be talking to. Non-socket channels and incoming channels
// have empty targets.
// "cert:whatever.com:80" or "nocert:whatever.com:80".
// The first word indicate whether or not a valid SSL certificate
// is required. The second word is the hostname. The third word is
// the port number.
//
const eng::string &drv_get_target(int chid) const;
std::string_view drv_get_target(int chid) const;
// Return true if the outgoing buffer is empty.
//

View File

@@ -4,33 +4,42 @@
static MonoClock monoclock;
namespace util {
double profiling_clock() {
namespace util
{
double profiling_clock()
{
return monoclock.get();
}
}
static void if_error_print_and_exit(const std::string &str) {
if (!str.empty()) {
std::cerr << std::endl << "error: " << str << std::endl;
static void if_error_print_and_exit(const std::string &str)
{
if (!str.empty())
{
std::cerr << std::endl
<< "error: " << str << std::endl;
exit(1);
}
}
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) {
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err)
{
FILE *f = fopen(fn, "r");
if (f == 0) {
if (f == 0)
{
err = std::string("cannot read file") + fn;
buf[0] = 0;
return std::string_view(buf, 0);
}
int nread = fread(buf, 1, bufsize, f);
if (nread < 0) {
if (nread < 0)
{
err = std::string("cannot read file: ") + fn;
buf[0] = 0;
return std::string_view(buf, 0);
}
if (nread == bufsize) {
if (nread == bufsize)
{
err = std::string("file too large: ") + fn;
buf[0] = 0;
return std::string_view(buf, 0);
@@ -39,42 +48,66 @@ static std::string_view read_file(const char *fn, char *buf, int bufsize, std::s
return std::string_view(buf, nread);
}
struct SSL_CTX_Deleter {
void operator()(SSL_CTX *ctx) {
struct SSL_CTX_Deleter
{
void operator()(SSL_CTX *ctx)
{
SSL_CTX_free(ctx);
}
};
using UniqueSSLCTX = std::unique_ptr<SSL_CTX, SSL_CTX_Deleter>;
static UniqueSSLCTX new_ssl_context(bool server_cert, bool root_certs, std::string_view require_cert) {
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
// server_cert is not implemented yet.
if (root_certs) {
SSL_CTX_set_default_verify_paths(ctx);
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
static std::string ssl_errors_string(bool lastonly = true)
{
std::string err;
const char *file, *data, *func;
int line, flags;
while (true)
{
unsigned long code = ERR_get_error_all(&file, &line, &func, &data, &flags);
if (code == 0)
break;
std::string reason;
if (ERR_SYSTEM_ERROR(code))
{
reason = strerror_str(ERR_GET_REASON(code));
}
// require_cert is not implemented yet.
return UniqueSSLCTX(ctx);
else
{
const char *rc = ERR_reason_error_string(code);
reason = (rc == nullptr) ? "unknown" : rc;
}
if (err.empty() || lastonly)
{
err = reason;
}
else
{
err = err + ", " + reason;
}
if (data != nullptr)
{
err = err + " " + data;
}
}
return err;
}
static std::string err_print_errors_str() {
BIO *bio = BIO_new(BIO_s_mem());
ERR_print_errors(bio);
char *buf;
size_t len = BIO_get_mem_data(bio, &buf);
std::string ret(buf, len);
BIO_free(bio);
return ret;
void assert_ssl_errors_empty()
{
int code = ERR_peek_error();
if (code != 0)
{
std::cerr << "SSL should not have errors at this point." << std::endl;
ERR_print_errors_fp(stderr);
exit(1);
}
}
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str)
{
BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
@@ -84,7 +117,8 @@ static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
return status;
}
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str)
{
BIO *bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
@@ -94,18 +128,33 @@ static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
return status;
}
static void ssl_ctx_use_dummycert(SSL_CTX *ctx)
{
if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0)
{
ERR_print_errors_fp(stderr);
exit(1);
}
if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0)
{
ERR_print_errors_fp(stderr);
exit(1);
}
}
class Driver {
class Driver
{
public:
enum ChanState {
enum ChanState
{
CHAN_INACTIVE,
CHAN_PLAINTEXT,
CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE,
};
struct ChanInfo {
struct ChanInfo
{
int chid;
SOCKET socket;
SSL *ssl;
@@ -129,15 +178,17 @@ public:
std::unique_ptr<struct pollfd[]> pollvec_;
drv::ReplayRecorder recorder_;
UniqueSSLCTX ssl_ctx_with_root_certs_;
UniqueSSLCTX ssl_ctx_with_server_certs_;
UniqueSSLCTX ssl_ctx_with_no_certs_;
UniqueSSLCTX ssl_server_ctx_;
UniqueSSLCTX ssl_client_secure_ctx_;
UniqueSSLCTX ssl_client_insecure_ctx_;
void handle_listen_ports() {
void handle_listen_ports()
{
const auto &listenports = recorder_.drv_get_listen_ports();
for (int port : listenports) {
if (listen_sockets_.find(port) == listen_sockets_.end()) {
for (int port : listenports)
{
if (listen_sockets_.find(port) == listen_sockets_.end())
{
std::string err;
SOCKET sock = listen_on_port(port, err);
if_error_print_and_exit(err);
@@ -147,14 +198,17 @@ public:
}
}
void handle_lua_source() {
if (recorder_.drv_get_rescan_lua_source()) {
void handle_lua_source()
{
if (recorder_.drv_get_rescan_lua_source())
{
std::string err;
std::string_view ctrl = read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err);
std::vector<std::string> names = drv::parse_control_lst(ctrl);
recorder_.drv_clear_lua_source();
for (const std::string &str : names) {
for (const std::string &str : names)
{
std::string lfn = std::string("lua/") + str;
std::string_view data = read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err);
@@ -163,11 +217,13 @@ public:
}
}
void close_channel(ChanInfo &chan, std::string_view err) {
void close_channel(ChanInfo &chan, std::string_view err)
{
// std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE);
// Close and release the SSL channel.
if (chan.ssl != nullptr) {
if (chan.ssl != nullptr)
{
SSL_free(chan.ssl);
chan.ssl = nullptr;
}
@@ -190,39 +246,52 @@ public:
chan.last_write_nbytes = 0;
}
void cleanup_channels() {
for (int i = 0; i < int(chans_.size()); ) {
if (chans_[i].state == CHAN_INACTIVE) {
void cleanup_channels()
{
for (int i = 0; i < int(chans_.size());)
{
if (chans_[i].state == CHAN_INACTIVE)
{
chans_[i] = chans_.back();
chans_.pop_back();
} else {
}
else
{
i += 1;
}
}
}
void handle_console_output() {
while (true) {
void handle_console_output()
{
while (true)
{
std::string_view s = recorder_.drv_peek_outgoing(0);
if (s.size() == 0) break;
if (s.size() == 0)
break;
int nwrote = console_write(s.data(), s.size());
if (nwrote <= 0) break;
if (nwrote <= 0)
break;
recorder_.drv_sent_outgoing(0, nwrote);
}
}
void handle_console_input() {
void handle_console_input()
{
char buffer[256];
read_console_recently_ = false;
while (true) {
while (true)
{
int nread = console_read(buffer, 256);
if (nread <= 0) break;
if (nread <= 0)
break;
read_console_recently_ = true;
recorder_.drv_recv_incoming(0, std::string_view(buffer, nread));
}
}
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state)
{
ChanInfo newchan;
newchan.chid = chid;
newchan.socket = sock;
@@ -243,53 +312,81 @@ public:
chans_.push_back(newchan);
}
void handle_new_outgoing_sockets() {
void handle_new_outgoing_sockets()
{
const auto &chans = recorder_.drv_get_new_outgoing();
for (int chid : chans) {
std::string err;
SOCKET sock = open_connection(recorder_.drv_get_target(chid), err);
if (sock == INVALID_SOCKET) {
recorder_.drv_notify_close(chid, err);
for (int chid : chans)
{
std::string err, cert, host, port;
std::string target(recorder_.drv_get_target(chid));
drv::split_target(target, cert, host, port);
if (cert.empty() || host.empty() || port.empty()) {
recorder_.drv_notify_close(chid, std::string("invalid target: ") + target);
continue;
}
SSL_CTX *ctx = nullptr;
if (cert == "cert") {
ctx = ssl_client_secure_ctx_.get();
} else if (cert == "nocert") {
ctx = ssl_client_insecure_ctx_.get();
} else {
recorder_.drv_notify_close(chid, std::string("invalid cert rule: ") + target);
continue;
}
SOCKET sock = open_connection(host.c_str(), port.c_str(), err);
if (sock == INVALID_SOCKET)
{
recorder_.drv_notify_close(chid, err);
continue;
}
// std::cerr << "Opening channel " << chid << std::endl;
make_channel(sock, chid, ssl_ctx_with_no_certs_.get(), CHAN_SSL_CONNECTING);
make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING);
}
}
if (!chans.empty()) {
if (!chans.empty())
{
recorder_.drv_clear_new_outgoing();
}
}
void accept_connection(int port, SOCKET sock) {
void accept_connection(int port, SOCKET sock)
{
std::string err;
SOCKET socket = accept_on_socket(sock, err);
if_error_print_and_exit(err);
if (socket != INVALID_SOCKET) {
if (socket != INVALID_SOCKET)
{
int chid = recorder_.drv_notify_accept(port);
// std::cerr << "Accepted channel " << chid << std::endl;
make_channel(socket, chid, ssl_ctx_with_server_certs_.get(), CHAN_SSL_ACCEPTING);
make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING);
}
}
void advance_plaintext(ChanInfo &chan) {
void advance_plaintext(ChanInfo &chan)
{
std::string err;
// If the channel has no outgoing bytes and has been released,
// just close it.
if (chan.released) {
if (chan.released)
{
close_channel(chan, "");
return;
}
// Try to write plaintext to the channel.
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
if (s.size() > 0) {
if (s.size() > 0)
{
int sbytes = s.size();
if (sbytes > 65536) sbytes = 65536;
if (sbytes > 65536)
sbytes = 65536;
int wbytes = socket_send(chan.socket, s.data(), sbytes, err);
if (wbytes < 0) {
if (wbytes < 0)
{
close_channel(chan, err);
} else {
}
else
{
recorder_.drv_sent_outgoing(chan.chid, wbytes);
}
}
@@ -297,9 +394,12 @@ public:
// Try to read plaintext from the channel.
// Someday, find a way to avoid this copy.
int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err);
if (nrecv < 0) {
if (nrecv < 0)
{
close_channel(chan, err);
} else {
}
else
{
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), nrecv));
}
@@ -308,84 +408,116 @@ public:
chan.ready_on_pollin = true;
}
void process_ssl_error(ChanInfo &chan, int retval) {
void process_ssl_error(ChanInfo &chan, int retval)
{
int error = SSL_get_error(chan.ssl, retval);
// std::cerr << "SSL error code = " << error << " ";
if (error == SSL_ERROR_WANT_READ) {
if (error == SSL_ERROR_WANT_READ)
{
chan.ready_on_pollin = true;
} else if (error == SSL_ERROR_WANT_WRITE) {
}
else if (error == SSL_ERROR_WANT_WRITE)
{
chan.ready_on_pollout = true;
} else {
close_channel(chan, err_print_errors_str());
}
else
{
close_channel(chan, ssl_errors_string());
}
}
void advance_ssl_connecting(ChanInfo &chan) {
void advance_ssl_connecting(ChanInfo &chan)
{
// std::cerr << "In advance_ssl_connecting" << std::endl;
int retval = SSL_connect(chan.ssl);
if (retval == 1) {
if (retval == 1)
{
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
}
else
{
// std::cerr << "ssl_connect_error";
process_ssl_error(chan, retval);
}
}
void advance_ssl_accepting(ChanInfo &chan) {
void advance_ssl_accepting(ChanInfo &chan)
{
// std::cerr << "In advance_ssl_accepting" << std::endl;
int retval = SSL_accept(chan.ssl);
if (retval == 1) {
if (retval == 1)
{
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
}
else
{
process_ssl_error(chan, retval);
}
}
void advance_ssl_readwrite(ChanInfo &chan) {
void advance_ssl_readwrite(ChanInfo &chan)
{
// std::cerr << "In advance_ssl_readwrite" << std::endl;
// Try to read data.
int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536);
if (read_result > 0) {
if (read_result > 0)
{
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), read_result));
chan.ready_now = true;
} else {
}
else
{
process_ssl_error(chan, read_result);
if (chan.state == CHAN_INACTIVE) return;
if (chan.state == CHAN_INACTIVE)
return;
}
// Try to write data.
int wbytes;
if (chan.last_write_nbytes > 0) {
if (chan.last_write_nbytes > 0)
{
wbytes = chan.last_write_nbytes;
assert(wbytes < chan.nbytes);
} else {
wbytes = chan.nbytes;
if (wbytes > 65536) wbytes = 65536;
}
if (wbytes > 0) {
else
{
wbytes = chan.nbytes;
if (wbytes > 65536)
wbytes = 65536;
}
if (wbytes > 0)
{
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
if (write_result > 0) {
if (write_result > 0)
{
recorder_.drv_sent_outgoing(chan.chid, write_result);
chan.last_write_nbytes = 0;
chan.ready_on_outgoing = true;
} else {
}
else
{
chan.last_write_nbytes = wbytes;
process_ssl_error(chan, write_result);
if (chan.state == CHAN_INACTIVE) return;
if (chan.state == CHAN_INACTIVE)
return;
}
} else {
}
else
{
chan.ready_on_outgoing = true;
}
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " ";
}
void advance_channel(ChanInfo &chan) {
switch(chan.state) {
void advance_channel(ChanInfo &chan)
{
assert_ssl_errors_empty();
switch (chan.state)
{
case CHAN_PLAINTEXT:
advance_plaintext(chan);
break;
@@ -402,20 +534,23 @@ public:
assert(false);
break;
}
assert_ssl_errors_empty();
}
void handle_socket_input_output() {
void handle_socket_input_output()
{
std::string err;
int mstimeout = read_console_recently_ ? 100 : 1000;
// Peek output buffers and determine channel release flags.
for (ChanInfo &chan : chans_) {
for (ChanInfo &chan : chans_)
{
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
chan.nbytes = s.size();
chan.bytes = s.data();
chan.just_released = false;
if ((chan.nbytes == 0)&&(!chan.released)) {
if ((chan.nbytes == 0) && (!chan.released))
{
chan.released = recorder_.drv_get_channel_released(chan.chid);
chan.just_released = chan.released;
}
@@ -423,23 +558,30 @@ public:
// Construct the struct pollfd vector.
int pollsize = 0;
for (const auto &p : listen_sockets_) {
for (const auto &p : listen_sockets_)
{
struct pollfd &pfd = pollvec_[pollsize++];
pfd.fd = p.second;
pfd.events = POLLIN;
pfd.revents = 0;
}
for (const ChanInfo &chan : chans_) {
for (const ChanInfo &chan : chans_)
{
struct pollfd &pfd = pollvec_[pollsize++];
assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket;
pfd.events = 0;
pfd.revents = 0;
if (chan.ready_now) mstimeout = 0;
if (chan.just_released) mstimeout = 0;
if (chan.ready_on_pollin) pfd.events |= POLLIN;
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT;
if (chan.ready_now)
mstimeout = 0;
if (chan.just_released)
mstimeout = 0;
if (chan.ready_on_pollin)
pfd.events |= POLLIN;
if (chan.ready_on_pollout)
pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0))
pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " ";
}
@@ -449,15 +591,18 @@ public:
// Check listening sockets.
int index = 0;
for (auto &p : listen_sockets_) {
for (auto &p : listen_sockets_)
{
struct pollfd &pfd = pollvec_[index++];
if (pfd.revents & (POLLIN | POLLERR)) {
if (pfd.revents & (POLLIN | POLLERR))
{
accept_connection(p.first, p.second);
}
}
// Advance channels where possible.
for (ChanInfo &chan : chans_) {
for (ChanInfo &chan : chans_)
{
struct pollfd &pfd = pollvec_[index++];
bool pollin = ((pfd.revents & POLLIN) != 0);
bool pollout = ((pfd.revents & POLLOUT) != 0);
@@ -465,7 +610,8 @@ public:
if (chan.ready_now || pollerr || chan.just_released ||
(chan.ready_on_pollin && pollin) ||
(chan.ready_on_pollout && pollout) ||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout))
{
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
@@ -480,36 +626,46 @@ public:
cleanup_channels();
}
int replay_logfile(const char *fn, bool verbose) {
int replay_logfile(const char *fn, bool verbose)
{
drv::ReplayPlayer player;
player.open_logfile(fn);
if (verbose) {
if (verbose)
{
player.enable_stdout();
}
while (true) {
while (true)
{
drv::ReplayPlayer::Status st = player.step();
if (st != drv::ReplayPlayer::ST_REPLAYING) {
if (st != drv::ReplayPlayer::ST_REPLAYING)
{
player.print_status(std::cerr);
return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1;
}
}
}
int drive(int argc, char *argv[]) {
int drive(int argc, char *argv[])
{
// Remove the program name from argv.
if (argc < 1) {
if (argc < 1)
{
DrivenEngine::print_usage(std::cerr, "<unknown>");
exit(1);
}
std::string program = argv[0];
argc -= 1; argv += 1;
argc -= 1;
argv += 1;
// If argv contains "replay <filename>", do a replay,
// and then skip everything else.
if (argc >= 1) {
if (argc >= 1)
{
std::string cmd(argv[0]);
if ((cmd == "replay") || (cmd == "vreplay")) {
if (argc != 2) {
if ((cmd == "replay") || (cmd == "vreplay"))
{
if (argc != 2)
{
std::cerr << "usage: " << program << " replay <filename>" << std::endl;
return 1;
}
@@ -519,29 +675,36 @@ public:
// If argv contains "record <filename>", start recording,
// and remove the "record <filename>" from argv.
if (argc >= 1) {
if (argc >= 1)
{
std::string cmd = argv[0];
if (cmd == "record") {
if (argc < 2) {
if (cmd == "record")
{
if (argc < 2)
{
DrivenEngine::print_usage(std::cerr, program);
return 1;
}
bool ok = recorder_.open_logfile(argv[1]);
if (!ok) {
if (!ok)
{
std::cerr << "Could not open logfile: " << argv[1] << std::endl;
return 1;
}
argc -= 2; argv += 2;
argc -= 2;
argv += 2;
}
}
// Create the engine.
if (argc < 1) {
if (argc < 1)
{
DrivenEngine::print_usage(std::cerr, program);
return 1;
}
bool engine_made = recorder_.create_engine(argv[0]);
if (!engine_made) {
if (!engine_made)
{
DrivenEngine::print_usage(std::cerr, program);
return 1;
}
@@ -551,25 +714,17 @@ public:
chbuf_.reset(new char[CHBUF_SIZE]);
pollvec_.reset(new struct pollfd[POLLVEC_SIZE]);
ssl_ctx_with_root_certs_ = new_ssl_context(false, true, "");
ssl_ctx_with_server_certs_ = new_ssl_context(true, false, "");
ssl_ctx_with_no_certs_ = new_ssl_context(false, false, "");
if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_.get(), dummycert::certificate) <= 0) {
ERR_print_errors_fp(stderr);
return 1;
}
if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_.get(), dummycert::privatekey) <= 0 ) {
ERR_print_errors_fp(stderr);
return 1;
}
ssl_server_ctx_.reset(new_ssl_server_context());
ssl_client_secure_ctx_.reset(new_ssl_client_context(SSL_VERIFY_PEER));
ssl_client_insecure_ctx_.reset(new_ssl_client_context(SSL_VERIFY_NONE));
assert_ssl_errors_empty();
handle_lua_source();
recorder_.drv_invoke_event_init(argc, argv);
handle_listen_ports();
while (!recorder_.drv_get_stop_driver()) {
while (!recorder_.drv_get_stop_driver())
{
handle_lua_source();
handle_console_output();
handle_new_outgoing_sockets();
@@ -579,7 +734,8 @@ public:
recorder_.drv_invoke_event_update(monoclock.get());
}
for (ChanInfo &chan : chans_) {
for (ChanInfo &chan : chans_)
{
close_channel(chan, "");
}
@@ -588,5 +744,3 @@ public:
return 0;
}
};

View File

@@ -41,7 +41,7 @@ struct termios orig_termios;
static std::string strerror_str(int err) {
char errbuf[256];
return strerror_r(errno, errbuf, 256);
return strerror_r(err, errbuf, 256);
}
void set_nonblocking(int fd) {
@@ -69,7 +69,7 @@ static void enable_tty_raw() {
assert(status >= 0);
}
static SOCKET open_connection(std::string_view target, std::string &err) {
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
struct addrinfo *addrs = nullptr;
struct addrinfo *goodaddr = nullptr;
struct addrinfo hints;
@@ -82,9 +82,7 @@ static SOCKET open_connection(std::string_view target, std::string &err) {
hints.ai_flags = AI_NUMERICSERV;
err.clear();
std::string host, port;
drv::split_host_port(target, host, port);
int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs);
int status = getaddrinfo(host, port, &hints, &addrs);
if (status != 0) {
err = gai_strerror(status);
goto error_general;
@@ -228,6 +226,25 @@ static int console_read(char *bytes, int nbytes) {
return read(0, bytes, nbytes);
}
static void ssl_ctx_use_dummycert(SSL_CTX *ctx);
static SSL_CTX *new_ssl_server_context() {
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
ssl_ctx_use_dummycert(ctx);
return ctx;
}
static SSL_CTX *new_ssl_client_context(int verify) {
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_default_verify_paths(ctx);
SSL_CTX_set_verify(ctx, verify, nullptr);
return ctx;
}
static void disable_randomization(int argc, char *argv[]) {
const int old_personality = personality(ADDR_NO_RANDOMIZE);

View File

@@ -55,17 +55,15 @@ static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) {
return nullptr;
}
static SOCKET open_connection(std::string_view target, std::string &err) {
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
PADDRINFOA addrs = nullptr;
PADDRINFOA goodaddr = nullptr;
SOCKET sock = INVALID_SOCKET;
std::string_view host, port;
err.clear();
util::split_host_port(target, host, port);
int status = getaddrinfo(host.data(), port.data(), nullptr, &addrs);
int status = getaddrinfo(host, port, nullptr, &addrs);
while (status == WSATRY_AGAIN) {
status = getaddrinfo(host.data(), port.data(), nullptr, &addrs);
status = getaddrinfo(host, port, nullptr, &addrs);
}
if (status == WSAHOST_NOT_FOUND) {
err = "host not found";

View File

@@ -17,16 +17,31 @@
namespace drv {
void split_host_port(std::string_view target, std::string &host, std::string &port) {
size_t lastcolon = target.rfind(':');
if (lastcolon == std::string_view::npos) {
host = ""; port = ""; return;
std::vector<std::string_view> split_view(std::string_view v, char sep) {
std::vector<std::string_view> result;
while (true) {
size_t pos = v.find(sep);
if (pos == std::string_view::npos) break;
result.push_back(v.substr(0, pos));
v = v.substr(pos + 1);
}
host = target.substr(0, lastcolon);
port = target.substr(lastcolon + 1);
if ((host == "") || (port == "")) {
host = ""; port = ""; return;
result.push_back(v);
return result;
}
void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port) {
std::vector<std::string_view> split = split_view(target, ':');
if (split.size() != 3) {
cert.clear(); host.clear(); port.clear();
return;
}
if (split[0].empty() || split[1].empty() || split[2].empty()) {
cert.clear(); host.clear(); port.clear();
return;
}
cert = std::string(split[0]);
host = std::string(split[1]);
port = std::string(split[2]);
}
std::vector<std::string> parse_control_lst(std::string_view ctrl) {
@@ -502,9 +517,10 @@ void ReplayPlayer::drv_invoke_event_update() {
} // namespace drv
LuaDefine(unittests_driverutil, "", "some unit tests") {
// Test split_host_port
std::string host, port;
drv::split_host_port("stanford.edu:80", host, port);
// Test split_target
std::string cert, host, port;
drv::split_target("cert:stanford.edu:80", cert, host, port);
LuaAssertStrEq(L, cert, "cert");
LuaAssertStrEq(L, host, "stanford.edu");
LuaAssertStrEq(L, port, "80");

View File

@@ -11,7 +11,8 @@
namespace drv {
void split_host_port(std::string_view target, std::string &host, std::string &port);
void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port);
std::vector<std::string> parse_control_lst(std::string_view ctrl);
@@ -58,7 +59,7 @@ public:
//
const eng::vector<int> &drv_get_listen_ports() const { return e_->drv_get_listen_ports(); }
const eng::vector<int> &drv_get_new_outgoing() const { return e_->drv_get_new_outgoing(); }
const eng::string &drv_get_target(int chid) const { return e_->drv_get_target(chid); }
std::string_view drv_get_target(int chid) const { return e_->drv_get_target(chid); }
bool drv_outgoing_empty(int chid) const { return e_->drv_outgoing_empty(chid); }
bool drv_get_channel_released(int chid) const { return e_->drv_get_channel_released(chid); }
std::string_view drv_peek_outgoing(int chid) const { return e_->drv_peek_outgoing(chid); }

View File

@@ -74,7 +74,7 @@ public:
set_initial_state();
// Establish a connection to the server.
channel_ = new_outgoing_channel("localhost:8085");
channel_ = new_outgoing_channel("cert:localhost:8085");
// Set the console prompt
get_stdio_channel()->set_prompt(console_.get_prompt());
@@ -262,7 +262,7 @@ public:
// Check for communication from server..
if (channel_ != nullptr) {
if (channel_->closed()) {
stdostream() << "Server closed connection " << channel_->error() << std::endl;
stdostream() << "server closed connection: " << channel_->error() << std::endl;
abandon_server();
} else {
while (true) {