Change directory structure

This commit is contained in:
2023-02-14 14:05:45 -05:00
parent acad4291b6
commit def6387ca3
323 changed files with 161 additions and 19581 deletions

View File

@@ -0,0 +1,488 @@
#define POLLVEC_SIZE (DRV_MAX_CHAN + 1)
static void if_error_print_and_exit(const std::string_view str) {
if (!str.empty()) {
std::cerr << std::endl << "error: " << str << std::endl;
exit(1);
}
}
class Driver {
public:
enum ChanState {
CHAN_INACTIVE,
CHAN_PLAINTEXT,
CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
CHAN_SSL_READWRITE,
};
struct ChanInfo {
int chid;
SOCKET socket;
SSL *ssl;
ChanState state;
uint32_t nbytes;
const char *bytes;
bool ready_now;
bool ready_on_pollin;
bool ready_on_pollout;
bool ready_on_outgoing;
uint32_t last_write_nbytes;
bool marked_for_deletion() const { return state == CHAN_INACTIVE; }
};
EngineWrapper engw;
std::vector<ChanInfo> chans_;
std::map<int, SOCKET> listen_sockets_;
bool read_console_recently_;
std::unique_ptr<struct pollfd[]> pollvec_;
std::unique_ptr<char[]> chbuf_;
sslutil::UniqueCTX ssl_server_ctx_;
sslutil::UniqueCTX ssl_client_secure_ctx_;
sslutil::UniqueCTX ssl_client_insecure_ctx_;
void handle_listen_ports() {
uint32_t nports; const uint32_t *ports;
engw.get_listen_ports(&engw, &nports, &ports);
for (uint32_t i = 0; i < nports; i++) {
int port = ports[i];
if (listen_sockets_.find(port) == listen_sockets_.end()) {
std::string err;
SOCKET sock = listen_on_port(port, err);
if_error_print_and_exit(err);
assert(sock != INVALID_SOCKET);
listen_sockets_[port] = sock;
}
}
}
void handle_lua_source() {
if (engw.get_rescan_lua_source(&engw)) {
drvutil::ostringstream oss;
std::string err = drvutil::package_lua_source(".", &oss);
if_error_print_and_exit(err);
engw.play_set_lua_source(&engw, oss.size(), oss.c_str());
}
}
void close_channel(ChanInfo &chan, std::string_view err) {
// std::cerr << "Closing channel " << chan.chid << std::endl;
assert(chan.state != CHAN_INACTIVE);
// Close and release the SSL channel.
if (chan.ssl != nullptr) {
SSL_free(chan.ssl);
chan.ssl = nullptr;
}
// Close and release the socket.
assert(chan.socket != INVALID_SOCKET);
assert(socket_close(chan.socket) == 0);
chan.socket = INVALID_SOCKET;
// Close everything else.
engw.play_notify_close(&engw, chan.chid, err.size(), err.data());
chan.state = CHAN_INACTIVE;
chan.chid = -1;
chan.nbytes = 0;
chan.bytes = 0;
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
chan.last_write_nbytes = 0;
}
void handle_console_output() {
while (true) {
uint32_t ndata; const char *data;
engw.get_outgoing(&engw, 0, &ndata, &data);
if (ndata == 0) break;
if (ndata > DRV_SHORTSTRING_SIZE) ndata = DRV_SHORTSTRING_SIZE;
int nwrote = console_write(data, ndata);
if (nwrote <= 0) break;
engw.play_sent_outgoing(&engw, 0, nwrote);
}
}
void handle_console_input() {
char buffer[256];
read_console_recently_ = false;
while (true) {
int nread = console_read(buffer, 256);
if (nread <= 0) break;
read_console_recently_ = true;
engw.play_recv_incoming(&engw, 0, nread, buffer);
}
}
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
ChanInfo newchan;
newchan.chid = chid;
newchan.socket = sock;
newchan.ssl = SSL_new(ctx);
newchan.state = state;
newchan.nbytes = 0;
newchan.bytes = 0;
newchan.ready_now = false;
newchan.ready_on_pollin = false;
newchan.ready_on_pollout = true;
newchan.ready_on_outgoing = false;
newchan.last_write_nbytes = 0;
SSL_set_fd(newchan.ssl, newchan.socket);
// SSL_set_msg_callback(newchan.ssl, SSL_trace);
// SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0));
chans_.push_back(newchan);
}
void handle_new_outgoing_sockets() {
uint32_t nchids; const uint32_t *chids;
engw.get_new_outgoing(&engw, &nchids, &chids);
for (uint32_t i = 0; i < nchids; i++) {
uint32_t chid = chids[i];
std::string err, cert, host, port;
const char *target = engw.get_target(&engw, chid);
drvutil::split_target(target, cert, host, port);
if (cert.empty() || host.empty() || port.empty()) {
std::string message = "invalid target: ";
message += target;
engw.play_notify_close(&engw, chid, message.size(), message.c_str());
continue;
}
SSL_CTX *ctx = nullptr;
if (cert == "cert") {
ctx = ssl_client_secure_ctx_.get();
} else if (cert == "nocert") {
ctx = ssl_client_insecure_ctx_.get();
} else {
std::string message = "invalid cert rule: ";
message += target;
engw.play_notify_close(&engw, chid, message.size(), message.c_str());
continue;
}
SOCKET sock = open_connection(host.c_str(), port.c_str(), err);
if (sock == INVALID_SOCKET) {
engw.play_notify_close(&engw, chid, err.size(), err.c_str());
continue;
}
// std::cerr << "Opening channel " << chid << std::endl;
make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING);
}
engw.play_clear_new_outgoing(&engw);
}
void accept_connection(int port, SOCKET sock) {
std::string err;
SOCKET socket = accept_on_socket(sock, err);
if_error_print_and_exit(err);
if (socket != INVALID_SOCKET) {
uint32_t chid = engw.play_notify_accept(&engw, port);
// std::cerr << "Accepted channel " << chid << std::endl;
make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING);
}
}
void advance_plaintext(ChanInfo &chan) {
std::string err;
// Try to write plaintext to the channel.
uint32_t ndata; const char *data;
engw.get_outgoing(&engw, chan.chid, &ndata, &data);
if (ndata > 0) {
int sbytes = ndata;
if (sbytes > DRV_SHORTSTRING_SIZE) sbytes = DRV_SHORTSTRING_SIZE;
int wbytes = socket_send(chan.socket, data, sbytes, err);
if (wbytes < 0) {
close_channel(chan, err.c_str());
} else {
engw.play_sent_outgoing(&engw, chan.chid, wbytes);
}
}
// Try to read plaintext from the channel.
// Someday, find a way to avoid this copy.
int nrecv = socket_recv(chan.socket, chbuf_.get(), DRV_SHORTSTRING_SIZE, err);
if (nrecv < 0) {
close_channel(chan, err.c_str());
} else {
engw.play_recv_incoming(&engw, chan.chid, nrecv, chbuf_.get());
}
// Update the ready-flags for next time.
chan.ready_on_outgoing = true;
chan.ready_on_pollin = true;
}
void process_ssl_error(ChanInfo &chan, int retval) {
int error = SSL_get_error(chan.ssl, retval);
// std::cerr << "SSL error code = " << error << " ";
if (error == SSL_ERROR_WANT_READ) {
chan.ready_on_pollin = true;
} else if (error == SSL_ERROR_WANT_WRITE) {
chan.ready_on_pollout = true;
} else {
std::string error = sslutil::error_string();
if (error == "") error = "unknown error";
close_channel(chan, error);
}
}
void advance_ssl_connecting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_connecting" << std::endl;
int retval = SSL_connect(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
// std::cerr << "ssl_connect_error";
process_ssl_error(chan, retval);
}
}
void advance_ssl_accepting(ChanInfo &chan) {
// std::cerr << "In advance_ssl_accepting" << std::endl;
int retval = SSL_accept(chan.ssl);
if (retval == 1) {
// Connection successful.
chan.state = CHAN_SSL_READWRITE;
chan.ready_now = true;
} else {
process_ssl_error(chan, retval);
}
}
void advance_ssl_readwrite(ChanInfo &chan) {
// std::cerr << "In advance_ssl_readwrite" << std::endl;
// Try to read data.
int read_result = SSL_read(chan.ssl, chbuf_.get(), DRV_SHORTSTRING_SIZE);
if (read_result > 0) {
engw.play_recv_incoming(&engw, chan.chid, read_result, chbuf_.get());
chan.ready_now = true;
} else {
process_ssl_error(chan, read_result);
if (chan.state == CHAN_INACTIVE) return;
}
// Try to write data.
uint32_t wbytes;
if (chan.last_write_nbytes > 0) {
wbytes = chan.last_write_nbytes;
assert(wbytes < chan.nbytes);
} else {
wbytes = chan.nbytes;
if (wbytes > 65536) wbytes = 65536;
}
if (wbytes > 0) {
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
if (write_result > 0) {
engw.play_sent_outgoing(&engw, chan.chid, write_result);
chan.last_write_nbytes = 0;
chan.ready_on_outgoing = true;
} else {
chan.last_write_nbytes = wbytes;
process_ssl_error(chan, write_result);
if (chan.state == CHAN_INACTIVE) return;
}
} else {
chan.ready_on_outgoing = true;
}
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" <<
// chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" <<
// chan.ready_on_outgoing << " ";
}
void advance_channel(ChanInfo &chan) {
sslutil::clear_all_errors();
switch (chan.state) {
case CHAN_PLAINTEXT:
advance_plaintext(chan);
break;
case CHAN_SSL_CONNECTING:
advance_ssl_connecting(chan);
break;
case CHAN_SSL_ACCEPTING:
advance_ssl_accepting(chan);
break;
case CHAN_SSL_READWRITE:
advance_ssl_readwrite(chan);
break;
default:
assert(false);
break;
}
}
void handle_socket_input_output() {
std::string err;
int mstimeout = read_console_recently_ ? 100 : 1000;
// Peek output buffers and determine channel release flags.
bool any_released = false;
for (ChanInfo &chan : chans_) {
engw.get_outgoing(&engw, chan.chid, &chan.nbytes, &chan.bytes);
if (chan.nbytes == 0) {
if (engw.get_channel_released(&engw, chan.chid)) {
close_channel(chan, "");
any_released = true;
}
}
}
// Delete any released channels
if (any_released) {
drvutil::remove_marked_items(chans_);
}
// Construct the struct pollfd vector.
int pollsize = 0;
for (const auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec_[pollsize++];
pfd.fd = p.second;
pfd.events = POLLIN;
pfd.revents = 0;
}
for (const ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec_[pollsize++];
assert(chan.socket != INVALID_SOCKET);
pfd.fd = chan.socket;
pfd.events = 0;
pfd.revents = 0;
if (chan.ready_now) mstimeout = 0;
if (chan.ready_on_pollin) pfd.events |= POLLIN;
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
if (chan.ready_on_outgoing && (chan.nbytes > 0))
pfd.events |= POLLOUT;
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes <<
// std::endl;
}
// Do the poll.
socket_poll(pollvec_.get(), pollsize, mstimeout, err);
if_error_print_and_exit(err);
// Check listening sockets.
int index = 0;
for (auto &p : listen_sockets_) {
struct pollfd &pfd = pollvec_[index++];
if (pfd.revents & (POLLIN | POLLERR)) {
accept_connection(p.first, p.second);
}
}
// Advance channels where possible.
for (ChanInfo &chan : chans_) {
struct pollfd &pfd = pollvec_[index++];
bool pollin = ((pfd.revents & POLLIN) != 0);
bool pollout = ((pfd.revents & POLLOUT) != 0);
bool pollerr = ((pfd.revents & (POLLERR | POLLHUP)) != 0);
if (chan.ready_now || pollerr ||
(chan.ready_on_pollin && pollin) ||
(chan.ready_on_pollout && pollout) ||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
chan.ready_now = false;
chan.ready_on_pollin = false;
chan.ready_on_pollout = false;
chan.ready_on_outgoing = false;
advance_channel(chan);
}
chan.nbytes = 0;
chan.bytes = 0;
}
// Delete any newly-inactive channels
drvutil::remove_marked_items(chans_);
}
int replay_logfile(const char *fn, bool verbose) {
engw.replay_initialize(&engw, fn);
if_error_print_and_exit(engw.error);
while (engw.rlog) {
engw.replay_step(&engw);
}
if_error_print_and_exit(engw.error);
return 0;
}
int drive(int argc, char *argv[]) {
// Remove the program name from argv.
std::string program = argv[0];
argc -= 1;
argv += 1;
// Load the DLL and gain access to its functions.
call_init_engine_wrapper(&engw);
// If argv contains "replay <filename>", do a replay,
// and then skip everything else.
if (argc >= 1) {
std::string cmd(argv[0]);
if ((cmd == "replay") || (cmd == "vreplay")) {
if (argc != 2) {
std::cerr << "usage: " << program << " replay <filename>"
<< std::endl;
return 1;
}
return replay_logfile(argv[1], cmd == "vreplay");
}
}
// If argv contains "record <filename>", start recording,
// and remove the "record <filename>" from argv.
std::string replaylogfn;
if (argc >= 1) {
std::string cmd = argv[0];
if (cmd == "record") {
if (argc < 2) {
std::cerr << "The 'record' command must be followed by a filename" << std::endl;
return 1;
}
replaylogfn = argv[1];
argc -= 2;
argv += 2;
}
}
// Initialize state variables.
read_console_recently_ = false;
chbuf_.reset(new char[DRV_SHORTSTRING_SIZE]);
pollvec_.reset(new struct pollfd[POLLVEC_SIZE]);
ssl_server_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE));
ssl_client_secure_ctx_.reset(sslutil::new_context(SSL_VERIFY_PEER));
ssl_client_insecure_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE));
ssl_load_certificate_authorities(ssl_client_secure_ctx_.get());
sslutil::ctx_load_dummy_cert(ssl_server_ctx_.get());
// Read the initial lua source code.
drvutil::ostringstream srcpak;
std::string srcpakerr = drvutil::package_lua_source(".", &srcpak);
if_error_print_and_exit(srcpakerr);
// Initialize the engine.
engw.play_initialize(&engw, argc, argv, srcpak.size(), srcpak.c_str(), replaylogfn.c_str());
if_error_print_and_exit(engw.error);
// Set up listening ports.
handle_listen_ports();
// Main loop.
while (!engw.get_stop_driver(&engw)) {
handle_lua_source();
handle_console_output();
handle_new_outgoing_sockets();
handle_socket_input_output();
handle_console_input();
handle_console_output();
engw.play_invoke_event_update(&engw, drvutil::get_monotonic_clock());
}
// Cleanup
engw.release(&engw);
for (ChanInfo &chan : chans_) {
close_channel(chan, "");
}
return 0;
}
};

View File

@@ -0,0 +1,260 @@
#include "drvutil.hpp"
#include "sslutil.hpp"
#include "../core/enginewrapper.hpp"
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cassert>
#include <map>
#include <vector>
#include <string>
#include <poll.h>
#include <sys/time.h>
#include <fcntl.h>
#include <termios.h>
#include <unistd.h>
#include <sys/select.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/personality.h>
#include <netdb.h>
#include <malloc.h>
#include <dlfcn.h>
using SOCKET=int;
const int INVALID_SOCKET = -1;
struct termios orig_termios;
void set_nonblocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
assert(flags != -1);
int status = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
assert(status != -1);
}
static void disable_tty_raw() {
tcsetattr(0, TCSAFLUSH, &orig_termios);
}
static void enable_tty_raw() {
int status = tcgetattr(0, &orig_termios);
assert(status >= 0);
atexit(disable_tty_raw);
struct termios raw = orig_termios;
raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON);
raw.c_lflag &= ~(ECHO | ICANON);
raw.c_oflag |= OPOST;
raw.c_cc[VMIN] = 0;
raw.c_cc[VTIME] = 0;
status = tcsetattr(0, TCSAFLUSH, &raw);
assert(status >= 0);
}
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
struct addrinfo *addrs = nullptr;
struct addrinfo *goodaddr = nullptr;
struct addrinfo hints;
SOCKET sock = INVALID_SOCKET;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
hints.ai_flags = AI_NUMERICSERV;
err.clear();
int status = getaddrinfo(host, port, &hints, &addrs);
if (status != 0) {
err = gai_strerror(status);
goto error_general;
}
if (addrs == nullptr) {
err = "no such host found";
goto error_general;
}
goodaddr = addrs;
assert(goodaddr->ai_family == AF_INET);
assert(goodaddr->ai_socktype == SOCK_STREAM);
assert(goodaddr->ai_protocol == IPPROTO_TCP);
sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol);
if (sock <= 0) goto error_errno;
set_nonblocking(sock);
status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen);
if ((status != 0) && (errno != EINPROGRESS)) goto error_errno;
freeaddrinfo(addrs);
return sock;
error_errno:
err = drvutil::strerror_str(errno);
error_general:
if (sock != INVALID_SOCKET) close(sock);
if (addrs != nullptr) freeaddrinfo(addrs);
return INVALID_SOCKET;
}
static SOCKET listen_on_port(int port, std::string &err) {
int status, enable;
err.clear();
SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock <= 0) goto error_errno;
enable = 1;
status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (status != 0) goto error_errno;
struct sockaddr_in server;
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(port);
status = bind(sock, (struct sockaddr *)&server, sizeof(server));
if (status != 0) goto error_errno;
status = listen(sock, 10);
if (status != 0) goto error_errno;
set_nonblocking(sock);
return sock;
error_errno:
err = drvutil::strerror_str(errno);
if (sock >= 0) close(sock);
return INVALID_SOCKET;
}
static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
err.clear();
SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock >= 0) {
set_nonblocking(chsock);
return chsock;
} else {
if ((errno != EAGAIN) && (errno != EWOULDBLOCK) && (errno != ECONNABORTED)) {
err = drvutil::strerror_str(errno);
}
return INVALID_SOCKET;
}
}
// the return values for socket_send and socket_recv are:
//
// positive: sent or received bytes successfully
// zero: would block
// negative: channel closed, possibly cleanly or possibly with error
//
static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) {
err.clear();
int wbytes = send(socket, bytes, nbytes, 0);
if (wbytes < 0) {
if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
return 0;
} else {
err = drvutil::strerror_str(errno);
return -1;
}
} else {
return wbytes;
}
}
static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) {
err.clear();
int nrecv = recv(socket, bytes, nbytes, 0);
if (nrecv < 0) {
if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) {
err = drvutil::strerror_str(errno);
return -1;
} else {
return 0;
}
} else if (nrecv == 0) {
return -1;
} else {
return nrecv;
}
}
static int socket_close(SOCKET socket) {
return close(socket);
}
static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) {
// socket_poll is implicitly expected to also poll stdin,
// if the OS allows that. Linux does, so we add stdin to the
// poll vector. The poll vector is required to have at
// least one free space in order to do this.
pollvec[pollcount].fd = 0;
pollvec[pollcount].events = POLLIN;
pollcount += 1;
// Do the poll.
int status = poll(pollvec, pollcount, mstimeout);
if (status < 0) {
err = drvutil::strerror_str(errno);
return -1;
}
return 0;
}
static int console_write(const char *bytes, int nbytes) {
return write(1, bytes, nbytes);
}
static int console_read(char *bytes, int nbytes) {
return read(0, bytes, nbytes);
}
// Load the DLL if it's not already loaded. Stores
// the handle in a global variable.
static void load_engine_dll() {
// Not actually implemented yet. Currently, the engine
// is linked right into the executable.
}
static void call_init_engine_wrapper(EngineWrapper *w) {
load_engine_dll();
using InitFn = void (*)(EngineWrapper *);
InitFn initfn = (InitFn)dlsym(RTLD_DEFAULT, "init_engine_wrapper");
assert(initfn != nullptr);
initfn(w);
}
static void ssl_load_certificate_authorities(SSL_CTX *ctx) {
assert(SSL_CTX_set_default_verify_paths(ctx) == 1);
}
static void disable_randomization(int argc, char *argv[]) {
const int old_personality = personality(ADDR_NO_RANDOMIZE);
if (!(old_personality & ADDR_NO_RANDOMIZE)) {
const int new_personality = personality(ADDR_NO_RANDOMIZE);
if (new_personality & ADDR_NO_RANDOMIZE) {
execv(argv[0], argv);
}
}
}
#include "driver-common.cpp"
int main(int argc, char **argv)
{
disable_randomization(argc, argv);
enable_tty_raw();
assert(OPENSSL_init_ssl(0, NULL) == 1);
sslutil::clear_all_errors();
Driver driver;
return driver.drive(argc, argv);
}

View File

@@ -0,0 +1,279 @@
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include "drvutil.hpp"
#include "sslutil.hpp"
#include "../cpp/enginewrapper.hpp"
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cassert>
#include <filesystem>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <synchapi.h>
#include <sysinfoapi.h>
#include <windows.h>
#include <openssl/ssl.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#include <openssl/pem.h>
// OpenSSL requires plain ascii pathnames. Returns empty string
// if the path cannot be converted to plain ascii.
std::string path_to_plain_ascii(const std::filesystem::path &path) {
std::wstring s = path.native();
for (wchar_t c : s) {
if ((c < 1) || (c > 127)) return "";
}
std::ostringstream oss;
for (wchar_t c : s) {
oss << ((char)c);
}
return oss.str();
}
static void set_nonblocking(SOCKET sock) {
u_long mode = 1; // 1 to enable non-blocking socket
int status = ioctlsocket(sock, FIONBIO, &mode);
assert(status == 0);
}
static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) {
for (PADDRINFOA addr = addrinfo; addr != nullptr; addr = addr->ai_next) {
if (addr->ai_family == AF_INET) {
return addr;
}
}
return nullptr;
}
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
PADDRINFOA addrs = nullptr;
PADDRINFOA goodaddr = nullptr;
SOCKET sock = INVALID_SOCKET;
err.clear();
int status = getaddrinfo(host, port, nullptr, &addrs);
while (status == WSATRY_AGAIN) {
status = getaddrinfo(host, port, nullptr, &addrs);
}
if (status == WSAHOST_NOT_FOUND) {
err = "host not found";
goto error;
}
if (status != 0) {
err = "DNS resolution malfunction";
goto error;
}
goodaddr = find_good_addr(addrs);
if (goodaddr == nullptr) {
err = "host not an internet host";
goto error;
}
sock = socket(goodaddr->ai_family, SOCK_STREAM, IPPROTO_TCP);
if (sock == INVALID_SOCKET) {
err = "could not create a socket";
goto error;
}
set_nonblocking(sock);
status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen);
if (status != 0) {
int errcode = WSAGetLastError();
if (errcode != WSAEWOULDBLOCK) {
err = "connect failure";
goto error;
}
}
freeaddrinfo(addrs);
return sock;
error:
if (sock != INVALID_SOCKET) closesocket(sock);
if (addrs != nullptr) freeaddrinfo(addrs);
return SOCKET_ERROR;
}
SOCKET listen_on_port(int port, std::string &err) {
int status;
err.clear();
SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET) {
err = "could not create a socket";
goto error;
}
struct sockaddr_in server;
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(port);
status = bind(sock, (struct sockaddr *)&server, sizeof(server));
if (status < 0) {
err = "could not bind port";
goto error;
}
status = listen(sock, 10);
if (status < 0) {
err = "could not listen on socket";
goto error;
}
set_nonblocking(sock);
std::cerr << "listening socket is " << sock << std::endl;
return sock;
error:
if (sock != INVALID_SOCKET) closesocket(sock);
return SOCKET_ERROR;
}
static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock != INVALID_SOCKET) {
set_nonblocking(chsock);
return chsock;
} else {
int errcode = WSAGetLastError();
if ((errcode == WSAEWOULDBLOCK) || (errcode == WSAECONNRESET)) {
return INVALID_SOCKET;
} else {
err = "accept failed";
return INVALID_SOCKET;
}
}
}
static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) {
err.clear();
int wbytes = send(socket, bytes, nbytes, 0);
if (wbytes == SOCKET_ERROR) {
int errcode = WSAGetLastError();
if (errcode == WSAEWOULDBLOCK) {
return 0;
} else {
err = "send failure";
return -1;
}
} else {
assert(wbytes > 0);
return wbytes;
}
}
static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) {
err.clear();
int nrecv = recv(socket, bytes, nbytes, 0);
if (nrecv < 0) {
int errcode = WSAGetLastError();
if (errcode == WSAEWOULDBLOCK) {
return 0;
} else {
err = "recv failure";
return -1;
}
} else if (nrecv == 0) {
return -1;
} else {
return nrecv;
}
}
static int socket_close(SOCKET socket) {
return closesocket(socket);
}
static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) {
if (pollcount == 0) {
if (mstimeout > 0) Sleep(mstimeout);
return 0;
}
int status = WSAPoll(pollvec, pollcount, mstimeout);
if (status < 0) {
err = strerror_str(WSAGetLastError());
return -1;
}
return status;
}
static void init_winsock() {
WSADATA data;
int errcode = WSAStartup(2, &data);
if (errcode != 0) {
fprintf(stderr, "Winsock didn't initalize, error %d", errcode);
exit(1);
}
}
static int console_write(const char *bytes, int nbytes) {
if (nbytes == 0) return 0;
HANDLE hstdout = GetStdHandle(STD_OUTPUT_HANDLE);
assert(hstdout != INVALID_HANDLE_VALUE);
DWORD nwrote;
if (nbytes > 10000) nbytes = 10000;
assert(WriteConsoleA(hstdout, bytes, nbytes, &nwrote, nullptr));
assert(nwrote > 0);
return nwrote;
}
static int console_read(char *bytes, int nbytes) {
HANDLE hstdin = GetStdHandle(STD_INPUT_HANDLE);
assert(hstdin != INVALID_HANDLE_VALUE);
INPUT_RECORD inrecords[512];
DWORD nread, nevents;
int nascii = 0;
if (GetNumberOfConsoleInputEvents(hstdin, &nevents)) {
if (int(nevents) > nbytes) nevents = nbytes;
ReadConsoleInputA(hstdin, inrecords, nevents, &nread);
for (int i = 0; i < int(nread); i++) {
const INPUT_RECORD &inr = inrecords[i];
if (inr.EventType != KEY_EVENT) continue;
const KEY_EVENT_RECORD &key = inr.Event.KeyEvent;
if (!key.bKeyDown) continue;
char c = key.uChar.AsciiChar;
bytes[nascii++] = c;
}
return nascii;
} else {
return 0;
}
}
static void ssl_load_certificate_authorities(SSL_CTX *ctx) {
HCERTSTORE hStore = CertOpenSystemStoreW(0, L"ROOT");
PCCERT_CONTEXT pContext = NULL;
X509 *x509;
X509_STORE *store = SSL_CTX_get_cert_store(ctx);
if (!hStore) {
fprintf(stderr, "Cannot open system certificate store.\n");
exit(1);
}
while ((pContext = CertEnumCertificatesInStore(hStore, pContext))) {
const unsigned char *encoded_cert = pContext->pbCertEncoded;
x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded);
if (x509) {
X509_STORE_add_cert(store, x509);
X509_free(x509);
}
}
CertCloseStore(hStore, 0);
}
#include "driver-common.cpp"
int main(int argc, char **argv)
{
init_winsock();
OPENSSL_init_ssl(0, NULL);
SourceDB::register_lua_builtins();
Driver driver;
return driver.drive(argc, argv);
}

269
luprex/cpp/drv/drvutil.cpp Normal file
View File

@@ -0,0 +1,269 @@
#include "drvutil.hpp"
#include <string_view>
#include <vector>
#include <cassert>
#include <sstream>
#include <fstream>
#include <string.h>
#include <iostream>
namespace drvutil {
inline static bool ascii_isspace(char c) {
return (c==' ')||(c=='\t')||(c=='\r')||(c=='\n')||(c=='\f')||(c=='\v');
}
std::string_view trim(std::string_view v) {
while ((!v.empty()) && (ascii_isspace(v.front()))) {
v.remove_prefix(1);
}
while ((!v.empty()) && (ascii_isspace(v.back()))) {
v.remove_suffix(1);
}
return v;
}
static std::string_view read_to_line(std::string_view &source) {
size_t pos = source.find('\n');
std::string_view result;
if (pos == std::string_view::npos) {
result = source;
source = std::string_view();
} else {
result = source.substr(0, pos);
source = source.substr(pos + 1);
}
if ((!result.empty()) && (result.back() == '\r')) {
result.remove_suffix(1);
}
return result;
}
std::vector<std::string_view> split_view(std::string_view v, char sep) {
std::vector<std::string_view> result;
while (true) {
size_t pos = v.find(sep);
if (pos == std::string_view::npos) break;
result.push_back(v.substr(0, pos));
v = v.substr(pos + 1);
}
result.push_back(v);
return result;
}
void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port) {
std::vector<std::string_view> split = split_view(target, ':');
if (split.size() != 3) {
cert.clear(); host.clear(); port.clear();
return;
}
if (split[0].empty() || split[1].empty() || split[2].empty()) {
cert.clear(); host.clear(); port.clear();
return;
}
cert = std::string(split[0]);
host = std::string(split[1]);
port = std::string(split[2]);
}
static std::vector<std::string> parse_control_lst(std::string_view ctrl) {
std::vector<std::string> result;
while (!ctrl.empty()) {
std::string_view line = read_to_line(ctrl);
std::string_view trimmed = trim(line);
if ((trimmed.size() > 0) && (trimmed[0] != '#')) {
result.emplace_back(trimmed);
}
}
return result;
}
// Read a source file into a string.
//
static std::string read_file(const char *fn, std::string &err) {
std::ifstream t(fn);
if (t.fail()) {
err = std::string("Could not open ") + fn;
return "";
}
t.seekg(0, std::ios::end);
size_t size = t.tellg();
std::string result(size, ' ');
t.seekg(0);
t.read(&result[0], size);
if ((t.fail()) || (size_t(t.tellg()) != size)) {
err = std::string("Could not read ") + fn;
return "";
}
err = "";
return result;
}
// This encoding can be read by StreamBuffer::read_uint32.
//
static void sbwrite_uint32(std::ostream *s, uint32_t v) {
s->write((const char *)&v, 4);
}
// This encoding can be read by StreamBuffer::read_uint64.
//
static void sbwrite_uint64(std::ostream *s, uint64_t v) {
s->write((const char *)&v, 8);
}
// This encoding can be read by StreamBuffer::read_string.
//
static void sbwrite_string(std::ostream *s, std::string_view sv) {
s->put(0xFF);
sbwrite_uint64(s, sv.size());
s->write(sv.data(), sv.size());
}
// This encoding can be read by StreamBuffer::read_string.
//
static bool sbwrite_file(std::ostream *s, const char *fn) {
s->put(0xFF);
uint64_t pos1 = s->tellp();
sbwrite_uint64(s, 0);
uint64_t pos2 = s->tellp();
std::ifstream t(fn);
if (t.fail()) {
return false;
}
*s << t.rdbuf();
if (t.fail()) {
return false;
}
uint64_t pos3 = s->tellp();
s->seekp(pos1);
sbwrite_uint64(s, pos3 - pos2);
s->seekp(pos3);
return true;
}
std::string package_lua_source(const std::string &base, std::ostream *s) {
std::string err;
std::string cfn = base + "/lua/control.lst";
std::string ctrl = read_file(cfn.c_str(), err);
if (!err.empty()) {
return err;
}
std::vector<std::string> names = parse_control_lst(ctrl);
sbwrite_uint32(s, names.size());
for (int i = 0; i < int(names.size()); i++) {
sbwrite_string(s, names[i]);
}
for (int i = 0; i < int(names.size()); i++) {
std::string lfn = base + "/lua/" + names[i];
if (!sbwrite_file(s, lfn.c_str())) {
return std::string("Cannot read source file: ") + lfn;
}
}
return "";
}
// strerror has to be the most overcomplicated function imaginable. The simple
// version, 'strerror', is not thread-safe, and the improved versions are all
// incompatible from OS to OS. Even different versions of linux aren't
// compatible. A lot of conditional compilation is needed.
#if defined(__linux__)
inline static void strerror_helper(int status, int errnum, char errbuf[256]) {
if (status != 0) {
snprintf(errbuf, 256, "unknown errno %d", errnum);
}
}
inline static void strerror_helper(const char *result, int errnum, char errbuf[256]) {
if (result != errbuf) {
snprintf(errbuf, 256, "%s", result);
}
}
void strerror_safe(int errnum, char errbuf[256]) {
auto rval = strerror_r(errnum, errbuf, 256);
strerror_helper(rval, errnum, errbuf);
}
#elif defined(_WIN32)
void strerror_safe(int errnum, char errbuf[256]) {
int status = strerror_s(errbuf, 256, errnum);
if (status != 0) {
snprintf(errbuf, 256, "unknown errno %d", errnum);
}
);
#endif
std::string strerror_str(int errnum) {
char buf[256];
strerror_safe(errnum, buf);
return buf;
}
// The monotonic clock is required to start at zero at initialization time,
// advance steadily, and never go backwards. It is okay, however, if it is a
// little inaccurate, or if it drifts a little over time.
#if defined(__linux__)
class MonoClock {
private:
struct timespec base_;
public:
MonoClock() {
int status = clock_gettime(CLOCK_MONOTONIC, &base_);
assert(status == 0);
}
double get() {
struct timespec t;
int status = clock_gettime(CLOCK_MONOTONIC, &t);
assert(status == 0);
double tv_sec = t.tv_sec - base_.tv_sec;
double tv_nsec = t.tv_nsec - base_.tv_nsec;
return tv_sec + (tv_nsec * 1.0E-9);
}
};
#elif defined(_WIN32)
class MonoClock {
public:
double freq_;
LONGLONG base_;
inline LONGLONG qpc() {
LARGE_INTEGER x;
BOOL status = QueryPerformanceCounter(&x);
assert(status != 0);
return x.QuadPart;
}
MonoClock() {
LARGE_INTEGER x;
BOOL status = QueryPerformanceFrequency(&x);
assert(status != 0);
freq_ = 1.0 / double(x.QuadPart);
base_ = qpc();
}
double get() {
return (qpc() - base) * freq_;
}
};
#else
#error "Only support __linux__ or _WIN32"
#endif
static MonoClock monoclock;
double get_monotonic_clock() {
return monoclock.get();
}
} // namespace drv

View File

@@ -0,0 +1,99 @@
////////////////////////////////////////////////////////////////////////////////
//
// DRIVER_UTIL
//
////////////////////////////////////////////////////////////////////////////////
#ifndef DRVUTIL_HPP
#define DRVUTIL_HPP
#include <vector>
#include <string>
#include <memory>
#include <string_view>
#include <ostream>
#include <sstream>
#include <algorithm>
namespace drvutil {
// Read the lua source from disk into an ostringstream.
//
// To pass the lua source into the DLL, here is what you do: Construct an
// ostringstream. Use package_lua_source to package all the lua source into
// the ostringstream. Fetch the packaged source code using ostringstream::str.
// Pass the packaged source code into drv_set_lua_source.
//
// The DLL must then decode the source package. Here is how it does that:
// It creates a StreamBuffer from the packaged up source. Then it must
// call these StreamBuffer methods:
//
// - read the number of source files using read_uint32.
// - for each file, read the filename using read_string.
// - for each file, read the contents using read_string.
//
// If package_lua_source encounters an error reading the source code, then it
// returns an error message. In this case, the ostream contains garbage. If
// there is no error, returns the empty string.
//
std::string package_lua_source(const std::string &base, std::ostream *oss);
// Parse a target designation.
//
// A target consists of 'cert::host::port'.
//
void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port);
// Get a system error message, in an OS-independent manner.
//
// These versions of strerror is thread-safe, and it never fails
// to put a message into the buffer.
//
void strerror_safe(int errnum, char result[256]);
std::string strerror_str(int errnum);
// Get the amount of time elapsed since program start.
//
// This is guaranteed to be monotonically increasing. It is not
// guaranteed to be accurate. Error could gradually accumulate over
// time.
//
double get_monotonic_clock();
// drvutil::ostringstream
//
// This is a variant of ostringstream in which it is possible
// to get the contents without copying. To get the contents
// without copying, use oss.size() and oss.c_str()
//
class ostringstream : public std::ostringstream {
class rstringbuf : public std::stringbuf {
public:
char *eback() { return std::streambuf::eback(); }
};
rstringbuf rsbuf_;
public:
ostringstream() {
std::basic_ostream<char>::rdbuf(&rsbuf_);
}
size_t size() {
return tellp();
}
const char *c_str() {
return rsbuf_.eback();
}
};
// Remove items from a vector that are marked for deletion.
//
template<class T>
void remove_marked_items(T &vec) {
auto iter = std::partition(vec.begin(), vec.end(), [] (const auto &x) { return !x.marked_for_deletion(); });
vec.erase(iter, vec.end());
}
} // namespace drvutil
#endif // DRVUTIL_HPP

211
luprex/cpp/drv/sslutil.cpp Normal file
View File

@@ -0,0 +1,211 @@
#include "drvutil.hpp"
#include "sslutil.hpp"
#include <iostream>
#include <cassert>
#include <vector>
#include <filesystem>
namespace sslutil {
const char *dummy_cert =
"-----BEGIN CERTIFICATE-----\n"
"MIIDezCCAmOgAwIBAgIUajKmxrLMr9zBMlphrTJU5qKG8FgwDQYJKoZIhvcNAQEL\n"
"BQAwTDELMAkGA1UEBhMCVVMxFTATBgNVBAgMDFBlbm5zeWx2YW5pYTESMBAGA1UE\n"
"CgwJbG9jYWxob3N0MRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMjIwMzIyMTczMzA4\n"
"WhgPMjEyMjAyMjYxNzMzMDhaMEwxCzAJBgNVBAYTAlVTMRUwEwYDVQQIDAxQZW5u\n"
"c3lsdmFuaWExEjAQBgNVBAoMCWxvY2FsaG9zdDESMBAGA1UEAwwJbG9jYWxob3N0\n"
"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5OWIaKqYae4nPxvu5EP3\n"
"VilcjApYcMT4+2ypfQoB6PEep5lwguA929rNsTKnhGsEiQAZ0eZPEZN7VhUwf/hz\n"
"26jIyTT43ELkt6k97wwSZSXuT65RpSiemwEs6g2mMwzpgP6nv+yam4HjE9AKiHGN\n"
"YeTV72Nw1EN70t6IjIf4jsJRXqDJkUx5sSSD6j0WBTOhzozIDgZHTDwiLhatE66m\n"
"SNoD8oWC0PscbUgOJkFpbaCAS8RJmpsdgkTFae2rzL9cOFLGw6OgV/BV1J1s0ks8\n"
"+veoMMtIO6fese+OZ+DyQbuGaoaltZUXzY6QjD5l34m2mGplelT7BrpcqJTBHwmh\n"
"CwIDAQABo1MwUTAdBgNVHQ4EFgQUXQM5TVfJ9gpUXg8fZ8yfuUVcBP8wHwYDVR0j\n"
"BBgwFoAUXQM5TVfJ9gpUXg8fZ8yfuUVcBP8wDwYDVR0TAQH/BAUwAwEB/zANBgkq\n"
"hkiG9w0BAQsFAAOCAQEAqYX/ZGv0Qh/xdXppjnqojm8mH0giDW4tvwMqHcW3YRa3\n"
"9J2yYot+rHjU5g4n6HEmWDBE0eqLz9n3Y3fkFzT8RWZwBaST965CgsfGofyuA2hC\n"
"Ddn4Am3B5tTPmi8WWRZg8amhpGVD/mwkoVFIK0M337b1aZUJYPE+Kc9WetSL2KqB\n"
"EhqSQpkAWhVadzP85dq2T9EDjAvhlFTFlDEBx1GDUcc8M0KQ9NEvLT7LgoUcbMiT\n"
"PerlSZQTB0crchXTRSERgiwu80r7D6STn/RcPL9Fg5PkA94/d87jGbmV4sxSRsvM\n"
"z+DnJGjHrV1J/jHPrnVvVLpigBlGno3C5O/sRw3gcQ==\n"
"-----END CERTIFICATE-----\n";
const char *dummy_key =
"-----BEGIN PRIVATE KEY-----\n"
"MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDk5Yhoqphp7ic/\n"
"G+7kQ/dWKVyMClhwxPj7bKl9CgHo8R6nmXCC4D3b2s2xMqeEawSJABnR5k8Rk3tW\n"
"FTB/+HPbqMjJNPjcQuS3qT3vDBJlJe5PrlGlKJ6bASzqDaYzDOmA/qe/7JqbgeMT\n"
"0AqIcY1h5NXvY3DUQ3vS3oiMh/iOwlFeoMmRTHmxJIPqPRYFM6HOjMgOBkdMPCIu\n"
"Fq0TrqZI2gPyhYLQ+xxtSA4mQWltoIBLxEmamx2CRMVp7avMv1w4UsbDo6BX8FXU\n"
"nWzSSzz696gwy0g7p96x745n4PJBu4ZqhqW1lRfNjpCMPmXfibaYamV6VPsGulyo\n"
"lMEfCaELAgMBAAECggEBAJa1AiFX4U4tva1xqNKmZV1XklWqIhzts7lnDBkF08gZ\n"
"qcNT5Z5mIpR09eVropwvEidZ56Yp63l5D0XYYbyAS1gfQ0QnGot7h7fdOKgB3MK4\n"
"PLY94gfKPNN17KqWHg2SvNNv1+cn04v78xUCb0zy5tHDp5Acexdm70ohtupARElJ\n"
"LSHdS7ebsqZUFXbbM3BpPEsQLi3PrzNs1DrKkZ3rR6eMGrsDqExXx8/foi9aZKsd\n"
"BGM2/kcTJ5aY6NhSv5iqO1oK46sbMrjVW/bYNsOyl0eFjwTRahn+Zhp/JMewZYeu\n"
"715g6kzbZNwEzBLgrhNPF6E2ycEr/C6z5bE78g5QCkECgYEA8s07UUY25bjYiWWy\n"
"W38pT7d/OXBSyKnq16N6MjVahl29r7nezFiDeLhLC0QiwXu/+qyxVZkB95MMGZXS\n"
"AsaKFNis3AJ6eR4SYyhpSScYKNvlKIiW37TtR4FDcy7y5LL6tFpiDDIGH3LuyWNo\n"
"d76142MBpv5aStnLGYU3pcZj43sCgYEA8VbNM4nqgSCQcbnHYjvsgphEMNSaoVie\n"
"xob2uigXdV6Te0ayoUFBnVNKVsRhk+sswuTV4k1pK/On+USVl2tQ16tcaVMjTfSD\n"
"HLYTJLmt6s4DcywWj5dfkbDoe5PulGXNZE960qXmOC62Lf0VMRwJ5x4FBRvGTjKC\n"
"zvekI2/kO7ECgYEAhBGeclb/BXXGUvY+TgadMf9d9KBkZ0IFu8Xwcd8TnoLe6vbv\n"
"ebery75zE228egIWKwREcYsIxuH1cvVLhrb35N73J7UxaTAyUD1rB598RL1XqPSj\n"
"HIwNhReK2NxwwnWYaQHA02FiczjRKjooWPojdcwk2fEArDZLg1YzLrj7HIECgYEA\n"
"htdx1Y8ESFtyeShMv5UtoxYCW6oeL3H9XH0CE6bc3IYYLvOkULbOO2HTEkGtJ2Fp\n"
"5AbJfiS0U4tS2dI5Jp4eUDH9cxexjRfFvd/5ODbKdnver5X9kQMJsbQ/YPSZg66R\n"
"oK9Lt7Bbvh5TScSy93psCgba1SzckspkDdGNkwMsaTECgYEAnFWaxormLUpXQRLs\n"
"tKzMMHgVnHlsHiqXH432zmT2fpGZHYoWbsGuQjjrHGnSiu3QbDhnzM6y/T2GRs6z\n"
"zHteIo/tzIyxg4MvJGJ9qANA7HoiKBdQ7G/I/NLJIyWAjj+e7/hgzKFcf+dpjpDq\n"
"HcKc9a4WXhC7yu79e5BnKWltHXY=\n"
"-----END PRIVATE KEY-----\n";
std::string error_string() {
// Get the last code.
int code = 0;
while (true) {
int icode = ERR_get_error();
if (icode == 0) break;
code = icode;
}
// Fetch and clear errno.
int terrno = errno;
errno = 0;
if (code != 0) {
const char *rc = ERR_reason_error_string(code);
if (rc != nullptr) {
return rc;
} else {
return drvutil::strerror_str(ERR_GET_REASON(code));
}
} else if (terrno != 0) {
return drvutil::strerror_str(terrno);
} else {
return "";
}
}
std::string path_to_plain_ascii(const std::filesystem::path &path) {
std::string s = path.native();
for (char c : s) {
if ((c < 1) || (c > 127)) return "";
}
return s;
}
void clear_all_errors() {
ERR_clear_error();
errno = 0;
}
SSL_CTX *new_context(int verify) {
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, verify, nullptr);
return ctx;
}
static int ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
UniqueBIO bio(BIO_new(BIO_s_mem()));
BIO_puts(bio.get(), str);
UniqueX509 certificate(PEM_read_bio_X509(bio.get(), NULL, NULL, NULL));
return SSL_CTX_use_certificate(ctx, certificate.get());
}
static int ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
UniqueBIO bio(BIO_new(BIO_s_mem()));
BIO_puts(bio.get(), str);
UniquePKEY pkey(PEM_read_bio_PrivateKey(bio.get(), NULL, NULL, NULL));
return SSL_CTX_use_PrivateKey(ctx, pkey.get());
}
void ctx_load_dummy_cert(SSL_CTX *ctx) {
ERR_clear_error();
if (ctx_use_certificate_str(ctx, dummy_cert) <= 0) {
ERR_print_errors_fp(stderr);
exit(1);
}
if (ctx_use_privatekey_str(ctx, dummy_key) <= 0) {
ERR_print_errors_fp(stderr);
exit(1);
}
}
static int count_certificates(const char *fn) {
static char null_passwd;
ErrClearErrorOnExit ece;
UniqueBIO bio(BIO_new(BIO_s_file()));
assert(bio != nullptr);
if (BIO_read_filename(bio.get(), fn) <= 0) {
std::cerr << "Cannot open file: " << fn << std::endl;
exit(1);
}
int total = 0;
while (true) {
UniqueX509 x(PEM_read_bio_X509_AUX(bio.get(), nullptr, nullptr, &null_passwd));
if (x == nullptr) break;
total += 1;
}
return total;
}
static bool contains_privatekey(const char *fn) {
static char null_passwd;
ErrClearErrorOnExit ece;
UniqueBIO bio(BIO_new(BIO_s_file()));
assert(bio != nullptr);
if (BIO_read_filename(bio.get(), fn) <= 0) {
std::cerr << "Cannot open file: " << fn << std::endl;
exit(1);
}
UniquePKEY k(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, &null_passwd));
return k != nullptr;
}
void ctx_load_cert_from_directory(SSL_CTX *ctx, const std::string &dir) {
std::vector<std::string> key_paths;
std::vector<std::string> cert_paths;
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
std::string fn = path_to_plain_ascii(entry.path());
if (fn.empty()) {
std::cerr << "Ignoring file with non-ascii filename: " << entry.path() << std::endl;
} else {
if (count_certificates(fn.c_str()) >= 1) {
cert_paths.push_back(fn);
}
if (contains_privatekey(fn.c_str())) {
key_paths.push_back(fn);
}
}
}
if (cert_paths.size() > 1) {
std::cerr << "Directory contains multiple certs: " << dir << std::endl;
exit(1);
}
if (key_paths.size() > 1) {
std::cerr << "Directory contains multiple keys: " << dir << std::endl;
exit(1);
}
if (cert_paths.empty()) {
std::cerr << "Directory doesn't contain a cert: " << dir << std::endl;
exit(1);
}
if (key_paths.empty()) {
std::cerr << "Directory doesn't contain a key: " << dir << std::endl;
exit(1);
}
int status;
status = SSL_CTX_use_PrivateKey_file(ctx, key_paths[0].c_str(), SSL_FILETYPE_PEM);
assert(status == 1);
status = SSL_CTX_use_certificate_chain_file(ctx, cert_paths[0].c_str());
assert(status == 1);
}
} // namespace sslutil

View File

@@ -0,0 +1,61 @@
#ifndef SSLUTIL_HPP
#define SSLUTIL_HPP
#include "drvutil.hpp"
#include <openssl/ssl.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#include <openssl/pem.h>
#include <openssl/conf.h>
#include <memory>
namespace sslutil {
struct SSL_Deleter {
void operator()(SSL *ssl) { SSL_free(ssl); }
};
struct CTX_Deleter {
void operator()(SSL_CTX *ctx) { SSL_CTX_free(ctx); }
};
struct BIO_Deleter {
void operator()(BIO *bio) { BIO_free(bio); }
};
struct X509_Deleter {
void operator()(X509 *x) { X509_free(x); }
};
struct PKEY_Deleter {
void operator()(EVP_PKEY *p) { EVP_PKEY_free(p); }
};
using UniqueSSL = std::unique_ptr<SSL, SSL_Deleter>;
using UniqueCTX = std::unique_ptr<SSL_CTX, CTX_Deleter>;
using UniqueBIO = std::unique_ptr<BIO, BIO_Deleter>;
using UniqueX509 = std::unique_ptr<X509, X509_Deleter>;
using UniquePKEY = std::unique_ptr<EVP_PKEY, PKEY_Deleter>;
struct ErrClearErrorOnExit {
~ErrClearErrorOnExit() {
ERR_clear_error();
}
};
// Return the OpenSSL error as a string.
std::string error_string();
void clear_all_errors();
SSL_CTX *new_context(int verify);
void ctx_load_dummy_cert(SSL_CTX *ctx);
void ctx_load_cert_from_directory(SSL_CTX *ctx, const std::string &dir);
} // namespace sslutil
#endif // SSLUTIL_HPP