#include "driver.hpp" #include "umm-malloc.hpp" #include "driver-util.hpp" #include "util.hpp" #include "drivenengine.hpp" #include "dummycert.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using SOCKET=int; const int INVALID_SOCKET = -1; struct termios orig_termios; static UmmString strerror_str(int err) { char errbuf[256]; return strerror_r(errno, errbuf, 256); } void set_nonblocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); assert(flags != -1); int status = fcntl(fd, F_SETFL, flags | O_NONBLOCK); assert(status != -1); } static void disable_tty_raw() { tcsetattr(0, TCSAFLUSH, &orig_termios); } static void enable_tty_raw() { int status = tcgetattr(0, &orig_termios); assert(status >= 0); atexit(disable_tty_raw); struct termios raw = orig_termios; raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); raw.c_lflag &= ~(ECHO | ICANON); raw.c_oflag |= OPOST; raw.c_cc[VMIN] = 0; raw.c_cc[VTIME] = 0; status = tcsetattr(0, TCSAFLUSH, &raw); assert(status >= 0); } static SOCKET open_connection(std::string_view target, UmmString &err) { struct addrinfo *addrs = nullptr; struct addrinfo *goodaddr = nullptr; struct addrinfo hints; SOCKET sock = INVALID_SOCKET; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; hints.ai_flags = AI_NUMERICSERV; err.clear(); UmmString host, port; drv::split_host_port(target, host, port); int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs); if (status != 0) { err = gai_strerror(status); goto error_general; } if (addrs == nullptr) { err = "no such host found"; goto error_general; } goodaddr = addrs; assert(goodaddr->ai_family == AF_INET); assert(goodaddr->ai_socktype == SOCK_STREAM); assert(goodaddr->ai_protocol == IPPROTO_TCP); sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol); if (sock <= 0) goto error_errno; set_nonblocking(sock); status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); if ((status != 0) && (errno != EINPROGRESS)) goto error_errno; freeaddrinfo(addrs); return sock; error_errno: err = strerror_str(errno); error_general: if (sock != INVALID_SOCKET) close(sock); if (addrs != nullptr) freeaddrinfo(addrs); return INVALID_SOCKET; } static SOCKET listen_on_port(int port, UmmString &err) { int status, enable; err.clear(); SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); if (sock <= 0) goto error_errno; enable = 1; status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); if (status != 0) goto error_errno; struct sockaddr_in server; server.sin_family = AF_INET; server.sin_addr.s_addr = INADDR_ANY; server.sin_port = htons(port); status = bind(sock, (struct sockaddr *)&server, sizeof(server)); if (status != 0) goto error_errno; status = listen(sock, 10); if (status != 0) goto error_errno; set_nonblocking(sock); return sock; error_errno: err = strerror_str(errno); if (sock >= 0) close(sock); return INVALID_SOCKET; } static SOCKET accept_on_socket(SOCKET listen_socket, UmmString &err) { err.clear(); SOCKET chsock = accept(listen_socket, nullptr, nullptr); if (chsock >= 0) { set_nonblocking(chsock); return chsock; } else { if ((errno != EAGAIN) && (errno != EWOULDBLOCK) && (errno != ECONNABORTED)) { err = strerror_str(errno); } return INVALID_SOCKET; } } // the return values for socket_send and socket_recv are: // // positive: sent or received bytes successfully // zero: would block // negative: channel closed, possibly cleanly or possibly with error // static int socket_send(SOCKET socket, const char *bytes, int nbytes, UmmString &err) { err.clear(); int wbytes = send(socket, bytes, nbytes, 0); if (wbytes < 0) { if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { return 0; } else { err = strerror_str(errno); return -1; } } else { return wbytes; } } static int socket_recv(SOCKET socket, char *bytes, int nbytes, UmmString &err) { err.clear(); int nrecv = recv(socket, bytes, nbytes, 0); if (nrecv < 0) { if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { err = strerror_str(errno); return -1; } else { return 0; } } else if (nrecv == 0) { return -1; } else { return nrecv; } } static int socket_close(SOCKET socket) { return close(socket); } static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, UmmString &err) { // socket_poll is implicitly expected to also poll stdin, // if the OS allows that. Linux does, so we add stdin to the // poll vector. The poll vector is required to have at // least one free space in order to do this. pollvec[pollcount].fd = 0; pollvec[pollcount].events = POLLIN; pollcount += 1; // Do the poll. int status = poll(pollvec, pollcount, mstimeout); if (status < 0) { err = strerror_str(errno); return -1; } return 0; } static int console_write(const char *bytes, int nbytes) { return write(1, bytes, nbytes); } static int console_read(char *bytes, int nbytes) { return read(0, bytes, nbytes); } static std::string_view read_file(const char *fn, char *buf, int bufsize, UmmString &err) { int nread; int fd = open(fn, O_RDONLY); if (fd < 0) goto error_errno; nread = read(fd, buf, bufsize); if (nread < 0) goto error_errno; if (nread == bufsize) { err = "file too large"; goto error; } buf[nread] = 0; err = ""; return std::string_view(buf, nread); error_errno: err = strerror_str(errno); error: buf[0] = 0; return std::string_view(buf, 0); } static void disable_randomization(int argc, char *argv[]) { const int old_personality = personality(ADDR_NO_RANDOMIZE); if (!(old_personality & ADDR_NO_RANDOMIZE)) { const int new_personality = personality(ADDR_NO_RANDOMIZE); if (new_personality & ADDR_NO_RANDOMIZE) { execv(argv[0], argv); } } } void driver_sysinit(int argc, char *argv[]) { disable_randomization(argc, argv); enable_tty_raw(); } class MonoClock { private: struct timespec base_; public: MonoClock() { int status = clock_gettime(CLOCK_MONOTONIC, &base_); assert(status == 0); } double get() { struct timespec t; int status = clock_gettime(CLOCK_MONOTONIC, &t); assert(status == 0); double tv_sec = t.tv_sec - base_.tv_sec; double tv_nsec = t.tv_nsec - base_.tv_nsec; return tv_sec + (tv_nsec * 1.0E-9); } }; #include "driver-common.cpp"