#define WINVER 0x0600 #define _WIN32_WINNT 0x0600 #include "driver.hpp" #include "util.hpp" #include "drivenengine.hpp" #include "dummycert.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using PollVector = std::vector; static void set_nonblocking(SOCKET sock) { u_long mode = 1; // 1 to enable non-blocking socket int status = ioctlsocket(sock, FIONBIO, &mode); assert(status == 0); } static void enable_tty_raw() { // Do nothing on windows. } static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) { for (PADDRINFOA addr = addrinfo; addr != nullptr; addr = addr->ai_next) { if (addr->ai_family == AF_INET) { return addr; } } return nullptr; } static SOCKET open_connection(const std::string &target, std::string &err) { PADDRINFOA addrs = nullptr; PADDRINFOA goodaddr = nullptr; SOCKET sock = INVALID_SOCKET; std::string host, port; err.clear(); util::split_host_port(target, host, port); int status = getaddrinfo(host.c_str(), port.c_str(), nullptr, &addrs); while (status == WSATRY_AGAIN) { status = getaddrinfo(host.c_str(), port.c_str(), nullptr, &addrs); } if (status == WSAHOST_NOT_FOUND) { err = "host not found"; goto error; } if (status != 0) { err = "DNS resolution malfunction"; goto error; } goodaddr = find_good_addr(addrs); if (goodaddr == nullptr) { err = "host not an internet host"; goto error; } sock = socket(goodaddr->ai_family, SOCK_STREAM, IPPROTO_TCP); if (sock == INVALID_SOCKET) { err = "could not create a socket"; goto error; } set_nonblocking(sock); status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); if (status != 0) { int errcode = WSAGetLastError(); if (errcode != WSAEWOULDBLOCK) { err = "connect failure"; goto error; } } freeaddrinfo(addrs); return sock; error: if (sock != INVALID_SOCKET) closesocket(sock); if (addrs != nullptr) freeaddrinfo(addrs); return SOCKET_ERROR; } SOCKET listen_on_port(int port, std::string &err) { int status; err.clear(); SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); if (sock == INVALID_SOCKET) { err = "could not create a socket"; goto error; } struct sockaddr_in server; server.sin_family = AF_INET; server.sin_addr.s_addr = INADDR_ANY; server.sin_port = htons(port); status = bind(sock, (struct sockaddr *)&server, sizeof(server)); if (status < 0) { err = "could not bind port"; goto error; } status = listen(sock, 10); if (status < 0) { err = "could not listen on socket"; goto error; } set_nonblocking(sock); return sock; error: if (sock != INVALID_SOCKET) closesocket(sock); return SOCKET_ERROR; } static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { SOCKET chsock = accept(listen_socket, nullptr, nullptr); if (chsock != INVALID_SOCKET) { set_nonblocking(chsock); return chsock; } else { int errcode = WSAGetLastError(); if ((errcode == WSAEWOULDBLOCK) || (errcode == WSAECONNRESET)) { return INVALID_SOCKET; } else { err = "accept failed"; return INVALID_SOCKET; } } } static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) { err.clear(); int wbytes = send(socket, bytes, nbytes, 0); if (wbytes == SOCKET_ERROR) { int errcode = WSAGetLastError(); if (errcode == WSAEWOULDBLOCK) { return 0; } else { err = "send failure"; return -1; } } else { assert(wbytes > 0); return wbytes; } } static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) { err.clear(); int nrecv = recv(socket, bytes, nbytes, 0); if (nrecv < 0) { int errcode = WSAGetLastError(); if (errcode == WSAEWOULDBLOCK) { return 0; } else { err = "recv failure"; return -1; } } else if (nrecv == 0) { return -1; } else { return nrecv; } } static int socket_close(SOCKET socket) { return closesocket(socket); } static int socket_poll(PollVector &pollvec, int mstimeout, std::string &err) { int status = WSAPoll(&pollvec[0], pollvec.size(), mstimeout); if (status < 0) { WSAGetLastError(); err = "poll failed"; return -1; } return status; } static int console_write(const char *bytes, int nbytes) { if (nbytes == 0) return 0; HANDLE hstdout = GetStdHandle(STD_OUTPUT_HANDLE); assert(hstdout != INVALID_HANDLE_VALUE); DWORD nwrote; if (nbytes > 10000) nbytes = 10000; assert(WriteConsoleA(hstdout, bytes, nbytes, &nwrote, nullptr)); assert(nwrote > 0); return nwrote; } static int console_read(char *bytes, int nbytes) { HANDLE hstdin = GetStdHandle(STD_INPUT_HANDLE); assert(hstdin != INVALID_HANDLE_VALUE); INPUT_RECORD inrecords[512]; DWORD nread, nevents; int nascii = 0; if (GetNumberOfConsoleInputEvents(hstdin, &nevents)) { if (int(nevents) > nbytes) nevents = nbytes; ReadConsoleInputA(hstdin, inrecords, nevents, &nread); for (int i = 0; i < int(nread); i++) { const INPUT_RECORD &inr = inrecords[i]; if (inr.EventType != KEY_EVENT) continue; const KEY_EVENT_RECORD &key = inr.Event.KeyEvent; if (!key.bKeyDown) continue; char c = key.uChar.AsciiChar; bytes[nascii++] = c; } return nascii; } else { return 0; } } // The last element in the vector is supposed to be // for polling stdio. But on windows, you can't poll // stdio, so on windows, we remove the last element from // the vector and we reduce mstimeout instead. static void fill_stdio_pollfd(PollVector &pollvec, int &mstimeout, bool read_console_recently) { pollvec.pop_back(); if (mstimeout > 100) mstimeout = 100; } class MonoClock { public: double freq_; MonoClock() { LARGE_INTEGER x; BOOL status = QueryPerformanceFrequency(&x); assert(status != 0); freq_ = 1.0 / double(x.QuadPart); } double get() { LARGE_INTEGER x; BOOL status = QueryPerformanceCounter(&x); assert(status != 0); return double(x.QuadPart) * freq_; } }; #include "driver-common.cpp"