Files
integration/luprex/core/cpp/driver-mingw.cpp

278 lines
7.4 KiB
C++

#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include "driver.hpp"
#include "util.hpp"
#include "drivenengine.hpp"
#include "dummycert.hpp"
#include <map>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cassert>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <synchapi.h>
#include <sysinfoapi.h>
#include <windows.h>
#include <openssl/ssl.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#include <openssl/pem.h>
using PollVector = std::vector<struct pollfd>;
static void set_nonblocking(SOCKET sock) {
u_long mode = 1; // 1 to enable non-blocking socket
int status = ioctlsocket(sock, FIONBIO, &mode);
assert(status == 0);
}
static void enable_tty_raw() {
// Do nothing on windows.
}
static std::string winsock_error_string(int errcode) {
std::ostringstream oss;
oss << "error " << errcode;
return oss.str();
}
static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) {
for (PADDRINFOA addr = addrinfo; addr != nullptr; addr = addr->ai_next) {
if (addr->ai_family == AF_INET) {
return addr;
}
}
return nullptr;
}
static SOCKET open_connection(const std::string &target, std::string &err) {
PADDRINFOA addrs = nullptr;
PADDRINFOA goodaddr = nullptr;
SOCKET sock = INVALID_SOCKET;
std::string host, port;
err.clear();
util::split_host_port(target, host, port);
int status = getaddrinfo(host.c_str(), port.c_str(), nullptr, &addrs);
while (status == WSATRY_AGAIN) {
status = getaddrinfo(host.c_str(), port.c_str(), nullptr, &addrs);
}
if (status == WSAHOST_NOT_FOUND) {
err = "host not found";
goto error;
}
if (status != 0) {
err = "DNS resolution malfunction";
goto error;
}
goodaddr = find_good_addr(addrs);
if (goodaddr == nullptr) {
err = "host not an internet host";
goto error;
}
sock = socket(goodaddr->ai_family, SOCK_STREAM, IPPROTO_TCP);
if (sock == INVALID_SOCKET) {
err = "could not create a socket";
goto error;
}
set_nonblocking(sock);
status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen);
if (status != 0) {
int errcode = WSAGetLastError();
if (errcode != WSAEWOULDBLOCK) {
err = "connect failure";
goto error;
}
}
freeaddrinfo(addrs);
return sock;
error:
if (sock != INVALID_SOCKET) closesocket(sock);
if (addrs != nullptr) freeaddrinfo(addrs);
return SOCKET_ERROR;
}
SOCKET listen_on_port(int port, std::string &err) {
int status;
err.clear();
SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET) {
err = "could not create a socket";
goto error;
}
struct sockaddr_in server;
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(port);
status = bind(sock, (struct sockaddr *)&server, sizeof(server));
if (status < 0) {
err = "could not bind port";
goto error;
}
status = listen(sock, 10);
if (status < 0) {
err = "could not listen on socket";
goto error;
}
set_nonblocking(sock);
std::cerr << "listening socket is " << sock << std::endl;
return sock;
error:
if (sock != INVALID_SOCKET) closesocket(sock);
return SOCKET_ERROR;
}
static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock != INVALID_SOCKET) {
set_nonblocking(chsock);
std::cerr << "accepted socket is " << chsock << std::endl;
return chsock;
} else {
int errcode = WSAGetLastError();
if ((errcode == WSAEWOULDBLOCK) || (errcode == WSAECONNRESET)) {
return INVALID_SOCKET;
} else {
err = "accept failed";
return INVALID_SOCKET;
}
}
}
static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) {
err.clear();
int wbytes = send(socket, bytes, nbytes, 0);
if (wbytes == SOCKET_ERROR) {
int errcode = WSAGetLastError();
if (errcode == WSAEWOULDBLOCK) {
return 0;
} else {
err = "send failure";
return -1;
}
} else {
assert(wbytes > 0);
return wbytes;
}
}
static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) {
err.clear();
int nrecv = recv(socket, bytes, nbytes, 0);
if (nrecv < 0) {
int errcode = WSAGetLastError();
if (errcode == WSAEWOULDBLOCK) {
return 0;
} else {
err = "recv failure";
return -1;
}
} else if (nrecv == 0) {
return -1;
} else {
return nrecv;
}
}
static int socket_close(SOCKET socket) {
return closesocket(socket);
}
static int socket_poll(PollVector &pollvec, int mstimeout, std::string &err) {
int status = WSAPoll(&pollvec[0], pollvec.size(), mstimeout);
if (status < 0) {
err = winsock_error_string(WSAGetLastError());
return -1;
}
return status;
}
static void socket_init() {
WSADATA data;
int errcode = WSAStartup(2, &data);
if (errcode != 0) {
fprintf(stderr, "Winsock didn't initalize, error %d", errcode);
exit(1);
}
}
static void socket_uninit() {
// Nothing needed.
}
static int console_write(const char *bytes, int nbytes) {
if (nbytes == 0) return 0;
HANDLE hstdout = GetStdHandle(STD_OUTPUT_HANDLE);
assert(hstdout != INVALID_HANDLE_VALUE);
DWORD nwrote;
if (nbytes > 10000) nbytes = 10000;
assert(WriteConsoleA(hstdout, bytes, nbytes, &nwrote, nullptr));
assert(nwrote > 0);
return nwrote;
}
static int console_read(char *bytes, int nbytes) {
HANDLE hstdin = GetStdHandle(STD_INPUT_HANDLE);
assert(hstdin != INVALID_HANDLE_VALUE);
INPUT_RECORD inrecords[512];
DWORD nread, nevents;
int nascii = 0;
if (GetNumberOfConsoleInputEvents(hstdin, &nevents)) {
if (int(nevents) > nbytes) nevents = nbytes;
ReadConsoleInputA(hstdin, inrecords, nevents, &nread);
for (int i = 0; i < int(nread); i++) {
const INPUT_RECORD &inr = inrecords[i];
if (inr.EventType != KEY_EVENT) continue;
const KEY_EVENT_RECORD &key = inr.Event.KeyEvent;
if (!key.bKeyDown) continue;
char c = key.uChar.AsciiChar;
bytes[nascii++] = c;
}
return nascii;
} else {
return 0;
}
}
// The last element in the vector is supposed to be
// for polling stdio. But on windows, you can't poll
// stdio, so on windows, we remove the last element from
// the vector and we reduce mstimeout instead.
static void fill_stdio_pollfd(PollVector &pollvec, int &mstimeout, bool read_console_recently) {
pollvec.pop_back();
if (mstimeout > 100) mstimeout = 100;
}
static void disable_randomization(int argc, char *argv[]) {
// Do nothing.
}
class MonoClock {
public:
double freq_;
MonoClock() {
LARGE_INTEGER x;
BOOL status = QueryPerformanceFrequency(&x);
assert(status != 0);
freq_ = 1.0 / double(x.QuadPart);
}
double get() {
LARGE_INTEGER x;
BOOL status = QueryPerformanceCounter(&x);
assert(status != 0);
return double(x.QuadPart) * freq_;
}
};
#include "driver-common.cpp"