From 569b8aef450a957bb013027e1a25f267ec569d9b Mon Sep 17 00:00:00 2001 From: jyelon Date: Tue, 14 Feb 2023 13:14:18 -0500 Subject: [PATCH] Refactor for DLL --- luprex/core/Makefile | 200 +++--- luprex/core/cpp/drivenengine.cpp | 913 ++++++++++++++++++++++--- luprex/core/cpp/drivenengine.hpp | 219 ++---- luprex/core/cpp/enginewrapper.hpp | 217 ++++++ luprex/core/cpp/lpxclient.cpp | 3 - luprex/core/cpp/util.hpp | 3 - luprex/core/cpp/world-core.cpp | 9 + luprex/core/drv/driver-common.cpp | 488 +++++++++++++ luprex/core/drv/driver-linux.cpp | 259 +++++++ luprex/core/drv/driver-mingw.cpp | 264 +++++++ luprex/core/drv/drvutil.cpp | 269 ++++++++ luprex/core/drv/drvutil.hpp | 99 +++ luprex/core/drv/sslutil.cpp | 199 ++++++ luprex/core/drv/sslutil.hpp | 61 ++ luprex/core/lua/control.lst | 1 - luprex/core/{lobj => obj/cpp}/.gitkeep | 0 luprex/core/obj/drv/.gitkeep | 0 luprex/core/obj/lua/.gitkeep | 0 18 files changed, 2821 insertions(+), 383 deletions(-) create mode 100644 luprex/core/cpp/enginewrapper.hpp create mode 100644 luprex/core/drv/driver-common.cpp create mode 100644 luprex/core/drv/driver-linux.cpp create mode 100644 luprex/core/drv/driver-mingw.cpp create mode 100644 luprex/core/drv/drvutil.cpp create mode 100644 luprex/core/drv/drvutil.hpp create mode 100644 luprex/core/drv/sslutil.cpp create mode 100644 luprex/core/drv/sslutil.hpp rename luprex/core/{lobj => obj/cpp}/.gitkeep (100%) create mode 100644 luprex/core/obj/drv/.gitkeep create mode 100644 luprex/core/obj/lua/.gitkeep diff --git a/luprex/core/Makefile b/luprex/core/Makefile index 035e3cd9..cc9f1b52 100644 --- a/luprex/core/Makefile +++ b/luprex/core/Makefile @@ -1,117 +1,107 @@ - -ifeq ($(OS),mingw) - EXE=main.exe - LIBS=-L../mingwlib -lssl -lcrypto -lws2_32 -lcrypt32 -lcryptui - INCS=-I../mingwlib - LUAFLAGS=-DLUA_COMPAT_ALL - OPT=-g -O0 - DRIVER=driver-mingw -else ifeq ($(OS),linux) - EXE=main - LIBS=-L../linuxlib -lssl -lcrypto - INCS=-I../linuxlib - LUAFLAGS=-DLUA_USE_POSIX - OPT=-g -O0 - DRIVER=driver-linux -else - # In this case, any attempt to build luprex will trigger an error, - # But making 'clean' will still work. - ERROR=$(error You must specify OS=linux or OS=mingw) - EXE=main - LIBS=$(ERROR) - INCS=$(ERROR) - LUAFLAGS=$(ERROR) - OPT=$(ERROR) - DRIVER=driver-xxx -endif - +all: main LUA_OBJ_FILES=\ - lobj/lapi.o \ - lobj/lcode.o \ - lobj/lctype.o \ - lobj/ldebug.o \ - lobj/ldo.o \ - lobj/ldump.o \ - lobj/lfunc.o \ - lobj/lgc.o \ - lobj/llex.o \ - lobj/lmem.o \ - lobj/lobject.o \ - lobj/lopcodes.o \ - lobj/lparser.o \ - lobj/lstate.o \ - lobj/lstring.o \ - lobj/ltable.o \ - lobj/ltm.o \ - lobj/lundump.o \ - lobj/lvm.o \ - lobj/lzio.o \ - lobj/lauxlib.o \ - lobj/lbaselib.o \ - lobj/lbitlib.o \ - lobj/lcorolib.o \ - lobj/ldblib.o \ - lobj/liolib.o \ - lobj/lmathlib.o \ - lobj/loslib.o \ - lobj/lstrlib.o \ - lobj/ltablib.o \ - lobj/loadlib.o \ - lobj/linit.o \ - lobj/eris.o \ + obj/lua/lapi.o \ + obj/lua/lcode.o \ + obj/lua/lctype.o \ + obj/lua/ldebug.o \ + obj/lua/ldo.o \ + obj/lua/ldump.o \ + obj/lua/lfunc.o \ + obj/lua/lgc.o \ + obj/lua/llex.o \ + obj/lua/lmem.o \ + obj/lua/lobject.o \ + obj/lua/lopcodes.o \ + obj/lua/lparser.o \ + obj/lua/lstate.o \ + obj/lua/lstring.o \ + obj/lua/ltable.o \ + obj/lua/ltm.o \ + obj/lua/lundump.o \ + obj/lua/lvm.o \ + obj/lua/lzio.o \ + obj/lua/lauxlib.o \ + obj/lua/lbaselib.o \ + obj/lua/lbitlib.o \ + obj/lua/lcorolib.o \ + obj/lua/ldblib.o \ + obj/lua/liolib.o \ + obj/lua/lmathlib.o \ + obj/lua/loslib.o \ + obj/lua/lstrlib.o \ + obj/lua/ltablib.o \ + obj/lua/loadlib.o \ + obj/lua/linit.o \ + obj/lua/eris.o \ CORE_OBJ_FILES=\ - obj/invocation.o\ - obj/spookyv2.o\ - obj/eng-malloc.o\ - obj/debugcollector.o\ - obj/drivenengine.o\ - obj/util.o\ - obj/luastack.o\ - obj/traceback.o\ - obj/planemap.o\ - obj/pprint.o\ - obj/luaconsole.o\ - obj/idalloc.o\ - obj/globaldb.o\ - obj/sched.o\ - obj/http.o\ - obj/json.o\ - obj/table.o\ - obj/gui.o\ - obj/luasnap.o\ - obj/animqueue.o\ - obj/streambuffer.o\ - obj/source.o\ - obj/world-core.o\ - obj/world-accessor.o\ - obj/world-difftab.o\ - obj/world-diffxmit.o\ - obj/world-pairtab.o\ - obj/world-testing.o\ - obj/textgame.o\ - obj/lpxserver.o\ - obj/lpxclient.o\ - obj/eng-tests.o\ - obj/printbuffer.o\ - obj/driver-util.o\ - obj/driver-ssl.o\ - obj/$(DRIVER).o\ + obj/cpp/invocation.o\ + obj/cpp/spookyv2.o\ + obj/cpp/eng-malloc.o\ + obj/cpp/debugcollector.o\ + obj/cpp/drivenengine.o\ + obj/cpp/util.o\ + obj/cpp/luastack.o\ + obj/cpp/traceback.o\ + obj/cpp/planemap.o\ + obj/cpp/pprint.o\ + obj/cpp/luaconsole.o\ + obj/cpp/idalloc.o\ + obj/cpp/globaldb.o\ + obj/cpp/sched.o\ + obj/cpp/http.o\ + obj/cpp/json.o\ + obj/cpp/table.o\ + obj/cpp/gui.o\ + obj/cpp/luasnap.o\ + obj/cpp/animqueue.o\ + obj/cpp/streambuffer.o\ + obj/cpp/source.o\ + obj/cpp/world-core.o\ + obj/cpp/world-accessor.o\ + obj/cpp/world-difftab.o\ + obj/cpp/world-diffxmit.o\ + obj/cpp/world-pairtab.o\ + obj/cpp/world-testing.o\ + obj/cpp/textgame.o\ + obj/cpp/lpxserver.o\ + obj/cpp/lpxclient.o\ + obj/cpp/eng-tests.o\ + obj/cpp/printbuffer.o\ -lobj/%.o: ../eris-master/src/%.c - gcc -Wall $(OPT) -DLUA_USE_APICHECK $(LUAFLAGS) -c -MMD $< -o $@ +DRV_OBJ_FILES=\ + objdrv/drvutil.o\ + objdrv/sslutil.o\ -obj/%.o: cpp/%.cpp - g++ -std=c++17 -Wall $(OPT) -I../eris-master/src -Iwrap -Icpp $(INCS) -c -MMD $< -o $@ -$(EXE): $(CORE_OBJ_FILES) $(LUA_OBJ_FILES) - g++ -std=c++17 -Wall $(OPT) -o $@ $(CORE_OBJ_FILES) $(LUA_OBJ_FILES) $(LIBS) +-include $(LUA_OBJ_FILES:%.o=%.d) +-include $(CORE_OBJ_FILES:%.o=%.d) +-include $(DRV_OBJ_FILES:%.o=%.d) + + +ifeq ($(OS),linux) + +OPT=-g -O0 + +main: $(DRV_OBJ_FILES) $(CORE_OBJ_FILES) $(LUA_OBJ_FILES) objdrv/driver-linux.o + g++ -std=c++17 -export-dynamic -Wall $(OPT) -o $@ $(DRV_OBJ_FILES) $(CORE_OBJ_FILES) $(LUA_OBJ_FILES) objdrv/driver-linux.o -L../linuxlib -lssl -lcrypto -ldl + +obj/lua/%.o: ../eris-master/src/%.c + gcc -Wall -fvisibility=hidden $(OPT) -DLUA_USE_APICHECK -DLUA_USE_POSIX -c -MMD $< -o $@ + +obj/cpp/%.o: cpp/%.cpp + g++ -Wall -fvisibility=hidden $(OPT) -std=c++17 -I../linuxlib -I../eris-master/src -Iwrap -Icpp -c -MMD $< -o $@ + +objdrv/%.o: drv/%.cpp + g++ -Wall -fvisibility=hidden $(OPT) -std=c++17 -I../linuxlib -Idrv -c -MMD $< -o $@ + +endif clean: - rm -f main.exe main obj/* lobj/* + rm -f main.exe main obj/cpp/*.* objdrv/*.* obj/lua/*.* + + --include $(CORE_OBJ_FILES:%.o=%.d) --include $(LUA_OBJ_FILES:%.o=%.d) diff --git a/luprex/core/cpp/drivenengine.cpp b/luprex/core/cpp/drivenengine.cpp index b7ce9569..717d49a5 100644 --- a/luprex/core/cpp/drivenengine.cpp +++ b/luprex/core/cpp/drivenengine.cpp @@ -3,9 +3,13 @@ #include "util.hpp" #include "drivenengine.hpp" +#include #include #include #include +#include +#include +#include DrivenEngineReg *DrivenEngineReg::All; @@ -16,23 +20,57 @@ DrivenEngineReg::DrivenEngineReg(const char *n, DrivenEngineMaker fn) { All = this; } -void DrivenEngine::print_usage(std::ostream &strm, std::string_view progname) { - strm << "Usage: " << progname << " " << std::endl; - for (auto reg = DrivenEngineReg::All; reg != nullptr; reg=reg->next) { - strm << " Mode can be: " << reg->name << std::endl; - } +DrivenEngineInitializer DrivenEngineInitializerReg::func; + +DrivenEngineInitializerReg::DrivenEngineInitializerReg(DrivenEngineInitializer fn) { + assert(func == nullptr); + func = fn; } -UniqueDrivenEngine DrivenEngine::make(std::string_view kind) { +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// DrivenEngine private methods +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +int DrivenEngine::find_unused_chid() { + // Note: channel ID zero is special, it is never reused. + for (int i = 0; i < DRV_MAX_CHAN; i++) { + int id = next_unused_chid_++; + if (next_unused_chid_ == DRV_MAX_CHAN) next_unused_chid_ = 1; + if (channels_[id] == nullptr) return id; + } + assert(false); + return 0; +} + +Channel *DrivenEngine::get_chid(int chid) const { + assert(unsigned(chid) < DRV_MAX_CHAN); + assert(channels_[chid].get() != nullptr); + return channels_[chid].get(); +} + +static DrivenEngine *make_engine(std::string_view kind) { for (auto reg = DrivenEngineReg::All; reg != nullptr; reg=reg->next) { - if (kind == std::string_view(reg->name)) { + if (kind == reg->name) { UniqueDrivenEngine result = reg->maker(); - return result; + return result.release(); } } return nullptr; } +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// Class Channel +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + + Channel::Channel(DrivenEngine *de, int chid, int port, const eng::string &target, bool stop) { chid_ = chid; port_ = port; @@ -136,24 +174,16 @@ void Channel::sent_outgoing(int nbytes) { sb_drvout_->read_bytes(nbytes); } -int DrivenEngine::find_unused_chid() { - // Note: channel ID zero is special, it is never reused. - for (int i = 0; i < MAX_CHAN; i++) { - int id = next_unused_chid_++; - if (next_unused_chid_ == MAX_CHAN) next_unused_chid_ = 1; - if (channels_[id] == nullptr) return id; - } - assert(false); - return 0; -} - -Channel *DrivenEngine::get_chid(int chid) const { - assert(unsigned(chid) < MAX_CHAN); - assert(channels_[chid].get() != nullptr); - return channels_[chid].get(); -} +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// DrivenEngine Client-Side API +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// void DrivenEngine::listen_port(int port) { + assert(listen_ports_.size() < DRV_MAX_LISTEN_PORTS); listen_ports_.push_back(port); } @@ -193,92 +223,255 @@ void DrivenEngine::rescan_lua_source() { void DrivenEngine::stop_driver() { stop_driver_ = true; - for (int i = 0; i < MAX_CHAN; i++) { + for (int i = 0; i < DRV_MAX_CHAN; i++) { if (channels_[i] != nullptr) { channels_[i]->stop_driver_ = true; } } } -const eng::vector &DrivenEngine::drv_get_listen_ports() const { - return listen_ports_; +DrivenEngine::DrivenEngine() { + next_unused_chid_ = 1; + stdio_channel_ = eng::make_shared(this, 0, 0, "", false); + stdio_channel_->sb_drvout_ = eng::make_shared(); + channels_[0] = stdio_channel_; + rescan_lua_source_ = false; + clock_ = 0.0; + stop_driver_ = false; } -const eng::vector &DrivenEngine::drv_get_new_outgoing() const { - return new_outgoing_; +DrivenEngine::~DrivenEngine() {} + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// LOGFILE EVENT IDS. +// +// There's one event ID for each mutator, plus one for 'release'. +// +// There are no event IDs for getters, these aren't considered loggable events. +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +enum DrvAction { + PLAY_INITIALIZE, + PLAY_CLEAR_NEW_OUTGOING, + PLAY_SENT_OUTGOING, + PLAY_RECV_INCOMING, + PLAY_NOTIFY_CLOSE, + PLAY_NOTIFY_ACCEPT, + PLAY_INVOKE_EVENT_UPDATE, + PLAY_SET_LUA_SOURCE, + PLAY_RELEASE, +}; + +inline static const char *action_string(DrvAction act) { + switch(act) { + case PLAY_INITIALIZE: return "PLAY_INITIALIZE"; + case PLAY_CLEAR_NEW_OUTGOING: return "PLAY_CLEAR_NEW_OUTGOING"; + case PLAY_SENT_OUTGOING: return "PLAY_SENT_OUTGOING"; + case PLAY_RECV_INCOMING: return "PLAY_RECV_INCOMING"; + case PLAY_NOTIFY_CLOSE: return "PLAY_NOTIFY_CLOSE"; + case PLAY_NOTIFY_ACCEPT: return "PLAY_NOTIFY_ACCEPT"; + case PLAY_SET_LUA_SOURCE: return "PLAY_SET_LUA_SOURCE"; + case PLAY_INVOKE_EVENT_UPDATE: return "PLAY_INVOKE_EVENT_UPDATE"; + case PLAY_RELEASE: return "PLAY_RELEASE"; + default: return "unknown"; + } +} +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// RLOG and WLOG, functions to read and write binary data to logfiles. +// +// After doing an rlog operation, you should check the stream +// for "good" to find out if there was any error. +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +class PlayLogfile : public std::ofstream { using std::ofstream::ofstream; }; +class ReplayLogfile : public std::ifstream { using std::ifstream::ifstream; }; + +static uint8_t rlog_uint8(EngineWrapper *w) { + uint8_t result; + w->rlog->read((char *)&result, 1); + if (!w->rlog->good()) return 0; + return result; } -void DrivenEngine::drv_clear_new_outgoing() { - new_outgoing_.clear(); +static uint32_t rlog_uint32(EngineWrapper *w) { + uint32_t result; + w->rlog->read((char *)&result, 4); + if (!w->rlog->good()) return 0; + return result; } -std::string_view DrivenEngine::drv_get_target(int chid) const { - return get_chid(chid)->target_; +static uint64_t rlog_uint64(EngineWrapper *w) { + uint64_t result; + w->rlog->read((char *)&result, 8); + if (!w->rlog->good()) return 0; + return result; } -bool DrivenEngine::drv_outgoing_empty(int chid) const { - std::string_view view = drv_peek_outgoing(chid); - return (view.size() == 0); +static double rlog_double(EngineWrapper *w) { + double result; + w->rlog->read((char *)&result, 8); + if (!w->rlog->good()) return 0.0; + return result; } -bool DrivenEngine::drv_get_channel_released(int chid) const { +std::string_view rlog_short_string(EngineWrapper *w) { + uint32_t len = rlog_uint8(w); + if (len == 255) { + len = rlog_uint32(w); + } + assert (len <= DRV_SHORTSTRING_SIZE); + if (len > 0) w->rlog->read(w->databuffer, len); + if (!w->rlog->good()) return std::string_view(); + return std::string_view(w->databuffer, len); +} + +std::string rlog_string(EngineWrapper *w) { + uint32_t len = rlog_uint8(w); + if (len == 255) { + len = rlog_uint32(w); + } + std::string result(len, ' '); + if (len > 0) w->rlog->read(&result[0], len); + if (!w->rlog->good()) return ""; + return result; +} + +static void wlog_uint8(EngineWrapper *w, uint8_t v) { + w->wlog->put((char)v); +} + +static void wlog_uint32(EngineWrapper *w, uint32_t v) { + w->wlog->write((const char *)&v, 4); +} + +static void wlog_uint64(EngineWrapper *w, uint64_t v) { + w->wlog->write((const char *)&v, 8); +} + +static void wlog_double(EngineWrapper *w, double v) { + w->wlog->write((const char *)&v, 8); +} + +static void wlog_short_string(EngineWrapper *w, std::string_view v) { + assert (v.size() <= DRV_SHORTSTRING_SIZE); + if (v.size() >= 255) { + wlog_uint8(w, 0xFF); + wlog_uint32(w, v.size()); + } else { + wlog_uint8(w, v.size()); + } + w->wlog->write(v.data(), v.size()); +} + +static void wlog_string(EngineWrapper *w, std::string_view v) { + if (v.size() >= 255) { + wlog_uint8(w, 0xFF); + wlog_uint32(w, v.size()); + } else { + wlog_uint8(w, v.size()); + } + w->wlog->write(v.data(), v.size()); +} + +static void wlog_cmd_hash(EngineWrapper *w, DrvAction act, uint32_t hash) { + wlog_uint8(w, act); + wlog_uint32(w, hash); +} + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// reset_wrapper +// +// Shut down a EngineWrapper, store an optional error message. +// +// release +// +// Shut down an EngineWrapper cleanly, with no error message, and +// log the step if the logfile is open. +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +static void reset_wrapper(EngineWrapper *w, const char *format, ...) { + va_list argp; + va_start(argp, format); + memset(w->error, 0, DRV_ERRMSG_SIZE); + vsnprintf(w->error, DRV_ERRMSG_SIZE, format, argp); + w->error[DRV_ERRMSG_SIZE - 1] = 0; + + if (w->wlog != nullptr) { + w->wlog->close(); + delete w->wlog; + w->wlog = nullptr; + } + + if (w->rlog != nullptr) { + w->rlog->close(); + delete w->rlog; + w->rlog = nullptr; + } + + if (w->engine != nullptr) { + delete w->engine; + w->engine = nullptr; + } +} + +static void release(EngineWrapper *w) { + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_RELEASE, eng::memhash()); + } + reset_wrapper(w, ""); +}; + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// DRIVER Methods: Getters +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +void DrivenEngine::drv_get_listen_ports(uint32_t *nports, const uint32_t **ports) const { + *nports = listen_ports_.size(); + *ports = &listen_ports_[0]; +} + +void DrivenEngine::drv_get_new_outgoing(uint32_t *nchids, const uint32_t **chids) const { + *nchids = new_outgoing_.size(); + *chids = &new_outgoing_[0]; +} + +const char *DrivenEngine::drv_get_target(uint32_t chid) const { + return get_chid(chid)->target_.c_str(); +} + +bool DrivenEngine::drv_get_channel_released(uint32_t chid) const { return channels_[chid].use_count() == 1; } -std::string_view DrivenEngine::drv_peek_outgoing(int chid) const { - return get_chid(chid)->peek_outgoing(); +void DrivenEngine::drv_get_outgoing(uint32_t chid, uint32_t *len, const char **data) const { + std::string_view v = get_chid(chid)->peek_outgoing(); + *len = v.size(); + *data = v.data(); } -void DrivenEngine::drv_sent_outgoing(int chid, int nbytes) { - return get_chid(chid)->sent_outgoing(nbytes); +bool DrivenEngine::drv_get_outgoing_empty(uint32_t chid) const { + std::string_view v = get_chid(chid)->peek_outgoing(); + return (v.size() == 0); } -void DrivenEngine::drv_recv_incoming(int chid, std::string_view data) { - if (data.size() > 0) { - Channel *ch = get_chid(chid); - if (ch->sb_drvout_ != ch->sb_out_) { - ch->feed_readline(data); - } else { - ch->sb_in_->write_bytes(data); - } - } -} - -void DrivenEngine::drv_notify_close(int chid, std::string_view err) { - Channel *ch = get_chid(chid); - ch->closed_ = true; - ch->error_ = err; - channels_[chid].reset(); -} - -int DrivenEngine::drv_notify_accept(int port) { - int chid = find_unused_chid(); - channels_[chid] = eng::make_shared(this, chid, port, "", stop_driver_); - accepted_channels_.push_back(channels_[chid]); - return chid; -} - -void DrivenEngine::drv_clear_lua_source() { - lua_source_.reset(); - rescan_lua_source_ = false; -} - -void DrivenEngine::drv_add_lua_source(std::string_view fn, std::string_view data) { - if (lua_source_ == nullptr) { - lua_source_.reset(new util::LuaSourceVec); - } - lua_source_->emplace_back(eng::string(fn), eng::string(data)); -} - -void DrivenEngine::drv_invoke_event_init(int argc, char *argv[]) { - event_init(argc, argv); - stdio_channel_->pump_readline(); -} - -void DrivenEngine::drv_invoke_event_update(double clock) { - clock_ = clock; - event_update(); - stdio_channel_->pump_readline(); +double DrivenEngine::drv_get_clock() const { + return clock_; } bool DrivenEngine::drv_get_rescan_lua_source() const { @@ -289,25 +482,539 @@ bool DrivenEngine::drv_get_stop_driver() const { return stop_driver_; } -DrivenEngine::DrivenEngine() { - next_unused_chid_ = 1; - stdio_channel_ = eng::make_shared(this, 0, 0, "", false); - stdio_channel_->sb_drvout_ = eng::make_shared(); - channels_[0] = stdio_channel_; - rescan_lua_source_ = true; - clock_ = 0.0; - stop_driver_ = false; +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// DRIVER Methods: Mutators +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +void DrivenEngine::drv_initialize(uint32_t srcpklen, const char *srcpk, int argc, char **argv) { + drv_set_lua_source(srcpklen, srcpk); + event_init(argc, argv); + stdio_channel_->pump_readline(); } -DrivenEngine::~DrivenEngine() {} - -static DrivenEngine *engine_; - -void DrivenEngine::set(DrivenEngine *de) { - engine_ = de; +void DrivenEngine::drv_clear_new_outgoing() { + new_outgoing_.clear(); } -DrivenEngine *DrivenEngine::get() { - return engine_; +void DrivenEngine::drv_sent_outgoing(uint32_t chid, uint32_t nbytes) { + return get_chid(chid)->sent_outgoing(nbytes); } +void DrivenEngine::drv_recv_incoming(uint32_t chid, uint32_t nbytes, const char *bytes) { + if (nbytes > 0) { + Channel *ch = get_chid(chid); + if (ch->sb_drvout_ != ch->sb_out_) { + ch->feed_readline(bytes); + } else { + ch->sb_in_->write_bytes(bytes); + } + } +} + +void DrivenEngine::drv_notify_close(uint32_t chid, uint32_t len, const char *data) { + Channel *ch = get_chid(chid); + ch->closed_ = true; + ch->error_ = std::string(data, len); + channels_[chid].reset(); +} + +uint32_t DrivenEngine::drv_notify_accept(uint32_t port) { + int chid = find_unused_chid(); + channels_[chid] = eng::make_shared(this, chid, port, "", stop_driver_); + accepted_channels_.push_back(channels_[chid]); + return chid; +} + +void DrivenEngine::drv_invoke_event_update(double clock) { + clock_ = clock; + event_update(); + stdio_channel_->pump_readline(); +} + +void DrivenEngine::drv_set_lua_source(uint32_t srcpklen, const char *srcpk) { + StreamBuffer sb(srcpk, srcpklen); + uint32_t nfiles = sb.read_uint32(); + lua_source_.reset(new util::LuaSourceVec); + lua_source_->resize(nfiles); + for (uint32_t i = 0; i < nfiles; i++) { + (*lua_source_)[i].first = sb.read_string(); + } + for (uint32_t i = 0; i < nfiles; i++) { + (*lua_source_)[i].second = sb.read_string(); + } + rescan_lua_source_ = false; +} + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// C Wrappers: Getters +// +// These wrappers make it possible to call the drv_get routines using C +// functions instead of methods. This is important if the engine is compiled +// with one C++ compiler, but the driver is compiled with a different C++ +// compiler. +// +// Some of these take parameter 'EngineWrapper', some take 'EngineWrapper', +// and some come in two versions. This all depends on whether they are used +// during play, during replay, or both. +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +static void drv_get_listen_ports(EngineWrapper *w, uint32_t *nports, const uint32_t **ports) { + return w->engine->drv_get_listen_ports(nports, ports); +} + +static void drv_get_new_outgoing(EngineWrapper *w, uint32_t *nchanids, const uint32_t **chanids) { + return w->engine->drv_get_new_outgoing(nchanids, chanids); +} + +static const char *drv_get_target(EngineWrapper *w, uint32_t chid) { + return w->engine->drv_get_target(chid); +} + +static bool drv_get_channel_released(EngineWrapper *w, uint32_t chid) { + return w->engine->drv_get_channel_released(chid); +} + +static void drv_get_outgoing(EngineWrapper *w, uint32_t chid, uint32_t *len, const char **data) { + return w->engine->drv_get_outgoing(chid, len, data); +} + +static bool drv_get_outgoing_empty(EngineWrapper *w, uint32_t chid) { + return w->engine->drv_get_outgoing_empty(chid); +} + +static double drv_get_clock(EngineWrapper *w) { + return w->engine->drv_get_clock(); +} + +static bool drv_get_rescan_lua_source(EngineWrapper *w) { + return w->engine->drv_get_rescan_lua_source(); +} + +static bool drv_get_stop_driver(EngineWrapper *w) { + return w->engine->drv_get_stop_driver(); +} + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// C Wrappers: Mutators +// +// The wrapper for a mutator consists of two parts: the wrapper which is used at +// 'play' time, and the wrapper which is used at 'replay' time. +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + + +static void play_initialize(EngineWrapper *w, uint32_t argc, char **argv, uint32_t srcpklen, const char *srcpk, const char *logfn) { + if (w->engine != nullptr) { + return reset_wrapper(w, "Cannot initialize wrapper, it's already initialized."); + } + + // Clear the error message. + memset(w->error, 0, DRV_ERRMSG_SIZE); + + // Open the logfile, if any is specified. + if ((logfn != nullptr) && (logfn[0] != 0)) { + w->wlog = new PlayLogfile(logfn, std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); + if (!w->wlog->good()) { + return reset_wrapper(w, "Could not open replay log for writing: %s", logfn); + } + } else { + w->wlog = nullptr; + } + + // If we have a logfile, then log this initialization. + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_INITIALIZE, eng::memhash()); + wlog_uint32(w, argc); + for (uint32_t i = 0; i < argc; i++) { + wlog_string(w, argv[i]); + } + wlog_string(w, std::string_view(srcpk, srcpklen)); + w->wlog->flush(); + } + + // Create the engine of the appropriate type. + if (argc < 1) { + std::ostringstream oss; + oss << "Must pass an engine type on the command line. Known types:\n"; + for (auto reg = DrivenEngineReg::All; reg != nullptr; reg=reg->next) { + oss << " " << reg->name << std::endl; + } + std::string err = oss.str(); + return reset_wrapper(w, err.c_str()); + } + w->engine = make_engine(argv[0]); + if (w->engine == nullptr) { + return reset_wrapper(w, "No such driven engine type: %s", argv[0]); + } + + // Call the engine initialization sequence. + w->engine->drv_initialize(srcpklen, srcpk, argc - 1, argv + 1); +} + + +static void replay_initialize(EngineWrapper *w) { + assert(w->rlog != nullptr); + std::vector argvstr; + uint32_t argc = rlog_uint32(w); + for (uint32_t i = 0; i < argc; i++) { + argvstr.push_back(rlog_string(w)); + } + std::string srcpk = rlog_string(w); + + if (!w->rlog->good()) { + return reset_wrapper(w, "replay log corrupt in replay_initialize"); + } + + // We need to convert the argument vector from an array + // of C++ strings into the canonical argc, argv format. + std::vector argvec; + for (uint32_t i = 0; i < argc; i++) { + argvec.push_back(&argvstr[i][0]); + } + char **argv = &argvec[0]; + + // Create the engine. + w->engine = make_engine(argv[0]); + if (w->engine == nullptr) { + return reset_wrapper(w, "No such driven engine type: %s", argvstr[0]); + } + + + w->engine->drv_initialize(srcpk.size(), srcpk.c_str(), argc - 1, argv + 1); +} + + +//////////////////////// + + +static void play_clear_new_outgoing(EngineWrapper *w) { + assert(w->rlog == nullptr); + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_CLEAR_NEW_OUTGOING, eng::memhash()); + w->wlog->flush(); + } + w->engine->drv_clear_new_outgoing(); +} + +static void replay_clear_new_outgoing(EngineWrapper *w) { + w->engine->drv_clear_new_outgoing(); +} + + +//////////////////////// + + +static void play_sent_outgoing(EngineWrapper *w, uint32_t chid, uint32_t nbytes) { + assert(w->rlog == nullptr); + if (w->wlog != nullptr) { + uint32_t ndata; const char *data; + w->engine->drv_get_outgoing(chid, &ndata, &data); + assert(nbytes <= ndata); + wlog_cmd_hash(w, PLAY_SENT_OUTGOING, eng::memhash()); + wlog_uint32(w, chid); + wlog_uint32(w, nbytes); + wlog_uint64(w, SpookyHash::QkHash64(data, nbytes)); + w->wlog->flush(); + } + w->engine->drv_sent_outgoing(chid, nbytes); +} + +static void replay_sent_outgoing(EngineWrapper *w) { + uint32_t chid = rlog_uint32(w); + uint32_t nbytes = rlog_uint32(w); + uint64_t hash = rlog_uint64(w); + + if (!w->rlog->good()) { + return reset_wrapper(w, "replay log corrupt in replay_sent_outgoing"); + } + + uint32_t ndata; const char *data; + w->engine->drv_get_outgoing(chid, &ndata, &data); + if ((nbytes > ndata) || (hash != SpookyHash::QkHash64(data, nbytes))) { + return reset_wrapper(w, "nondeterministic in replay_sent_outgoing"); + } + w->engine->drv_sent_outgoing(chid, nbytes); +} + + +//////////////////////// + + +static void play_recv_incoming(EngineWrapper *w, uint32_t chid, uint32_t len, const char *data) { + assert(w->rlog == nullptr); + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_RECV_INCOMING, eng::memhash()); + wlog_uint32(w, chid); + wlog_short_string(w, std::string_view(data, len)); + w->wlog->flush(); + } + w->engine->drv_recv_incoming(chid, len, data); +} + +static void replay_recv_incoming(EngineWrapper *w) { + uint32_t chid = rlog_uint32(w); + std::string_view data = rlog_short_string(w); + + if (!w->rlog->good()) { + return reset_wrapper(w, "replay log corrupt in replay_recv_incoming"); + } + + w->engine->drv_recv_incoming(chid, data.size(), data.data()); +} + + +//////////////////////// + + +static void play_notify_close(EngineWrapper *w, uint32_t chid, uint32_t len, const char *data) { + assert(w->rlog == nullptr); + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_NOTIFY_CLOSE, eng::memhash()); + wlog_uint32(w, chid); + wlog_string(w, std::string_view(data, len)); + w->wlog->flush(); + } + + w->engine->drv_notify_close(chid, len, data); +} + +static void replay_notify_close(EngineWrapper *w) { + uint32_t chid = rlog_uint32(w); + std::string message = rlog_string(w); + + if (!w->rlog->good()) { + return reset_wrapper(w, "replay log corrupt in replay_notify_close"); + } + + w->engine->drv_notify_close(chid, message.size(), message.c_str()); +} + + +//////////////////////// + + +static uint32_t play_notify_accept(EngineWrapper *w, uint32_t port) { + assert(w->rlog == nullptr); + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_NOTIFY_ACCEPT, eng::memhash()); + wlog_uint32(w, port); + w->wlog->flush(); + } + + return w->engine->drv_notify_accept(port); +} + +static void replay_notify_accept(EngineWrapper *w) { + uint32_t port = rlog_uint32(w); + + if (!w->rlog->good()) { + return reset_wrapper(w, "replay log corrupt in replay_notify_accept"); + } + + w->engine->drv_notify_accept(port); +} + + +//////////////////////// + + +static void play_invoke_event_update(EngineWrapper *w, double clock) { + assert(w->rlog == nullptr); + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_INVOKE_EVENT_UPDATE, eng::memhash()); + wlog_double(w, clock); + w->wlog->flush(); + } + + w->engine->drv_invoke_event_update(clock); +} + +static void replay_invoke_event_update(EngineWrapper *w) { + double clock = rlog_double(w); + + if (!w->rlog->good()) { + return reset_wrapper(w, "replay log corrupt in replay_event_update"); + } + + w->engine->drv_invoke_event_update(clock); +} + + +//////////////////////// + + +void play_set_lua_source(EngineWrapper *w, uint32_t srcpklen, const char *srcpk) { + assert(w->rlog == nullptr); + if (w->wlog != nullptr) { + wlog_cmd_hash(w, PLAY_SET_LUA_SOURCE, eng::memhash()); + wlog_string(w, std::string_view(srcpk, srcpklen)); + w->wlog->flush(); + } + + w->engine->drv_set_lua_source(srcpklen, srcpk); +} + +void replay_set_lua_source(EngineWrapper *w) { + std::string srcpack = rlog_string(w); + + if (!w->rlog->good()) { + return reset_wrapper(w, "replay log corrupt in replay_set_lua_source"); + } + + w->engine->drv_set_lua_source(srcpack.size(), srcpack.c_str()); +} + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// Replay Core +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + + +static void replaycore_initialize(EngineWrapper *w, const char *logfn) { + std::cerr << "Memhash before replaycore_initialize: " << eng::memhash() << std::endl; + if (w->engine != nullptr) { + return reset_wrapper(w, "Cannot initialize wrapper, it's already initialized."); + return; + } + + // Clear the error message. + memset(w->error, 0, DRV_ERRMSG_SIZE); + + // Open the logfile. + w->rlog = new ReplayLogfile(logfn, std::ios_base::in | std::ios_base::binary); + if (!w->rlog->good()) { + return reset_wrapper(w, "Could not open replay log for reading: %s", logfn); + } + + // Read one step from the logfile, and make sure it's an initialize step. + uint8_t code = rlog_uint8(w); + int hash = rlog_uint32(w); + if (!w->rlog->good()) { + return reset_wrapper(w, "logfile corrupt"); + } + if (hash != eng::memhash()) { + return reset_wrapper(w, "nondeterminism detected in initial step"); + } + if (code != PLAY_INITIALIZE) { + return reset_wrapper(w, "replay log doesn't begin with initialize step"); + } + + // Replay the initialize step from the logfile. + // Doing this immediately, rather than waiting for the driver + // to call 'step', enforces the invariant that after calling + // initialize, there's an engine. + replay_initialize(w); +} + +static void replaycore_step(EngineWrapper *w) { + if (w->rlog == nullptr) { + return; + } + + uint8_t code = rlog_uint8(w); + int hash = rlog_uint32(w); + if (!w->rlog->good()) { + return reset_wrapper(w, "logfile corrupt"); + } + if (hash != eng::memhash()) { + return reset_wrapper(w, "nondeterminism detected"); + } + switch (code) { + case PLAY_CLEAR_NEW_OUTGOING: replay_clear_new_outgoing(w); return; + case PLAY_SENT_OUTGOING: replay_sent_outgoing(w); return; + case PLAY_RECV_INCOMING: replay_recv_incoming(w); return; + case PLAY_NOTIFY_CLOSE: replay_notify_close(w); return; + case PLAY_NOTIFY_ACCEPT: replay_notify_accept(w); return; + case PLAY_SET_LUA_SOURCE: replay_set_lua_source(w); return; + case PLAY_INVOKE_EVENT_UPDATE: replay_invoke_event_update(w); return; + case PLAY_RELEASE: release(w); return; + default: return reset_wrapper(w, "Replay log corrupt in command dispatcher"); + } +} + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// General Mutators +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + + + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// +// +// Wrapper Initialization +// +// To access the engine across a DLL boundary, you first use +// GetProcAddress or dlsym to fetch the addresses of 'init_play_engine' +// and 'init_replay_engine'. Then, you use those two functions to +// initialize a EngineWrapper or a EngineWrapper, which contain the addresses +// of all the other functions you need. These are the only two functions +// marked 'DLLEXPORT', all other functions are exported from the DLL +// indirectly. +// +////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +#if defined(__linux__) + #define DLLEXPORT __attribute__((visibility("default"))) +#elif defined(_WIN32) + #define DLLEXPORT __declspec(dllexport) +#endif + +static void init_engine_wrapper_helper(EngineWrapper *w) { + static bool called_initializer; + assert(DrivenEngineInitializerReg::func != nullptr); + if (!called_initializer) { + DrivenEngineInitializerReg::func(); + called_initializer = true; + } + + memset(w, 0, sizeof(EngineWrapper)); + + w->get_listen_ports = drv_get_listen_ports; + w->get_new_outgoing = drv_get_new_outgoing; + w->get_target = drv_get_target; + w->get_channel_released = drv_get_channel_released; + w->get_outgoing = drv_get_outgoing; + w->get_outgoing_empty = drv_get_outgoing_empty; + w->get_clock = drv_get_clock; + w->get_rescan_lua_source = drv_get_rescan_lua_source; + w->get_stop_driver = drv_get_stop_driver; + + w->play_initialize = play_initialize; + w->play_clear_new_outgoing = play_clear_new_outgoing; + w->play_sent_outgoing = play_sent_outgoing; + w->play_recv_incoming = play_recv_incoming; + w->play_notify_close = play_notify_close; + w->play_notify_accept = play_notify_accept; + w->play_invoke_event_update = play_invoke_event_update; + w->play_set_lua_source = play_set_lua_source; + + w->replay_initialize = replaycore_initialize; + w->replay_step = replaycore_step; + + w->release = release; +}; + +extern "C" { +DLLEXPORT void init_engine_wrapper(EngineWrapper *w) { + init_engine_wrapper_helper(w); +} +} diff --git a/luprex/core/cpp/drivenengine.hpp b/luprex/core/cpp/drivenengine.hpp index 9d7a58cc..a24ec9c6 100644 --- a/luprex/core/cpp/drivenengine.hpp +++ b/luprex/core/cpp/drivenengine.hpp @@ -2,12 +2,9 @@ // // DrivenEngine // -// This module embodies the idea of an "event-driven game engine." We want the -// engine to be event-driven because an event-driven engine is a deterministic -// state machine. That, in turn, makes it possible to do replay logging. -// -// The DrivenEngine module provides two APIs: the 'engine-side' API, and the -// 'driver-side' API. +// This module embodies the idea of an "event-driven game engine." The +// DrivenEngine module provides two APIs: the engine-side API, and the +// driver-side API. // // The engine-side API looks like a typical collection of I/O primitives. It // includes methods to open sockets, read and write sockets, read lua source, @@ -41,53 +38,6 @@ // machine, free of all OS-specific code. // ////////////////////////////////////////////////////////////// -// -// Here are the rules for what the driver must do: -// -// * Before doing anything else, the driver must select one of the three -// logmodes. -// -// * If 'logmode_replay' is selected, then the driver must proceed to invoke -// 'drv_step_logfile' over and over until it returns false. In replay mode, -// the driver should not do anything else. -// -// * If 'logmode_write' or 'logmode_none' is selected, the driver must proceed -// to drive the application. Follow the remainder of these steps. -// -// * Open a hardwired list of ports for listening. -// -// * Repeat the following steps over and over: -// -// - If the engine asked that the lua source be refreshed, read the source -// from disk and call 'drv_set_lua_source'. -// -// - Get a list of recently-closed channels using drv_get_closed_channels. -// Close any socket associated with these channels and free all resources. -// -// - Get a list of recently-opened channels using drv_get_opened_channels. -// Open new outgoing connections for these channels. -// -// - Do an OS 'poll'. The poll should include the sockets for all channels -// in the channel list, all listening ports, and stdio. -// -// - If the poll indicates that a listening port has acceptable -// connections, accept and call drv_notify_accept. Associate the -// accepted socket with the channel. -// -// - If the poll indicates that a connection can accept outgoing data, use -// drv_peek_outgoing to fetch some data to write, and write it. Use -// drv_sent_outgoing_bytes to indicate that the data was sent. -// -// - If the poll indicates that a connection has incoming data, read the -// data then push it into the channel using drv_recv_incoming. -// -// - If the poll indicates that STDIO can be read/written, use -// drv_peek_outgoing, drv_sent_outgoing, and drv_recv_incoming in the -// same manner as you would for a socket. -// -// - Use 'drv_invoke_event_update' to invoke the engine's update callback. -// -////////////////////////////////////////////////////////////// #ifndef DRIVENENGINE_HPP #define DRIVENENGINE_HPP @@ -101,10 +51,12 @@ #include "util.hpp" #include "streambuffer.hpp" +#include "enginewrapper.hpp" class DrivenEngine; using UniqueDrivenEngine = std::unique_ptr; using DrivenEngineMaker = UniqueDrivenEngine (*)(); +using DrivenEngineInitializer = void (*)(); class Channel : public eng::opnew { public: @@ -209,10 +161,14 @@ public: // ////////////////////////////////////////////////////////////// + // The init callback. You may override this in a subclass. + // This will be called once at program initialization. + // + virtual void event_init(int argc, char *argv[]) {} + // The update callback. You may override this in a subclass. // This will be called whenever anything changes. // - virtual void event_init(int argc, char *argv[]) {} virtual void event_update() {} // Specify the set of listening ports. @@ -288,111 +244,6 @@ public: // void stop_driver(); - ////////////////////////////////////////////////////////////// - // - // The following methods are the 'driver' side of the pipe. - // - ////////////////////////////////////////////////////////////// - - // The maximum channel ID plus one. - // - static const int MAX_CHAN = 256; - - // Get a list of all the listening ports. The driver is expected - // to fetch this set shortly after the event_init callback is invoked. - // - const eng::vector &drv_get_listen_ports() const; - - // Get a list of all recently-opened channels that were created using - // drv_new_outgoing_channel. The driver should initiate outgoing - // connections for these channels. - // - const eng::vector &drv_get_new_outgoing() const; - - // Clear the list of recently-opened channels that were created using - // drv_new_outgoing_channel. - // - void drv_clear_new_outgoing(); - - // Get the target of a channel. A target is a string like - // "cert:whatever.com:80" or "nocert:whatever.com:80". - // The first word indicate whether or not a valid SSL certificate - // is required. The second word is the hostname. The third word is - // the port number. - // - std::string_view drv_get_target(int chid) const; - - // Return true if the outgoing buffer is empty. - // - bool drv_outgoing_empty(int chid) const; - - // Return true if the user has released all references to this channel. - // In this case, the driver should initiate shutdown of the channel, - // and the driver should eventually call drv_notify_close. - // - bool drv_get_channel_released(int chid) const; - - // Get a pointer to the bytes in the outgoing buffer. The pointer returned - // here is naturally only valid until the buffer is changed. This function - // is used for all channels, including sockets and stdio. - // - std::string_view drv_peek_outgoing(int chid) const; - - // Notifies the channel that some bytes were transmitted. This causes those - // bytes to be removed from the outgoing buffer. This function is used for - // all channels, including sockets and stdio. - // - void drv_sent_outgoing(int chid, int nbytes); - - // Notifies the channel that some bytes were received. This causes those - // bytes to be appended to the incoming buffer. This function is used for - // all channels, including sockets and stdio. - // - void drv_recv_incoming(int chid, std::string_view data); - - // Notify the channel that the connection was closed. This includes all - // sorts of closes, including friendly termination, all the way to network - // failure. Closing the channel doesn't delete it. The engine is - // responsible for noticing that the channel closed and the engine must - // delete it. Closing a channel prevents it from showing up in - // 'drv_list_channels'. - // - void drv_notify_close(int chid, std::string_view err); - - // Notify the DrivenEngine that somebody connected to an incoming port. - // This will cause the DrivenEngine to allocate a new channel and put the - // new channel into the incoming channels queue. Returns the new channel - // ID. The new incoming channel appears in the 'drv_list_channels' list, - // even before the engine pops the channel from the incoming channels queue. - // - int drv_notify_accept(int port); - - - // Clear the lua source code. - // - void drv_clear_lua_source(); - - // Set the lua source code. The driver is expected to read the lua source - // code and store it (using this function) once before invoking - // - void drv_add_lua_source(std::string_view fn, std::string_view data); - - // Invoke the init or update event. - // - void drv_invoke_event_init(int argc, char *argv[]); - void drv_invoke_event_update(double clock); - - // Check the 'rescan_lua_source' flag. If this flag is set, it means - // that the engine wants the driver to rescan the lua source code. - // When the driver sees this flag, it should rescan the source and call - // drv_set_source. - // - bool drv_get_rescan_lua_source() const; - - // If true, the engine is done. Stop the driver. - // - bool drv_get_stop_driver() const; - ////////////////////////////////////////////////////////////// // // Creation and Destruction. @@ -413,15 +264,38 @@ public: // virtual ~DrivenEngine(); - // Set/Get Global Pointer. + ////////////////////////////////////////////////////////////// // - // Normally, there is a single global "DrivenEngine" instance. - // We provide a global pointer to store this instance. This is - // a raw pointer, you must manually delete the DrivenEngine. + // The following accessors are for use by PlayWrapper and ReplayWrapper. // - static void set(DrivenEngine *de); - static DrivenEngine *get(); + // The PlayWrapper and ReplayWrapper use C stubs to access + // the engine. The C stubs, in turn, call these C++ methods. + // + // The stubs for the getters are trivial, one-line stubs. + // + // The stubs for the mutators add logging. + // + ////////////////////////////////////////////////////////////// + void drv_get_listen_ports(uint32_t *nports, const uint32_t **ports) const; + void drv_get_new_outgoing(uint32_t *nchids, const uint32_t **chids) const; + const char *drv_get_target(uint32_t chid) const; + bool drv_get_channel_released(uint32_t chid) const; + void drv_get_outgoing(uint32_t chid, uint32_t *len, const char **data) const; + bool drv_get_outgoing_empty(uint32_t chid) const; + double drv_get_clock() const; + bool drv_get_rescan_lua_source() const; + bool drv_get_stop_driver() const; + + void drv_initialize(uint32_t srcpklen, const char *srcpk, int argc, char **argv); + void drv_clear_new_outgoing(); + void drv_sent_outgoing(uint32_t chid, uint32_t nbytes); + void drv_recv_incoming(uint32_t chid, uint32_t nbytes, const char *bytes); + void drv_notify_close(uint32_t chid, uint32_t len, const char *data); + uint32_t drv_notify_accept(uint32_t port); + void drv_invoke_event_update(double clock); + void drv_set_lua_source(uint32_t srcpklen, const char *srcpk); + private: // Find a currently-unused channel ID. Channel IDs // are small integers that are reused. @@ -431,19 +305,23 @@ private: Channel *get_chid(int chid) const; private: - SharedChannel channels_[MAX_CHAN]; + SharedChannel channels_[DRV_MAX_CHAN]; int next_unused_chid_; SharedChannel stdio_channel_; eng::vector accepted_channels_; - eng::vector new_outgoing_; + eng::vector new_outgoing_; util::LuaSourcePtr lua_source_; - eng::vector listen_ports_; + eng::vector listen_ports_; bool rescan_lua_source_; double clock_; bool stop_driver_; friend class Channel; }; + +////////////////////////////////////////////////////////////////////////////////// + + struct DrivenEngineReg { const char *name; DrivenEngineMaker maker; @@ -458,5 +336,10 @@ struct DrivenEngineReg { } \ DrivenEngineReg dengreg_##cname(name, dengmake_##cname); +struct DrivenEngineInitializerReg { + static DrivenEngineInitializer func; + DrivenEngineInitializerReg(DrivenEngineInitializer f); +}; + #endif // DRIVENENGINE_HPP diff --git a/luprex/core/cpp/enginewrapper.hpp b/luprex/core/cpp/enginewrapper.hpp new file mode 100644 index 00000000..e8323df8 --- /dev/null +++ b/luprex/core/cpp/enginewrapper.hpp @@ -0,0 +1,217 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// enginewrapper.hpp +// +// This header file contains driver's interface to class DrivenEngine. +// This is meant to be used across a DLL boundary. Since the DLL may have +// been compiled by a different compiler than the driver, we use only simple +// POD types and we only use C calling conventions. +// +// When calling a wrapper function, you must always pass in the wrapper as +// the first parameter. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef ENGINEWRAPPER_H +#define ENGINEWRAPPER_H + +#define DRV_MAX_CHAN 256 +#define DRV_MAX_LISTEN_PORTS 256 +#define DRV_ERRMSG_SIZE 8192 +#define DRV_SHORTSTRING_SIZE 65536 + +class DrivenEngine; +class PlayLogfile; +class ReplayLogfile; + +struct EngineWrapper { + char error[DRV_ERRMSG_SIZE]; + char databuffer[DRV_SHORTSTRING_SIZE]; + DrivenEngine *engine; + PlayLogfile *wlog; + ReplayLogfile *rlog; + + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + // + // CONSTRUCTION + // + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + + // Of course, there's no constructor, since this is a C struct. + // To initialize it, you use 'dlsym' or 'GetProcAddress' to get the + // address of the function 'init_engine_wrapper'. Then, you call + // the function init_engine_wrapper(&wrapper). + + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + // + // GETTERS + // + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + + // Get a list of all the listening ports. The driver is expected + // to fetch this set shortly after the event_init callback is invoked. + // + void (*get_listen_ports)(EngineWrapper *w, uint32_t *nports, const uint32_t **ports); + + // Get a list of all recently-opened channels that were created using + // new_outgoing_channel. The driver should initiate outgoing + // connections for these channels. + // + void (*get_new_outgoing)(EngineWrapper *w, uint32_t *nchanids, const uint32_t **chanids); + + // Get a string_view of the target of a channel. A target is a string like + // "cert:whatever.com:80" or "nocert:whatever.com:80". + // The first word indicate whether or not a valid SSL certificate + // is required. The second word is the hostname. The third word is + // the port number. The char string returned here is valid until + // the channel is closed. + // + const char *(*get_target)(EngineWrapper *w, uint32_t chid); + + // Return true if the user has released all references to this channel. + // In this case, the driver should initiate shutdown of the channel, + // and the driver should eventually call notify_close. + // + bool (*get_channel_released)(EngineWrapper *w, uint32_t chid); + + // Get a pointer to the bytes in the outgoing buffer. The char pointer + // returned here is naturally only valid until the buffer is changed. + // This function is used for all channels, including sockets and stdio. + // + void (*get_outgoing)(EngineWrapper *w, uint32_t chid, uint32_t *len, const char **data); + + // Return true if the outgoing buffer is empty. + // + bool (*get_outgoing_empty)(EngineWrapper *w, uint32_t chid); + + // Get the clock. + // + // Get the current time. This is equal to the last value passed + // in by invoke_event_update. + // + double (*get_clock)(EngineWrapper *w); + + // Check the 'rescan_lua_source' flag. If this flag is set, it means + // that the engine wants the driver to rescan the lua source code. + // When the driver sees this flag, it should rescan the source and call + // set_lua_source. + // + bool (*get_rescan_lua_source)(EngineWrapper *w); + + // If true, the engine is done. Stop the driver. + // + bool (*get_stop_driver)(EngineWrapper *w); + + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + // + // MUTATORS USED ONLY IN PLAY MODE + // + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + + // Create the driven engine. argc and argv allow you to specify what + // kind of engine you want. You must pass in the initial state of the lua + // source, if you have any. You may optionally also specify a replay log. + // If you don't want to create a replay log, pass a null pointer. + // + // Check to see if the error buffer contains a message after calling + // this function. + // + void (*play_initialize)(EngineWrapper *w, uint32_t argc, char **argv, uint32_t srcpklen, const char *srcpk, const char *logfn); + + // Clear the list of recently-opened channels. You are meant to fetch + // new outgoing channels using get_new_outgoing, then you call + // clear_new_outgoing after you've opened those channels. + // + void (*play_clear_new_outgoing)(EngineWrapper *w); + + // Notifies the channel that some bytes were transmitted. This causes those + // bytes to be removed from the outgoing buffer. This function is used for + // all channels, including sockets and stdio. + // + void (*play_sent_outgoing)(EngineWrapper *w, uint32_t chid, uint32_t nbytes); + + // Notifies the channel that some bytes were received. This causes those + // bytes to be appended to the incoming buffer. This function is used for + // all channels, including sockets and stdio. + // + void (*play_recv_incoming)(EngineWrapper *w, uint32_t chid, uint32_t len, const char *data); + + // Notify the channel that the connection was closed. This includes all + // sorts of closes, including friendly termination, all the way to network + // failure. Closing the channel doesn't delete it. The engine is + // responsible for noticing that the channel closed and the engine must + // delete it. Closing a channel prevents it from showing up in + // 'list_channels'. + // + void (*play_notify_close)(EngineWrapper *w, uint32_t chid, uint32_t len, const char *data); + + // Notify the DrivenEngine that somebody connected to an incoming port. + // This will cause the DrivenEngine to allocate a new channel and put the + // new channel into the incoming channels queue. Returns the new channel + // ID. The new incoming channel appears in the 'list_channels' list, + // even before the engine pops the channel from the incoming channels queue. + // + uint32_t (*play_notify_accept)(EngineWrapper *w, uint32_t port); + + // Invoke the update event. + // + // The clock value must absolutely be monotonically increasing, + // and it should roughly be equal to the number of seconds since + // the program started. + // + void (*play_invoke_event_update)(EngineWrapper *w, double clock); + + // Store the lua source code. + // + void (*play_set_lua_source)(EngineWrapper *w, uint32_t srcpklen, const char *srcpk); + + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + // + // MUTATORS USED ONLY IN REPLAY MODE + // + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + + // Begin a replay. + // + // Opens the logfile and prepares to replay the log. + // If an error occurs, the error buffer contains a message, + // and the done flag is set to true. + // + void (*replay_initialize)(EngineWrapper *w, const char *logfn); + + // Execute a single step from the replay log. + // + // Calling this when 'done' is true is a no-op. + // + void (*replay_step)(EngineWrapper *w); + + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + // + // FUNCTIONS THAT CAN BE USED AT ANY TIME + // + ////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////// + + // Restore the wrapper to its initial blank state. + // + // Note that the wrapper must have already been initialized using + // init_engine_wrapper. Otherwise, the 'release' function pointer would not + // be initialized. If writing a logfile, this stores a 'clean exit' marker + // in the logfile, indicating that the engine exited cleanly, as opposed to + // crashing. + // + // If the wrapper is already in its clear state, this is a no-op. + // + void (*release)(EngineWrapper *w); +}; + +#endif // ENGINEWRAPPER_HPP \ No newline at end of file diff --git a/luprex/core/cpp/lpxclient.cpp b/luprex/core/cpp/lpxclient.cpp index 8e113806..811d97eb 100644 --- a/luprex/core/cpp/lpxclient.cpp +++ b/luprex/core/cpp/lpxclient.cpp @@ -151,13 +151,10 @@ public: void do_work_command(const util::StringVec &words) { int reps = 10000; - int64_t t1 = util::profiling_clock(); for (int i = 0; i < reps; i++) { world_to_synchronous(); world_to_asynchronous(); } - int64_t t2 = util::profiling_clock(); - stdostream() << "Snapshot/rollback took " << ((t2-t1)/reps) << " nanosec." << std::endl; } void do_quit_command(const util::StringVec &words) { diff --git a/luprex/core/cpp/util.hpp b/luprex/core/cpp/util.hpp index 430a9401..ab0c6d74 100644 --- a/luprex/core/cpp/util.hpp +++ b/luprex/core/cpp/util.hpp @@ -224,9 +224,6 @@ using IdVector = eng::vector; eng::string ascii_tolower(std::string_view c); eng::string ascii_toupper(std::string_view c); -// Return seconds elapsed, for profiling purposes. -double profiling_clock(); - // Output a string to a stream using Lua string escaping and quoting. void quote_string(const eng::string &str, std::ostream *os); diff --git a/luprex/core/cpp/world-core.cpp b/luprex/core/cpp/world-core.cpp index 7b0c7f1b..64122f91 100644 --- a/luprex/core/cpp/world-core.cpp +++ b/luprex/core/cpp/world-core.cpp @@ -1033,3 +1033,12 @@ void World::rollback() { assert(snapshot_.empty()); } +// This is the main routine for the DLL. We have to use a registration device +// to register this main routine with DrivenEngine. DrivenEngine will then call +// it exactly once the first time that the driver initializes an EngineWrapper. +// +void engine_initialization() { + SourceDB::register_lua_builtins(); +} + +static DrivenEngineInitializerReg eireg(engine_initialization); diff --git a/luprex/core/drv/driver-common.cpp b/luprex/core/drv/driver-common.cpp new file mode 100644 index 00000000..8166fd5c --- /dev/null +++ b/luprex/core/drv/driver-common.cpp @@ -0,0 +1,488 @@ +#define POLLVEC_SIZE (DRV_MAX_CHAN + 1) + + +static void if_error_print_and_exit(const std::string_view str) { + if (!str.empty()) { + std::cerr << std::endl << "error: " << str << std::endl; + exit(1); + } +} + +class Driver { + public: + enum ChanState { + CHAN_INACTIVE, + CHAN_PLAINTEXT, + CHAN_SSL_CONNECTING, + CHAN_SSL_ACCEPTING, + CHAN_SSL_READWRITE, + }; + struct ChanInfo { + int chid; + SOCKET socket; + SSL *ssl; + + ChanState state; + uint32_t nbytes; + const char *bytes; + bool ready_now; + bool ready_on_pollin; + bool ready_on_pollout; + bool ready_on_outgoing; + uint32_t last_write_nbytes; + + bool marked_for_deletion() const { return state == CHAN_INACTIVE; } + }; + + EngineWrapper engw; + std::vector chans_; + std::map listen_sockets_; + bool read_console_recently_; + std::unique_ptr pollvec_; + std::unique_ptr chbuf_; + + sslutil::UniqueCTX ssl_server_ctx_; + sslutil::UniqueCTX ssl_client_secure_ctx_; + sslutil::UniqueCTX ssl_client_insecure_ctx_; + + void handle_listen_ports() { + uint32_t nports; const uint32_t *ports; + engw.get_listen_ports(&engw, &nports, &ports); + for (uint32_t i = 0; i < nports; i++) { + int port = ports[i]; + if (listen_sockets_.find(port) == listen_sockets_.end()) { + std::string err; + SOCKET sock = listen_on_port(port, err); + if_error_print_and_exit(err); + assert(sock != INVALID_SOCKET); + listen_sockets_[port] = sock; + } + } + } + + void handle_lua_source() { + if (engw.get_rescan_lua_source(&engw)) { + drvutil::ostringstream oss; + std::string err = drvutil::package_lua_source(".", &oss); + if_error_print_and_exit(err); + engw.play_set_lua_source(&engw, oss.size(), oss.c_str()); + } + } + + void close_channel(ChanInfo &chan, std::string_view err) { + // std::cerr << "Closing channel " << chan.chid << std::endl; + assert(chan.state != CHAN_INACTIVE); + // Close and release the SSL channel. + if (chan.ssl != nullptr) { + SSL_free(chan.ssl); + chan.ssl = nullptr; + } + // Close and release the socket. + assert(chan.socket != INVALID_SOCKET); + assert(socket_close(chan.socket) == 0); + chan.socket = INVALID_SOCKET; + // Close everything else. + engw.play_notify_close(&engw, chan.chid, err.size(), err.data()); + chan.state = CHAN_INACTIVE; + chan.chid = -1; + chan.nbytes = 0; + chan.bytes = 0; + chan.ready_now = false; + chan.ready_on_pollin = false; + chan.ready_on_pollout = false; + chan.ready_on_outgoing = false; + chan.last_write_nbytes = 0; + } + + void handle_console_output() { + while (true) { + uint32_t ndata; const char *data; + engw.get_outgoing(&engw, 0, &ndata, &data); + if (ndata == 0) break; + if (ndata > DRV_SHORTSTRING_SIZE) ndata = DRV_SHORTSTRING_SIZE; + int nwrote = console_write(data, ndata); + if (nwrote <= 0) break; + engw.play_sent_outgoing(&engw, 0, nwrote); + } + } + + void handle_console_input() { + char buffer[256]; + read_console_recently_ = false; + while (true) { + int nread = console_read(buffer, 256); + if (nread <= 0) break; + read_console_recently_ = true; + engw.play_recv_incoming(&engw, 0, nread, buffer); + } + } + + void make_channel(SOCKET sock, int chid, SSL_CTX *ctx, ChanState state) { + ChanInfo newchan; + newchan.chid = chid; + newchan.socket = sock; + newchan.ssl = SSL_new(ctx); + newchan.state = state; + newchan.nbytes = 0; + newchan.bytes = 0; + newchan.ready_now = false; + newchan.ready_on_pollin = false; + newchan.ready_on_pollout = true; + newchan.ready_on_outgoing = false; + newchan.last_write_nbytes = 0; + SSL_set_fd(newchan.ssl, newchan.socket); + // SSL_set_msg_callback(newchan.ssl, SSL_trace); + // SSL_set_msg_callback_arg(newchan.ssl, BIO_new_fp(stderr,0)); + chans_.push_back(newchan); + } + + void handle_new_outgoing_sockets() { + uint32_t nchids; const uint32_t *chids; + engw.get_new_outgoing(&engw, &nchids, &chids); + for (uint32_t i = 0; i < nchids; i++) { + uint32_t chid = chids[i]; + std::string err, cert, host, port; + const char *target = engw.get_target(&engw, chid); + drvutil::split_target(target, cert, host, port); + if (cert.empty() || host.empty() || port.empty()) { + std::string message = "invalid target: "; + message += target; + engw.play_notify_close(&engw, chid, message.size(), message.c_str()); + continue; + } + SSL_CTX *ctx = nullptr; + if (cert == "cert") { + ctx = ssl_client_secure_ctx_.get(); + } else if (cert == "nocert") { + ctx = ssl_client_insecure_ctx_.get(); + } else { + std::string message = "invalid cert rule: "; + message += target; + engw.play_notify_close(&engw, chid, message.size(), message.c_str()); + continue; + } + SOCKET sock = open_connection(host.c_str(), port.c_str(), err); + if (sock == INVALID_SOCKET) { + engw.play_notify_close(&engw, chid, err.size(), err.c_str()); + continue; + } + // std::cerr << "Opening channel " << chid << std::endl; + make_channel(sock, chid, ctx, CHAN_SSL_CONNECTING); + } + engw.play_clear_new_outgoing(&engw); + } + + void accept_connection(int port, SOCKET sock) { + std::string err; + SOCKET socket = accept_on_socket(sock, err); + if_error_print_and_exit(err); + if (socket != INVALID_SOCKET) { + uint32_t chid = engw.play_notify_accept(&engw, port); + // std::cerr << "Accepted channel " << chid << std::endl; + make_channel(socket, chid, ssl_server_ctx_.get(), CHAN_SSL_ACCEPTING); + } + } + + void advance_plaintext(ChanInfo &chan) { + std::string err; + + // Try to write plaintext to the channel. + uint32_t ndata; const char *data; + engw.get_outgoing(&engw, chan.chid, &ndata, &data); + if (ndata > 0) { + int sbytes = ndata; + if (sbytes > DRV_SHORTSTRING_SIZE) sbytes = DRV_SHORTSTRING_SIZE; + int wbytes = socket_send(chan.socket, data, sbytes, err); + if (wbytes < 0) { + close_channel(chan, err.c_str()); + } else { + engw.play_sent_outgoing(&engw, chan.chid, wbytes); + } + } + + // Try to read plaintext from the channel. + // Someday, find a way to avoid this copy. + int nrecv = socket_recv(chan.socket, chbuf_.get(), DRV_SHORTSTRING_SIZE, err); + if (nrecv < 0) { + close_channel(chan, err.c_str()); + } else { + engw.play_recv_incoming(&engw, chan.chid, nrecv, chbuf_.get()); + } + + // Update the ready-flags for next time. + chan.ready_on_outgoing = true; + chan.ready_on_pollin = true; + } + + void process_ssl_error(ChanInfo &chan, int retval) { + int error = SSL_get_error(chan.ssl, retval); + // std::cerr << "SSL error code = " << error << " "; + if (error == SSL_ERROR_WANT_READ) { + chan.ready_on_pollin = true; + } else if (error == SSL_ERROR_WANT_WRITE) { + chan.ready_on_pollout = true; + } else { + std::string error = sslutil::error_string(); + if (error == "") error = "unknown error"; + close_channel(chan, error); + } + } + + void advance_ssl_connecting(ChanInfo &chan) { + // std::cerr << "In advance_ssl_connecting" << std::endl; + int retval = SSL_connect(chan.ssl); + if (retval == 1) { + // Connection successful. + chan.state = CHAN_SSL_READWRITE; + chan.ready_now = true; + } else { + // std::cerr << "ssl_connect_error"; + process_ssl_error(chan, retval); + } + } + + void advance_ssl_accepting(ChanInfo &chan) { + // std::cerr << "In advance_ssl_accepting" << std::endl; + int retval = SSL_accept(chan.ssl); + if (retval == 1) { + // Connection successful. + chan.state = CHAN_SSL_READWRITE; + chan.ready_now = true; + } else { + process_ssl_error(chan, retval); + } + } + + void advance_ssl_readwrite(ChanInfo &chan) { + // std::cerr << "In advance_ssl_readwrite" << std::endl; + // Try to read data. + int read_result = SSL_read(chan.ssl, chbuf_.get(), DRV_SHORTSTRING_SIZE); + if (read_result > 0) { + engw.play_recv_incoming(&engw, chan.chid, read_result, chbuf_.get()); + chan.ready_now = true; + } else { + process_ssl_error(chan, read_result); + if (chan.state == CHAN_INACTIVE) return; + } + + // Try to write data. + uint32_t wbytes; + if (chan.last_write_nbytes > 0) { + wbytes = chan.last_write_nbytes; + assert(wbytes < chan.nbytes); + } else { + wbytes = chan.nbytes; + if (wbytes > 65536) wbytes = 65536; + } + if (wbytes > 0) { + int write_result = SSL_write(chan.ssl, chan.bytes, wbytes); + if (write_result > 0) { + engw.play_sent_outgoing(&engw, chan.chid, write_result); + chan.last_write_nbytes = 0; + chan.ready_on_outgoing = true; + } else { + chan.last_write_nbytes = wbytes; + process_ssl_error(chan, write_result); + if (chan.state == CHAN_INACTIVE) return; + } + } else { + chan.ready_on_outgoing = true; + } + // std::cerr << "rpi=" << chan.ready_on_pollin << ".rpo=" << + // chan.ready_on_pollout << ".rn=" << chan.ready_now << ".rog=" << + // chan.ready_on_outgoing << " "; + } + + void advance_channel(ChanInfo &chan) { + sslutil::clear_all_errors(); + switch (chan.state) { + case CHAN_PLAINTEXT: + advance_plaintext(chan); + break; + case CHAN_SSL_CONNECTING: + advance_ssl_connecting(chan); + break; + case CHAN_SSL_ACCEPTING: + advance_ssl_accepting(chan); + break; + case CHAN_SSL_READWRITE: + advance_ssl_readwrite(chan); + break; + default: + assert(false); + break; + } + } + + void handle_socket_input_output() { + std::string err; + int mstimeout = read_console_recently_ ? 100 : 1000; + + // Peek output buffers and determine channel release flags. + bool any_released = false; + for (ChanInfo &chan : chans_) { + engw.get_outgoing(&engw, chan.chid, &chan.nbytes, &chan.bytes); + if (chan.nbytes == 0) { + if (engw.get_channel_released(&engw, chan.chid)) { + close_channel(chan, ""); + any_released = true; + } + } + } + + // Delete any released channels + if (any_released) { + drvutil::remove_marked_items(chans_); + } + + // Construct the struct pollfd vector. + int pollsize = 0; + for (const auto &p : listen_sockets_) { + struct pollfd &pfd = pollvec_[pollsize++]; + pfd.fd = p.second; + pfd.events = POLLIN; + pfd.revents = 0; + } + for (const ChanInfo &chan : chans_) { + struct pollfd &pfd = pollvec_[pollsize++]; + assert(chan.socket != INVALID_SOCKET); + pfd.fd = chan.socket; + pfd.events = 0; + pfd.revents = 0; + if (chan.ready_now) mstimeout = 0; + if (chan.ready_on_pollin) pfd.events |= POLLIN; + if (chan.ready_on_pollout) pfd.events |= POLLOUT; + if (chan.ready_on_outgoing && (chan.nbytes > 0)) + pfd.events |= POLLOUT; + // std::cerr << "evt=" << pfd.events << ".nb=" << chan.nbytes << + // std::endl; + } + + // Do the poll. + socket_poll(pollvec_.get(), pollsize, mstimeout, err); + if_error_print_and_exit(err); + + // Check listening sockets. + int index = 0; + for (auto &p : listen_sockets_) { + struct pollfd &pfd = pollvec_[index++]; + if (pfd.revents & (POLLIN | POLLERR)) { + accept_connection(p.first, p.second); + } + } + + // Advance channels where possible. + for (ChanInfo &chan : chans_) { + struct pollfd &pfd = pollvec_[index++]; + bool pollin = ((pfd.revents & POLLIN) != 0); + bool pollout = ((pfd.revents & POLLOUT) != 0); + bool pollerr = ((pfd.revents & (POLLERR | POLLHUP)) != 0); + if (chan.ready_now || pollerr || + (chan.ready_on_pollin && pollin) || + (chan.ready_on_pollout && pollout) || + (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { + chan.ready_now = false; + chan.ready_on_pollin = false; + chan.ready_on_pollout = false; + chan.ready_on_outgoing = false; + advance_channel(chan); + } + chan.nbytes = 0; + chan.bytes = 0; + } + + // Delete any newly-inactive channels + drvutil::remove_marked_items(chans_); + } + + int replay_logfile(const char *fn, bool verbose) { + engw.replay_initialize(&engw, fn); + if_error_print_and_exit(engw.error); + while (engw.rlog) { + engw.replay_step(&engw); + } + if_error_print_and_exit(engw.error); + return 0; + } + + int drive(int argc, char *argv[]) { + // Remove the program name from argv. + std::string program = argv[0]; + argc -= 1; + argv += 1; + + // Load the DLL and gain access to its functions. + call_init_engine_wrapper(&engw); + + // If argv contains "replay ", do a replay, + // and then skip everything else. + if (argc >= 1) { + std::string cmd(argv[0]); + if ((cmd == "replay") || (cmd == "vreplay")) { + if (argc != 2) { + std::cerr << "usage: " << program << " replay " + << std::endl; + return 1; + } + return replay_logfile(argv[1], cmd == "vreplay"); + } + } + + // If argv contains "record ", start recording, + // and remove the "record " from argv. + std::string replaylogfn; + if (argc >= 1) { + std::string cmd = argv[0]; + if (cmd == "record") { + if (argc < 2) { + std::cerr << "The 'record' command must be followed by a filename" << std::endl; + return 1; + } + replaylogfn = argv[1]; + argc -= 2; + argv += 2; + } + } + + // Initialize state variables. + read_console_recently_ = false; + chbuf_.reset(new char[DRV_SHORTSTRING_SIZE]); + pollvec_.reset(new struct pollfd[POLLVEC_SIZE]); + + ssl_server_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE)); + ssl_client_secure_ctx_.reset(sslutil::new_context(SSL_VERIFY_PEER)); + ssl_client_insecure_ctx_.reset(sslutil::new_context(SSL_VERIFY_NONE)); + ssl_load_certificate_authorities(ssl_client_secure_ctx_.get()); + sslutil::ctx_load_dummy_cert(ssl_server_ctx_.get()); + + // Read the initial lua source code. + drvutil::ostringstream srcpak; + std::string srcpakerr = drvutil::package_lua_source(".", &srcpak); + if_error_print_and_exit(srcpakerr); + + // Initialize the engine. + engw.play_initialize(&engw, argc, argv, srcpak.size(), srcpak.c_str(), replaylogfn.c_str()); + if_error_print_and_exit(engw.error); + + // Set up listening ports. + handle_listen_ports(); + + // Main loop. + while (!engw.get_stop_driver(&engw)) { + handle_lua_source(); + handle_console_output(); + handle_new_outgoing_sockets(); + handle_socket_input_output(); + handle_console_input(); + handle_console_output(); + engw.play_invoke_event_update(&engw, drvutil::get_monotonic_clock()); + } + + // Cleanup + engw.release(&engw); + for (ChanInfo &chan : chans_) { + close_channel(chan, ""); + } + return 0; + } +}; diff --git a/luprex/core/drv/driver-linux.cpp b/luprex/core/drv/driver-linux.cpp new file mode 100644 index 00000000..2dc4e384 --- /dev/null +++ b/luprex/core/drv/driver-linux.cpp @@ -0,0 +1,259 @@ + + +#include "drvutil.hpp" +#include "sslutil.hpp" +#include "../cpp/enginewrapper.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using SOCKET=int; +const int INVALID_SOCKET = -1; + +struct termios orig_termios; + +void set_nonblocking(int fd) { + int flags = fcntl(fd, F_GETFL, 0); + assert(flags != -1); + int status = fcntl(fd, F_SETFL, flags | O_NONBLOCK); + assert(status != -1); +} + +static void disable_tty_raw() { + tcsetattr(0, TCSAFLUSH, &orig_termios); +} + +static void enable_tty_raw() { + int status = tcgetattr(0, &orig_termios); + assert(status >= 0); + atexit(disable_tty_raw); + struct termios raw = orig_termios; + raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + raw.c_lflag &= ~(ECHO | ICANON); + raw.c_oflag |= OPOST; + raw.c_cc[VMIN] = 0; + raw.c_cc[VTIME] = 0; + status = tcsetattr(0, TCSAFLUSH, &raw); + assert(status >= 0); +} + +static SOCKET open_connection(const char *host, const char *port, std::string &err) { + struct addrinfo *addrs = nullptr; + struct addrinfo *goodaddr = nullptr; + struct addrinfo hints; + SOCKET sock = INVALID_SOCKET; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_NUMERICSERV; + + err.clear(); + int status = getaddrinfo(host, port, &hints, &addrs); + if (status != 0) { + err = gai_strerror(status); + goto error_general; + } + if (addrs == nullptr) { + err = "no such host found"; + goto error_general; + } + goodaddr = addrs; + assert(goodaddr->ai_family == AF_INET); + assert(goodaddr->ai_socktype == SOCK_STREAM); + assert(goodaddr->ai_protocol == IPPROTO_TCP); + sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol); + if (sock <= 0) goto error_errno; + + set_nonblocking(sock); + + status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); + if ((status != 0) && (errno != EINPROGRESS)) goto error_errno; + + freeaddrinfo(addrs); + return sock; + +error_errno: + err = drvutil::strerror_str(errno); +error_general: + if (sock != INVALID_SOCKET) close(sock); + if (addrs != nullptr) freeaddrinfo(addrs); + return INVALID_SOCKET; +} + +static SOCKET listen_on_port(int port, std::string &err) { + int status, enable; + err.clear(); + + SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock <= 0) goto error_errno; + + enable = 1; + status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); + if (status != 0) goto error_errno; + + struct sockaddr_in server; + server.sin_family = AF_INET; + server.sin_addr.s_addr = INADDR_ANY; + server.sin_port = htons(port); + + status = bind(sock, (struct sockaddr *)&server, sizeof(server)); + if (status != 0) goto error_errno; + + status = listen(sock, 10); + if (status != 0) goto error_errno; + + set_nonblocking(sock); + return sock; + +error_errno: + err = drvutil::strerror_str(errno); + if (sock >= 0) close(sock); + return INVALID_SOCKET; +} + +static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { + err.clear(); + SOCKET chsock = accept(listen_socket, nullptr, nullptr); + if (chsock >= 0) { + set_nonblocking(chsock); + return chsock; + } else { + if ((errno != EAGAIN) && (errno != EWOULDBLOCK) && (errno != ECONNABORTED)) { + err = drvutil::strerror_str(errno); + } + return INVALID_SOCKET; + } +} + +// the return values for socket_send and socket_recv are: +// +// positive: sent or received bytes successfully +// zero: would block +// negative: channel closed, possibly cleanly or possibly with error +// +static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) { + err.clear(); + int wbytes = send(socket, bytes, nbytes, 0); + if (wbytes < 0) { + if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { + return 0; + } else { + err = drvutil::strerror_str(errno); + return -1; + } + } else { + return wbytes; + } +} + +static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) { + err.clear(); + int nrecv = recv(socket, bytes, nbytes, 0); + if (nrecv < 0) { + if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { + err = drvutil::strerror_str(errno); + return -1; + } else { + return 0; + } + } else if (nrecv == 0) { + return -1; + } else { + return nrecv; + } +} + +static int socket_close(SOCKET socket) { + return close(socket); +} + +static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) { + // socket_poll is implicitly expected to also poll stdin, + // if the OS allows that. Linux does, so we add stdin to the + // poll vector. The poll vector is required to have at + // least one free space in order to do this. + pollvec[pollcount].fd = 0; + pollvec[pollcount].events = POLLIN; + pollcount += 1; + + // Do the poll. + int status = poll(pollvec, pollcount, mstimeout); + if (status < 0) { + err = drvutil::strerror_str(errno); + return -1; + } + return 0; +} + +static int console_write(const char *bytes, int nbytes) { + return write(1, bytes, nbytes); +} + +static int console_read(char *bytes, int nbytes) { + return read(0, bytes, nbytes); +} + +// Load the DLL if it's not already loaded. Stores +// the handle in a global variable. +static void load_engine_dll() { + // Not actually implemented yet. Currently, the engine + // is linked right into the executable. +} + +static void call_init_engine_wrapper(EngineWrapper *w) { + load_engine_dll(); + using InitFn = void (*)(EngineWrapper *); + InitFn initfn = (InitFn)dlsym(RTLD_DEFAULT, "init_engine_wrapper"); + assert(initfn != nullptr); + initfn(w); +} + +static void ssl_load_certificate_authorities(SSL_CTX *ctx) { + assert(SSL_CTX_set_default_verify_paths(ctx) == 1); +} + +static void disable_randomization(int argc, char *argv[]) { + const int old_personality = personality(ADDR_NO_RANDOMIZE); + if (!(old_personality & ADDR_NO_RANDOMIZE)) { + const int new_personality = personality(ADDR_NO_RANDOMIZE); + if (new_personality & ADDR_NO_RANDOMIZE) { + execv(argv[0], argv); + } + } +} + +#include "driver-common.cpp" + + +int main(int argc, char **argv) +{ + disable_randomization(argc, argv); + enable_tty_raw(); + assert(OPENSSL_init_ssl(0, NULL) == 1); + sslutil::clear_all_errors(); + Driver driver; + return driver.drive(argc, argv); +} + diff --git a/luprex/core/drv/driver-mingw.cpp b/luprex/core/drv/driver-mingw.cpp new file mode 100644 index 00000000..0741e8fc --- /dev/null +++ b/luprex/core/drv/driver-mingw.cpp @@ -0,0 +1,264 @@ +#define WINVER 0x0600 +#define _WIN32_WINNT 0x0600 + +#include "drvutil.hpp" +#include "sslutil.hpp" +#include "../cpp/enginewrapper.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void set_nonblocking(SOCKET sock) { + u_long mode = 1; // 1 to enable non-blocking socket + int status = ioctlsocket(sock, FIONBIO, &mode); + assert(status == 0); +} + +static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) { + for (PADDRINFOA addr = addrinfo; addr != nullptr; addr = addr->ai_next) { + if (addr->ai_family == AF_INET) { + return addr; + } + } + return nullptr; +} + +static SOCKET open_connection(const char *host, const char *port, std::string &err) { + PADDRINFOA addrs = nullptr; + PADDRINFOA goodaddr = nullptr; + SOCKET sock = INVALID_SOCKET; + + err.clear(); + int status = getaddrinfo(host, port, nullptr, &addrs); + while (status == WSATRY_AGAIN) { + status = getaddrinfo(host, port, nullptr, &addrs); + } + if (status == WSAHOST_NOT_FOUND) { + err = "host not found"; + goto error; + } + if (status != 0) { + err = "DNS resolution malfunction"; + goto error; + } + goodaddr = find_good_addr(addrs); + if (goodaddr == nullptr) { + err = "host not an internet host"; + goto error; + } + sock = socket(goodaddr->ai_family, SOCK_STREAM, IPPROTO_TCP); + if (sock == INVALID_SOCKET) { + err = "could not create a socket"; + goto error; + } + set_nonblocking(sock); + status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); + if (status != 0) { + int errcode = WSAGetLastError(); + if (errcode != WSAEWOULDBLOCK) { + err = "connect failure"; + goto error; + } + } + freeaddrinfo(addrs); + return sock; + +error: + if (sock != INVALID_SOCKET) closesocket(sock); + if (addrs != nullptr) freeaddrinfo(addrs); + return SOCKET_ERROR; +} + +SOCKET listen_on_port(int port, std::string &err) { + int status; + err.clear(); + SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == INVALID_SOCKET) { + err = "could not create a socket"; + goto error; + } + + struct sockaddr_in server; + server.sin_family = AF_INET; + server.sin_addr.s_addr = INADDR_ANY; + server.sin_port = htons(port); + + status = bind(sock, (struct sockaddr *)&server, sizeof(server)); + if (status < 0) { + err = "could not bind port"; + goto error; + } + status = listen(sock, 10); + if (status < 0) { + err = "could not listen on socket"; + goto error; + } + set_nonblocking(sock); + std::cerr << "listening socket is " << sock << std::endl; + return sock; + +error: + if (sock != INVALID_SOCKET) closesocket(sock); + return SOCKET_ERROR; +} + +static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { + SOCKET chsock = accept(listen_socket, nullptr, nullptr); + if (chsock != INVALID_SOCKET) { + set_nonblocking(chsock); + return chsock; + } else { + int errcode = WSAGetLastError(); + if ((errcode == WSAEWOULDBLOCK) || (errcode == WSAECONNRESET)) { + return INVALID_SOCKET; + } else { + err = "accept failed"; + return INVALID_SOCKET; + } + } +} + +static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) { + err.clear(); + int wbytes = send(socket, bytes, nbytes, 0); + if (wbytes == SOCKET_ERROR) { + int errcode = WSAGetLastError(); + if (errcode == WSAEWOULDBLOCK) { + return 0; + } else { + err = "send failure"; + return -1; + } + } else { + assert(wbytes > 0); + return wbytes; + } +} + +static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) { + err.clear(); + int nrecv = recv(socket, bytes, nbytes, 0); + if (nrecv < 0) { + int errcode = WSAGetLastError(); + if (errcode == WSAEWOULDBLOCK) { + return 0; + } else { + err = "recv failure"; + return -1; + } + } else if (nrecv == 0) { + return -1; + } else { + return nrecv; + } +} + +static int socket_close(SOCKET socket) { + return closesocket(socket); +} + +static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) { + if (pollcount == 0) { + if (mstimeout > 0) Sleep(mstimeout); + return 0; + } + int status = WSAPoll(pollvec, pollcount, mstimeout); + if (status < 0) { + err = strerror_str(WSAGetLastError()); + return -1; + } + return status; +} + +static void init_winsock() { + WSADATA data; + int errcode = WSAStartup(2, &data); + if (errcode != 0) { + fprintf(stderr, "Winsock didn't initalize, error %d", errcode); + exit(1); + } +} + +static int console_write(const char *bytes, int nbytes) { + if (nbytes == 0) return 0; + HANDLE hstdout = GetStdHandle(STD_OUTPUT_HANDLE); + assert(hstdout != INVALID_HANDLE_VALUE); + DWORD nwrote; + if (nbytes > 10000) nbytes = 10000; + assert(WriteConsoleA(hstdout, bytes, nbytes, &nwrote, nullptr)); + assert(nwrote > 0); + return nwrote; +} + +static int console_read(char *bytes, int nbytes) { + HANDLE hstdin = GetStdHandle(STD_INPUT_HANDLE); + assert(hstdin != INVALID_HANDLE_VALUE); + INPUT_RECORD inrecords[512]; + DWORD nread, nevents; + int nascii = 0; + if (GetNumberOfConsoleInputEvents(hstdin, &nevents)) { + if (int(nevents) > nbytes) nevents = nbytes; + ReadConsoleInputA(hstdin, inrecords, nevents, &nread); + for (int i = 0; i < int(nread); i++) { + const INPUT_RECORD &inr = inrecords[i]; + if (inr.EventType != KEY_EVENT) continue; + const KEY_EVENT_RECORD &key = inr.Event.KeyEvent; + if (!key.bKeyDown) continue; + char c = key.uChar.AsciiChar; + bytes[nascii++] = c; + } + return nascii; + } else { + return 0; + } +} + +static void ssl_load_certificate_authorities(SSL_CTX *ctx) { + HCERTSTORE hStore = CertOpenSystemStoreW(0, L"ROOT"); + PCCERT_CONTEXT pContext = NULL; + X509 *x509; + X509_STORE *store = SSL_CTX_get_cert_store(ctx); + + if (!hStore) { + fprintf(stderr, "Cannot open system certificate store.\n"); + exit(1); + } + + while ((pContext = CertEnumCertificatesInStore(hStore, pContext))) { + const unsigned char *encoded_cert = pContext->pbCertEncoded; + x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + } + } + + CertCloseStore(hStore, 0); +} + +#include "driver-common.cpp" + +int main(int argc, char **argv) +{ + init_winsock(); + OPENSSL_init_ssl(0, NULL); + SourceDB::register_lua_builtins(); + Driver driver; + return driver.drive(argc, argv); +} + diff --git a/luprex/core/drv/drvutil.cpp b/luprex/core/drv/drvutil.cpp new file mode 100644 index 00000000..3b56abd7 --- /dev/null +++ b/luprex/core/drv/drvutil.cpp @@ -0,0 +1,269 @@ + +#include "drvutil.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace drvutil { + + +inline static bool ascii_isspace(char c) { + return (c==' ')||(c=='\t')||(c=='\r')||(c=='\n')||(c=='\f')||(c=='\v'); +} + +std::string_view trim(std::string_view v) { + while ((!v.empty()) && (ascii_isspace(v.front()))) { + v.remove_prefix(1); + } + while ((!v.empty()) && (ascii_isspace(v.back()))) { + v.remove_suffix(1); + } + return v; +} + +static std::string_view read_to_line(std::string_view &source) { + size_t pos = source.find('\n'); + std::string_view result; + if (pos == std::string_view::npos) { + result = source; + source = std::string_view(); + } else { + result = source.substr(0, pos); + source = source.substr(pos + 1); + } + if ((!result.empty()) && (result.back() == '\r')) { + result.remove_suffix(1); + } + return result; +} + +std::vector split_view(std::string_view v, char sep) { + std::vector result; + while (true) { + size_t pos = v.find(sep); + if (pos == std::string_view::npos) break; + result.push_back(v.substr(0, pos)); + v = v.substr(pos + 1); + } + result.push_back(v); + return result; +} + +void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port) { + std::vector split = split_view(target, ':'); + if (split.size() != 3) { + cert.clear(); host.clear(); port.clear(); + return; + } + if (split[0].empty() || split[1].empty() || split[2].empty()) { + cert.clear(); host.clear(); port.clear(); + return; + } + cert = std::string(split[0]); + host = std::string(split[1]); + port = std::string(split[2]); +} + + +static std::vector parse_control_lst(std::string_view ctrl) { + std::vector result; + while (!ctrl.empty()) { + std::string_view line = read_to_line(ctrl); + std::string_view trimmed = trim(line); + if ((trimmed.size() > 0) && (trimmed[0] != '#')) { + result.emplace_back(trimmed); + } + } + return result; +} + +// Read a source file into a string. +// +static std::string read_file(const char *fn, std::string &err) { + std::ifstream t(fn); + if (t.fail()) { + err = std::string("Could not open ") + fn; + return ""; + } + t.seekg(0, std::ios::end); + size_t size = t.tellg(); + std::string result(size, ' '); + t.seekg(0); + t.read(&result[0], size); + if ((t.fail()) || (size_t(t.tellg()) != size)) { + err = std::string("Could not read ") + fn; + return ""; + } + err = ""; + return result; +} + +// This encoding can be read by StreamBuffer::read_uint32. +// +static void sbwrite_uint32(std::ostream *s, uint32_t v) { + s->write((const char *)&v, 4); +} + +// This encoding can be read by StreamBuffer::read_uint64. +// +static void sbwrite_uint64(std::ostream *s, uint64_t v) { + s->write((const char *)&v, 8); +} + +// This encoding can be read by StreamBuffer::read_string. +// +static void sbwrite_string(std::ostream *s, std::string_view sv) { + s->put(0xFF); + sbwrite_uint64(s, sv.size()); + s->write(sv.data(), sv.size()); +} + +// This encoding can be read by StreamBuffer::read_string. +// +static bool sbwrite_file(std::ostream *s, const char *fn) { + s->put(0xFF); + uint64_t pos1 = s->tellp(); + sbwrite_uint64(s, 0); + uint64_t pos2 = s->tellp(); + std::ifstream t(fn); + if (t.fail()) { + return false; + } + *s << t.rdbuf(); + if (t.fail()) { + return false; + } + uint64_t pos3 = s->tellp(); + s->seekp(pos1); + sbwrite_uint64(s, pos3 - pos2); + s->seekp(pos3); + return true; +} + +std::string package_lua_source(const std::string &base, std::ostream *s) { + std::string err; + std::string cfn = base + "/lua/control.lst"; + std::string ctrl = read_file(cfn.c_str(), err); + if (!err.empty()) { + return err; + } + + std::vector names = parse_control_lst(ctrl); + sbwrite_uint32(s, names.size()); + for (int i = 0; i < int(names.size()); i++) { + sbwrite_string(s, names[i]); + } + for (int i = 0; i < int(names.size()); i++) { + std::string lfn = base + "/lua/" + names[i]; + if (!sbwrite_file(s, lfn.c_str())) { + return std::string("Cannot read source file: ") + lfn; + } + } + return ""; +} + +// strerror has to be the most overcomplicated function imaginable. The simple +// version, 'strerror', is not thread-safe, and the improved versions are all +// incompatible from OS to OS. Even different versions of linux aren't +// compatible. A lot of conditional compilation is needed. + +#if defined(__linux__) + +inline static void strerror_helper(int status, int errnum, char errbuf[256]) { + if (status != 0) { + snprintf(errbuf, 256, "unknown errno %d", errnum); + } +} + +inline static void strerror_helper(const char *result, int errnum, char errbuf[256]) { + if (result != errbuf) { + snprintf(errbuf, 256, "%s", result); + } +} + +void strerror_safe(int errnum, char errbuf[256]) { + auto rval = strerror_r(errnum, errbuf, 256); + strerror_helper(rval, errnum, errbuf); +} + +#elif defined(_WIN32) + +void strerror_safe(int errnum, char errbuf[256]) { + int status = strerror_s(errbuf, 256, errnum); + if (status != 0) { + snprintf(errbuf, 256, "unknown errno %d", errnum); + } +); + +#endif + +std::string strerror_str(int errnum) { + char buf[256]; + strerror_safe(errnum, buf); + return buf; +} + +// The monotonic clock is required to start at zero at initialization time, +// advance steadily, and never go backwards. It is okay, however, if it is a +// little inaccurate, or if it drifts a little over time. + +#if defined(__linux__) + + class MonoClock { + private: + struct timespec base_; + public: + MonoClock() { + int status = clock_gettime(CLOCK_MONOTONIC, &base_); + assert(status == 0); + } + double get() { + struct timespec t; + int status = clock_gettime(CLOCK_MONOTONIC, &t); + assert(status == 0); + double tv_sec = t.tv_sec - base_.tv_sec; + double tv_nsec = t.tv_nsec - base_.tv_nsec; + return tv_sec + (tv_nsec * 1.0E-9); + } + }; + +#elif defined(_WIN32) + + class MonoClock { + public: + double freq_; + LONGLONG base_; + inline LONGLONG qpc() { + LARGE_INTEGER x; + BOOL status = QueryPerformanceCounter(&x); + assert(status != 0); + return x.QuadPart; + } + MonoClock() { + LARGE_INTEGER x; + BOOL status = QueryPerformanceFrequency(&x); + assert(status != 0); + freq_ = 1.0 / double(x.QuadPart); + base_ = qpc(); + } + double get() { + return (qpc() - base) * freq_; + } + }; + +#else + #error "Only support __linux__ or _WIN32" +#endif + + +static MonoClock monoclock; +double get_monotonic_clock() { + return monoclock.get(); +} + +} // namespace drv \ No newline at end of file diff --git a/luprex/core/drv/drvutil.hpp b/luprex/core/drv/drvutil.hpp new file mode 100644 index 00000000..852d1216 --- /dev/null +++ b/luprex/core/drv/drvutil.hpp @@ -0,0 +1,99 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// DRIVER_UTIL +// +//////////////////////////////////////////////////////////////////////////////// + + +#ifndef DRVUTIL_HPP +#define DRVUTIL_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace drvutil { + +// Read the lua source from disk into an ostringstream. +// +// To pass the lua source into the DLL, here is what you do: Construct an +// ostringstream. Use package_lua_source to package all the lua source into +// the ostringstream. Fetch the packaged source code using ostringstream::str. +// Pass the packaged source code into drv_set_lua_source. +// +// The DLL must then decode the source package. Here is how it does that: +// It creates a StreamBuffer from the packaged up source. Then it must +// call these StreamBuffer methods: +// +// - read the number of source files using read_uint32. +// - for each file, read the filename using read_string. +// - for each file, read the contents using read_string. +// +// If package_lua_source encounters an error reading the source code, then it +// returns an error message. In this case, the ostream contains garbage. If +// there is no error, returns the empty string. +// +std::string package_lua_source(const std::string &base, std::ostream *oss); + +// Parse a target designation. +// +// A target consists of 'cert::host::port'. +// +void split_target(std::string_view target, std::string &cert, std::string &host, std::string &port); + +// Get a system error message, in an OS-independent manner. +// +// These versions of strerror is thread-safe, and it never fails +// to put a message into the buffer. +// +void strerror_safe(int errnum, char result[256]); +std::string strerror_str(int errnum); + +// Get the amount of time elapsed since program start. +// +// This is guaranteed to be monotonically increasing. It is not +// guaranteed to be accurate. Error could gradually accumulate over +// time. +// +double get_monotonic_clock(); + +// drvutil::ostringstream +// +// This is a variant of ostringstream in which it is possible +// to get the contents without copying. To get the contents +// without copying, use oss.size() and oss.c_str() +// +class ostringstream : public std::ostringstream { + class rstringbuf : public std::stringbuf { + public: + char *eback() { return std::streambuf::eback(); } + }; + rstringbuf rsbuf_; +public: + ostringstream() { + std::basic_ostream::rdbuf(&rsbuf_); + } + size_t size() { + return tellp(); + } + const char *c_str() { + return rsbuf_.eback(); + } +}; + +// Remove items from a vector that are marked for deletion. +// +template +void remove_marked_items(T &vec) { + auto iter = std::partition(vec.begin(), vec.end(), [] (const auto &x) { return !x.marked_for_deletion(); }); + vec.erase(iter, vec.end()); +} + + +} // namespace drvutil + +#endif // DRVUTIL_HPP diff --git a/luprex/core/drv/sslutil.cpp b/luprex/core/drv/sslutil.cpp new file mode 100644 index 00000000..33878c86 --- /dev/null +++ b/luprex/core/drv/sslutil.cpp @@ -0,0 +1,199 @@ +#include "drvutil.hpp" +#include "sslutil.hpp" +#include +#include +#include +#include + +namespace sslutil { + +const char *dummy_cert = + "-----BEGIN CERTIFICATE-----\n" + "MIIDezCCAmOgAwIBAgIUajKmxrLMr9zBMlphrTJU5qKG8FgwDQYJKoZIhvcNAQEL\n" + "BQAwTDELMAkGA1UEBhMCVVMxFTATBgNVBAgMDFBlbm5zeWx2YW5pYTESMBAGA1UE\n" + "CgwJbG9jYWxob3N0MRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMjIwMzIyMTczMzA4\n" + "WhgPMjEyMjAyMjYxNzMzMDhaMEwxCzAJBgNVBAYTAlVTMRUwEwYDVQQIDAxQZW5u\n" + "c3lsdmFuaWExEjAQBgNVBAoMCWxvY2FsaG9zdDESMBAGA1UEAwwJbG9jYWxob3N0\n" + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5OWIaKqYae4nPxvu5EP3\n" + "VilcjApYcMT4+2ypfQoB6PEep5lwguA929rNsTKnhGsEiQAZ0eZPEZN7VhUwf/hz\n" + "26jIyTT43ELkt6k97wwSZSXuT65RpSiemwEs6g2mMwzpgP6nv+yam4HjE9AKiHGN\n" + "YeTV72Nw1EN70t6IjIf4jsJRXqDJkUx5sSSD6j0WBTOhzozIDgZHTDwiLhatE66m\n" + "SNoD8oWC0PscbUgOJkFpbaCAS8RJmpsdgkTFae2rzL9cOFLGw6OgV/BV1J1s0ks8\n" + "+veoMMtIO6fese+OZ+DyQbuGaoaltZUXzY6QjD5l34m2mGplelT7BrpcqJTBHwmh\n" + "CwIDAQABo1MwUTAdBgNVHQ4EFgQUXQM5TVfJ9gpUXg8fZ8yfuUVcBP8wHwYDVR0j\n" + "BBgwFoAUXQM5TVfJ9gpUXg8fZ8yfuUVcBP8wDwYDVR0TAQH/BAUwAwEB/zANBgkq\n" + "hkiG9w0BAQsFAAOCAQEAqYX/ZGv0Qh/xdXppjnqojm8mH0giDW4tvwMqHcW3YRa3\n" + "9J2yYot+rHjU5g4n6HEmWDBE0eqLz9n3Y3fkFzT8RWZwBaST965CgsfGofyuA2hC\n" + "Ddn4Am3B5tTPmi8WWRZg8amhpGVD/mwkoVFIK0M337b1aZUJYPE+Kc9WetSL2KqB\n" + "EhqSQpkAWhVadzP85dq2T9EDjAvhlFTFlDEBx1GDUcc8M0KQ9NEvLT7LgoUcbMiT\n" + "PerlSZQTB0crchXTRSERgiwu80r7D6STn/RcPL9Fg5PkA94/d87jGbmV4sxSRsvM\n" + "z+DnJGjHrV1J/jHPrnVvVLpigBlGno3C5O/sRw3gcQ==\n" + "-----END CERTIFICATE-----\n"; + +const char *dummy_key = + "-----BEGIN PRIVATE KEY-----\n" + "MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDk5Yhoqphp7ic/\n" + "G+7kQ/dWKVyMClhwxPj7bKl9CgHo8R6nmXCC4D3b2s2xMqeEawSJABnR5k8Rk3tW\n" + "FTB/+HPbqMjJNPjcQuS3qT3vDBJlJe5PrlGlKJ6bASzqDaYzDOmA/qe/7JqbgeMT\n" + "0AqIcY1h5NXvY3DUQ3vS3oiMh/iOwlFeoMmRTHmxJIPqPRYFM6HOjMgOBkdMPCIu\n" + "Fq0TrqZI2gPyhYLQ+xxtSA4mQWltoIBLxEmamx2CRMVp7avMv1w4UsbDo6BX8FXU\n" + "nWzSSzz696gwy0g7p96x745n4PJBu4ZqhqW1lRfNjpCMPmXfibaYamV6VPsGulyo\n" + "lMEfCaELAgMBAAECggEBAJa1AiFX4U4tva1xqNKmZV1XklWqIhzts7lnDBkF08gZ\n" + "qcNT5Z5mIpR09eVropwvEidZ56Yp63l5D0XYYbyAS1gfQ0QnGot7h7fdOKgB3MK4\n" + "PLY94gfKPNN17KqWHg2SvNNv1+cn04v78xUCb0zy5tHDp5Acexdm70ohtupARElJ\n" + "LSHdS7ebsqZUFXbbM3BpPEsQLi3PrzNs1DrKkZ3rR6eMGrsDqExXx8/foi9aZKsd\n" + "BGM2/kcTJ5aY6NhSv5iqO1oK46sbMrjVW/bYNsOyl0eFjwTRahn+Zhp/JMewZYeu\n" + "715g6kzbZNwEzBLgrhNPF6E2ycEr/C6z5bE78g5QCkECgYEA8s07UUY25bjYiWWy\n" + "W38pT7d/OXBSyKnq16N6MjVahl29r7nezFiDeLhLC0QiwXu/+qyxVZkB95MMGZXS\n" + "AsaKFNis3AJ6eR4SYyhpSScYKNvlKIiW37TtR4FDcy7y5LL6tFpiDDIGH3LuyWNo\n" + "d76142MBpv5aStnLGYU3pcZj43sCgYEA8VbNM4nqgSCQcbnHYjvsgphEMNSaoVie\n" + "xob2uigXdV6Te0ayoUFBnVNKVsRhk+sswuTV4k1pK/On+USVl2tQ16tcaVMjTfSD\n" + "HLYTJLmt6s4DcywWj5dfkbDoe5PulGXNZE960qXmOC62Lf0VMRwJ5x4FBRvGTjKC\n" + "zvekI2/kO7ECgYEAhBGeclb/BXXGUvY+TgadMf9d9KBkZ0IFu8Xwcd8TnoLe6vbv\n" + "ebery75zE228egIWKwREcYsIxuH1cvVLhrb35N73J7UxaTAyUD1rB598RL1XqPSj\n" + "HIwNhReK2NxwwnWYaQHA02FiczjRKjooWPojdcwk2fEArDZLg1YzLrj7HIECgYEA\n" + "htdx1Y8ESFtyeShMv5UtoxYCW6oeL3H9XH0CE6bc3IYYLvOkULbOO2HTEkGtJ2Fp\n" + "5AbJfiS0U4tS2dI5Jp4eUDH9cxexjRfFvd/5ODbKdnver5X9kQMJsbQ/YPSZg66R\n" + "oK9Lt7Bbvh5TScSy93psCgba1SzckspkDdGNkwMsaTECgYEAnFWaxormLUpXQRLs\n" + "tKzMMHgVnHlsHiqXH432zmT2fpGZHYoWbsGuQjjrHGnSiu3QbDhnzM6y/T2GRs6z\n" + "zHteIo/tzIyxg4MvJGJ9qANA7HoiKBdQ7G/I/NLJIyWAjj+e7/hgzKFcf+dpjpDq\n" + "HcKc9a4WXhC7yu79e5BnKWltHXY=\n" + "-----END PRIVATE KEY-----\n"; + +std::string error_string() { + // Get the last code. + int code = 0; + while (true) { + int icode = ERR_get_error(); + if (icode == 0) break; + code = icode; + } + + // Fetch and clear errno. + int terrno = errno; + errno = 0; + + if (code != 0) { + const char *rc = ERR_reason_error_string(code); + if (rc != nullptr) { + return rc; + } else { + return drvutil::strerror_str(ERR_GET_REASON(code)); + } + } else if (terrno != 0) { + return drvutil::strerror_str(terrno); + } else { + return ""; + } +} + +void clear_all_errors() { + ERR_clear_error(); + errno = 0; +} + +SSL_CTX *new_context(int verify) { + SSL_CTX *ctx = SSL_CTX_new(TLS_method()); + SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); + SSL_CTX_set_verify(ctx, verify, nullptr); + return ctx; +} + +static int ctx_use_certificate_str(SSL_CTX *ctx, const char *str) { + UniqueBIO bio(BIO_new(BIO_s_mem())); + BIO_puts(bio.get(), str); + UniqueX509 certificate(PEM_read_bio_X509(bio.get(), NULL, NULL, NULL)); + return SSL_CTX_use_certificate(ctx, certificate.get()); +} + +static int ctx_use_privatekey_str(SSL_CTX *ctx, const char *str) { + UniqueBIO bio(BIO_new(BIO_s_mem())); + BIO_puts(bio.get(), str); + UniquePKEY pkey(PEM_read_bio_PrivateKey(bio.get(), NULL, NULL, NULL)); + return SSL_CTX_use_PrivateKey(ctx, pkey.get()); +} + +void ctx_load_dummy_cert(SSL_CTX *ctx) { + ERR_clear_error(); + if (ctx_use_certificate_str(ctx, dummy_cert) <= 0) { + ERR_print_errors_fp(stderr); + exit(1); + } + if (ctx_use_privatekey_str(ctx, dummy_key) <= 0) { + ERR_print_errors_fp(stderr); + exit(1); + } +} + +static int count_certificates(const char *fn) { + static char null_passwd; + ErrClearErrorOnExit ece; + UniqueBIO bio(BIO_new(BIO_s_file())); + assert(bio != nullptr); + if (BIO_read_filename(bio.get(), fn) <= 0) { + std::cerr << "Cannot open file: " << fn << std::endl; + exit(1); + } + int total = 0; + while (true) { + UniqueX509 x(PEM_read_bio_X509_AUX(bio.get(), nullptr, nullptr, &null_passwd)); + if (x == nullptr) break; + total += 1; + } + return total; +} + +static bool contains_privatekey(const char *fn) { + static char null_passwd; + ErrClearErrorOnExit ece; + UniqueBIO bio(BIO_new(BIO_s_file())); + assert(bio != nullptr); + if (BIO_read_filename(bio.get(), fn) <= 0) { + std::cerr << "Cannot open file: " << fn << std::endl; + exit(1); + } + UniquePKEY k(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, &null_passwd)); + return k != nullptr; +} + +void ctx_load_cert_from_directory(SSL_CTX *ctx, const std::string &dir) { + std::vector key_paths; + std::vector cert_paths; + + for (const auto & entry : std::filesystem::directory_iterator(dir)) { + std::string fn = entry.path(); + if (count_certificates(fn.c_str()) >= 1) { + cert_paths.push_back(fn); + } + if (contains_privatekey(fn.c_str())) { + key_paths.push_back(fn); + } + } + + if (cert_paths.size() > 1) { + std::cerr << "Directory contains multiple certs: " << dir << std::endl; + exit(1); + } + if (key_paths.size() > 1) { + std::cerr << "Directory contains multiple keys: " << dir << std::endl; + exit(1); + } + if (cert_paths.empty()) { + std::cerr << "Directory doesn't contain a cert: " << dir << std::endl; + exit(1); + } + if (key_paths.empty()) { + std::cerr << "Directory doesn't contain a key: " << dir << std::endl; + exit(1); + } + + int status; + status = SSL_CTX_use_PrivateKey_file(ctx, key_paths[0].c_str(), SSL_FILETYPE_PEM); + assert(status == 1); + status = SSL_CTX_use_certificate_chain_file(ctx, cert_paths[0].c_str()); + assert(status == 1); +} + +} // namespace sslutil + diff --git a/luprex/core/drv/sslutil.hpp b/luprex/core/drv/sslutil.hpp new file mode 100644 index 00000000..a59bacb2 --- /dev/null +++ b/luprex/core/drv/sslutil.hpp @@ -0,0 +1,61 @@ +#ifndef SSLUTIL_HPP +#define SSLUTIL_HPP + +#include "drvutil.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace sslutil { + +struct SSL_Deleter { + void operator()(SSL *ssl) { SSL_free(ssl); } +}; + +struct CTX_Deleter { + void operator()(SSL_CTX *ctx) { SSL_CTX_free(ctx); } +}; + +struct BIO_Deleter { + void operator()(BIO *bio) { BIO_free(bio); } +}; + +struct X509_Deleter { + void operator()(X509 *x) { X509_free(x); } +}; + +struct PKEY_Deleter { + void operator()(EVP_PKEY *p) { EVP_PKEY_free(p); } +}; + +using UniqueSSL = std::unique_ptr; +using UniqueCTX = std::unique_ptr; +using UniqueBIO = std::unique_ptr; +using UniqueX509 = std::unique_ptr; +using UniquePKEY = std::unique_ptr; + +struct ErrClearErrorOnExit { + ~ErrClearErrorOnExit() { + ERR_clear_error(); + } +}; + +// Return the OpenSSL error as a string. +std::string error_string(); +void clear_all_errors(); +SSL_CTX *new_context(int verify); +void ctx_load_dummy_cert(SSL_CTX *ctx); +void ctx_load_cert_from_directory(SSL_CTX *ctx, const std::string &dir); + +} // namespace sslutil + +#endif // SSLUTIL_HPP + diff --git a/luprex/core/lua/control.lst b/luprex/core/lua/control.lst index 66b051de..f41943a2 100644 --- a/luprex/core/lua/control.lst +++ b/luprex/core/lua/control.lst @@ -9,4 +9,3 @@ ut-tablecmp.lua basics.lua uglyglobals.lua login.lua -spectra.lua diff --git a/luprex/core/lobj/.gitkeep b/luprex/core/obj/cpp/.gitkeep similarity index 100% rename from luprex/core/lobj/.gitkeep rename to luprex/core/obj/cpp/.gitkeep diff --git a/luprex/core/obj/drv/.gitkeep b/luprex/core/obj/drv/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/luprex/core/obj/lua/.gitkeep b/luprex/core/obj/lua/.gitkeep new file mode 100644 index 00000000..e69de29b