SSL stuff working on windows again, excepting win CA registry
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
Improve table.findremove to work on tables, not just vectors.
|
Calling out to external servers.
|
||||||
|
|
||||||
Finish documenting all builtins.
|
Support ANSI escape sequences on output.
|
||||||
|
|
||||||
Get rid of source_install_builtins after documenting all builtins.
|
Make math.random do something predictable.
|
||||||
- but don't forget that source_install_builtins sets the string metatable.
|
|
||||||
|
|
||||||
Do something about std::cerr && std::cout once and for all.
|
Do something about std::cerr && std::cout once and for all.
|
||||||
|
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ void DrivenEngine::drv_clear_new_outgoing() {
|
|||||||
new_outgoing_.clear();
|
new_outgoing_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
const eng::string &DrivenEngine::drv_get_target(int chid) const {
|
std::string_view DrivenEngine::drv_get_target(int chid) const {
|
||||||
return get_chid(chid)->target_;
|
return get_chid(chid)->target_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -115,14 +115,15 @@ public:
|
|||||||
|
|
||||||
// The channel ID. These are reused.
|
// The channel ID. These are reused.
|
||||||
//
|
//
|
||||||
int chid() { return chid_; }
|
int chid() const { return chid_; }
|
||||||
|
|
||||||
// If this is a socket connection, the receiver's port number.
|
// If this is a socket connection, the receiver's port number.
|
||||||
//
|
//
|
||||||
int port() { return port_; }
|
int port() const { return port_; }
|
||||||
|
|
||||||
// If this is an outgoing socket connection, get the target host.
|
// If this is an outgoing socket connection, get the target host.
|
||||||
const eng::string &target() { return target_; }
|
//
|
||||||
|
const eng::string &target() const { return target_; }
|
||||||
|
|
||||||
// True if the remote closed the connection, or a failure occurred.
|
// True if the remote closed the connection, or a failure occurred.
|
||||||
//
|
//
|
||||||
@@ -133,7 +134,7 @@ public:
|
|||||||
// If this is an empty string, there is no error. If this is set,
|
// If this is an empty string, there is no error. If this is set,
|
||||||
// then the channel is also closed.
|
// then the channel is also closed.
|
||||||
//
|
//
|
||||||
eng::string error() const { return error_; }
|
const eng::string &error() const { return error_; }
|
||||||
|
|
||||||
// Set the prompt for readline mode.
|
// Set the prompt for readline mode.
|
||||||
//
|
//
|
||||||
@@ -314,11 +315,12 @@ public:
|
|||||||
void drv_clear_new_outgoing();
|
void drv_clear_new_outgoing();
|
||||||
|
|
||||||
// Get the target of a channel. A target is a string like
|
// Get the target of a channel. A target is a string like
|
||||||
// "www.whatever.com:80". It indicates the host and port that the channel
|
// "cert:whatever.com:80" or "nocert:whatever.com:80".
|
||||||
// is supposed to be talking to. Non-socket channels and incoming channels
|
// The first word indicate whether or not a valid SSL certificate
|
||||||
// have empty targets.
|
// is required. The second word is the hostname. The third word is
|
||||||
|
// the port number.
|
||||||
//
|
//
|
||||||
const eng::string &drv_get_target(int chid) const;
|
std::string_view drv_get_target(int chid) const;
|
||||||
|
|
||||||
// Return true if the outgoing buffer is empty.
|
// Return true if the outgoing buffer is empty.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -1,36 +1,45 @@
|
|||||||
|
|
||||||
#define CHBUF_SIZE (256*1024)
|
#define CHBUF_SIZE (256 * 1024)
|
||||||
#define POLLVEC_SIZE (DrivenEngine::MAX_CHAN+1)
|
#define POLLVEC_SIZE (DrivenEngine::MAX_CHAN + 1)
|
||||||
|
|
||||||
static MonoClock monoclock;
|
static MonoClock monoclock;
|
||||||
|
|
||||||
namespace util {
|
namespace util
|
||||||
double profiling_clock() {
|
{
|
||||||
|
double profiling_clock()
|
||||||
|
{
|
||||||
return monoclock.get();
|
return monoclock.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void if_error_print_and_exit(const std::string &str) {
|
static void if_error_print_and_exit(const std::string &str)
|
||||||
if (!str.empty()) {
|
{
|
||||||
std::cerr << std::endl << "error: " << str << std::endl;
|
if (!str.empty())
|
||||||
|
{
|
||||||
|
std::cerr << std::endl
|
||||||
|
<< "error: " << str << std::endl;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err) {
|
static std::string_view read_file(const char *fn, char *buf, int bufsize, std::string &err)
|
||||||
|
{
|
||||||
FILE *f = fopen(fn, "r");
|
FILE *f = fopen(fn, "r");
|
||||||
if (f == 0) {
|
if (f == 0)
|
||||||
|
{
|
||||||
err = std::string("cannot read file") + fn;
|
err = std::string("cannot read file") + fn;
|
||||||
buf[0] = 0;
|
buf[0] = 0;
|
||||||
return std::string_view(buf, 0);
|
return std::string_view(buf, 0);
|
||||||
}
|
}
|
||||||
int nread = fread(buf, 1, bufsize, f);
|
int nread = fread(buf, 1, bufsize, f);
|
||||||
if (nread < 0) {
|
if (nread < 0)
|
||||||
|
{
|
||||||
err = std::string("cannot read file: ") + fn;
|
err = std::string("cannot read file: ") + fn;
|
||||||
buf[0] = 0;
|
buf[0] = 0;
|
||||||
return std::string_view(buf, 0);
|
return std::string_view(buf, 0);
|
||||||
}
|
}
|
||||||
if (nread == bufsize) {
|
if (nread == bufsize)
|
||||||
|
{
|
||||||
err = std::string("file too large: ") + fn;
|
err = std::string("file too large: ") + fn;
|
||||||
buf[0] = 0;
|
buf[0] = 0;
|
||||||
return std::string_view(buf, 0);
|
return std::string_view(buf, 0);
|
||||||
@@ -39,42 +48,66 @@ static std::string_view read_file(const char *fn, char *buf, int bufsize, std::s
|
|||||||
return std::string_view(buf, nread);
|
return std::string_view(buf, nread);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SSL_CTX_Deleter {
|
struct SSL_CTX_Deleter
|
||||||
void operator()(SSL_CTX *ctx) {
|
{
|
||||||
|
void operator()(SSL_CTX *ctx)
|
||||||
|
{
|
||||||
SSL_CTX_free(ctx);
|
SSL_CTX_free(ctx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
using UniqueSSLCTX = std::unique_ptr<SSL_CTX, SSL_CTX_Deleter>;
|
using UniqueSSLCTX = std::unique_ptr<SSL_CTX, SSL_CTX_Deleter>;
|
||||||
|
|
||||||
static UniqueSSLCTX new_ssl_context(bool server_cert, bool root_certs, std::string_view require_cert) {
|
static std::string ssl_errors_string(bool lastonly = true)
|
||||||
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
|
{
|
||||||
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
std::string err;
|
||||||
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
const char *file, *data, *func;
|
||||||
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
|
int line, flags;
|
||||||
// server_cert is not implemented yet.
|
|
||||||
if (root_certs) {
|
while (true)
|
||||||
SSL_CTX_set_default_verify_paths(ctx);
|
{
|
||||||
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
|
unsigned long code = ERR_get_error_all(&file, &line, &func, &data, &flags);
|
||||||
|
if (code == 0)
|
||||||
|
break;
|
||||||
|
std::string reason;
|
||||||
|
if (ERR_SYSTEM_ERROR(code))
|
||||||
|
{
|
||||||
|
reason = strerror_str(ERR_GET_REASON(code));
|
||||||
}
|
}
|
||||||
// require_cert is not implemented yet.
|
else
|
||||||
return UniqueSSLCTX(ctx);
|
{
|
||||||
|
const char *rc = ERR_reason_error_string(code);
|
||||||
|
reason = (rc == nullptr) ? "unknown" : rc;
|
||||||
|
}
|
||||||
|
if (err.empty() || lastonly)
|
||||||
|
{
|
||||||
|
err = reason;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
err = err + ", " + reason;
|
||||||
|
}
|
||||||
|
if (data != nullptr)
|
||||||
|
{
|
||||||
|
err = err + " " + data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void assert_ssl_errors_empty()
|
||||||
|
{
|
||||||
|
int code = ERR_peek_error();
|
||||||
static std::string err_print_errors_str() {
|
if (code != 0)
|
||||||
BIO *bio = BIO_new(BIO_s_mem());
|
{
|
||||||
ERR_print_errors(bio);
|
std::cerr << "SSL should not have errors at this point." << std::endl;
|
||||||
char *buf;
|
ERR_print_errors_fp(stderr);
|
||||||
size_t len = BIO_get_mem_data(bio, &buf);
|
exit(1);
|
||||||
std::string ret(buf, len);
|
}
|
||||||
BIO_free(bio);
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
|
static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str)
|
||||||
|
{
|
||||||
BIO *bio = BIO_new(BIO_s_mem());
|
BIO *bio = BIO_new(BIO_s_mem());
|
||||||
BIO_puts(bio, str);
|
BIO_puts(bio, str);
|
||||||
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
|
X509 *certificate = PEM_read_bio_X509(bio, NULL, NULL, NULL);
|
||||||
@@ -84,7 +117,8 @@ static int ssl_ctx_use_certificate_str(SSL_CTX *ctx, const char *str) {
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
|
static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str)
|
||||||
|
{
|
||||||
BIO *bio = BIO_new(BIO_s_mem());
|
BIO *bio = BIO_new(BIO_s_mem());
|
||||||
BIO_puts(bio, str);
|
BIO_puts(bio, str);
|
||||||
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
|
EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
|
||||||
@@ -94,18 +128,33 @@ static int ssl_ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) {
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ssl_ctx_use_dummycert(SSL_CTX *ctx)
|
||||||
|
{
|
||||||
|
if (ssl_ctx_use_certificate_str(ctx, dummycert::certificate) <= 0)
|
||||||
|
{
|
||||||
|
ERR_print_errors_fp(stderr);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
if (ssl_ctx_use_privatekey_str(ctx, dummycert::privatekey) <= 0)
|
||||||
|
{
|
||||||
|
ERR_print_errors_fp(stderr);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class Driver {
|
class Driver
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
|
enum ChanState
|
||||||
enum ChanState {
|
{
|
||||||
CHAN_INACTIVE,
|
CHAN_INACTIVE,
|
||||||
CHAN_PLAINTEXT,
|
CHAN_PLAINTEXT,
|
||||||
CHAN_SSL_CONNECTING,
|
CHAN_SSL_CONNECTING,
|
||||||
CHAN_SSL_ACCEPTING,
|
CHAN_SSL_ACCEPTING,
|
||||||
CHAN_SSL_READWRITE,
|
CHAN_SSL_READWRITE,
|
||||||
};
|
};
|
||||||
struct ChanInfo {
|
struct ChanInfo
|
||||||
|
{
|
||||||
int chid;
|
int chid;
|
||||||
SOCKET socket;
|
SOCKET socket;
|
||||||
SSL *ssl;
|
SSL *ssl;
|
||||||
@@ -129,15 +178,17 @@ public:
|
|||||||
std::unique_ptr<struct pollfd[]> pollvec_;
|
std::unique_ptr<struct pollfd[]> pollvec_;
|
||||||
drv::ReplayRecorder recorder_;
|
drv::ReplayRecorder recorder_;
|
||||||
|
|
||||||
UniqueSSLCTX ssl_ctx_with_root_certs_;
|
UniqueSSLCTX ssl_server_ctx_;
|
||||||
UniqueSSLCTX ssl_ctx_with_server_certs_;
|
UniqueSSLCTX ssl_client_secure_ctx_;
|
||||||
UniqueSSLCTX ssl_ctx_with_no_certs_;
|
UniqueSSLCTX ssl_client_insecure_ctx_;
|
||||||
|
|
||||||
|
void handle_listen_ports()
|
||||||
void handle_listen_ports() {
|
{
|
||||||
const auto &listenports = recorder_.drv_get_listen_ports();
|
const auto &listenports = recorder_.drv_get_listen_ports();
|
||||||
for (int port : listenports) {
|
for (int port : listenports)
|
||||||
if (listen_sockets_.find(port) == listen_sockets_.end()) {
|
{
|
||||||
|
if (listen_sockets_.find(port) == listen_sockets_.end())
|
||||||
|
{
|
||||||
std::string err;
|
std::string err;
|
||||||
SOCKET sock = listen_on_port(port, err);
|
SOCKET sock = listen_on_port(port, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
@@ -147,14 +198,17 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_lua_source() {
|
void handle_lua_source()
|
||||||
if (recorder_.drv_get_rescan_lua_source()) {
|
{
|
||||||
|
if (recorder_.drv_get_rescan_lua_source())
|
||||||
|
{
|
||||||
std::string err;
|
std::string err;
|
||||||
std::string_view ctrl = read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err);
|
std::string_view ctrl = read_file("lua/control.lst", chbuf_.get(), CHBUF_SIZE, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
std::vector<std::string> names = drv::parse_control_lst(ctrl);
|
std::vector<std::string> names = drv::parse_control_lst(ctrl);
|
||||||
recorder_.drv_clear_lua_source();
|
recorder_.drv_clear_lua_source();
|
||||||
for (const std::string &str : names) {
|
for (const std::string &str : names)
|
||||||
|
{
|
||||||
std::string lfn = std::string("lua/") + str;
|
std::string lfn = std::string("lua/") + str;
|
||||||
std::string_view data = read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err);
|
std::string_view data = read_file(lfn.c_str(), chbuf_.get(), CHBUF_SIZE, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
@@ -163,11 +217,13 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void close_channel(ChanInfo &chan, std::string_view err) {
|
void close_channel(ChanInfo &chan, std::string_view err)
|
||||||
|
{
|
||||||
// std::cerr << "Closing channel " << chan.chid << std::endl;
|
// std::cerr << "Closing channel " << chan.chid << std::endl;
|
||||||
assert(chan.state != CHAN_INACTIVE);
|
assert(chan.state != CHAN_INACTIVE);
|
||||||
// Close and release the SSL channel.
|
// Close and release the SSL channel.
|
||||||
if (chan.ssl != nullptr) {
|
if (chan.ssl != nullptr)
|
||||||
|
{
|
||||||
SSL_free(chan.ssl);
|
SSL_free(chan.ssl);
|
||||||
chan.ssl = nullptr;
|
chan.ssl = nullptr;
|
||||||
}
|
}
|
||||||
@@ -190,39 +246,52 @@ public:
|
|||||||
chan.last_write_nbytes = 0;
|
chan.last_write_nbytes = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void cleanup_channels() {
|
void cleanup_channels()
|
||||||
for (int i = 0; i < int(chans_.size()); ) {
|
{
|
||||||
if (chans_[i].state == CHAN_INACTIVE) {
|
for (int i = 0; i < int(chans_.size());)
|
||||||
|
{
|
||||||
|
if (chans_[i].state == CHAN_INACTIVE)
|
||||||
|
{
|
||||||
chans_[i] = chans_.back();
|
chans_[i] = chans_.back();
|
||||||
chans_.pop_back();
|
chans_.pop_back();
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_console_output() {
|
void handle_console_output()
|
||||||
while (true) {
|
{
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
std::string_view s = recorder_.drv_peek_outgoing(0);
|
std::string_view s = recorder_.drv_peek_outgoing(0);
|
||||||
if (s.size() == 0) break;
|
if (s.size() == 0)
|
||||||
|
break;
|
||||||
int nwrote = console_write(s.data(), s.size());
|
int nwrote = console_write(s.data(), s.size());
|
||||||
if (nwrote <= 0) break;
|
if (nwrote <= 0)
|
||||||
|
break;
|
||||||
recorder_.drv_sent_outgoing(0, nwrote);
|
recorder_.drv_sent_outgoing(0, nwrote);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_console_input() {
|
void handle_console_input()
|
||||||
|
{
|
||||||
char buffer[256];
|
char buffer[256];
|
||||||
read_console_recently_ = false;
|
read_console_recently_ = false;
|
||||||
while (true) {
|
while (true)
|
||||||
|
{
|
||||||
int nread = console_read(buffer, 256);
|
int nread = console_read(buffer, 256);
|
||||||
if (nread <= 0) break;
|
if (nread <= 0)
|
||||||
|
break;
|
||||||
read_console_recently_ = true;
|
read_console_recently_ = true;
|
||||||
recorder_.drv_recv_incoming(0, std::string_view(buffer, nread));
|
recorder_.drv_recv_incoming(0, std::string_view(buffer, nread));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) {
|
void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state)
|
||||||
|
{
|
||||||
ChanInfo newchan;
|
ChanInfo newchan;
|
||||||
newchan.chid = chid;
|
newchan.chid = chid;
|
||||||
newchan.socket = sock;
|
newchan.socket = sock;
|
||||||
@@ -243,53 +312,81 @@ public:
|
|||||||
chans_.push_back(newchan);
|
chans_.push_back(newchan);
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_new_outgoing_sockets() {
|
void handle_new_outgoing_sockets()
|
||||||
|
{
|
||||||
const auto &chans = recorder_.drv_get_new_outgoing();
|
const auto &chans = recorder_.drv_get_new_outgoing();
|
||||||
for (int chid : chans) {
|
for (int chid : chans)
|
||||||
std::string err;
|
{
|
||||||
SOCKET sock = open_connection(recorder_.drv_get_target(chid), err);
|
std::string err, cert, host, port;
|
||||||
if (sock == INVALID_SOCKET) {
|
std::string target(recorder_.drv_get_target(chid));
|
||||||
recorder_.drv_notify_close(chid, err);
|
drv::split_target(target, cert, host, port);
|
||||||
|
if (cert.empty() || host.empty() || port.empty()) {
|
||||||
|
recorder_.drv_notify_close(chid, std::string("invalid target: ") + target);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
SSL_CTX *ctx = nullptr;
|
||||||
|
if (cert == "cert") {
|
||||||
|
ctx = ssl_client_secure_ctx_.get();
|
||||||
|
} else if (cert == "nocert") {
|
||||||
|
ctx = ssl_client_insecure_ctx_.get();
|
||||||
} else {
|
} else {
|
||||||
//std::cerr << "Opening channel " << chid << std::endl;
|
recorder_.drv_notify_close(chid, std::string("invalid cert rule: ") + target);
|
||||||
make_channel(sock, chid, ssl_ctx_with_no_certs_.get(), CHAN_SSL_CONNECTING);
|
continue;
|
||||||
}
|
}
|
||||||
|
SOCKET sock = open_connection(host.c_str(), port.c_str(), err);
|
||||||
|
if (sock == INVALID_SOCKET)
|
||||||
|
{
|
||||||
|
recorder_.drv_notify_close(chid, err);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
if (!chans.empty()) {
|
// std::cerr << "Opening channel " << chid << std::endl;
|
||||||
|
make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING);
|
||||||
|
}
|
||||||
|
if (!chans.empty())
|
||||||
|
{
|
||||||
recorder_.drv_clear_new_outgoing();
|
recorder_.drv_clear_new_outgoing();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void accept_connection(int port, SOCKET sock) {
|
void accept_connection(int port, SOCKET sock)
|
||||||
|
{
|
||||||
std::string err;
|
std::string err;
|
||||||
SOCKET socket = accept_on_socket(sock, err);
|
SOCKET socket = accept_on_socket(sock, err);
|
||||||
if_error_print_and_exit(err);
|
if_error_print_and_exit(err);
|
||||||
if (socket != INVALID_SOCKET) {
|
if (socket != INVALID_SOCKET)
|
||||||
|
{
|
||||||
int chid = recorder_.drv_notify_accept(port);
|
int chid = recorder_.drv_notify_accept(port);
|
||||||
// std::cerr << "Accepted channel " << chid << std::endl;
|
// std::cerr << "Accepted channel " << chid << std::endl;
|
||||||
make_channel(socket, chid, ssl_ctx_with_server_certs_.get(), CHAN_SSL_ACCEPTING);
|
make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_plaintext(ChanInfo &chan) {
|
void advance_plaintext(ChanInfo &chan)
|
||||||
|
{
|
||||||
std::string err;
|
std::string err;
|
||||||
|
|
||||||
// If the channel has no outgoing bytes and has been released,
|
// If the channel has no outgoing bytes and has been released,
|
||||||
// just close it.
|
// just close it.
|
||||||
if (chan.released) {
|
if (chan.released)
|
||||||
|
{
|
||||||
close_channel(chan, "");
|
close_channel(chan, "");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to write plaintext to the channel.
|
// Try to write plaintext to the channel.
|
||||||
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
||||||
if (s.size() > 0) {
|
if (s.size() > 0)
|
||||||
|
{
|
||||||
int sbytes = s.size();
|
int sbytes = s.size();
|
||||||
if (sbytes > 65536) sbytes = 65536;
|
if (sbytes > 65536)
|
||||||
|
sbytes = 65536;
|
||||||
int wbytes = socket_send(chan.socket, s.data(), sbytes, err);
|
int wbytes = socket_send(chan.socket, s.data(), sbytes, err);
|
||||||
if (wbytes < 0) {
|
if (wbytes < 0)
|
||||||
|
{
|
||||||
close_channel(chan, err);
|
close_channel(chan, err);
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
recorder_.drv_sent_outgoing(chan.chid, wbytes);
|
recorder_.drv_sent_outgoing(chan.chid, wbytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -297,9 +394,12 @@ public:
|
|||||||
// Try to read plaintext from the channel.
|
// Try to read plaintext from the channel.
|
||||||
// Someday, find a way to avoid this copy.
|
// Someday, find a way to avoid this copy.
|
||||||
int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err);
|
int nrecv = socket_recv(chan.socket, chbuf_.get(), 65536, err);
|
||||||
if (nrecv < 0) {
|
if (nrecv < 0)
|
||||||
|
{
|
||||||
close_channel(chan, err);
|
close_channel(chan, err);
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), nrecv));
|
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), nrecv));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,84 +408,116 @@ public:
|
|||||||
chan.ready_on_pollin = true;
|
chan.ready_on_pollin = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void process_ssl_error(ChanInfo &chan, int retval) {
|
void process_ssl_error(ChanInfo &chan, int retval)
|
||||||
|
{
|
||||||
int error = SSL_get_error(chan.ssl, retval);
|
int error = SSL_get_error(chan.ssl, retval);
|
||||||
// std::cerr << "SSL error code = " << error << " ";
|
// std::cerr << "SSL error code = " << error << " ";
|
||||||
if (error == SSL_ERROR_WANT_READ) {
|
if (error == SSL_ERROR_WANT_READ)
|
||||||
|
{
|
||||||
chan.ready_on_pollin = true;
|
chan.ready_on_pollin = true;
|
||||||
} else if (error == SSL_ERROR_WANT_WRITE) {
|
}
|
||||||
|
else if (error == SSL_ERROR_WANT_WRITE)
|
||||||
|
{
|
||||||
chan.ready_on_pollout = true;
|
chan.ready_on_pollout = true;
|
||||||
} else {
|
}
|
||||||
close_channel(chan, err_print_errors_str());
|
else
|
||||||
|
{
|
||||||
|
close_channel(chan, ssl_errors_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_ssl_connecting(ChanInfo &chan) {
|
void advance_ssl_connecting(ChanInfo &chan)
|
||||||
|
{
|
||||||
// std::cerr << "In advance_ssl_connecting" << std::endl;
|
// std::cerr << "In advance_ssl_connecting" << std::endl;
|
||||||
int retval = SSL_connect(chan.ssl);
|
int retval = SSL_connect(chan.ssl);
|
||||||
if (retval == 1) {
|
if (retval == 1)
|
||||||
|
{
|
||||||
// Connection successful.
|
// Connection successful.
|
||||||
chan.state = CHAN_SSL_READWRITE;
|
chan.state = CHAN_SSL_READWRITE;
|
||||||
chan.ready_now = true;
|
chan.ready_now = true;
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
// std::cerr << "ssl_connect_error";
|
// std::cerr << "ssl_connect_error";
|
||||||
process_ssl_error(chan, retval);
|
process_ssl_error(chan, retval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_ssl_accepting(ChanInfo &chan) {
|
void advance_ssl_accepting(ChanInfo &chan)
|
||||||
|
{
|
||||||
// std::cerr << "In advance_ssl_accepting" << std::endl;
|
// std::cerr << "In advance_ssl_accepting" << std::endl;
|
||||||
int retval = SSL_accept(chan.ssl);
|
int retval = SSL_accept(chan.ssl);
|
||||||
if (retval == 1) {
|
if (retval == 1)
|
||||||
|
{
|
||||||
// Connection successful.
|
// Connection successful.
|
||||||
chan.state = CHAN_SSL_READWRITE;
|
chan.state = CHAN_SSL_READWRITE;
|
||||||
chan.ready_now = true;
|
chan.ready_now = true;
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
process_ssl_error(chan, retval);
|
process_ssl_error(chan, retval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_ssl_readwrite(ChanInfo &chan) {
|
void advance_ssl_readwrite(ChanInfo &chan)
|
||||||
|
{
|
||||||
// std::cerr << "In advance_ssl_readwrite" << std::endl;
|
// std::cerr << "In advance_ssl_readwrite" << std::endl;
|
||||||
// Try to read data.
|
// Try to read data.
|
||||||
int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536);
|
int read_result = SSL_read(chan.ssl, chbuf_.get(), 65536);
|
||||||
if (read_result > 0) {
|
if (read_result > 0)
|
||||||
|
{
|
||||||
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), read_result));
|
recorder_.drv_recv_incoming(chan.chid, std::string_view(chbuf_.get(), read_result));
|
||||||
chan.ready_now = true;
|
chan.ready_now = true;
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
process_ssl_error(chan, read_result);
|
process_ssl_error(chan, read_result);
|
||||||
if (chan.state == CHAN_INACTIVE) return;
|
if (chan.state == CHAN_INACTIVE)
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to write data.
|
// Try to write data.
|
||||||
int wbytes;
|
int wbytes;
|
||||||
if (chan.last_write_nbytes > 0) {
|
if (chan.last_write_nbytes > 0)
|
||||||
|
{
|
||||||
wbytes = chan.last_write_nbytes;
|
wbytes = chan.last_write_nbytes;
|
||||||
assert(wbytes < chan.nbytes);
|
assert(wbytes < chan.nbytes);
|
||||||
} else {
|
|
||||||
wbytes = chan.nbytes;
|
|
||||||
if (wbytes > 65536) wbytes = 65536;
|
|
||||||
}
|
}
|
||||||
if (wbytes > 0) {
|
else
|
||||||
|
{
|
||||||
|
wbytes = chan.nbytes;
|
||||||
|
if (wbytes > 65536)
|
||||||
|
wbytes = 65536;
|
||||||
|
}
|
||||||
|
if (wbytes > 0)
|
||||||
|
{
|
||||||
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
|
int write_result = SSL_write(chan.ssl, chan.bytes, wbytes);
|
||||||
if (write_result > 0) {
|
if (write_result > 0)
|
||||||
|
{
|
||||||
recorder_.drv_sent_outgoing(chan.chid, write_result);
|
recorder_.drv_sent_outgoing(chan.chid, write_result);
|
||||||
chan.last_write_nbytes = 0;
|
chan.last_write_nbytes = 0;
|
||||||
chan.ready_on_outgoing = true;
|
chan.ready_on_outgoing = true;
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
chan.last_write_nbytes = wbytes;
|
chan.last_write_nbytes = wbytes;
|
||||||
process_ssl_error(chan, write_result);
|
process_ssl_error(chan, write_result);
|
||||||
if (chan.state == CHAN_INACTIVE) return;
|
if (chan.state == CHAN_INACTIVE)
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
chan.ready_on_outgoing = true;
|
chan.ready_on_outgoing = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " ";
|
// std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << chan.ready_on_outgoing << " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
void advance_channel(ChanInfo &chan) {
|
void advance_channel(ChanInfo &chan)
|
||||||
switch(chan.state) {
|
{
|
||||||
|
assert_ssl_errors_empty();
|
||||||
|
switch (chan.state)
|
||||||
|
{
|
||||||
case CHAN_PLAINTEXT:
|
case CHAN_PLAINTEXT:
|
||||||
advance_plaintext(chan);
|
advance_plaintext(chan);
|
||||||
break;
|
break;
|
||||||
@@ -402,20 +534,23 @@ public:
|
|||||||
assert(false);
|
assert(false);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
assert_ssl_errors_empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void handle_socket_input_output()
|
||||||
void handle_socket_input_output() {
|
{
|
||||||
std::string err;
|
std::string err;
|
||||||
int mstimeout = read_console_recently_ ? 100 : 1000;
|
int mstimeout = read_console_recently_ ? 100 : 1000;
|
||||||
|
|
||||||
// Peek output buffers and determine channel release flags.
|
// Peek output buffers and determine channel release flags.
|
||||||
for (ChanInfo &chan : chans_) {
|
for (ChanInfo &chan : chans_)
|
||||||
|
{
|
||||||
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
std::string_view s = recorder_.drv_peek_outgoing(chan.chid);
|
||||||
chan.nbytes = s.size();
|
chan.nbytes = s.size();
|
||||||
chan.bytes = s.data();
|
chan.bytes = s.data();
|
||||||
chan.just_released = false;
|
chan.just_released = false;
|
||||||
if ((chan.nbytes == 0)&&(!chan.released)) {
|
if ((chan.nbytes == 0) && (!chan.released))
|
||||||
|
{
|
||||||
chan.released = recorder_.drv_get_channel_released(chan.chid);
|
chan.released = recorder_.drv_get_channel_released(chan.chid);
|
||||||
chan.just_released = chan.released;
|
chan.just_released = chan.released;
|
||||||
}
|
}
|
||||||
@@ -423,23 +558,30 @@ public:
|
|||||||
|
|
||||||
// Construct the struct pollfd vector.
|
// Construct the struct pollfd vector.
|
||||||
int pollsize = 0;
|
int pollsize = 0;
|
||||||
for (const auto &p : listen_sockets_) {
|
for (const auto &p : listen_sockets_)
|
||||||
|
{
|
||||||
struct pollfd &pfd = pollvec_[pollsize++];
|
struct pollfd &pfd = pollvec_[pollsize++];
|
||||||
pfd.fd = p.second;
|
pfd.fd = p.second;
|
||||||
pfd.events = POLLIN;
|
pfd.events = POLLIN;
|
||||||
pfd.revents = 0;
|
pfd.revents = 0;
|
||||||
}
|
}
|
||||||
for (const ChanInfo &chan : chans_) {
|
for (const ChanInfo &chan : chans_)
|
||||||
|
{
|
||||||
struct pollfd &pfd = pollvec_[pollsize++];
|
struct pollfd &pfd = pollvec_[pollsize++];
|
||||||
assert(chan.socket != INVALID_SOCKET);
|
assert(chan.socket != INVALID_SOCKET);
|
||||||
pfd.fd = chan.socket;
|
pfd.fd = chan.socket;
|
||||||
pfd.events = 0;
|
pfd.events = 0;
|
||||||
pfd.revents = 0;
|
pfd.revents = 0;
|
||||||
if (chan.ready_now) mstimeout = 0;
|
if (chan.ready_now)
|
||||||
if (chan.just_released) mstimeout = 0;
|
mstimeout = 0;
|
||||||
if (chan.ready_on_pollin) pfd.events |= POLLIN;
|
if (chan.just_released)
|
||||||
if (chan.ready_on_pollout) pfd.events |= POLLOUT;
|
mstimeout = 0;
|
||||||
if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT;
|
if (chan.ready_on_pollin)
|
||||||
|
pfd.events |= POLLIN;
|
||||||
|
if (chan.ready_on_pollout)
|
||||||
|
pfd.events |= POLLOUT;
|
||||||
|
if (chan.ready_on_outgoing && (chan.nbytes > 0))
|
||||||
|
pfd.events |= POLLOUT;
|
||||||
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " ";
|
// std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -449,15 +591,18 @@ public:
|
|||||||
|
|
||||||
// Check listening sockets.
|
// Check listening sockets.
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (auto &p : listen_sockets_) {
|
for (auto &p : listen_sockets_)
|
||||||
|
{
|
||||||
struct pollfd &pfd = pollvec_[index++];
|
struct pollfd &pfd = pollvec_[index++];
|
||||||
if (pfd.revents & (POLLIN | POLLERR)) {
|
if (pfd.revents & (POLLIN | POLLERR))
|
||||||
|
{
|
||||||
accept_connection(p.first, p.second);
|
accept_connection(p.first, p.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Advance channels where possible.
|
// Advance channels where possible.
|
||||||
for (ChanInfo &chan : chans_) {
|
for (ChanInfo &chan : chans_)
|
||||||
|
{
|
||||||
struct pollfd &pfd = pollvec_[index++];
|
struct pollfd &pfd = pollvec_[index++];
|
||||||
bool pollin = ((pfd.revents & POLLIN) != 0);
|
bool pollin = ((pfd.revents & POLLIN) != 0);
|
||||||
bool pollout = ((pfd.revents & POLLOUT) != 0);
|
bool pollout = ((pfd.revents & POLLOUT) != 0);
|
||||||
@@ -465,7 +610,8 @@ public:
|
|||||||
if (chan.ready_now || pollerr || chan.just_released ||
|
if (chan.ready_now || pollerr || chan.just_released ||
|
||||||
(chan.ready_on_pollin && pollin) ||
|
(chan.ready_on_pollin && pollin) ||
|
||||||
(chan.ready_on_pollout && pollout) ||
|
(chan.ready_on_pollout && pollout) ||
|
||||||
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) {
|
(chan.ready_on_outgoing && (chan.nbytes > 0) && pollout))
|
||||||
|
{
|
||||||
chan.ready_now = false;
|
chan.ready_now = false;
|
||||||
chan.ready_on_pollin = false;
|
chan.ready_on_pollin = false;
|
||||||
chan.ready_on_pollout = false;
|
chan.ready_on_pollout = false;
|
||||||
@@ -480,36 +626,46 @@ public:
|
|||||||
cleanup_channels();
|
cleanup_channels();
|
||||||
}
|
}
|
||||||
|
|
||||||
int replay_logfile(const char *fn, bool verbose) {
|
int replay_logfile(const char *fn, bool verbose)
|
||||||
|
{
|
||||||
drv::ReplayPlayer player;
|
drv::ReplayPlayer player;
|
||||||
player.open_logfile(fn);
|
player.open_logfile(fn);
|
||||||
if (verbose) {
|
if (verbose)
|
||||||
|
{
|
||||||
player.enable_stdout();
|
player.enable_stdout();
|
||||||
}
|
}
|
||||||
while (true) {
|
while (true)
|
||||||
|
{
|
||||||
drv::ReplayPlayer::Status st = player.step();
|
drv::ReplayPlayer::Status st = player.step();
|
||||||
if (st != drv::ReplayPlayer::ST_REPLAYING) {
|
if (st != drv::ReplayPlayer::ST_REPLAYING)
|
||||||
|
{
|
||||||
player.print_status(std::cerr);
|
player.print_status(std::cerr);
|
||||||
return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1;
|
return (st == drv::ReplayPlayer::ST_CLEAN_EXIT) ? 0 : 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int drive(int argc, char *argv[]) {
|
int drive(int argc, char *argv[])
|
||||||
|
{
|
||||||
// Remove the program name from argv.
|
// Remove the program name from argv.
|
||||||
if (argc < 1) {
|
if (argc < 1)
|
||||||
|
{
|
||||||
DrivenEngine::print_usage(std::cerr, "<unknown>");
|
DrivenEngine::print_usage(std::cerr, "<unknown>");
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
std::string program = argv[0];
|
std::string program = argv[0];
|
||||||
argc -= 1; argv += 1;
|
argc -= 1;
|
||||||
|
argv += 1;
|
||||||
|
|
||||||
// If argv contains "replay <filename>", do a replay,
|
// If argv contains "replay <filename>", do a replay,
|
||||||
// and then skip everything else.
|
// and then skip everything else.
|
||||||
if (argc >= 1) {
|
if (argc >= 1)
|
||||||
|
{
|
||||||
std::string cmd(argv[0]);
|
std::string cmd(argv[0]);
|
||||||
if ((cmd == "replay") || (cmd == "vreplay")) {
|
if ((cmd == "replay") || (cmd == "vreplay"))
|
||||||
if (argc != 2) {
|
{
|
||||||
|
if (argc != 2)
|
||||||
|
{
|
||||||
std::cerr << "usage: " << program << " replay <filename>" << std::endl;
|
std::cerr << "usage: " << program << " replay <filename>" << std::endl;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@@ -519,29 +675,36 @@ public:
|
|||||||
|
|
||||||
// If argv contains "record <filename>", start recording,
|
// If argv contains "record <filename>", start recording,
|
||||||
// and remove the "record <filename>" from argv.
|
// and remove the "record <filename>" from argv.
|
||||||
if (argc >= 1) {
|
if (argc >= 1)
|
||||||
|
{
|
||||||
std::string cmd = argv[0];
|
std::string cmd = argv[0];
|
||||||
if (cmd == "record") {
|
if (cmd == "record")
|
||||||
if (argc < 2) {
|
{
|
||||||
|
if (argc < 2)
|
||||||
|
{
|
||||||
DrivenEngine::print_usage(std::cerr, program);
|
DrivenEngine::print_usage(std::cerr, program);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
bool ok = recorder_.open_logfile(argv[1]);
|
bool ok = recorder_.open_logfile(argv[1]);
|
||||||
if (!ok) {
|
if (!ok)
|
||||||
|
{
|
||||||
std::cerr << "Could not open logfile: " << argv[1] << std::endl;
|
std::cerr << "Could not open logfile: " << argv[1] << std::endl;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
argc -= 2; argv += 2;
|
argc -= 2;
|
||||||
|
argv += 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the engine.
|
// Create the engine.
|
||||||
if (argc < 1) {
|
if (argc < 1)
|
||||||
|
{
|
||||||
DrivenEngine::print_usage(std::cerr, program);
|
DrivenEngine::print_usage(std::cerr, program);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
bool engine_made = recorder_.create_engine(argv[0]);
|
bool engine_made = recorder_.create_engine(argv[0]);
|
||||||
if (!engine_made) {
|
if (!engine_made)
|
||||||
|
{
|
||||||
DrivenEngine::print_usage(std::cerr, program);
|
DrivenEngine::print_usage(std::cerr, program);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@@ -551,25 +714,17 @@ public:
|
|||||||
chbuf_.reset(new char[CHBUF_SIZE]);
|
chbuf_.reset(new char[CHBUF_SIZE]);
|
||||||
pollvec_.reset(new struct pollfd[POLLVEC_SIZE]);
|
pollvec_.reset(new struct pollfd[POLLVEC_SIZE]);
|
||||||
|
|
||||||
ssl_ctx_with_root_certs_ = new_ssl_context(false, true, "");
|
ssl_server_ctx_.reset(new_ssl_server_context());
|
||||||
ssl_ctx_with_server_certs_ = new_ssl_context(true, false, "");
|
ssl_client_secure_ctx_.reset(new_ssl_client_context(SSL_VERIFY_PEER));
|
||||||
ssl_ctx_with_no_certs_ = new_ssl_context(false, false, "");
|
ssl_client_insecure_ctx_.reset(new_ssl_client_context(SSL_VERIFY_NONE));
|
||||||
|
assert_ssl_errors_empty();
|
||||||
if (ssl_ctx_use_certificate_str(ssl_ctx_with_server_certs_.get(), dummycert::certificate) <= 0) {
|
|
||||||
ERR_print_errors_fp(stderr);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ssl_ctx_use_privatekey_str(ssl_ctx_with_server_certs_.get(), dummycert::privatekey) <= 0 ) {
|
|
||||||
ERR_print_errors_fp(stderr);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
handle_lua_source();
|
handle_lua_source();
|
||||||
recorder_.drv_invoke_event_init(argc, argv);
|
recorder_.drv_invoke_event_init(argc, argv);
|
||||||
handle_listen_ports();
|
handle_listen_ports();
|
||||||
|
|
||||||
while (!recorder_.drv_get_stop_driver()) {
|
while (!recorder_.drv_get_stop_driver())
|
||||||
|
{
|
||||||
handle_lua_source();
|
handle_lua_source();
|
||||||
handle_console_output();
|
handle_console_output();
|
||||||
handle_new_outgoing_sockets();
|
handle_new_outgoing_sockets();
|
||||||
@@ -579,7 +734,8 @@ public:
|
|||||||
recorder_.drv_invoke_event_update(monoclock.get());
|
recorder_.drv_invoke_event_update(monoclock.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ChanInfo &chan : chans_) {
|
for (ChanInfo &chan : chans_)
|
||||||
|
{
|
||||||
close_channel(chan, "");
|
close_channel(chan, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -588,5 +744,3 @@ public:
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ struct termios orig_termios;
|
|||||||
|
|
||||||
static std::string strerror_str(int err) {
|
static std::string strerror_str(int err) {
|
||||||
char errbuf[256];
|
char errbuf[256];
|
||||||
return strerror_r(errno, errbuf, 256);
|
return strerror_r(err, errbuf, 256);
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_nonblocking(int fd) {
|
void set_nonblocking(int fd) {
|
||||||
@@ -69,7 +69,7 @@ static void enable_tty_raw() {
|
|||||||
assert(status >= 0);
|
assert(status >= 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static SOCKET open_connection(std::string_view target, std::string &err) {
|
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
|
||||||
struct addrinfo *addrs = nullptr;
|
struct addrinfo *addrs = nullptr;
|
||||||
struct addrinfo *goodaddr = nullptr;
|
struct addrinfo *goodaddr = nullptr;
|
||||||
struct addrinfo hints;
|
struct addrinfo hints;
|
||||||
@@ -82,9 +82,7 @@ static SOCKET open_connection(std::string_view target, std::string &err) {
|
|||||||
hints.ai_flags = AI_NUMERICSERV;
|
hints.ai_flags = AI_NUMERICSERV;
|
||||||
|
|
||||||
err.clear();
|
err.clear();
|
||||||
std::string host, port;
|
int status = getaddrinfo(host, port, &hints, &addrs);
|
||||||
drv::split_host_port(target, host, port);
|
|
||||||
int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs);
|
|
||||||
if (status != 0) {
|
if (status != 0) {
|
||||||
err = gai_strerror(status);
|
err = gai_strerror(status);
|
||||||
goto error_general;
|
goto error_general;
|
||||||
@@ -228,6 +226,25 @@ static int console_read(char *bytes, int nbytes) {
|
|||||||
return read(0, bytes, nbytes);
|
return read(0, bytes, nbytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ssl_ctx_use_dummycert(SSL_CTX *ctx);
|
||||||
|
|
||||||
|
static SSL_CTX *new_ssl_server_context() {
|
||||||
|
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||||
|
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
|
||||||
|
ssl_ctx_use_dummycert(ctx);
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SSL_CTX *new_ssl_client_context(int verify) {
|
||||||
|
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||||
|
SSL_CTX_set_default_verify_paths(ctx);
|
||||||
|
SSL_CTX_set_verify(ctx, verify, nullptr);
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
static void disable_randomization(int argc, char *argv[]) {
|
static void disable_randomization(int argc, char *argv[]) {
|
||||||
const int old_personality = personality(ADDR_NO_RANDOMIZE);
|
const int old_personality = personality(ADDR_NO_RANDOMIZE);
|
||||||
|
|||||||
@@ -28,19 +28,13 @@
|
|||||||
#include <openssl/bio.h>
|
#include <openssl/bio.h>
|
||||||
#include <openssl/pem.h>
|
#include <openssl/pem.h>
|
||||||
|
|
||||||
#define CHBUF_SIZE (256*1024)
|
|
||||||
#define POLLVEC_SIZE (DrivenEngine::MAX_CHAN+1)
|
|
||||||
|
|
||||||
static std::unique_ptr<char[]> chbuf;
|
|
||||||
static std::unique_ptr<struct pollfd[]> pollvec;
|
|
||||||
|
|
||||||
static void set_nonblocking(SOCKET sock) {
|
static void set_nonblocking(SOCKET sock) {
|
||||||
u_long mode = 1; // 1 to enable non-blocking socket
|
u_long mode = 1; // 1 to enable non-blocking socket
|
||||||
int status = ioctlsocket(sock, FIONBIO, &mode);
|
int status = ioctlsocket(sock, FIONBIO, &mode);
|
||||||
assert(status == 0);
|
assert(status == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string winsock_error_string(int errcode) {
|
static std::string strerror_str(int errcode) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "error " << errcode;
|
oss << "error " << errcode;
|
||||||
return oss.str();
|
return oss.str();
|
||||||
@@ -55,17 +49,15 @@ static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static SOCKET open_connection(std::string_view target, std::string &err) {
|
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
|
||||||
PADDRINFOA addrs = nullptr;
|
PADDRINFOA addrs = nullptr;
|
||||||
PADDRINFOA goodaddr = nullptr;
|
PADDRINFOA goodaddr = nullptr;
|
||||||
SOCKET sock = INVALID_SOCKET;
|
SOCKET sock = INVALID_SOCKET;
|
||||||
std::string host, port;
|
|
||||||
|
|
||||||
err.clear();
|
err.clear();
|
||||||
drv::split_host_port(target, host, port);
|
int status = getaddrinfo(host, port, nullptr, &addrs);
|
||||||
int status = getaddrinfo(host.data(), port.data(), nullptr, &addrs);
|
|
||||||
while (status == WSATRY_AGAIN) {
|
while (status == WSATRY_AGAIN) {
|
||||||
status = getaddrinfo(host.data(), port.data(), nullptr, &addrs);
|
status = getaddrinfo(host, port, nullptr, &addrs);
|
||||||
}
|
}
|
||||||
if (status == WSAHOST_NOT_FOUND) {
|
if (status == WSAHOST_NOT_FOUND) {
|
||||||
err = "host not found";
|
err = "host not found";
|
||||||
@@ -194,7 +186,7 @@ static int socket_close(SOCKET socket) {
|
|||||||
static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) {
|
static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) {
|
||||||
int status = WSAPoll(pollvec, pollcount, mstimeout);
|
int status = WSAPoll(pollvec, pollcount, mstimeout);
|
||||||
if (status < 0) {
|
if (status < 0) {
|
||||||
err = winsock_error_string(WSAGetLastError());
|
err = strerror_str(WSAGetLastError());
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
@@ -243,7 +235,23 @@ static int console_read(char *bytes, int nbytes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void driver_sysinit(int argc, char *argv[]) {
|
static void ssl_ctx_use_dummycert(SSL_CTX *ctx);
|
||||||
|
|
||||||
|
static SSL_CTX *new_ssl_server_context() {
|
||||||
|
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||||
|
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr);
|
||||||
|
ssl_ctx_use_dummycert(ctx);
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SSL_CTX *new_ssl_client_context(int verify) {
|
||||||
|
SSL_CTX *ctx = SSL_CTX_new(TLS_method());
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
||||||
|
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||||
|
SSL_CTX_set_verify(ctx, verify, nullptr);
|
||||||
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
class MonoClock {
|
class MonoClock {
|
||||||
|
|||||||
@@ -17,16 +17,31 @@
|
|||||||
|
|
||||||
namespace drv {
|
namespace drv {
|
||||||
|
|
||||||
void split_host_port(std::string_view target, std::string &host, std::string &port) {
|
std::vector<std::string_view> split_view(std::string_view v, char sep) {
|
||||||
size_t lastcolon = target.rfind(':');
|
std::vector<std::string_view> result;
|
||||||
if (lastcolon == std::string_view::npos) {
|
while (true) {
|
||||||
host = ""; port = ""; return;
|
size_t pos = v.find(sep);
|
||||||
|
if (pos == std::string_view::npos) break;
|
||||||
|
result.push_back(v.substr(0, pos));
|
||||||
|
v = v.substr(pos + 1);
|
||||||
}
|
}
|
||||||
host = target.substr(0, lastcolon);
|
result.push_back(v);
|
||||||
port = target.substr(lastcolon + 1);
|
return result;
|
||||||
if ((host == "") || (port == "")) {
|
}
|
||||||
host = ""; port = ""; return;
|
|
||||||
|
void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port) {
|
||||||
|
std::vector<std::string_view> split = split_view(target, ':');
|
||||||
|
if (split.size() != 3) {
|
||||||
|
cert.clear(); host.clear(); port.clear();
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
if (split[0].empty() || split[1].empty() || split[2].empty()) {
|
||||||
|
cert.clear(); host.clear(); port.clear();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cert = std::string(split[0]);
|
||||||
|
host = std::string(split[1]);
|
||||||
|
port = std::string(split[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> parse_control_lst(std::string_view ctrl) {
|
std::vector<std::string> parse_control_lst(std::string_view ctrl) {
|
||||||
@@ -502,9 +517,10 @@ void ReplayPlayer::drv_invoke_event_update() {
|
|||||||
} // namespace drv
|
} // namespace drv
|
||||||
|
|
||||||
LuaDefine(unittests_driverutil, "", "some unit tests") {
|
LuaDefine(unittests_driverutil, "", "some unit tests") {
|
||||||
// Test split_host_port
|
// Test split_target
|
||||||
std::string host, port;
|
std::string cert, host, port;
|
||||||
drv::split_host_port("stanford.edu:80", host, port);
|
drv::split_target("cert:stanford.edu:80", cert, host, port);
|
||||||
|
LuaAssertStrEq(L, cert, "cert");
|
||||||
LuaAssertStrEq(L, host, "stanford.edu");
|
LuaAssertStrEq(L, host, "stanford.edu");
|
||||||
LuaAssertStrEq(L, port, "80");
|
LuaAssertStrEq(L, port, "80");
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,8 @@
|
|||||||
|
|
||||||
namespace drv {
|
namespace drv {
|
||||||
|
|
||||||
void split_host_port(std::string_view target, std::string &host, std::string &port);
|
void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port);
|
||||||
|
|
||||||
|
|
||||||
std::vector<std::string> parse_control_lst(std::string_view ctrl);
|
std::vector<std::string> parse_control_lst(std::string_view ctrl);
|
||||||
|
|
||||||
@@ -58,7 +59,7 @@ public:
|
|||||||
//
|
//
|
||||||
const eng::vector<int> &drv_get_listen_ports() const { return e_->drv_get_listen_ports(); }
|
const eng::vector<int> &drv_get_listen_ports() const { return e_->drv_get_listen_ports(); }
|
||||||
const eng::vector<int> &drv_get_new_outgoing() const { return e_->drv_get_new_outgoing(); }
|
const eng::vector<int> &drv_get_new_outgoing() const { return e_->drv_get_new_outgoing(); }
|
||||||
const eng::string &drv_get_target(int chid) const { return e_->drv_get_target(chid); }
|
std::string_view drv_get_target(int chid) const { return e_->drv_get_target(chid); }
|
||||||
bool drv_outgoing_empty(int chid) const { return e_->drv_outgoing_empty(chid); }
|
bool drv_outgoing_empty(int chid) const { return e_->drv_outgoing_empty(chid); }
|
||||||
bool drv_get_channel_released(int chid) const { return e_->drv_get_channel_released(chid); }
|
bool drv_get_channel_released(int chid) const { return e_->drv_get_channel_released(chid); }
|
||||||
std::string_view drv_peek_outgoing(int chid) const { return e_->drv_peek_outgoing(chid); }
|
std::string_view drv_peek_outgoing(int chid) const { return e_->drv_peek_outgoing(chid); }
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ public:
|
|||||||
set_initial_state();
|
set_initial_state();
|
||||||
|
|
||||||
// Establish a connection to the server.
|
// Establish a connection to the server.
|
||||||
channel_ = new_outgoing_channel("localhost:8085");
|
channel_ = new_outgoing_channel("nocert:localhost:8085");
|
||||||
|
|
||||||
// Set the console prompt
|
// Set the console prompt
|
||||||
get_stdio_channel()->set_prompt(console_.get_prompt());
|
get_stdio_channel()->set_prompt(console_.get_prompt());
|
||||||
@@ -262,7 +262,7 @@ public:
|
|||||||
// Check for communication from server..
|
// Check for communication from server..
|
||||||
if (channel_ != nullptr) {
|
if (channel_ != nullptr) {
|
||||||
if (channel_->closed()) {
|
if (channel_->closed()) {
|
||||||
stdostream() << "Server closed connection " << channel_->error() << std::endl;
|
stdostream() << "server closed connection: " << channel_->error() << std::endl;
|
||||||
abandon_server();
|
abandon_server();
|
||||||
} else {
|
} else {
|
||||||
while (true) {
|
while (true) {
|
||||||
|
|||||||
@@ -7,11 +7,12 @@ LuaSpecial LuaRegistry(LUA_REGISTRYINDEX);
|
|||||||
LuaNilMarker LuaNil;
|
LuaNilMarker LuaNil;
|
||||||
LuaNewTableMarker LuaNewTable;
|
LuaNewTableMarker LuaNewTable;
|
||||||
|
|
||||||
LuaFunctionReg::LuaFunctionReg(const char *n, const char *a, const char *d, lua_CFunction f) {
|
LuaFunctionReg::LuaFunctionReg(const char *n, const char *a, const char *d, bool s, lua_CFunction f) {
|
||||||
name_ = n;
|
name_ = n;
|
||||||
args_ = a;
|
args_ = a;
|
||||||
docs_ = d;
|
docs_ = d;
|
||||||
func_ = f;
|
func_ = f;
|
||||||
|
sandbox_ = s;
|
||||||
next_ = All;
|
next_ = All;
|
||||||
All = this;
|
All = this;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -475,18 +475,20 @@ private:
|
|||||||
const char *name_;
|
const char *name_;
|
||||||
const char *args_;
|
const char *args_;
|
||||||
const char *docs_;
|
const char *docs_;
|
||||||
|
bool sandbox_;
|
||||||
lua_CFunction func_;
|
lua_CFunction func_;
|
||||||
LuaFunctionReg *next_;
|
LuaFunctionReg *next_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static LuaFunctionReg *All;
|
static LuaFunctionReg *All;
|
||||||
LuaFunctionReg(const char *name, const char *args, const char *docs, lua_CFunction f);
|
LuaFunctionReg(const char *name, const char *args, const char *docs, bool sand, lua_CFunction f);
|
||||||
static const LuaFunctionReg *lookup(lua_CFunction fn);
|
static const LuaFunctionReg *lookup(lua_CFunction fn);
|
||||||
|
|
||||||
const char *get_name() const { return name_; }
|
const char *get_name() const { return name_; }
|
||||||
const char *get_args() const { return args_; }
|
const char *get_args() const { return args_; }
|
||||||
const char *get_docs() const { return docs_; }
|
const char *get_docs() const { return docs_; }
|
||||||
lua_CFunction get_func() const { return func_; }
|
lua_CFunction get_func() const { return func_; }
|
||||||
|
bool get_sandbox() const { return sandbox_; }
|
||||||
LuaFunctionReg *next() const { return next_; }
|
LuaFunctionReg *next() const { return next_; }
|
||||||
void set_func(lua_CFunction fn) { func_ = fn; }
|
void set_func(lua_CFunction fn) { func_ = fn; }
|
||||||
};
|
};
|
||||||
@@ -494,13 +496,19 @@ public:
|
|||||||
|
|
||||||
#define LuaDefine(name, args, docs) \
|
#define LuaDefine(name, args, docs) \
|
||||||
int lfn_##name(lua_State *L); \
|
int lfn_##name(lua_State *L); \
|
||||||
LuaFunctionReg reg_##name(#name, args, docs, lfn_##name); \
|
LuaFunctionReg reg_##name(#name, args, docs, false, lfn_##name); \
|
||||||
int lfn_##name(lua_State *L)
|
int lfn_##name(lua_State *L)
|
||||||
|
|
||||||
|
#define LuaSandbox(name, args, docs) \
|
||||||
|
int lfn_##name(lua_State *L); \
|
||||||
|
LuaFunctionReg reg_##name(#name, args, docs, true, lfn_##name); \
|
||||||
|
int lfn_##name(lua_State *L)
|
||||||
|
|
||||||
#define LuaDefineBuiltin(name, args, docs) \
|
#define LuaDefineBuiltin(name, args, docs) \
|
||||||
LuaFunctionReg reg_##name(#name, args, docs, nullptr);
|
LuaFunctionReg reg_##name(#name, args, docs, false, nullptr);
|
||||||
|
|
||||||
|
#define LuaSandboxBuiltin(name, args, docs) \
|
||||||
|
LuaFunctionReg reg_##name(#name, args, docs, true, nullptr);
|
||||||
|
|
||||||
#define LuaStringify(x) #x
|
#define LuaStringify(x) #x
|
||||||
#define LuaAssert(L, x) if (!(x)) { luaL_error((L), "Assert failed: %s (file %s line %d)", LuaStringify(x), __FILE__, __LINE__); }
|
#define LuaAssert(L, x) if (!(x)) { luaL_error((L), "Assert failed: %s (file %s line %d)", LuaStringify(x), __FILE__, __LINE__); }
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
|
|
||||||
|
#define _USE_MATH_DEFINES
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
#include "wrap-string.hpp"
|
#include "wrap-string.hpp"
|
||||||
#include "wrap-vector.hpp"
|
#include "wrap-vector.hpp"
|
||||||
#include "wrap-map.hpp"
|
#include "wrap-map.hpp"
|
||||||
@@ -66,70 +69,6 @@ static void get_reg_name(const LuaFunctionReg *reg, std::string_view &classname,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// the 'luaopen' function creates a new table.
|
|
||||||
//
|
|
||||||
// We don't want to create new tables during reload, so we call luaopen,
|
|
||||||
// then we copy the contents of the new table into the existing 'makeclass'
|
|
||||||
// table.
|
|
||||||
//
|
|
||||||
static void load_builtin_class(lua_State *L, const char *name, lua_CFunction func) {
|
|
||||||
LuaVar sourcetab, classtab, key, value;
|
|
||||||
LuaStack LS(L, sourcetab, classtab, key, value);
|
|
||||||
LS.makeclass(classtab, name);
|
|
||||||
func(L);
|
|
||||||
lua_replace(L, sourcetab.index());
|
|
||||||
LS.set(key, LuaNil);
|
|
||||||
while (LS.next(sourcetab, key, value) != 0) {
|
|
||||||
LS.rawset(classtab, key, value);
|
|
||||||
}
|
|
||||||
LS.result();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void erase_builtin(LuaStack &LS, LuaSlot globtab, const eng::string &classname, const eng::string &funcname) {
|
|
||||||
if (classname.empty()) {
|
|
||||||
LS.rawset(globtab, funcname, LuaNil);
|
|
||||||
} else {
|
|
||||||
LuaVar classtab;
|
|
||||||
LuaStack LSX(LS.state(), classtab);
|
|
||||||
LS.rawget(classtab, globtab, classname);
|
|
||||||
if (LS.istable(classtab)) {
|
|
||||||
LS.rawset(classtab, funcname, LuaNil);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void source_install_builtins(lua_State *L) {
|
|
||||||
LuaVar nullstring, stringclass, globtab;
|
|
||||||
LuaStack LS(L, nullstring, stringclass, globtab);
|
|
||||||
luaopen_base(L);
|
|
||||||
load_builtin_class(L, "table", luaopen_table);
|
|
||||||
load_builtin_class(L, "string", luaopen_string);
|
|
||||||
load_builtin_class(L, "math", luaopen_math);
|
|
||||||
load_builtin_class(L, "debug", luaopen_debug);
|
|
||||||
load_builtin_class(L, "coroutine", luaopen_coroutine);
|
|
||||||
|
|
||||||
// Nuke a few of the builtin functions for sandboxing reasons.
|
|
||||||
LS.getglobaltable(globtab);
|
|
||||||
erase_builtin(LS, globtab, "", "dofile");
|
|
||||||
erase_builtin(LS, globtab, "", "collectgarbage");
|
|
||||||
erase_builtin(LS, globtab, "", "loadfile");
|
|
||||||
erase_builtin(LS, globtab, "", "load");
|
|
||||||
erase_builtin(LS, globtab, "", "loadstring");
|
|
||||||
erase_builtin(LS, globtab, "", "print");
|
|
||||||
erase_builtin(LS, globtab, "", "xpcall");
|
|
||||||
erase_builtin(LS, globtab, "string", "dump");
|
|
||||||
|
|
||||||
// Set the metatable for strings.
|
|
||||||
// Normally, this would be done by luaopen_string, but we're
|
|
||||||
// messing with the tables so we have to redo it.
|
|
||||||
LS.makeclass(stringclass, "string");
|
|
||||||
LS.set(nullstring, "");
|
|
||||||
LS.setmetatable(nullstring, stringclass);
|
|
||||||
|
|
||||||
LS.result();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void get_info_table(LuaStack &LS, LuaSlot db, LuaSlot info, const eng::string &fn) {
|
static void get_info_table(LuaStack &LS, LuaSlot db, LuaSlot info, const eng::string &fn) {
|
||||||
LS.rawget(info, db, fn);
|
LS.rawget(info, db, fn);
|
||||||
if (!LS.istable(info)) {
|
if (!LS.istable(info)) {
|
||||||
@@ -339,6 +278,7 @@ static void source_load_cfunctions(lua_State *L) {
|
|||||||
LuaStack LS(L, classobj);
|
LuaStack LS(L, classobj);
|
||||||
for (auto r = LuaFunctionReg::All; r != nullptr; r=r->next()) {
|
for (auto r = LuaFunctionReg::All; r != nullptr; r=r->next()) {
|
||||||
lua_CFunction func = r->get_func();
|
lua_CFunction func = r->get_func();
|
||||||
|
if ((func != nullptr) && (!r->get_sandbox())) {
|
||||||
std::string_view classname;
|
std::string_view classname;
|
||||||
std::string_view funcname;
|
std::string_view funcname;
|
||||||
get_reg_name(r, classname, funcname);
|
get_reg_name(r, classname, funcname);
|
||||||
@@ -350,6 +290,7 @@ static void source_load_cfunctions(lua_State *L) {
|
|||||||
LS.rawset(classobj, funcname, func);
|
LS.rawset(classobj, funcname, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
LS.result();
|
LS.result();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -400,10 +341,18 @@ static eng::string source_load_lfunctions(lua_State *L) {
|
|||||||
|
|
||||||
eng::string SourceDB::rebuild() {
|
eng::string SourceDB::rebuild() {
|
||||||
lua_State *L = lua_state_;
|
lua_State *L = lua_state_;
|
||||||
|
LuaVar mathclass;
|
||||||
|
LuaStack LS(L, mathclass);
|
||||||
source_clear_globals(L);
|
source_clear_globals(L);
|
||||||
// source_install_builtins(L);
|
|
||||||
source_load_cfunctions(L);
|
source_load_cfunctions(L);
|
||||||
eng::string errs = source_load_lfunctions(L);
|
eng::string errs = source_load_lfunctions(L);
|
||||||
|
|
||||||
|
// A few builtin constants. These are hardwired.
|
||||||
|
LS.makeclass(mathclass, "math");
|
||||||
|
LS.rawset(mathclass, "pi", M_PI);
|
||||||
|
LS.rawset(mathclass, "huge", HUGE_VAL);
|
||||||
|
|
||||||
|
LS.result();
|
||||||
return errs;
|
return errs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,17 +395,22 @@ void SourceDB::run_unittests() {
|
|||||||
|
|
||||||
void SourceDB::init(lua_State *L) {
|
void SourceDB::init(lua_State *L) {
|
||||||
lua_state_ = L;
|
lua_state_ = L;
|
||||||
LuaVar globtab, persist, unpersist, classname, classtab, funcname, funcp, rawfunc;
|
LuaVar globtab, persist, unpersist, classname, classtab, funcname, funcp, rawfunc, nullstring;
|
||||||
LuaStack LS(L, globtab, persist, unpersist, classname, classtab, funcname, funcp, rawfunc);
|
LuaStack LS(L, globtab, persist, unpersist, classname, classtab, funcname, funcp, rawfunc, nullstring);
|
||||||
|
LS.getglobaltable(globtab);
|
||||||
|
LS.rawset(LuaRegistry, "sourcedb", LuaNewTable);
|
||||||
|
|
||||||
|
// Set the metatable for strings.
|
||||||
|
LS.makeclass(classtab, "string");
|
||||||
|
LS.set(nullstring, "");
|
||||||
|
LS.setmetatable(nullstring, classtab);
|
||||||
|
|
||||||
|
// Rebuild the global environment.
|
||||||
|
rebuild();
|
||||||
|
|
||||||
// We need to register all C functions with the eris permanents tables.
|
// We need to register all C functions with the eris permanents tables.
|
||||||
source_clear_globals(L);
|
|
||||||
source_install_builtins(L);
|
|
||||||
source_load_cfunctions(L);
|
|
||||||
LS.getglobaltable(globtab);
|
|
||||||
LS.rawget(persist, LuaRegistry, "persist");
|
LS.rawget(persist, LuaRegistry, "persist");
|
||||||
LS.rawget(unpersist, LuaRegistry, "unpersist");
|
LS.rawget(unpersist, LuaRegistry, "unpersist");
|
||||||
LS.rawset(LuaRegistry, "sourcedb", LuaNewTable);
|
|
||||||
LS.set(classname, LuaNil);
|
LS.set(classname, LuaNil);
|
||||||
while (LS.next(globtab, classname, classtab) != 0) {
|
while (LS.next(globtab, classname, classtab) != 0) {
|
||||||
if (LS.isstring(classname) && LS.istable(classtab)) {
|
if (LS.isstring(classname) && LS.istable(classtab)) {
|
||||||
@@ -500,30 +454,57 @@ void SourceDB::deserialize_source(util::LuaSourceVec *sv, StreamBuffer *sb) {
|
|||||||
void SourceDB::register_lua_builtins() {
|
void SourceDB::register_lua_builtins() {
|
||||||
lua_State *L = LuaStack::newstate(nullptr);
|
lua_State *L = LuaStack::newstate(nullptr);
|
||||||
luaL_openlibs(L);
|
luaL_openlibs(L);
|
||||||
LuaVar globals,classtab,func;
|
LuaVar globals, lclassname, lfuncname, classtab, func;
|
||||||
LuaStack LS(L, globals, classtab, func);
|
LuaStack LS(L, globals, lclassname, lfuncname, classtab, func);
|
||||||
LS.getglobaltable(globals);
|
LS.getglobaltable(globals);
|
||||||
|
|
||||||
|
// Iterate over the function registry, copying function pointers from
|
||||||
|
// the prototype lua state into the registry, then remove the closure
|
||||||
|
// from the prototype.
|
||||||
for (auto reg = LuaFunctionReg::All; reg != nullptr; reg=reg->next()) {
|
for (auto reg = LuaFunctionReg::All; reg != nullptr; reg=reg->next()) {
|
||||||
if (reg->get_func() == nullptr) {
|
|
||||||
std::string_view funcname;
|
std::string_view funcname;
|
||||||
std::string_view classname;
|
std::string_view classname;
|
||||||
get_reg_name(reg, classname, funcname);
|
get_reg_name(reg, classname, funcname);
|
||||||
if (classname.empty()) {
|
if (classname.empty()) {
|
||||||
LS.rawget(func, globals, funcname);
|
LS.getglobaltable(classtab);
|
||||||
if (LS.iscfunction(func)) {
|
|
||||||
reg->set_func(lua_tocfunction(L, func.index()));
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
LS.rawget(classtab, globals, classname);
|
LS.rawget(classtab, globals, classname);
|
||||||
|
}
|
||||||
|
lua_CFunction builtin = nullptr;
|
||||||
if (LS.istable(classtab)) {
|
if (LS.istable(classtab)) {
|
||||||
LS.rawget(func, classtab, funcname);
|
LS.rawget(func, classtab, funcname);
|
||||||
if (LS.iscfunction(func)) {
|
if (LS.iscfunction(func)) {
|
||||||
reg->set_func(lua_tocfunction(L, func.index()));
|
builtin = lua_tocfunction(L, func.index());
|
||||||
|
LS.rawset(classtab, funcname, LuaNil);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (reg->get_func() == nullptr) {
|
||||||
|
if (builtin == nullptr) {
|
||||||
|
if ((!reg->get_sandbox()) || (reg->get_args() != nullptr)) {
|
||||||
|
std::cerr << "No such builtin function: " << classname << " " << funcname << std::endl;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
reg->set_func(builtin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over the prototype. All cfunctions should have been removed.
|
||||||
|
LS.set(lclassname, LuaNil);
|
||||||
|
while (LS.next(globals, lclassname, classtab)) {
|
||||||
|
if (LS.isstring(lclassname)) {
|
||||||
|
if (LS.istable(classtab)) {
|
||||||
|
LS.set(lfuncname, LuaNil);
|
||||||
|
while (LS.next(classtab, lfuncname, func)) {
|
||||||
|
if (LS.iscfunction(func)) {
|
||||||
|
std::cerr << "Failed to declare builtin: " << LS.ckstring(lclassname) << "."
|
||||||
|
<< LS.ckstring(lfuncname) << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lua_close(L);
|
lua_close(L);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -703,6 +684,7 @@ LuaDefineBuiltin(table_remove, "vector, pos", "remove an element from a vector")
|
|||||||
LuaDefineBuiltin(table_sort, "vector [,comparefn]", "sort a vector");
|
LuaDefineBuiltin(table_sort, "vector [,comparefn]", "sort a vector");
|
||||||
LuaDefineBuiltin(table_pack, "v1, v2, v3...", "turn a sequence of arguments into a vector");
|
LuaDefineBuiltin(table_pack, "v1, v2, v3...", "turn a sequence of arguments into a vector");
|
||||||
LuaDefineBuiltin(table_unpack, "vector", "turn a vector into a sequence of return values");
|
LuaDefineBuiltin(table_unpack, "vector", "turn a vector into a sequence of return values");
|
||||||
|
LuaSandboxBuiltin(table_maxn, "", "");
|
||||||
|
|
||||||
LuaDefineBuiltin(string_byte, "str [,index]", "get a single byte from a string");
|
LuaDefineBuiltin(string_byte, "str [,index]", "get a single byte from a string");
|
||||||
LuaDefineBuiltin(string_char, "byte, byte,...", "convert sequence of bytes to a string");
|
LuaDefineBuiltin(string_char, "byte, byte,...", "convert sequence of bytes to a string");
|
||||||
@@ -717,6 +699,49 @@ LuaDefineBuiltin(string_gmatch, "str, pattern", "iterate over pattern-matched su
|
|||||||
LuaDefineBuiltin(string_gsub, "str, pattern, replace", "global replace pattern in string");
|
LuaDefineBuiltin(string_gsub, "str, pattern, replace", "global replace pattern in string");
|
||||||
LuaDefineBuiltin(string_match, "str, pattern", "return start and end of pattern in string");
|
LuaDefineBuiltin(string_match, "str, pattern", "return start and end of pattern in string");
|
||||||
LuaDefineBuiltin(string_sub, "str, pos1, pos2", "return substring of str from pos1 to pos2");
|
LuaDefineBuiltin(string_sub, "str, pos1, pos2", "return substring of str from pos1 to pos2");
|
||||||
|
LuaSandboxBuiltin(string_dump, "func", "convert a function to a string");
|
||||||
|
|
||||||
|
LuaDefineBuiltin(bit32_arshift, "n, shift", "shift 32-bit number to the right, keeping high bit unchanged");
|
||||||
|
LuaDefineBuiltin(bit32_band, "n, n, n...", "return the bitwise and of all 32-bit numbers");
|
||||||
|
LuaDefineBuiltin(bit32_bnot, "n", "return the bitwise negation of n, ie, (-1 - n) % 2^32");
|
||||||
|
LuaDefineBuiltin(bit32_bor, "n, n, n...", "return the bitwise or of all 32-bit arguments");
|
||||||
|
LuaDefineBuiltin(bit32_bxor, "n, n, n...", "return the bitwise exclusive or of all 32-bit arguments");
|
||||||
|
LuaDefineBuiltin(bit32_btest, "n, n, n...", "compute bitwise and of all 32-bit arguments, return true if nonzero");
|
||||||
|
LuaDefineBuiltin(bit32_extract, "n, field, width", "return value from extracted bitfield of 32-bit number");
|
||||||
|
LuaDefineBuiltin(bit32_lrotate, "n, shift", "rotate 32-bit number to the left");
|
||||||
|
LuaDefineBuiltin(bit32_lshift, "n, shift", "shift 32-bit number to the left, padding with zeros");
|
||||||
|
LuaDefineBuiltin(bit32_replace, "n, v, field, width", "change value of extracted bitfield in 32-bit number");
|
||||||
|
LuaDefineBuiltin(bit32_rrotate, "n, shift", "rotate 32-bit number to the right");
|
||||||
|
LuaDefineBuiltin(bit32_rshift, "n, shift", "shift 32-bit number to the right, padding with zeros");
|
||||||
|
|
||||||
|
LuaDefineBuiltin(math_abs, "x", "return absolute value of x");
|
||||||
|
LuaDefineBuiltin(math_acos, "x", "return arc-cosine of x in radians");
|
||||||
|
LuaDefineBuiltin(math_asin, "x", "return arc-sine of x in radians");
|
||||||
|
LuaDefineBuiltin(math_atan, "x", "return the arc-tangent of x in radians");
|
||||||
|
LuaDefineBuiltin(math_atan2, "y, x", "return arc-tangent of y/x, using signs to compute quadrant");
|
||||||
|
LuaDefineBuiltin(math_ceil, "x", "return the smallest integer larger than or equal to x");
|
||||||
|
LuaDefineBuiltin(math_cos, "x", "return the cosine of x in radians");
|
||||||
|
LuaDefineBuiltin(math_cosh, "x", "return the hyperbolic cosine of x in radians");
|
||||||
|
LuaDefineBuiltin(math_deg, "rad", "convert radians to degrees");
|
||||||
|
LuaDefineBuiltin(math_exp, "x", "returns the value e^x");
|
||||||
|
LuaDefineBuiltin(math_floor, "x", "returns the smallest integer less than or equal to x");
|
||||||
|
LuaDefineBuiltin(math_fmod, "x, y", "return the remainder of x/y that rounds the quotient towards zero");
|
||||||
|
LuaDefineBuiltin(math_frexp, "x", "given x, returns mantissa and exponent");
|
||||||
|
LuaDefineBuiltin(math_ldexp, "x", "return unnormalized mantissa and exponent");
|
||||||
|
LuaDefineBuiltin(math_log, "x [, base]", "return the log of x in base, default base is e");
|
||||||
|
LuaDefineBuiltin(math_max, "x, x, x...", "return the largest argument");
|
||||||
|
LuaDefineBuiltin(math_min, "x, x, x...", "return the smallest argument");
|
||||||
|
LuaDefineBuiltin(math_modf, "x", "returns the integral and fractional part of x");
|
||||||
|
LuaDefineBuiltin(math_pow, "x, y", "returns x ^ y, equivalent to the operator");
|
||||||
|
LuaDefineBuiltin(math_rad, "deg", "convert degrees to radians");
|
||||||
|
LuaDefineBuiltin(math_random, "[m [, n]]", "return random [0.0-1.0), or [1-m], or [m-n].");
|
||||||
|
LuaDefineBuiltin(math_randomseed, "x", "set x as the seed for random numbers");
|
||||||
|
LuaDefineBuiltin(math_sin, "x", "return the sine of x in radians");
|
||||||
|
LuaDefineBuiltin(math_sinh, "x", "return the hyperbolic sine of x in radians");
|
||||||
|
LuaDefineBuiltin(math_sqrt, "x", "return the square root of x");
|
||||||
|
LuaDefineBuiltin(math_tan, "x", "return the tangent of x in radians");
|
||||||
|
LuaDefineBuiltin(math_tanh, "x", "return the hyperbolic tangent of x in radians");
|
||||||
|
LuaSandboxBuiltin(math_log10, "", "");
|
||||||
|
|
||||||
LuaDefineBuiltin(assert, "flag [,message]", "assert that flag is true, if not, raise error");
|
LuaDefineBuiltin(assert, "flag [,message]", "assert that flag is true, if not, raise error");
|
||||||
LuaDefineBuiltin(error, "message", "raise an error");
|
LuaDefineBuiltin(error, "message", "raise an error");
|
||||||
@@ -734,4 +759,73 @@ LuaDefineBuiltin(select, "n, arg1, arg2, ...", "return the nth argument");
|
|||||||
LuaDefineBuiltin(setmetatable, "table, meta", "set the metatable of the specified table");
|
LuaDefineBuiltin(setmetatable, "table, meta", "set the metatable of the specified table");
|
||||||
LuaDefineBuiltin(tonumber, "str", "convert a string to a number");
|
LuaDefineBuiltin(tonumber, "str", "convert a string to a number");
|
||||||
LuaDefineBuiltin(type, "obj", "return the type of obj as a string");
|
LuaDefineBuiltin(type, "obj", "return the type of obj as a string");
|
||||||
|
// print is redefined in world.cpp (because it prints into the world model)
|
||||||
|
// tostring is redefined in pprint.cpp
|
||||||
|
|
||||||
|
LuaSandboxBuiltin(collectgarbage, "", "");
|
||||||
|
LuaSandboxBuiltin(dofile, "", "");
|
||||||
|
LuaSandboxBuiltin(xpcall, "", "");
|
||||||
|
LuaSandboxBuiltin(loadfile, "", "");
|
||||||
|
LuaSandboxBuiltin(load, "", "");
|
||||||
|
LuaSandboxBuiltin(require, "", "");
|
||||||
|
LuaSandboxBuiltin(module, "", "");
|
||||||
|
LuaSandboxBuiltin(loadstring, "", "");
|
||||||
|
LuaSandboxBuiltin(unpack, "", "");
|
||||||
|
|
||||||
|
|
||||||
|
LuaSandboxBuiltin(debug_debug, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_getuservalue, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_gethook, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_getinfo, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_getlocal, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_getregistry, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_getmetatable, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_getupvalue, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_upvaluejoin, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_upvalueid, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_setuservalue, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_sethook, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_setlocal, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_setmetatable, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_setupvalue, "", "");
|
||||||
|
LuaSandboxBuiltin(debug_traceback, "", "");
|
||||||
|
|
||||||
|
LuaSandboxBuiltin(eris_persist, "", "");
|
||||||
|
LuaSandboxBuiltin(eris_unpersist, "", "");
|
||||||
|
LuaSandboxBuiltin(eris_settings, "", "");
|
||||||
|
|
||||||
|
LuaSandboxBuiltin(package_loadlib, "", "");
|
||||||
|
LuaSandboxBuiltin(package_searchpath, "", "");
|
||||||
|
LuaSandboxBuiltin(package_seeall, "", "");
|
||||||
|
|
||||||
|
LuaSandboxBuiltin(coroutine_create, "", "");
|
||||||
|
LuaSandboxBuiltin(coroutine_resume, "", "");
|
||||||
|
LuaSandboxBuiltin(coroutine_running, "", "");
|
||||||
|
LuaSandboxBuiltin(coroutine_status, "", "");
|
||||||
|
LuaSandboxBuiltin(coroutine_wrap, "", "");
|
||||||
|
LuaSandboxBuiltin(coroutine_yield, "", "");
|
||||||
|
|
||||||
|
LuaSandboxBuiltin(io_close, "", "");
|
||||||
|
LuaSandboxBuiltin(io_flush, "", "");
|
||||||
|
LuaSandboxBuiltin(io_input, "", "");
|
||||||
|
LuaSandboxBuiltin(io_lines, "", "");
|
||||||
|
LuaSandboxBuiltin(io_open, "", "");
|
||||||
|
LuaSandboxBuiltin(io_output, "", "");
|
||||||
|
LuaSandboxBuiltin(io_popen, "", "");
|
||||||
|
LuaSandboxBuiltin(io_read, "", "");
|
||||||
|
LuaSandboxBuiltin(io_tmpfile, "", "");
|
||||||
|
LuaSandboxBuiltin(io_type, "", "");
|
||||||
|
LuaSandboxBuiltin(io_write, "", "");
|
||||||
|
|
||||||
|
LuaSandboxBuiltin(os_clock, "", "");
|
||||||
|
LuaSandboxBuiltin(os_date, "", "");
|
||||||
|
LuaSandboxBuiltin(os_difftime, "", "");
|
||||||
|
LuaSandboxBuiltin(os_execute, "", "");
|
||||||
|
LuaSandboxBuiltin(os_exit, "", "");
|
||||||
|
LuaSandboxBuiltin(os_getenv, "", "");
|
||||||
|
LuaSandboxBuiltin(os_remove, "", "");
|
||||||
|
LuaSandboxBuiltin(os_rename, "", "");
|
||||||
|
LuaSandboxBuiltin(os_setlocale, "", "");
|
||||||
|
LuaSandboxBuiltin(os_time, "", "");
|
||||||
|
LuaSandboxBuiltin(os_tmpname, "", "");
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ void quote_string(const eng::string &s, std::ostream *os) {
|
|||||||
case '\t': (*os) << "\\t"; break;
|
case '\t': (*os) << "\\t"; break;
|
||||||
case '\r': (*os) << "\\r"; break;
|
case '\r': (*os) << "\\r"; break;
|
||||||
default:
|
default:
|
||||||
(*os) << "\\" << std::setw(3) << int(c);
|
(*os) << "\\" << std::setfill('0') << std::setw(3) << int(c);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user