Files
integration/luprex/cpp/drv/driver-linux.cpp

272 lines
7.3 KiB
C++

#include "drvutil.hpp"
#include "osdrvutil.hpp"
#include "sslutil.hpp"
#include "readline.hpp"
#include "../core/enginewrapper.hpp"
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cassert>
#include <map>
#include <vector>
#include <string>
#include <filesystem>
#include <fstream>
#include <poll.h>
#include <sys/time.h>
#include <fcntl.h>
#include <termios.h>
#include <unistd.h>
#include <sys/select.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/personality.h>
#include <netdb.h>
#include <malloc.h>
#include <dlfcn.h>
using SOCKET=int;
const int INVALID_SOCKET = -1;
struct termios orig_termios;
std::filesystem::path get_exe_path() {
char result[ PATH_MAX ];
ssize_t count = readlink( "/proc/self/exe", result, PATH_MAX );
return std::filesystem::path(std::string( result, (count > 0) ? count : 0 ));
}
void set_nonblocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
assert(flags != -1);
int status = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
assert(status != -1);
}
static void disable_tty_raw() {
tcsetattr(0, TCSAFLUSH, &orig_termios);
}
static void enable_tty_raw() {
int status = tcgetattr(0, &orig_termios);
assert(status >= 0);
atexit(disable_tty_raw);
struct termios raw = orig_termios;
raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON);
raw.c_lflag &= ~(ECHO | ICANON);
raw.c_oflag |= OPOST;
raw.c_cc[VMIN] = 0;
raw.c_cc[VTIME] = 0;
status = tcsetattr(0, TCSAFLUSH, &raw);
assert(status >= 0);
}
static SOCKET open_connection(const char *host, const char *port, std::string &err) {
struct addrinfo *addrs = nullptr;
struct addrinfo *goodaddr = nullptr;
struct addrinfo hints;
SOCKET sock = INVALID_SOCKET;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
hints.ai_flags = AI_NUMERICSERV;
err.clear();
int status = getaddrinfo(host, port, &hints, &addrs);
if (status != 0) {
err = gai_strerror(status);
goto error_general;
}
if (addrs == nullptr) {
err = "no such host found";
goto error_general;
}
goodaddr = addrs;
assert(goodaddr->ai_family == AF_INET);
assert(goodaddr->ai_socktype == SOCK_STREAM);
assert(goodaddr->ai_protocol == IPPROTO_TCP);
sock = socket(goodaddr->ai_family, goodaddr->ai_socktype, goodaddr->ai_protocol);
if (sock <= 0) goto error_errno;
set_nonblocking(sock);
status = connect(sock, goodaddr->ai_addr, goodaddr->ai_addrlen);
if ((status != 0) && (errno != EINPROGRESS)) goto error_errno;
freeaddrinfo(addrs);
return sock;
error_errno:
err = drvutil::strerror_str(errno);
error_general:
if (sock != INVALID_SOCKET) close(sock);
if (addrs != nullptr) freeaddrinfo(addrs);
return INVALID_SOCKET;
}
static SOCKET listen_on_port(int port, std::string &err) {
int status, enable;
err.clear();
SOCKET sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock <= 0) goto error_errno;
enable = 1;
status = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (status != 0) goto error_errno;
struct sockaddr_in server;
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(port);
status = bind(sock, (struct sockaddr *)&server, sizeof(server));
if (status != 0) goto error_errno;
status = listen(sock, 10);
if (status != 0) goto error_errno;
set_nonblocking(sock);
return sock;
error_errno:
err = drvutil::strerror_str(errno);
if (sock >= 0) close(sock);
return INVALID_SOCKET;
}
static SOCKET accept_on_socket(SOCKET listen_socket, std::string &err) {
err.clear();
SOCKET chsock = accept(listen_socket, nullptr, nullptr);
if (chsock >= 0) {
set_nonblocking(chsock);
return chsock;
} else {
if ((errno != EAGAIN) && (errno != EWOULDBLOCK) && (errno != ECONNABORTED)) {
err = drvutil::strerror_str(errno);
}
return INVALID_SOCKET;
}
}
// the return values for socket_send:
//
// positive: sent bytes successfully
// negative: error.
// If the error message is empty, then it's "would block"
// Any other error generates an error message.
//
static int socket_send(SOCKET socket, const char *bytes, int nbytes, std::string &err) {
int wbytes = send(socket, bytes, nbytes, 0);
if (wbytes < 0) {
if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
err.clear();
} else {
err = drvutil::strerror_str(errno);
}
return -1;
} else {
err.clear();
return wbytes;
}
}
static int socket_recv(SOCKET socket, char *bytes, int nbytes, std::string &err) {
int nrecv = recv(socket, bytes, nbytes, 0);
if (nrecv < 0) {
if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
err.clear();
} else {
err = drvutil::strerror_str(errno);
}
return -1;
} else if (nrecv == 0) {
err.clear();
return 0;
} else {
err.clear();
return nrecv;
}
}
static int socket_close(SOCKET socket) {
return close(socket);
}
static int socket_poll(struct pollfd *pollvec, int pollcount, int mstimeout, std::string &err) {
// socket_poll is implicitly expected to also poll stdin,
// if the OS allows that. Linux does, so we add stdin to the
// poll vector. The poll vector is required to have at
// least one free space in order to do this.
pollvec[pollcount].fd = 0;
pollvec[pollcount].events = POLLIN;
pollcount += 1;
// Do the poll.
int status = poll(pollvec, pollcount, mstimeout);
if (status < 0) {
err = drvutil::strerror_str(errno);
return -1;
}
return 0;
}
// Write unicode onto the console.
static void console_write(const std::u32string &cps) {
std::string utf8 = drvutil::to_utf8(cps);
write(1, utf8.c_str(), utf8.size());
}
static std::u32string console_read() {
std::u32string result;
char buffer[512];
int nread = read(0, buffer, 512);
if (nread > 0) {
std::string_view s(buffer, nread);
result = drvutil::from_utf8(s, nullptr);
}
return result;
}
static void call_init_engine_wrapper(const std::filesystem::path &luprexroot, EngineWrapper *w) {
using InitFn = void (*)(EngineWrapper *);
InitFn initfn = (InitFn)dlsym(nullptr, "init_engine_wrapper");
if (initfn == nullptr) {
std::string path = luprexroot / "build/linux/luprexlib.so";
void *dll_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
assert(dll_handle != nullptr);
initfn = (InitFn)dlsym(dll_handle, "init_engine_wrapper");
}
assert(initfn != nullptr);
initfn(w);
}
static void ssl_load_certificate_authorities(SSL_CTX *ctx) {
assert(SSL_CTX_set_default_verify_paths(ctx) == 1);
}
static void disable_randomization(int argc, char *argv[]) {
const int old_personality = personality(ADDR_NO_RANDOMIZE);
if (!(old_personality & ADDR_NO_RANDOMIZE)) {
const int new_personality = personality(ADDR_NO_RANDOMIZE);
if (new_personality & ADDR_NO_RANDOMIZE) {
execv(argv[0], argv);
}
}
}
static void os_initialize(int argc, char **argv) {
disable_randomization(argc, argv);
enable_tty_raw();
}