#include "driver.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using SOCKET=int; const int INVALID_SOCKET = -1; using SocketVector = std::vector; using PollVector = std::vector; struct termios orig_termios; void set_nonblocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); assert(flags != -1); int status = fcntl(fd, F_SETFL, flags | O_NONBLOCK); assert(status != -1); } static void disableRawMode() { tcsetattr(0, TCSAFLUSH, &orig_termios); } static void enableRawMode() { int status = tcgetattr(0, &orig_termios); assert(status >= 0); atexit(disableRawMode); struct termios raw = orig_termios; raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); raw.c_lflag &= ~(ECHO | ICANON); raw.c_oflag |= OPOST; raw.c_cc[VMIN] = 0; raw.c_cc[VTIME] = 0; status = tcsetattr(0, TCSAFLUSH, &raw); assert(status >= 0); } SOCKET open_connection(const std::string &target, std::string &err) { struct addrinfo *addrs = nullptr; struct addrinfo *goodaddr = nullptr; struct addrinfo hints; SOCKET sock = INVALID_SOCKET; std::string host, port; char errbuf[1024]; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; hints.ai_flags = AI_NUMERICSERV; err = ""; util::split_host_port(target, host, port); int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs); if (status != 0) { err = gai_strerror(status); goto error; } if (addrs == nullptr) { err = "no such host found"; goto error; } goodaddr = addrs; assert(goodaddr->ai_family == AF_INET); assert(goodaddr->ai_socktype == SOCK_STREAM); assert(goodaddr->ai_protocol == IPPROTO_TCP); sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol); assert(sock > 0); set_nonblocking(sock); status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); if ((status != 0) && (errno != EINPROGRESS)) { goto error_errno; } freeaddrinfo(addrs); return sock; error_errno: err = strerror_r(errno, errbuf, 1024); error: if (sock != INVALID_SOCKET) close(sock); if (addrs != nullptr) freeaddrinfo(addrs); return INVALID_SOCKET; } SOCKET listen_on_port(int port, std::string &err) { int status; err = ""; SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); assert(sock > 0); int enable = 1; status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); assert(status == 0); struct sockaddr_in server; server.sin_family = AF_INET; server.sin_addr.s_addr = INADDR_ANY; server.sin_port = htons(port); status = bind(sock, (struct sockaddr *)&server, sizeof(server)); assert(status == 0); status = listen(sock, 10); assert(status == 0); set_nonblocking(sock); return sock; } SocketVector accept_on_socket(SOCKET listen_socket) { SocketVector result; while (true) { SOCKET chsock = accept(listen_socket, nullptr, nullptr); if (chsock >= 0) { result.push_back(chsock); } else { if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { // Normal completion - we're out of incoming sockets. return result; } else if (errno == ECONNABORTED) { // The remote disconnected before we had a chance to accept. // Just pretend it never happened. } else { // Unexpected error. assert(false); } } } } class MonoClock { private: struct timespec base_; public: MonoClock() { int status = clock_gettime(CLOCK_MONOTONIC, &base_); assert(status == 0); } double get() { struct timespec t; int status = clock_gettime(CLOCK_MONOTONIC, &t); assert(status == 0); double tv_sec = t.tv_sec - base_.tv_sec; double tv_nsec = t.tv_nsec - base_.tv_nsec; return tv_sec + (tv_nsec * 1.0E-9); } }; static MonoClock monoclock; namespace util { double profiling_clock() { return monoclock.get(); } } class Driver { public: enum ChanState { CHAN_INACTIVE, CHAN_CONNECTING, CHAN_OPEN, }; struct ChanInfo { int chid; ChanState state; SOCKET socket; }; DrivenEngine *driven_; std::vector chans_; bool any_inactive_; bool short_sleep_; std::map listen_sockets_; std::unique_ptr chbuf; void init(DrivenEngine *de) { driven_ = de; any_inactive_ = false; short_sleep_ = false; chbuf.reset(new char[65536]); } void handle_listen_ports() { std::set listenports; driven_->drv_get_listen_ports(listenports); for (int port : listenports) { if (listen_sockets_.find(port) == listen_sockets_.end()) { std::string err; SOCKET sock = listen_on_port(port, err); if (sock != INVALID_SOCKET) { listen_sockets_[port] = sock; } } } } void handle_lua_source() { if (driven_->drv_get_rescan_lua_source()) { driven_->drv_set_lua_source(util::read_lua_source("lua")); short_sleep_ = true; } } void close_channel(ChanInfo &chan, const std::string &err) { assert(chan.state != CHAN_INACTIVE); assert(close(chan.socket) == 0); driven_->drv_notify_close(chan.chid, err); chan.state = CHAN_INACTIVE; chan.socket = INVALID_SOCKET; chan.chid = -1; any_inactive_ = true; short_sleep_ = true; } void cleanup_channels() { if (any_inactive_) { for (int i = 0; i < int(chans_.size()); ) { if (chans_[i].state == CHAN_INACTIVE) { chans_[i] = chans_.back(); chans_.pop_back(); } else { i += 1; } } any_inactive_ = false; } } void handle_released_channels() { for (ChanInfo &chan : chans_) { if (driven_->drv_get_channel_released(chan.chid)) { close_channel(chan, ""); } } cleanup_channels(); } void handle_new_outgoing_sockets() { std::set chans; driven_->drv_get_new_outgoing(chans); for (int chid : chans) { std::string err; SOCKET sock = open_connection(driven_->drv_get_target(chid), err); if (sock == INVALID_SOCKET) { driven_->drv_notify_close(chid, err); short_sleep_ = true; } else { ChanInfo newchan; newchan.chid = chid; newchan.state = CHAN_CONNECTING; newchan.socket = sock; chans_.push_back(newchan); } } } void handle_console_output() { while (true) { int nbytes; const char *bytes; driven_->drv_peek_outgoing(0, &nbytes, &bytes); if (nbytes == 0) break; int nwrote = write(1, bytes, nbytes); assert(nwrote > 0); driven_->drv_sent_outgoing(0, nwrote); } } void handle_console_input() { char buffer[256]; while (true) { int nread = read(0, buffer, 256); if (nread == 0) break; assert(nread > 0); driven_->drv_recv_incoming(0, nread, buffer); } } void accept_connections(int port, SOCKET sock) { SocketVector sockets = accept_on_socket(sock); for (SOCKET sock : sockets) { int chid = driven_->drv_notify_accept(port); ChanInfo newchan; newchan.chid = chid; newchan.state = CHAN_OPEN; newchan.socket = sock; chans_.push_back(newchan); short_sleep_ = true; } } int calc_select_sets(fd_set &rfds, fd_set &wfds, fd_set &efds) const { FD_ZERO(&rfds); FD_ZERO(&wfds); FD_ZERO(&efds); int largest = -1; for (const auto &p : listen_sockets_) { FD_SET(p.second, &rfds); FD_SET(p.second, &efds); if (p.second > largest) largest = p.second; } for (const ChanInfo &chan : chans_) { SOCKET sock = chan.socket; if (sock == INVALID_SOCKET) continue; FD_SET(sock, &rfds); FD_SET(sock, &efds); if (!driven_->drv_outgoing_empty(chan.chid)) { FD_SET(sock, &wfds); } if (sock > largest) largest = sock; } return largest + 1; } void handle_socket_input_output(int mstimeout) { // Construct the pollfd vector. std::vector pollvec; pollvec.resize(listen_sockets_.size() + chans_.size() + 1); int index = 0; for (const auto &p : listen_sockets_) { struct pollfd &pfd = pollvec[index++]; pfd.fd = p.second; pfd.events = POLLIN; } for (const ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec[index++]; pfd.fd = chan.socket; pfd.events = POLLIN; if (!driven_->drv_outgoing_empty(chan.chid)) { pfd.events |= POLLOUT; } } struct pollfd &stdiopoll = pollvec[index++]; stdiopoll.fd = 0; stdiopoll.events = POLLIN; // Do the poll. int status = poll(&pollvec[0], pollvec.size(), mstimeout); assert(status >= 0); // Check listening sockets. index = 0; for (auto &p : listen_sockets_) { struct pollfd &pfd = pollvec[index++]; if (pfd.revents & (POLLIN | POLLERR)) { accept_connections(p.first, p.second); } } // Transfer bytes wherever possible. for (ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec[index++]; SOCKET sock = chan.socket; if (sock == INVALID_SOCKET) continue; if (pfd.revents & POLLOUT) { int nbytes; const char *bytes; chan.state = CHAN_OPEN; driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes); if (nbytes > 0) { int wbytes = send(sock, bytes, nbytes, 0); if (wbytes < 0) { close_channel(chan, "send failure"); continue; } else { driven_->drv_sent_outgoing(chan.chid, wbytes); } } } if (pfd.revents & (POLLIN | POLLERR)) { // Someday, find a way to avoid this copy. int nrecv = recv(sock, chbuf.get(), 65536, 0); if (nrecv <= 0) { close_channel(chan, "recv failure"); continue; } else { driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get()); short_sleep_ = true; } } } // Delete any newly-inactive channels cleanup_channels(); } void drive(DrivenEngine *de, int argc, char *argv[]) { enableRawMode(); init(de); DrivenEngine::set(de); driven_->drv_set_lua_source(util::read_lua_source("lua")); driven_->drv_invoke_event_init(argc, argv); handle_listen_ports(); while (!de->drv_get_stop_driver()) { short_sleep_ = false; handle_lua_source(); handle_console_output(); handle_console_input(); handle_console_output(); handle_released_channels(); handle_new_outgoing_sockets(); int mstimeout = short_sleep_ ? 0 : 100; handle_socket_input_output(mstimeout); driven_->drv_set_clock(monoclock.get()); de->drv_invoke_event_update(); } DrivenEngine::set(nullptr); } }; void driver_drive(DrivenEngine *de, int argc, char *argv[]) { Driver driver; driver.drive(de, argc, argv); }