Files
integration/luprex/cpp/drv/driver-windows.cpp

329 lines
9.1 KiB
C++
Raw Normal View History

2023-02-14 13:14:18 -05:00
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include "drvutil.hpp"
#include "osdrvutil.hpp"
2023-02-14 13:14:18 -05:00
#include "sslutil.hpp"
#include "readline.hpp"
#include "../core/enginewrapper.hpp"
2023-02-14 13:14:18 -05:00
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cassert>
#include <map>
#include <vector>
#include <string>
#include <filesystem>
2023-05-10 15:24:47 -04:00
#include <fstream>
2023-02-14 13:14:18 -05:00
#include <winsock2.h>
#include <ws2tcpip.h>
#include <synchapi.h>
#include <sysinfoapi.h>
#include <libloaderapi.h>
2023-02-14 13:14:18 -05:00
#include <windows.h>
#include <openssl/ssl.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <openssl/err.h>
#include <openssl/bio.h>
#include <openssl/pem.h>
// OpenSSL requires plain ascii pathnames. Returns empty string
// if the path cannot be converted to plain ascii.
std::string path_to_plain_ascii(const std::filesystem::path &path) {
std::wstring s = path.native();
for (wchar_t c : s) {
if ((c < 1) || (c > 127)) return "";
}
std::ostringstream oss;
for (wchar_t c : s) {
oss << ((char)c);
}
return oss.str();
}
2023-05-10 15:24:47 -04:00
std::filesystem::path get_exe_path() {
WCHAR exepath[MAX_PATH];
DWORD status = GetModuleFileNameW( NULL, exepath, MAX_PATH );
assert(status != 0);
return std::filesystem::path(exepath);
}
2023-02-14 13:14:18 -05:00
static void set_nonblocking(SOCKET sock) {
u_long mode = 1; // 1 to enable non-blocking socket
int status = ioctlsocket(sock, FIONBIO, &mode);
assert(status == 0);
}
static PADDRINFOA find_good_addr(PADDRINFOA addrinfo) {
for (PADDRINFOA addr = addrinfo; addr != nullptr; addr = addr->ai_next) {
if (addr->ai_family == AF_INET) {
return addr;
}
}
return nullptr;
}
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
PADDRINFOA addrs = nullptr;
PADDRINFOA goodaddr = nullptr;
SOCKET sock = INVALID_SOCKET;
err.clear();
int status = getaddrinfo(host, port, nullptr, &addrs);
while (status == WSATRY_AGAIN) {
status = getaddrinfo(host, port, nullptr, &addrs);
}
if (status == WSAHOST_NOT_FOUND) {
err = "host not found";
goto error;
}
if (status != 0) {
err = "DNS resolution malfunction";
goto error;
}
goodaddr = find_good_addr(addrs);
if (goodaddr == nullptr) {
err = "host not an internet host";
goto error;
}
sock = socket(goodaddr->ai_family, SOCK_STREAM, IPPROTO_TCP);
if (sock == INVALID_SOCKET) {
err = "could not create a socket";
goto error;
}
set_nonblocking(sock);
status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen);
if (status != 0) {
int errcode = WSAGetLastError();
if (errcode != WSAEWOULDBLOCK) {
err = "connect failure";
goto error;
}
}
freeaddrinfo(addrs);
return sock;
error:
if (sock != INVALID_SOCKET) closesocket(sock);
if (addrs != nullptr) freeaddrinfo(addrs);
return SOCKET_ERROR;
}
SOCKET listen_on_port(int port, std::string &err) {
int status;
err.clear();
SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET) {
err = "could not create a socket";
goto error;
}
struct sockaddr_in server;
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(port);
status = bind(sock, (struct sockaddr *)&server, sizeof(server));
if (status < 0) {
err = "could not bind port";
goto error;
}
status = listen(sock, 10);
if (status < 0) {
err = "could not listen on socket";
goto error;
}
set_nonblocking(sock);
return sock;
error:
if (sock != INVALID_SOCKET) closesocket(sock);
return SOCKET_ERROR;
}
static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock != INVALID_SOCKET) {
set_nonblocking(chsock);
return chsock;
} else {
int errcode = WSAGetLastError();
if ((errcode == WSAEWOULDBLOCK) || (errcode == WSAECONNRESET)) {
return INVALID_SOCKET;
} else {
err = "accept failed";
return INVALID_SOCKET;
}
}
}
// the return values for socket_send:
//
// positive: sent bytes successfully
// negative: error.
// If the error message is empty, then it's "would block"
// Any other error generates an error message.
//
2023-02-14 13:14:18 -05:00
static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) {
int wbytes = send(socket, bytes, nbytes, 0);
if (wbytes == SOCKET_ERROR) {
int errcode = WSAGetLastError();
if (errcode == WSAEWOULDBLOCK) {
err.clear();
2023-02-14 13:14:18 -05:00
} else {
err = "send failure";
}
return -1;
2023-02-14 13:14:18 -05:00
} else {
assert(wbytes > 0);
err.clear();
2023-02-14 13:14:18 -05:00
return wbytes;
}
}
static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) {
int nrecv = recv(socket, bytes, nbytes, 0);
if (nrecv < 0) {
int errcode = WSAGetLastError();
if (errcode == WSAEWOULDBLOCK) {
err = "";
2023-02-14 13:14:18 -05:00
} else {
err = "recv failure";
}
return -1;
} else if (nrecv == 0) {
err.clear();
return 0;
2023-02-14 13:14:18 -05:00
} else {
err.clear();
2023-02-14 13:14:18 -05:00
return nrecv;
}
}
static int socket_close(SOCKET socket) {
return closesocket(socket);
}
static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) {
if (pollcount == 0) {
if (mstimeout > 0) Sleep(mstimeout);
return 0;
}
int status = WSAPoll(pollvec, pollcount, mstimeout);
if (status < 0) {
err = drvutil::strerror_str(WSAGetLastError());
2023-02-14 13:14:18 -05:00
return -1;
}
return status;
}
static void init_winsock() {
WSADATA data;
int errcode = WSAStartup(2, &data);
if (errcode != 0) {
fprintf(stderr, "Winsock didn't initalize, error %d", errcode);
exit(1);
}
}
2023-05-19 00:23:23 -04:00
static void console_write(const std::u32string &cps) {
if (cps.size() == 0) return;
2023-05-19 00:23:23 -04:00
// Convert to wstring. Any character not representable as a single wchar_t
// is replaced with a box. It's not ideal, but it's pretty good.
std::wstring ws(cps.size(), 0);
for (int i = 0; i < int(cps.size()); i++) {
char32_t c = cps[i];
2023-05-19 00:23:23 -04:00
if (drvutil::is_single_wchar_t(c)) ws[i] = (wchar_t)c;
else ws[i] = 0x2610;
}
2023-02-14 13:14:18 -05:00
HANDLE hstdout = GetStdHandle(STD_OUTPUT_HANDLE);
assert(hstdout != INVALID_HANDLE_VALUE);
DWORD nwrote;
std::wstring_view v(ws);
while (v.size() > 0) {
int nwrite = v.size();
if (nwrite > 10000) nwrite = 10000;
assert(WriteConsoleW(hstdout, v.data(), nwrite, &nwrote, nullptr));
assert(nwrote > 0);
v.remove_prefix(nwrote);
}
2023-02-14 13:14:18 -05:00
}
2023-05-19 00:23:23 -04:00
static std::u32string console_read() {
2023-02-14 13:14:18 -05:00
HANDLE hstdin = GetStdHandle(STD_INPUT_HANDLE);
assert(hstdin != INVALID_HANDLE_VALUE);
INPUT_RECORD inrecords[512];
DWORD nread, nevents;
if (GetNumberOfConsoleInputEvents(hstdin, &nevents)) {
if (int(nevents) > 0) {
if (int(nevents) > 512) nevents = 512;
ReadConsoleInputW(hstdin, inrecords, nevents, &nread);
2023-05-19 00:23:23 -04:00
std::u32string result(nread, 0);
int len = 0;
for (int i = 0; i < int(nread); i++) {
const INPUT_RECORD &inr = inrecords[i];
if (inr.EventType != KEY_EVENT) continue;
const KEY_EVENT_RECORD &key = inr.Event.KeyEvent;
if (!key.bKeyDown) continue;
result[len++] = key.uChar.UnicodeChar;
}
return result.substr(0, len);
2023-02-14 13:14:18 -05:00
}
}
2023-05-19 00:23:23 -04:00
return std::u32string();
2023-02-14 13:14:18 -05:00
}
static void ssl_load_certificate_authorities(SSL_CTX *ctx) {
HCERTSTORE hStore = CertOpenSystemStoreW(0, L"ROOT");
PCCERT_CONTEXT pContext = NULL;
X509 *x509;
X509_STORE *store = SSL_CTX_get_cert_store(ctx);
if (!hStore) {
fprintf(stderr, "Cannot open system certificate store.\n");
exit(1);
}
while ((pContext = CertEnumCertificatesInStore(hStore, pContext))) {
const unsigned char *encoded_cert = pContext->pbCertEncoded;
x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded);
if (x509) {
X509_STORE_add_cert(store, x509);
X509_free(x509);
}
}
CertCloseStore(hStore, 0);
}
2023-05-10 15:24:47 -04:00
static void call_init_engine_wrapper(const std::filesystem::path &luprexroot, EngineWrapper *w) {
2023-05-09 22:12:17 -04:00
HMODULE exe = GetModuleHandleA(NULL);
using InitFn = void (*)(EngineWrapper *);
2023-05-09 22:12:17 -04:00
InitFn initfn = (InitFn)GetProcAddress(exe, "init_engine_wrapper");
if (initfn == nullptr) {
2023-05-10 15:24:47 -04:00
#if defined(_MSC_VER)
std::wstring path = luprexroot / "build/visual/luprexlib.dll";
#elif defined(__GNUC__)
std::wstring path = luprexroot / "build/mingw/luprexlib.dll";
#else
#error "Cannot detect OS type"
#endif
HMODULE dll = LoadLibraryW(path.c_str());
2023-05-09 22:12:17 -04:00
assert(dll != nullptr);
initfn = (InitFn)GetProcAddress(dll, "init_engine_wrapper");
}
assert(initfn != nullptr);
initfn(w);
}
2023-05-09 18:43:40 -04:00
void os_initialize(int argc, char **argv) {
2023-02-14 13:14:18 -05:00
init_winsock();
}