#define WINVER 0x0600 #define _WIN32_WINNT 0x0600 #include "wrap-map.hpp" #include "wrap-string.hpp" #include "wrap-vector.hpp" #include "driver-util.hpp" #include "drivenengine.hpp" #include "dummycert.hpp" #include "util.hpp" #include "source.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include static void set_nonblocking(SOCKET sock) { u_long mode = 1; // 1 to enable non-blocking socket int status = ioctlsocket(sock, FIONBIO, &mode); assert(status == 0); } static std::string strerror_str(int errcode) { std::ostringstream oss; oss << "error " << errcode; return oss.str(); } static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) { for (PADDRINFOA addr = addrinfo; addr != nullptr; addr = addr->ai_next) { if (addr->ai_family == AF_INET) { return addr; } } return nullptr; } static SOCKET open_connection(const char *host, const char *port, std::string &err) { PADDRINFOA addrs = nullptr; PADDRINFOA goodaddr = nullptr; SOCKET sock = INVALID_SOCKET; err.clear(); int status = getaddrinfo(host, port, nullptr, &addrs); while (status == WSATRY_AGAIN) { status = getaddrinfo(host, port, nullptr, &addrs); } if (status == WSAHOST_NOT_FOUND) { err = "host not found"; goto error; } if (status != 0) { err = "DNS resolution malfunction"; goto error; } goodaddr = find_good_addr(addrs); if (goodaddr == nullptr) { err = "host not an internet host"; goto error; } sock = socket(goodaddr->ai_family, SOCK_STREAM, IPPROTO_TCP); if (sock == INVALID_SOCKET) { err = "could not create a socket"; goto error; } set_nonblocking(sock); status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen); if (status != 0) { int errcode = WSAGetLastError(); if (errcode != WSAEWOULDBLOCK) { err = "connect failure"; goto error; } } freeaddrinfo(addrs); return sock; error: if (sock != INVALID_SOCKET) closesocket(sock); if (addrs != nullptr) freeaddrinfo(addrs); return SOCKET_ERROR; } SOCKET listen_on_port(int port, std::string &err) { int status; err.clear(); SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); if (sock == INVALID_SOCKET) { err = "could not create a socket"; goto error; } struct sockaddr_in server; server.sin_family = AF_INET; server.sin_addr.s_addr = INADDR_ANY; server.sin_port = htons(port); status = bind(sock, (struct sockaddr *)&server, sizeof(server)); if (status < 0) { err = "could not bind port"; goto error; } status = listen(sock, 10); if (status < 0) { err = "could not listen on socket"; goto error; } set_nonblocking(sock); std::cerr << "listening socket is " << sock << std::endl; return sock; error: if (sock != INVALID_SOCKET) closesocket(sock); return SOCKET_ERROR; } static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) { SOCKET chsock = accept(listen_socket, nullptr, nullptr); if (chsock != INVALID_SOCKET) { set_nonblocking(chsock); return chsock; } else { int errcode = WSAGetLastError(); if ((errcode == WSAEWOULDBLOCK) || (errcode == WSAECONNRESET)) { return INVALID_SOCKET; } else { err = "accept failed"; return INVALID_SOCKET; } } } static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) { err.clear(); int wbytes = send(socket, bytes, nbytes, 0); if (wbytes == SOCKET_ERROR) { int errcode = WSAGetLastError(); if (errcode == WSAEWOULDBLOCK) { return 0; } else { err = "send failure"; return -1; } } else { assert(wbytes > 0); return wbytes; } } static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) { err.clear(); int nrecv = recv(socket, bytes, nbytes, 0); if (nrecv < 0) { int errcode = WSAGetLastError(); if (errcode == WSAEWOULDBLOCK) { return 0; } else { err = "recv failure"; return -1; } } else if (nrecv == 0) { return -1; } else { return nrecv; } } static int socket_close(SOCKET socket) { return closesocket(socket); } static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) { if (pollcount == 0) { if (mstimeout > 0) Sleep(mstimeout); return 0; } int status = WSAPoll(pollvec, pollcount, mstimeout); if (status < 0) { err = strerror_str(WSAGetLastError()); return -1; } return status; } static void init_winsock() { WSADATA data; int errcode = WSAStartup(2, &data); if (errcode != 0) { fprintf(stderr, "Winsock didn't initalize, error %d", errcode); exit(1); } } static int console_write(const char *bytes, int nbytes) { if (nbytes == 0) return 0; HANDLE hstdout = GetStdHandle(STD_OUTPUT_HANDLE); assert(hstdout != INVALID_HANDLE_VALUE); DWORD nwrote; if (nbytes > 10000) nbytes = 10000; assert(WriteConsoleA(hstdout, bytes, nbytes, &nwrote, nullptr)); assert(nwrote > 0); return nwrote; } static int console_read(char *bytes, int nbytes) { HANDLE hstdin = GetStdHandle(STD_INPUT_HANDLE); assert(hstdin != INVALID_HANDLE_VALUE); INPUT_RECORD inrecords[512]; DWORD nread, nevents; int nascii = 0; if (GetNumberOfConsoleInputEvents(hstdin, &nevents)) { if (int(nevents) > nbytes) nevents = nbytes; ReadConsoleInputA(hstdin, inrecords, nevents, &nread); for (int i = 0; i < int(nread); i++) { const INPUT_RECORD &inr = inrecords[i]; if (inr.EventType != KEY_EVENT) continue; const KEY_EVENT_RECORD &key = inr.Event.KeyEvent; if (!key.bKeyDown) continue; char c = key.uChar.AsciiChar; bytes[nascii++] = c; } return nascii; } else { return 0; } } static void load_root_certs(SSL_CTX *ctx) { HCERTSTORE hStore = CertOpenSystemStoreW(0, L"ROOT"); PCCERT_CONTEXT pContext = NULL; X509 *x509; X509_STORE *store = SSL_CTX_get_cert_store(ctx); if (!hStore) { fprintf(stderr, "Cannot open system certificate store.\n"); exit(1); } while ((pContext = CertEnumCertificatesInStore(hStore, pContext))) { const unsigned char *encoded_cert = pContext->pbCertEncoded; x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); if (x509) { X509_STORE_add_cert(store, x509); X509_free(x509); } } CertCloseStore(hStore, 0); } static void ssl_ctx_use_dummycert(SSL_CTX *ctx); static SSL_CTX *new_ssl_server_context() { SSL_CTX *ctx = SSL_CTX_new(TLS_method()); SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, nullptr); ssl_ctx_use_dummycert(ctx); return ctx; } static SSL_CTX *new_ssl_client_context(int verify) { SSL_CTX *ctx = SSL_CTX_new(TLS_method()); SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); if (verify == SSL_VERIFY_PEER) load_root_certs(ctx); SSL_CTX_set_verify(ctx, verify, nullptr); return ctx; } class MonoClock { public: double freq_; MonoClock() { LARGE_INTEGER x; BOOL status = QueryPerformanceFrequency(&x); assert(status != 0); freq_ = 1.0 / double(x.QuadPart); } double get() { LARGE_INTEGER x; BOOL status = QueryPerformanceCounter(&x); assert(status != 0); return double(x.QuadPart) * freq_; } }; #include "driver-common.cpp" int main(int argc, char **argv) { init_winsock(); OPENSSL_init_ssl(0, NULL); SourceDB::register_lua_builtins(); Driver driver; return driver.drive(argc, argv); }