#include "driver.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using SOCKET=int; const int INVALID_SOCKET = -1; using SocketVector = std::vector; using PollVector = std::vector; struct termios orig_termios; void set_nonblocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); assert(flags != -1); int status = fcntl(fd, F_SETFL, flags | O_NONBLOCK); assert(status != -1); } static void disableRawMode() { tcsetattr(0, TCSAFLUSH, &orig_termios); } static void enableRawMode() { int status = tcgetattr(0, &orig_termios); assert(status >= 0); atexit(disableRawMode); struct termios raw = orig_termios; raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); raw.c_lflag &= ~(ECHO | ICANON); raw.c_oflag |= OPOST; raw.c_cc[VMIN] = 0; raw.c_cc[VTIME] = 0; status = tcsetattr(0, TCSAFLUSH, &raw); assert(status >= 0); } SOCKET open_connection(const std::string &target, std::string &err) { struct addrinfo *addrs = nullptr; struct addrinfo *goodaddr = nullptr; struct addrinfo hints; SOCKET sock = INVALID_SOCKET; std::string host, port; char errbuf[1024]; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; hints.ai_flags = AI_NUMERICSERV; err = ""; util::split_host_port(target, host, port); int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs); if (status != 0) { err = gai_strerror(status); goto error; } if (addrs == nullptr) { err = "no such host found"; goto error; } goodaddr = addrs; assert(goodaddr->ai_family == AF_INET); assert(goodaddr->ai_socktype == SOCK_STREAM); assert(goodaddr->ai_protocol == IPPROTO_TCP); sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol); assert(sock > 0); set_nonblocking(sock); status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); if ((status != 0) && (errno != EINPROGRESS)) { goto error_errno; } freeaddrinfo(addrs); return sock; error_errno: err = strerror_r(errno, errbuf, 1024); error: if (sock != INVALID_SOCKET) close(sock); if (addrs != nullptr) freeaddrinfo(addrs); return INVALID_SOCKET; } SOCKET listen_on_port(int port, std::string &err) { int status; err = ""; SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); assert(sock > 0); int enable = 1; status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); assert(status == 0); struct sockaddr_in server; server.sin_family = AF_INET; server.sin_addr.s_addr = INADDR_ANY; server.sin_port = htons(port); status = bind(sock, (struct sockaddr *)&server, sizeof(server)); assert(status == 0); status = listen(sock, 10); assert(status == 0); set_nonblocking(sock); return sock; } SocketVector accept_on_socket(SOCKET listen_socket) { SocketVector result; while (true) { SOCKET chsock = accept(listen_socket, nullptr, nullptr); if (chsock >= 0) { set_nonblocking(chsock); result.push_back(chsock); } else { if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { // Normal completion - we're out of incoming sockets. return result; } else if (errno == ECONNABORTED) { // The remote disconnected before we had a chance to accept. // Just pretend it never happened. } else { // Unexpected error. assert(false); } } } } SSL_CTX *new_ssl_context(bool server_cert, bool root_certs, const std::string &require_cert) { SSL_CTX *ctx = SSL_CTX_new(TLS_method()); SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr); // server_cert is not implemented yet. if (root_certs) { SSL_CTX_set_default_verify_paths(ctx); SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL); } // require_cert is not implemented yet. return ctx; } class MonoClock { private: struct timespec base_; public: MonoClock() { int status = clock_gettime(CLOCK_MONOTONIC, &base_); assert(status == 0); } double get() { struct timespec t; int status = clock_gettime(CLOCK_MONOTONIC, &t); assert(status == 0); double tv_sec = t.tv_sec - base_.tv_sec; double tv_nsec = t.tv_nsec - base_.tv_nsec; return tv_sec + (tv_nsec * 1.0E-9); } }; static MonoClock monoclock; namespace util { double profiling_clock() { return monoclock.get(); } } class Driver { public: enum ChanState { CHAN_INACTIVE, CHAN_PLAINTEXT, CHAN_SSL_CONNECTING, CHAN_SSL_ACCEPTING, CHAN_SSL_READWRITE, CHAN_SSL_SHUTDOWN }; struct ChanInfo { int chid; SOCKET socket; SSL_CTX *ssl_ctx; SSL *ssl; ChanState state; int nbytes; const char *bytes; bool released; bool just_released; bool ready_now; bool ready_on_pollin; bool ready_on_pollout; bool ready_on_outgoing; }; DrivenEngine *driven_; std::vector chans_; bool any_inactive_; std::map listen_sockets_; std::unique_ptr chbuf; SSL_CTX *ssl_ctx_with_root_certs_; SSL_CTX *ssl_ctx_with_server_certs_; SSL_CTX *ssl_ctx_with_no_certs_; void handle_listen_ports() { std::set listenports; driven_->drv_get_listen_ports(listenports); for (int port : listenports) { if (listen_sockets_.find(port) == listen_sockets_.end()) { std::string err; SOCKET sock = listen_on_port(port, err); if (sock != INVALID_SOCKET) { listen_sockets_[port] = sock; } } } } void handle_lua_source() { if (driven_->drv_get_rescan_lua_source()) { driven_->drv_set_lua_source(util::read_lua_source("lua")); } } void close_channel(ChanInfo &chan, const std::string &err) { // std::cerr << "Closing channel " << chan.chid << std::endl; assert(chan.state != CHAN_INACTIVE); // Close the SSL channel. if (chan.ssl != nullptr) { SSL_free(chan.ssl); chan.ssl = nullptr; } // Close the SSL_CTX if (chan.ssl_ctx != nullptr) { SSL_CTX_free(chan.ssl_ctx); chan.ssl_ctx = nullptr; } // Close the socket. assert(chan.socket != INVALID_SOCKET); assert(close(chan.socket) == 0); chan.socket = INVALID_SOCKET; // Close everything else. driven_->drv_notify_close(chan.chid, err); chan.state = CHAN_INACTIVE; chan.chid = -1; chan.nbytes = 0; chan.bytes = 0; chan.released = false; chan.just_released = false; chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; chan.ready_on_outgoing = false; // Set global variables. any_inactive_ = true; } void cleanup_channels() { if (any_inactive_) { for (int i = 0; i < int(chans_.size()); ) { if (chans_[i].state == CHAN_INACTIVE) { chans_[i] = chans_.back(); chans_.pop_back(); } else { i += 1; } } any_inactive_ = false; } } void handle_console_output() { while (true) { int nbytes; const char *bytes; driven_->drv_peek_outgoing(0, &nbytes, &bytes); if (nbytes == 0) break; int nwrote = write(1, bytes, nbytes); assert(nwrote > 0); driven_->drv_sent_outgoing(0, nwrote); } } void handle_console_input() { char buffer[256]; while (true) { int nread = read(0, buffer, 256); if (nread == 0) break; assert(nread > 0); driven_->drv_recv_incoming(0, nread, buffer); } } ChanInfo make_channel(SOCKET sock, int chid, SSL_CTX *ctx, SSL *ssl, ChanState state) { ChanInfo newchan; newchan.chid = chid; newchan.socket = sock; newchan.ssl_ctx = ctx; newchan.ssl = ssl; newchan.state = state; newchan.nbytes = 0; newchan.bytes = 0; newchan.released = false; newchan.just_released = false; newchan.ready_now = true; newchan.ready_on_pollin = false; newchan.ready_on_pollout = false; newchan.ready_on_outgoing = false; return newchan; } void handle_new_outgoing_sockets() { std::set chans; driven_->drv_get_new_outgoing(chans); for (int chid : chans) { std::string err; SOCKET sock = open_connection(driven_->drv_get_target(chid), err); if (sock == INVALID_SOCKET) { driven_->drv_notify_close(chid, err); } else { //std::cerr << "Opening channel " << chid << std::endl; SSL_CTX *ctx = nullptr; SSL *ssl = SSL_new(ssl_ctx_with_no_certs_); chans_.push_back(make_channel(sock, chid, ctx, ssl, CHAN_PLAINTEXT)); } } } void accept_connections(int port, SOCKET sock) { SocketVector sockets = accept_on_socket(sock); for (SOCKET sock : sockets) { int chid = driven_->drv_notify_accept(port); // std::cerr << "Accepted channel " << chid << std::endl; SSL_CTX *ctx = nullptr; SSL *ssl = SSL_new(ssl_ctx_with_server_certs_); chans_.push_back(make_channel(sock, chid, ctx, ssl, CHAN_PLAINTEXT)); } } void advance_plaintext(ChanInfo &chan) { // If the channel has no outgoing bytes and has been released, // just close it. if (chan.released) { close_channel(chan, ""); return; } // Try to write plaintext to the channel. int nbytes; const char *bytes; driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes); if (nbytes > 0) { int sbytes = nbytes; if (sbytes > 65536) sbytes = 65536; int wbytes = send(chan.socket, bytes, sbytes, 0); // std::cerr << "send.bytes="<< wbytes << ".errno=" << errno << " "; if (wbytes < 0) { if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) { close_channel(chan, "send failure"); return; } } else { driven_->drv_sent_outgoing(chan.chid, wbytes); } } // Try to read plaintext from the channel. // Someday, find a way to avoid this copy. int nrecv = recv(chan.socket, chbuf.get(), 65536, 0); // std::cerr << "recv.bytes="<< nrecv << ".errno=" << errno << " "; if (nrecv < 0) { if ((errno != EWOULDBLOCK) && (errno != EAGAIN)) { close_channel(chan, "recv failure"); return; } } else if (nrecv == 0) { close_channel(chan, ""); return; } else { driven_->drv_recv_incoming(chan.chid, nrecv, chbuf.get()); } // Update the ready-flags for next time. chan.ready_on_outgoing = true; chan.ready_on_pollin = true; } void advance_ssl_connecting(ChanInfo &chan) { assert(false); } void advance_ssl_accepting(ChanInfo &chan) { assert(false); } void advance_ssl_readwrite(ChanInfo &chan) { assert(false); } void advance_ssl_shutdown(ChanInfo &chan) { assert(false); } void advance_channel(ChanInfo &chan) { switch(chan.state) { case CHAN_PLAINTEXT: advance_plaintext(chan); break; case CHAN_SSL_CONNECTING: advance_ssl_connecting(chan); break; case CHAN_SSL_ACCEPTING: advance_ssl_accepting(chan); break; case CHAN_SSL_READWRITE: advance_ssl_readwrite(chan); break; case CHAN_SSL_SHUTDOWN: advance_ssl_shutdown(chan); break; default: assert(false); break; } } void handle_socket_input_output() { int mstimeout = 1000; // Peek output buffers and determine channel release flags. for (ChanInfo &chan : chans_) { driven_->drv_peek_outgoing(chan.chid, &chan.nbytes, &chan.bytes); chan.just_released = false; if ((chan.nbytes == 0)&&(!chan.released)) { chan.released = driven_->drv_get_channel_released(chan.chid); chan.just_released = chan.released; } } // Construct the pollfd vector. std::vector pollvec; pollvec.resize(listen_sockets_.size() + chans_.size() + 1); int index = 0; for (const auto &p : listen_sockets_) { struct pollfd &pfd = pollvec[index++]; pfd.fd = p.second; pfd.events = POLLIN; } for (const ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec[index++]; assert(chan.socket != INVALID_SOCKET); pfd.fd = chan.socket; pfd.events = POLLERR; if (chan.ready_now) mstimeout = 0; if (chan.just_released) mstimeout = 0; if (chan.ready_on_pollin) pfd.events |= POLLIN; if (chan.ready_on_pollout) pfd.events |= POLLOUT; if (chan.ready_on_outgoing && (chan.nbytes > 0)) pfd.events |= POLLOUT; // std::cerr << "evt=" << pfd.events << " "; } struct pollfd &stdiopoll = pollvec[index++]; stdiopoll.fd = 0; stdiopoll.events = POLLIN; // Do the poll. int status = poll(&pollvec[0], pollvec.size(), mstimeout); assert(status >= 0); // Check listening sockets. index = 0; for (auto &p : listen_sockets_) { struct pollfd &pfd = pollvec[index++]; if (pfd.revents & (POLLIN | POLLERR)) { accept_connections(p.first, p.second); } } // Advance channels where possible. for (ChanInfo &chan : chans_) { struct pollfd &pfd = pollvec[index++]; bool pollin = ((pfd.revents & POLLIN) != 0); bool pollout = ((pfd.revents & POLLOUT) != 0); bool pollerr = ((pfd.revents & POLLERR) != 0); if (chan.ready_now || pollerr || chan.just_released || (chan.ready_on_pollin && pollin) || (chan.ready_on_pollout && pollout) || (chan.ready_on_outgoing && (chan.nbytes > 0) && pollout)) { chan.ready_now = false; chan.ready_on_pollin = false; chan.ready_on_pollout = false; chan.ready_on_outgoing = false; advance_channel(chan); chan.nbytes = false; chan.bytes = 0; } } // Delete any newly-inactive channels cleanup_channels(); } void drive(DrivenEngine *de, int argc, char *argv[]) { enableRawMode(); driven_ = de; any_inactive_ = false; chbuf.reset(new char[65536]); ssl_ctx_with_root_certs_ = new_ssl_context(false, true, ""); ssl_ctx_with_server_certs_ = new_ssl_context(true, false, ""); ssl_ctx_with_no_certs_ = new_ssl_context(false, false, ""); DrivenEngine::set(de); driven_->drv_set_lua_source(util::read_lua_source("lua")); driven_->drv_invoke_event_init(argc, argv); handle_listen_ports(); while (!de->drv_get_stop_driver()) { handle_lua_source(); handle_console_output(); handle_new_outgoing_sockets(); handle_socket_input_output(); handle_console_input(); handle_console_output(); de->drv_invoke_event_update(monoclock.get()); } for (ChanInfo &chan : chans_) { close_channel(chan, ""); } SSL_CTX_free(ssl_ctx_with_no_certs_); SSL_CTX_free(ssl_ctx_with_root_certs_); SSL_CTX_free(ssl_ctx_with_server_certs_); DrivenEngine::set(nullptr); } }; void driver_drive(DrivenEngine *de, int argc, char *argv[]) { Driver driver; driver.drive(de, argc, argv); }