diff --git a/luprex/core/cpp/drivenengine.cpp b/luprex/core/cpp/drivenengine.cpp index ab1a9182..0ed021e5 100644 --- a/luprex/core/cpp/drivenengine.cpp +++ b/luprex/core/cpp/drivenengine.cpp @@ -37,21 +37,25 @@ Channel *DrivenEngine::get_chid(int chid) { return channels_[chid]; } +void DrivenEngine::listen_port(int port) { + listen_ports_.insert(port); +} + double DrivenEngine::get_clock() { return clock_; } -std::unique_ptr DrivenEngine::new_outgoing_channel(const std::string &target) { +UniqueChannel DrivenEngine::new_outgoing_channel(const std::string &target) { int chid = find_unused_chid(); new_outgoing_.insert(chid); - return std::unique_ptr(new Channel(this, chid, 0, target)); + return UniqueChannel(new Channel(this, chid, 0, target)); } -std::unique_ptr DrivenEngine::new_incoming_channel() { +UniqueChannel DrivenEngine::new_incoming_channel() { if (accepted_channels_.empty()) { return nullptr; } else { - std::unique_ptr result = std::move(accepted_channels_.back()); + UniqueChannel result = std::move(accepted_channels_.back()); accepted_channels_.pop_back(); return std::move(result); } @@ -73,6 +77,10 @@ void DrivenEngine::stop_driver() { stop_driver_ = true; } +void DrivenEngine::drv_get_listen_ports(std::set &ports) { + ports = listen_ports_; +} + void DrivenEngine::drv_get_new_closed(std::set &channels) { channels = std::move(new_closed_); new_closed_.clear(); diff --git a/luprex/core/cpp/drivenengine.hpp b/luprex/core/cpp/drivenengine.hpp index 95608847..7af8e3da 100644 --- a/luprex/core/cpp/drivenengine.hpp +++ b/luprex/core/cpp/drivenengine.hpp @@ -108,6 +108,10 @@ public: StreamBuffer *out() { return sb_out_.get(); } StreamBuffer *in() { return sb_in_.get(); } + // The channel ID. These are reused. + // + int chid() { return chid_; } + // If this is a socket connection, the receiver's port number. // int port() { return port_; } @@ -140,6 +144,7 @@ private: std::string target_; friend class DrivenEngine; }; +using UniqueChannel = std::unique_ptr; class DrivenEngine { public: @@ -155,6 +160,11 @@ public: virtual void event_init() {} virtual void event_update() {} + // Specify the set of listening ports. + // This can only be used during the init routine. + // + void listen_port(int port); + // Get the current time. // // DRIVER: This returns the time most recently stored by the driver @@ -169,7 +179,7 @@ public: // actually opening the connection and relaying data into the channel using // drv_get_target, drv_peek_outgoing, drv_sent_outgoing, drv_recv_incoming. // - std::unique_ptr new_outgoing_channel(const std::string &target); + UniqueChannel new_outgoing_channel(const std::string &target); // Create a new channel from any pending incoming connection. If there is no // incoming connection, returns nullptr. @@ -182,7 +192,7 @@ public: // using drv_get_target, drv_peek_outgoing, drv_sent_outgoing, // drv_recv_incoming. // - std::unique_ptr new_incoming_channel(); + UniqueChannel new_incoming_channel(); // Obtain the stdio channel. There is only one stdio channel. It is owned // by the DrivenEngine. It is an error to delete the stdio channel. @@ -223,6 +233,11 @@ public: // static const int MAX_CHAN = 256; + // Get a list of all the listening ports. The driver is expected + // to fetch this set shortly after the event_init callback is invoked. + // + void drv_get_listen_ports(std::set &ports); + // Get a list of all recently-closed channels. The driver should // discard all socket information associated with these channels. // Caution: this may contain channels that the driver has never @@ -348,11 +363,12 @@ private: private: Channel *channels_[MAX_CHAN]; int next_unused_chid_; - std::unique_ptr stdio_channel_; - std::vector> accepted_channels_; + UniqueChannel stdio_channel_; + std::vector accepted_channels_; std::set new_closed_; std::set new_outgoing_; util::LuaSourcePtr lua_source_; + std::set listen_ports_; bool rescan_lua_source_; double clock_; bool stop_driver_; diff --git a/luprex/core/cpp/driver-mingw.cpp b/luprex/core/cpp/driver-mingw.cpp index a9e43ec3..03136a9a 100644 --- a/luprex/core/cpp/driver-mingw.cpp +++ b/luprex/core/cpp/driver-mingw.cpp @@ -22,6 +22,7 @@ public: bool engine_wakeup_; char console_line_[CONSOLE_MAX + 1]; int console_len_; + std::map listen_sockets_; std::unique_ptr chbuf; static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) { @@ -33,12 +34,17 @@ public: return nullptr; } + void set_nonblocking(SOCKET sock) { + u_long mode = 1; // 1 to enable non-blocking socket + int status = ioctlsocket(sock, FIONBIO, &mode); + assert(status == 0); + } + SOCKET open_connection(const std::string &target, std::string &err) { PADDRINFOA addrs = nullptr; PADDRINFOA goodaddr = nullptr; SOCKET sock = INVALID_SOCKET; std::string host, port; - u_long mode = 1; // 1 to enable non-blocking socket err = ""; util::split_host_port(target, host, port); @@ -62,8 +68,7 @@ public: } sock = socket(goodaddr->ai_family, SOCK_STREAM, IPPROTO_TCP); assert(sock != INVALID_SOCKET); - status = ioctlsocket(sock, FIONBIO, &mode); - assert(status == 0); + set_nonblocking(sock); status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); if (status != 0) { int errcode = WSAGetLastError(); @@ -83,6 +88,25 @@ public: return SOCKET_ERROR; } + SOCKET listen_on_port(int port, std::string &err) { + int status; + err = ""; + SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); + assert(sock != INVALID_SOCKET); + + struct sockaddr_in server; + server.sin_family = AF_INET; + server.sin_addr.s_addr = INADDR_ANY; + server.sin_port = htons(port); + + status = bind(sock, (struct sockaddr *)&server, sizeof(server)); + assert(status == 0); + status = listen(sock, 10); + assert(status == 0); + set_nonblocking(sock); + return sock; + } + void init(DrivenEngine *de) { driven_ = de; for (int i = 0; i < MAX_CHAN; i++) { @@ -94,6 +118,20 @@ public: chbuf.reset(new char[65536]); } + void handle_listen_ports() { + std::set listenports; + driven_->drv_get_listen_ports(listenports); + for (int port : listenports) { + if (listen_sockets_.find(port) == listen_sockets_.end()) { + std::string err; + SOCKET sock = listen_on_port(port, err); + if (sock != INVALID_SOCKET) { + listen_sockets_[port] = sock; + } + } + } + } + void handle_lua_source() { if (driven_->drv_get_rescan_lua_source()) { driven_->drv_set_lua_source(util::read_lua_source("lua")); @@ -190,11 +228,24 @@ public: void handle_clock() { } + void close_socket(int chid) { + assert(socket_[chid] != INVALID_SOCKET); + assert(closesocket(socket_[chid]) == 0); + driven_->drv_notify_close(chid); + socket_[chid] = INVALID_SOCKET; + connected_[chid] = false; + engine_wakeup_ = true; + } + bool calc_select_sets(fd_set &rfds, fd_set &wfds, fd_set &efds) const { FD_ZERO(&rfds); FD_ZERO(&wfds); FD_ZERO(&efds); bool any = false; + for (const auto &p : listen_sockets_) { + FD_SET(p.second, &rfds); + any = true; + } for (int chid = 1; chid < MAX_CHAN; chid++) { SOCKET sock = socket_[chid]; if (sock == INVALID_SOCKET) continue; @@ -207,13 +258,30 @@ public: return any; } - void close_socket(int chid) { - assert(socket_[chid] != INVALID_SOCKET); - assert(closesocket(socket_[chid]) == 0); - driven_->drv_notify_close(chid); - socket_[chid] = INVALID_SOCKET; - connected_[chid] = false; - engine_wakeup_ = true; + void accept_connections(int port, SOCKET sock) { + while (true) { + SOCKET chsock = accept(sock, nullptr, nullptr); + if (chsock != INVALID_SOCKET) { + int chid = driven_->drv_notify_accept(port); + socket_[chid] = chsock; + connected_[chid] = true; + engine_wakeup_ = true; + continue; + } + int errcode = WSAGetLastError(); + if (errcode == WSAEWOULDBLOCK) { + return; + } + if (errcode == WSAECONNRESET) { + // The remote disconnected before we had a chance to accept. + // Just pretend it never happened. + continue; + } + // If a listening port fails in a non-transient way, + // we don't really have any good way of handling + // that. + assert(false); + } } void handle_socket_input_output(int mstimeout) { @@ -229,6 +297,13 @@ public: timeout.tv_usec = (mstimeout - (timeout.tv_sec*1000)) * 1000; int status = select(1, &rfds, &wfds, &efds, &timeout); assert(status != SOCKET_ERROR); + + for (auto &p : listen_sockets_) { + if (FD_ISSET(p.second, &rfds)) { + accept_connections(p.first, p.second); + } + } + for (int chid = 1; chid < MAX_CHAN; chid++) { SOCKET sock = socket_[chid]; if (sock == INVALID_SOCKET) continue; @@ -267,6 +342,7 @@ public: DrivenEngine::set(de); driven_->drv_set_lua_source(util::read_lua_source("lua")); driven_->drv_invoke_event_init(); + handle_listen_ports(); while (!de->drv_get_stop_driver()) { engine_wakeup_ = false; handle_lua_source(); diff --git a/luprex/core/cpp/main.cpp b/luprex/core/cpp/main.cpp index 027fbef4..beb6d1c5 100644 --- a/luprex/core/cpp/main.cpp +++ b/luprex/core/cpp/main.cpp @@ -5,26 +5,45 @@ class TNTest : public DrivenEngine { public: - std::unique_ptr chan_; + std::vector channels_; virtual void event_init() { - chan_ = new_outgoing_channel("stanford.edu:80"); - chan_->out()->write_bytes("GET /index.html HTTP/1.1\n\n"); + // UniqueChannel ch = new_outgoing_channel("stanford.edu:80"); + // ch->out()->write_bytes("GET /index.html HTTP/1.1\n\n"); + // channels_.emplace_back(std::move(ch)); + listen_port(8085); } - virtual void event_update() { - std::string input = get_stdio_channel()->in()->read_entire_contents(); - if (input != "") { - get_stdio_channel()->out()->write_bytes("stdin: "); - get_stdio_channel()->out()->write_bytes(input); + + void dump_lines(StreamBuffer *in, StreamBuffer *out, int chid) { + while (true) { + std::string l = in->readline(); + if (l == "") break; + std::ostringstream oss; + oss << "Chan " << chid << ": " << l; + out->write_bytes(oss.str()); } - if (chan_ != nullptr) { - if (chan_->closed()) { - get_stdio_channel()->out()->write_bytes("Connection closed.\n"); - chan_.reset(); + } + + virtual void event_update() { + while (true) { + UniqueChannel ch = new_incoming_channel(); + if (ch == nullptr) break; + channels_.emplace_back(std::move(ch)); + } + + Channel *stdioch = get_stdio_channel(); + dump_lines(stdioch->in(), stdioch->out(), 0); + std::vector keep; + for (UniqueChannel &ch : channels_) { + dump_lines(ch->in(), stdioch->out(), ch->chid()); + if (ch->closed()) { + std::ostringstream oss; + oss << "Chan " << ch->chid() << " closed.\n"; + stdioch->out()->write_bytes(oss.str()); } else { - chan_->in()->copy_into(get_stdio_channel()->out()); - chan_->in()->clear(); + keep.emplace_back(std::move(ch)); } } + channels_ = std::move(keep); } };