lots of work on determinism in the linux driver.

This commit is contained in:
2022-02-18 03:59:21 -05:00
parent 6a6d2c7f75
commit ba1e923b5a
10 changed files with 175 additions and 110 deletions

View File

@@ -33,7 +33,6 @@ Channel::Channel(DrivenEngine *de, int chid, int port, const std::string &target
port_ = port; port_ = port;
closed_ = false; closed_ = false;
target_ = target; target_ = target;
readline_enabled_ = false;
readline_lastc_ = 0; readline_lastc_ = 0;
desired_prompt_ = ""; desired_prompt_ = "";
stop_driver_ = stop; stop_driver_ = stop;
@@ -112,37 +111,25 @@ void Channel::feed_readline(int nbytes, const char *bytes) {
} }
} }
void Channel::peek_outgoing(int *nbytes, const char **bytes) { void Channel::peek_outgoing(int *nbytes, const char **bytes) const {
if (readline_enabled_) { *nbytes = sb_drvout_->fill();
*bytes = sb_drvout_->data();
}
void Channel::pump_readline() {
if (sb_drvout_ != sb_out_) {
if (!sb_out_->empty()) { if (!sb_out_->empty()) {
erase_command(); erase_command();
sb_out_->transfer_into(sb_drvout_.get()); sb_out_->transfer_into(sb_drvout_.get());
} }
echo_command(); echo_command();
} }
*nbytes = sb_drvout_->fill();
*bytes = sb_drvout_->data();
} }
void Channel::sent_outgoing(int nbytes) { void Channel::sent_outgoing(int nbytes) {
sb_drvout_->read_bytes(nbytes); sb_drvout_->read_bytes(nbytes);
} }
void Channel::set_readline(bool e) {
if (e != readline_enabled_) {
readline_enabled_ = e;
if (readline_enabled_) {
sb_drvout_ = std::make_shared<StreamBuffer>();
} else {
sb_out_->transfer_into(sb_drvout_.get());
sb_out_->clear();
sb_drvout_->transfer_into(sb_out_.get());
sb_drvout_ = sb_out_;
}
desired_command_ = "";
}
}
int DrivenEngine::find_unused_chid() { int DrivenEngine::find_unused_chid() {
// Note: channel ID zero is special, it is never reused. // Note: channel ID zero is special, it is never reused.
for (int i = 0; i < MAX_CHAN; i++) { for (int i = 0; i < MAX_CHAN; i++) {
@@ -244,7 +231,7 @@ void DrivenEngine::drv_sent_outgoing(int chid, int nbytes) {
void DrivenEngine::drv_recv_incoming(int chid, int nbytes, const char *bytes) { void DrivenEngine::drv_recv_incoming(int chid, int nbytes, const char *bytes) {
if (nbytes > 0) { if (nbytes > 0) {
Channel *ch = get_chid(chid); Channel *ch = get_chid(chid);
if (ch->readline_enabled_) { if (ch->sb_drvout_ != ch->sb_out_) {
ch->feed_readline(nbytes, bytes); ch->feed_readline(nbytes, bytes);
} else { } else {
ch->sb_in_->write_bytes(bytes, nbytes); ch->sb_in_->write_bytes(bytes, nbytes);
@@ -271,25 +258,22 @@ void DrivenEngine::drv_clear_lua_source() {
rescan_lua_source_ = false; rescan_lua_source_ = false;
} }
void DrivenEngine::drv_add_lua_source(const char *fn, const char *data) { void DrivenEngine::drv_add_lua_source(std::string_view fn, std::string_view data) {
if (lua_source_ == nullptr) { if (lua_source_ == nullptr) {
lua_source_.reset(new util::LuaSourceVec); lua_source_.reset(new util::LuaSourceVec);
} }
lua_source_->emplace_back(std::string(fn), std::string(data)); lua_source_->emplace_back(std::string(fn), std::string(data));
} }
void DrivenEngine::drv_set_lua_source(util::LuaSourcePtr src) {
lua_source_ = std::move(src);
rescan_lua_source_ = false;
}
void DrivenEngine::drv_invoke_event_init(int argc, char *argv[]) { void DrivenEngine::drv_invoke_event_init(int argc, char *argv[]) {
event_init(argc, argv); event_init(argc, argv);
stdio_channel_->pump_readline();
} }
void DrivenEngine::drv_invoke_event_update(double clock) { void DrivenEngine::drv_invoke_event_update(double clock) {
clock_ = clock; clock_ = clock;
event_update(); event_update();
stdio_channel_->pump_readline();
} }
bool DrivenEngine::drv_get_rescan_lua_source() const { bool DrivenEngine::drv_get_rescan_lua_source() const {
@@ -303,7 +287,7 @@ bool DrivenEngine::drv_get_stop_driver() const {
DrivenEngine::DrivenEngine() { DrivenEngine::DrivenEngine() {
next_unused_chid_ = 1; next_unused_chid_ = 1;
stdio_channel_ = std::make_shared<Channel>(this, 0, 0, "", false); stdio_channel_ = std::make_shared<Channel>(this, 0, 0, "", false);
stdio_channel_->set_readline(true); stdio_channel_->sb_drvout_ = std::make_shared<StreamBuffer>();
channels_[0] = stdio_channel_; channels_[0] = stdio_channel_;
rescan_lua_source_ = true; rescan_lua_source_ = true;
clock_ = 0.0; clock_ = 0.0;

View File

@@ -132,20 +132,6 @@ public:
// //
std::string error() const { return error_; } std::string error() const { return error_; }
// True if the channel is in readline mode.
//
// Stdio always starts with this enabled, other channels always start
// with this disabled.
//
bool readline_enabled() const { return readline_enabled_; }
// Put the channel into readline mode.
//
// Caution: the channel better be coming from a raw tty, otherwise,
// this is going to produce weird results.
//
void set_readline(bool enabled);
// Set the prompt for readline mode. // Set the prompt for readline mode.
// //
void set_prompt(const std::string &prompt); void set_prompt(const std::string &prompt);
@@ -164,10 +150,11 @@ private:
// //
void feed_readline(int nbytes, const char *bytes); void feed_readline(int nbytes, const char *bytes);
void peek_outgoing(int *nbytes, const char **bytes); void peek_outgoing(int *nbytes, const char **bytes) const;
void sent_outgoing(int nbytes); void sent_outgoing(int nbytes);
void erase_command(); void erase_command();
void echo_command(); void echo_command();
void pump_readline();
private: private:
static const int READLINE_MAX=512; static const int READLINE_MAX=512;
@@ -177,9 +164,9 @@ private:
std::shared_ptr<StreamBuffer> sb_in_; std::shared_ptr<StreamBuffer> sb_in_;
std::shared_ptr<StreamBuffer> sb_out_; std::shared_ptr<StreamBuffer> sb_out_;
// In readline mode, we inject tty echoes into the output stream. // If this is stdio, we inject tty echoes into the output stream.
// This buffer holds the users output interleaved with the tty echoes. // This buffer holds the users output interleaved with the tty echoes.
// In non-readline mode, this is just another pointer to sb_out. // In any other channel, this is just another pointer to sb_out.
std::shared_ptr<StreamBuffer> sb_drvout_; std::shared_ptr<StreamBuffer> sb_drvout_;
int port_; int port_;
@@ -188,13 +175,12 @@ private:
std::string target_; std::string target_;
bool stop_driver_; bool stop_driver_;
// Readline stuff. // Readline stuff. Only used on channel 0 (stdio).
std::string desired_command_; std::string desired_command_;
std::string current_command_; std::string current_command_;
std::string desired_prompt_; std::string desired_prompt_;
std::string current_prompt_; std::string current_prompt_;
char readline_lastc_; char readline_lastc_;
bool readline_enabled_;
friend class DrivenEngine; friend class DrivenEngine;
}; };
@@ -374,8 +360,7 @@ public:
// Set the lua source code. The driver is expected to read the lua source // Set the lua source code. The driver is expected to read the lua source
// code and store it (using this function) once before invoking // code and store it (using this function) once before invoking
// //
void drv_add_lua_source(const char *fn, const char *data); void drv_add_lua_source(std::string_view fn, std::string_view data);
void drv_set_lua_source(util::LuaSourcePtr source);
// Invoke the init or update event. // Invoke the init or update event.
// //

View File

@@ -2,6 +2,15 @@
#define CHBUF_SIZE (256*1024) #define CHBUF_SIZE (256*1024)
#define POLLVEC_SIZE (DrivenEngine::MAX_CHAN+1) #define POLLVEC_SIZE (DrivenEngine::MAX_CHAN+1)
int mallocstate(int n) {
int64_t result = 0;
for (int i = 0; i < n; i++) {
int64_t n = int64_t(malloc(1));
result = (result * 17) + n;
}
return result & 0x7fffffff;
}
static MonoClock monoclock; static MonoClock monoclock;
namespace util { namespace util {
@@ -17,7 +26,7 @@ static void if_error_print_and_exit(const UmmString &str) {
} }
} }
static SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &require_cert) { static SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, std::string_view require_cert) {
SSL_CTX *ctx = SSL_CTX_new(TLS_method()); SSL_CTX *ctx = SSL_CTX_new(TLS_method());
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
@@ -92,8 +101,8 @@ public:
}; };
DrivenEngine *driven_; DrivenEngine *driven_;
std::vector<ChanInfo> chans_; UmmVector<ChanInfo> chans_;
std::map<int, SOCKET> listen_sockets_; UmmMap<int, SOCKET> listen_sockets_;
bool read_console_recently_; bool read_console_recently_;
SSL_CTX *ssl_ctx_with_root_certs_; SSL_CTX *ssl_ctx_with_root_certs_;
@@ -116,7 +125,17 @@ public:
void handle_lua_source() { void handle_lua_source() {
if (driven_->drv_get_rescan_lua_source()) { if (driven_->drv_get_rescan_lua_source()) {
driven_->drv_set_lua_source(util::read_lua_source("lua")); UmmString err;
std::string_view ctrl = read_file("lua/control.lst", chbuf.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err);
UmmStringVec names = drv::parse_control_lst(ctrl);
driven_->drv_clear_lua_source();
for (const UmmString &str : names) {
UmmString lfn = UmmString("lua/") + str;
std::string_view data = read_file(lfn.c_str(), chbuf.get(), CHBUF_SIZE, err);
if_error_print_and_exit(err);
driven_->drv_add_lua_source(str, data);
}
} }
} }
@@ -458,7 +477,7 @@ public:
} }
DrivenEngine::set(de); DrivenEngine::set(de);
driven_->drv_set_lua_source(util::read_lua_source("lua")); handle_lua_source();
driven_->drv_invoke_event_init(argc, argv); driven_->drv_invoke_event_init(argc, argv);
handle_listen_ports(); handle_listen_ports();
@@ -490,12 +509,17 @@ void driver_drive(int argc, char *argv[]) {
// doesn't break the determinism of the execution during replay. // doesn't break the determinism of the execution during replay.
umm_init_heap(malloc(OPENSSL_HEAP_SIZE), OPENSSL_HEAP_SIZE); umm_init_heap(malloc(OPENSSL_HEAP_SIZE), OPENSSL_HEAP_SIZE);
CRYPTO_set_mem_functions(umm_malloc_ssl, umm_realloc_ssl, umm_free_ssl); CRYPTO_set_mem_functions(umm_malloc_ssl, umm_realloc_ssl, umm_free_ssl);
chbuf.reset(new char[CHBUF_SIZE]); chbuf.reset(new char[CHBUF_SIZE]);
pollvec.reset(new struct pollfd[POLLVEC_SIZE]); pollvec.reset(new struct pollfd[POLLVEC_SIZE]);
ERR_load_crypto_strings(); ERR_load_crypto_strings();
SSL_load_error_strings(); SSL_load_error_strings();
std::cerr << "#2 " << std::hex << mallocstate(1) << std::endl;
Driver driver; Driver driver;
if (argc < 2) { if (argc < 2) {
DrivenEngine::print_usage(std::cerr, argv[0]); DrivenEngine::print_usage(std::cerr, argv[0]);

View File

@@ -225,7 +225,25 @@ static int console_read(char *bytes, int nbytes) {
return read(0, bytes, nbytes); return read(0, bytes, nbytes);
} }
static std::string_view read_file(const char *fn, char *buf, int bufsize, UmmString &err) {
int nread;
int fd = open(fn, O_RDONLY);
if (fd < 0) goto error_errno;
nread = read(fd, buf, bufsize);
if (nread < 0) goto error_errno;
if (nread == bufsize) {
err = "file too large";
goto error;
}
buf[nread] = 0;
err = "";
return std::string_view(buf, nread);
error_errno:
err = strerror_str(errno);
error:
buf[0] = 0;
return std::string_view(buf, 0);
}
static void disable_randomization(int argc, char *argv[]) { static void disable_randomization(int argc, char *argv[]) {
const int old_personality = personality(ADDR_NO_RANDOMIZE); const int old_personality = personality(ADDR_NO_RANDOMIZE);

View File

@@ -1,6 +1,7 @@
#include "driver-util.hpp" #include "driver-util.hpp"
#include "luastack.hpp" #include "luastack.hpp"
#include "util.hpp"
namespace drv { namespace drv {
@@ -16,6 +17,19 @@ void split_host_port(std::string_view target, UmmString &host, UmmString &port)
} }
} }
UmmStringVec parse_control_lst(std::string_view ctrl) {
UmmStringVec result;
while (!ctrl.empty()) {
std::string_view line = util::sv_read_line(ctrl);
std::string_view trimmed = util::sv_trim(line);
if ((trimmed.size() > 0) && (trimmed[0] != '#')) {
result.emplace_back(trimmed);
}
}
return result;
}
} // namespace drv } // namespace drv
LuaDefine(unittests_driverutil, "", "some unit tests") { LuaDefine(unittests_driverutil, "", "some unit tests") {
@@ -24,6 +38,7 @@ LuaDefine(unittests_driverutil, "", "some unit tests") {
drv::split_host_port("stanford.edu:80", host, port); drv::split_host_port("stanford.edu:80", host, port);
LuaAssertStrEq(L, host, "stanford.edu"); LuaAssertStrEq(L, host, "stanford.edu");
LuaAssertStrEq(L, port, "80"); LuaAssertStrEq(L, port, "80");
return 0; return 0;
} }

View File

@@ -4,10 +4,14 @@
#include "umm-malloc.hpp" #include "umm-malloc.hpp"
using UmmStringVec = UmmVector<UmmString>;
namespace drv { namespace drv {
void split_host_port(std::string_view target, UmmString &host, UmmString &port); void split_host_port(std::string_view target, UmmString &host, UmmString &port);
UmmStringVec parse_control_lst(std::string_view ctrl);
} }
#endif // DRIVER_UTIL_HPP #endif // DRIVER_UTIL_HPP

View File

@@ -19,38 +19,6 @@ static void dump_lines(StreamBuffer *in, StreamBuffer *out, int chid) {
} }
} }
// This test allows input on stdin or on port 8085.
// You can type lines and see them echoed.
class DriverListenTest : public DrivenEngine {
public:
std::vector<SharedChannel> channels_;
virtual void event_init(int argc, char *argv[]) {
listen_port(8085);
}
virtual void event_update() {
while (true) {
SharedChannel ch = new_incoming_channel();
if (ch == nullptr) break;
ch->set_readline(true);
channels_.emplace_back(std::move(ch));
}
SharedChannel stdioch = get_stdio_channel();
dump_lines(stdioch->in(), stdioch->out(), 0);
std::vector<SharedChannel> keep;
for (SharedChannel &ch : channels_) {
dump_lines(ch->in(), stdioch->out(), ch->chid());
if (ch->closed()) {
write_closed_message(ch.get(), stdioch->out());
} else {
keep.emplace_back(std::move(ch));
}
}
channels_ = std::move(keep);
}
};
// This test connects to a public webserver and prints // This test connects to a public webserver and prints
// the output from the server. // the output from the server.
class DriverWebServerTest : public DrivenEngine { class DriverWebServerTest : public DrivenEngine {
@@ -104,6 +72,15 @@ public:
} }
}; };
static int64_t mallocstate() {
int64_t result = 0;
for (int i = 0; i < 10; i++) {
int64_t n = int64_t(malloc(1));
result = (result * 17) + n;
}
return result;
}
// This test just prints the time. // This test just prints the time.
class DriverPrintClockTest : public DrivenEngine { class DriverPrintClockTest : public DrivenEngine {
public: public:
@@ -117,11 +94,12 @@ public:
virtual void event_update() { virtual void event_update() {
double clock = get_clock(); double clock = get_clock();
if (clock > last_clock_ + 0.5) { if (clock > last_clock_ + 0.5) {
stdostream() << std::fixed << std::setprecision(2) << clock << " "; int64_t ms = mallocstate();
stdostream() << std::fixed << std::setprecision(2) << clock << " " << std::hex << ms << " ";
count_++; count_++;
last_clock_ = clock; last_clock_ = clock;
} }
if (count_ == 10) { if (count_ == 4) {
stdostream() << std::endl; stdostream() << std::endl;
count_ = 0; count_ = 0;
} }
@@ -143,10 +121,6 @@ private:
}; };
UniqueDrivenEngine make_DriverListenTest() {
return UniqueDrivenEngine(new DriverListenTest);
}
UniqueDrivenEngine make_DriverWebServerTest() { UniqueDrivenEngine make_DriverWebServerTest() {
return UniqueDrivenEngine(new DriverWebServerTest); return UniqueDrivenEngine(new DriverWebServerTest);
} }

View File

@@ -6,17 +6,16 @@
#include "driver.hpp" #include "driver.hpp"
#include "source.hpp" #include "source.hpp"
#include <iostream> #include <iostream>
#include <time.h>
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
driver_sysinit(argc, argv); driver_sysinit(argc, argv);
SourceDB::register_lua_builtins(); SourceDB::register_lua_builtins();
DrivenEngine::register_maker("textgame", make_TextGame); DrivenEngine::register_maker("textgame", make_TextGame);
DrivenEngine::register_maker("lpxclient", make_LpxClient); DrivenEngine::register_maker("lpxclient", make_LpxClient);
DrivenEngine::register_maker("lpxserver", make_LpxServer); DrivenEngine::register_maker("lpxserver", make_LpxServer);
DrivenEngine::register_maker("driverlistentest", make_DriverListenTest);
DrivenEngine::register_maker("driverwebservertest", make_DriverWebServerTest); DrivenEngine::register_maker("driverwebservertest", make_DriverWebServerTest);
DrivenEngine::register_maker("driverdnsfailtest", make_DriverDNSFailTest); DrivenEngine::register_maker("driverdnsfailtest", make_DriverDNSFailTest);
DrivenEngine::register_maker("driverprintclocktest", make_DriverPrintClockTest); DrivenEngine::register_maker("driverprintclocktest", make_DriverPrintClockTest);

View File

@@ -269,22 +269,67 @@ double strtodouble(const std::string &value) {
} }
} }
std::string ltrim(std::string s) { std::string_view sv_ltrim(std::string_view v) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), const char *b = v.data();
std::not1(std::ptr_fun<int, int>(std::isspace)))); const char *e = v.data() + v.size();
return s; while ((e > b) && (std::isspace(b[0]))) {
b++;
}
return std::string_view(b, e-b);
} }
std::string rtrim(std::string s) { std::string_view sv_rtrim(std::string_view v) {
s.erase(std::find_if(s.rbegin(), s.rend(), const char *b = v.data();
std::not1(std::ptr_fun<int, int>(std::isspace))).base(), s.end()); const char *e = v.data() + v.size();
return s; while ((e > b) && (std::isspace(e[-1]))) {
e--;
}
return std::string_view(b, e-b);
} }
std::string trim(std::string s) { std::string_view sv_trim(std::string_view v) {
return ltrim(rtrim(s)); const char *b = v.data();
const char *e = v.data() + v.size();
while ((e > b) && (std::isspace(b[0]))) {
b++;
}
while ((e > b) && (std::isspace(e[-1]))) {
e--;
}
return std::string_view(b, e-b);
} }
std::string ltrim(std::string_view v) {
return std::string(sv_ltrim(v));
}
std::string rtrim(std::string_view v) {
return std::string(sv_rtrim(v));
}
std::string trim(std::string_view v) {
return std::string(sv_trim(v));
}
std::string_view sv_read_line(std::string_view &source) {
size_t pos = source.find('\n');
std::string_view result;
if (pos == std::string_view::npos) {
result = source;
source = "";
} else {
result = source.substr(0, pos);
source = source.substr(pos + 1);
}
int fsize = result.size();
if ((fsize >= 1) && (result[fsize - 1] == '\r')) {
result.remove_suffix(1);
}
return result;
}
double distance_squared(double x1, double y1, double x2, double y2) { double distance_squared(double x1, double y1, double x2, double y2) {
double dx = x1 - x2; double dx = x1 - x2;
double dy = y1 - y2; double dy = y1 - y2;
@@ -431,6 +476,15 @@ LuaDefine(unittests_util, "", "some unit tests") {
LuaAssert(L, util::trim("foo") == "foo"); LuaAssert(L, util::trim("foo") == "foo");
LuaAssert(L, util::trim("") == ""); LuaAssert(L, util::trim("") == "");
// Test sv_read_line
std::string_view v = "foo\nbar\r\n";
std::string_view v1 = util::sv_read_line(v);
std::string_view v2 = util::sv_read_line(v);
std::string_view v3 = util::sv_read_line(v);
LuaAssertStrEq(L, v1, "foo");
LuaAssertStrEq(L, v2, "bar");
LuaAssertStrEq(L, v3, "");
// Test distance_squared // Test distance_squared
LuaAssert(L, util::distance_squared(1, 1, 5, 4) == 25.0); LuaAssert(L, util::distance_squared(1, 1, 5, 4) == 25.0);
LuaAssert(L, util::distance_squared(5, 4, 1, 1) == 25.0); LuaAssert(L, util::distance_squared(5, 4, 1, 1) == 25.0);

View File

@@ -100,10 +100,18 @@ int64_t strtoint(const std::string &value, int64_t errval);
// String to double. Returns NAN if the number is not parseable. // String to double. Returns NAN if the number is not parseable.
double strtodouble(const std::string &value); double strtodouble(const std::string &value);
// Trim a string_view
std::string_view sv_ltrim(std::string_view v);
std::string_view sv_rtrim(std::string_view v);
std::string_view sv_trim(std::string_view v);
// Trim strings: left end, right end, both ends. // Trim strings: left end, right end, both ends.
std::string ltrim(std::string s); std::string ltrim(std::string_view s);
std::string rtrim(std::string s); std::string rtrim(std::string_view s);
std::string trim(std::string s); std::string trim(std::string_view s);
// Read a line from a string_view
std::string_view sv_read_line(std::string_view &source);
// Calculate distance between two points // Calculate distance between two points
double distance_squared(double x1, double y1, double x2, double y2); double distance_squared(double x1, double y1, double x2, double y2);