More cleanup in LuprexSockets.

This commit is contained in:
2026-03-01 06:47:23 -05:00
parent f75dff4cbc
commit 3613528ab8

View File

@@ -9,7 +9,6 @@
#include "SocketSubsystem.h"
#include "AddressInfoTypes.h"
#define UI UI_ST
THIRD_PARTY_INCLUDES_START
#include <openssl/ssl.h>
@@ -36,7 +35,9 @@ THIRD_PARTY_INCLUDES_END
#define MAX_BIO_BUFFER (128 * 1024)
static const char* dummy_cert =
namespace {
const char* dummy_cert =
"-----BEGIN CERTIFICATE-----\n"
"MIIDezCCAmOgAwIBAgIUajKmxrLMr9zBMlphrTJU5qKG8FgwDQYJKoZIhvcNAQEL\n"
"BQAwTDELMAkGA1UEBhMCVVMxFTATBgNVBAgMDFBlbm5zeWx2YW5pYTESMBAGA1UE\n"
@@ -59,7 +60,7 @@ static const char* dummy_cert =
"z+DnJGjHrV1J/jHPrnVvVLpigBlGno3C5O/sRw3gcQ==\n"
"-----END CERTIFICATE-----\n";
static const char* dummy_key =
const char* dummy_key =
"-----BEGIN PRIVATE KEY-----\n"
"MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDk5Yhoqphp7ic/\n"
"G+7kQ/dWKVyMClhwxPj7bKl9CgHo8R6nmXCC4D3b2s2xMqeEawSJABnR5k8Rk3tW\n"
@@ -116,7 +117,8 @@ public:
//
/////////////////////////////////////////////////////////////////
enum EChanState {
enum EChanState
{
CHAN_INACTIVE,
CHAN_SSL_CONNECTING,
CHAN_SSL_ACCEPTING,
@@ -204,20 +206,16 @@ public:
TArray<FlxListener> Listeners;
// Pointer to the socket subsystem.
ISocketSubsystem* Subsys;
ISocketSubsystem* Subsys = nullptr;
BIO* TraceBIO;
SSL_CTX* ServerCTX;
SSL_CTX* ClientSecureCTX;
SSL_CTX* ClientInsecureCTX;
SSL_CTX* ServerCTX = nullptr;
SSL_CTX* ClientSecureCTX = nullptr;
SSL_CTX* ClientInsecureCTX = nullptr;
public:
FlxSocketsI(FlxLockedWrapper &w);
virtual ~FlxSocketsI() override;
// Copy the trace to UE_LOG.
void LogTrace();
// Error handling.
void SetError(const std::string& s);
virtual std::string GetError() override { return FatalError; }
@@ -246,8 +244,7 @@ public:
/////////////////////////////////////////////////////////////////
static FSocket* OpenConnection(ISocketSubsystem *subsys, const std::string& host, const std::string& port, std::string& err)
FSocket* OpenConnection(ISocketSubsystem *subsys, const std::string& host, const std::string& port, std::string& err)
{
std::string hostport = host + ":" + port;
FString fshost(host.size(), (const UTF8CHAR*)host.c_str());
@@ -255,7 +252,8 @@ static FSocket* OpenConnection(ISocketSubsystem *subsys, const std::string& host
FAddressInfoResult air = subsys->GetAddressInfo(*fshost, *fsport, EAddressInfoFlags::Default, NAME_None, ESocketType::SOCKTYPE_Streaming);
if (air.Results.Num() == 0) {
if (air.Results.Num() == 0)
{
err = std::string("DNS Lookup failed for: ") + hostport;
return nullptr;
}
@@ -336,12 +334,14 @@ FSocket* ListenOnPort(ISocketSubsystem* subsys, int port, std::string& err)
//
/////////////////////////////////////////////////////////////////
static void SSLClearErrors() {
void SSLClearErrors()
{
ERR_clear_error();
errno = 0;
}
static std::string SSLFullErrorString() {
std::string SSLFullErrorString()
{
BIO* b = BIO_new(BIO_s_mem());
ERR_print_errors(b);
char* data;
@@ -352,10 +352,12 @@ static std::string SSLFullErrorString() {
return result;
}
static std::string SSLErrorString() {
std::string SSLErrorString()
{
// Get the last code.
int code = 0;
while (true) {
while (true)
{
int icode = ERR_get_error();
if (icode == 0) break;
code = icode;
@@ -365,50 +367,55 @@ static std::string SSLErrorString() {
int terrno = errno;
errno = 0;
if (code != 0) {
if (code != 0)
{
const char* rc = ERR_reason_error_string(code);
if (rc != nullptr) {
if (rc != nullptr)
{
return rc;
}
else {
else
{
return std::system_category().message(ERR_GET_REASON(code));
}
}
else if (terrno != 0) {
else if (terrno != 0)
{
return std::system_category().message(terrno);
}
else {
else
{
return "";
}
}
static SSL_CTX* SSLNewContext(int verify, const SSL_METHOD *method, BIO *tracebio) {
SSL_CTX* SSLNewContext(int verify, const SSL_METHOD *method)
{
check(method != nullptr);
SSL_CTX* ctx = SSL_CTX_new(method);
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_verify(ctx, verify, nullptr);
SSL_CTX_set_ecdh_auto(ctx, 1);
//if (tracebio != nullptr)
//{
// SSL_CTX_set_msg_callback(ctx, SSL_trace);
// SSL_CTX_set_msg_callback_arg(ctx, tracebio);
//}
return ctx;
}
#ifdef __linux__
static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) {
if (SSL_CTX_set_default_verify_paths(ctx) != 1) {
std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx)
{
if (SSL_CTX_set_default_verify_paths(ctx) != 1)
{
return "Could not load default certificate authority paths.";
}
return "";
}
#else
static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) {
std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx)
{
HCERTSTORE hStore = CertOpenSystemStoreW(0, L"ROOT");
if (!hStore) {
if (!hStore)
{
return "Could not open system cert store.";
}
@@ -422,7 +429,8 @@ static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) {
if (pContext == nullptr) break;
const unsigned char* encoded_cert = pContext->pbCertEncoded;
x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded);
if (x509) {
if (x509)
{
X509_STORE_add_cert(store, x509);
X509_free(x509);
}
@@ -433,7 +441,8 @@ static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) {
}
#endif
static std::string SSLUseCertificateString(SSL_CTX* ctx, const char* str) {
std::string SSLUseCertificateString(SSL_CTX* ctx, const char* str)
{
SSLClearErrors();
BIO* bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
@@ -452,7 +461,8 @@ static std::string SSLUseCertificateString(SSL_CTX* ctx, const char* str) {
return result;
}
static std::string SSLUsePrivateKeyString(SSL_CTX* ctx, const char* str) {
std::string SSLUsePrivateKeyString(SSL_CTX* ctx, const char* str)
{
SSLClearErrors();
BIO* bio = BIO_new(BIO_s_mem());
BIO_puts(bio, str);
@@ -471,7 +481,7 @@ static std::string SSLUsePrivateKeyString(SSL_CTX* ctx, const char* str) {
return result;
}
static std::string SSLLoadDummyCert(SSL_CTX* ctx)
std::string SSLLoadDummyCert(SSL_CTX* ctx)
{
std::string err1 = SSLUseCertificateString(ctx, dummy_cert);
std::string err2 = SSLUsePrivateKeyString(ctx, dummy_key);
@@ -484,7 +494,8 @@ static std::string SSLLoadDummyCert(SSL_CTX* ctx)
// because MEM BIOs technically have unlimited capacity. We're
// artificially limiting them to a certain size because there's no
// reason to buffer huge amounts of data.
static int BIOSpace(BIO* bio) {
int BIOSpace(BIO* bio)
{
int space = (MAX_BIO_BUFFER)-BIO_pending(bio);
if (space < 0) space = 0;
return space;
@@ -493,8 +504,10 @@ static int BIOSpace(BIO* bio) {
// Discard the first nbytes in buffer.
// This is a terribly inefficient way to discard data that has
// already been processed. There has to be something better.
static void BIODiscard(BIO* b, int nbytes, char* chbuf) {
while (nbytes > 0) {
void BIODiscard(BIO* b, int nbytes, char* chbuf)
{
while (nbytes > 0)
{
int nread = nbytes;
if (nread > DRV_SHORTSTRING_SIZE) nread = DRV_SHORTSTRING_SIZE;
int ndropped = BIO_read(b, chbuf, nread);
@@ -510,7 +523,6 @@ static void BIODiscard(BIO* b, int nbytes, char* chbuf) {
//
/////////////////////////////////////////////////////////////////
#pragma optimize("", off)
FlxChannel::FlxChannel(FlxSocketsI* lsi, FSocket* sock, int chid, SSL_CTX* ctx, EChanState st)
{
LSI = lsi;
@@ -530,10 +542,12 @@ FlxChannel::FlxChannel(FlxSocketsI* lsi, FSocket* sock, int chid, SSL_CTX* ctx,
State = st;
}
void FlxChannel::Close(std::string_view err) {
void FlxChannel::Close(std::string_view err)
{
// Close and release the SSL channel.
// This frees the BIO objects as well.
if (SSLState != nullptr) {
if (SSLState != nullptr)
{
SSL_free(SSLState);
SSLState = nullptr;
}
@@ -567,9 +581,10 @@ void FlxChannel::Close(std::string_view err) {
State = CHAN_INACTIVE;
}
#pragma optimize("", off)
void FlxChannel::TransferSocketToRecvBIO() {
if ((State == CHAN_INACTIVE) || RecvFail) {
void FlxChannel::TransferSocketToRecvBIO()
{
if ((State == CHAN_INACTIVE) || RecvFail)
{
return;
}
@@ -587,8 +602,10 @@ void FlxChannel::TransferSocketToRecvBIO() {
}
}
void FlxChannel::TransferSendBIOToSocket() {
if ((State == CHAN_INACTIVE) || SendFail) {
void FlxChannel::TransferSendBIOToSocket()
{
if ((State == CHAN_INACTIVE) || SendFail)
{
return;
}
@@ -611,7 +628,8 @@ void FlxChannel::TransferSendBIOToSocket() {
}
}
void FlxChannel::CloseChannelIfSSLErrorIsSerious(int retval) {
void FlxChannel::CloseChannelIfSSLErrorIsSerious(int retval)
{
int error = SSL_get_error(SSLState, retval);
// Should never have write errors, because we're
@@ -620,8 +638,10 @@ void FlxChannel::CloseChannelIfSSLErrorIsSerious(int retval) {
// If we get a read error, make sure it's plausible:
// if the recv bio is full, that makes no sense.
if (error == SSL_ERROR_WANT_READ) {
if (BIOSpace(RecvBIO) == 0) {
if (error == SSL_ERROR_WANT_READ)
{
if (BIOSpace(RecvBIO) == 0)
{
Close("ssl waiting for data, but there's tons of data");
}
return;
@@ -646,7 +666,6 @@ void FlxChannel::AdvanceConnecting()
}
}
#pragma optimize("", off)
void FlxChannel::AdvanceAccepting()
{
int retval = SSL_accept(SSLState);
@@ -658,7 +677,6 @@ void FlxChannel::AdvanceAccepting()
{
CloseChannelIfSSLErrorIsSerious( retval);
}
LSI->LogTrace();
}
void FlxChannel::AdvanceReadWrite()
@@ -718,7 +736,6 @@ void FlxChannel::AdvanceReadWrite()
}
}
#pragma optimize("", off)
void FlxChannel::Advance()
{
check(State != CHAN_INACTIVE);
@@ -731,9 +748,12 @@ void FlxChannel::Advance()
// If all outgoing buffers are empty, and Luprex has released
// the channel, close the channel.
if (NBytes == 0) {
if (LSI->Luprex->get_channel_released(LSI->Luprex, ChannelID)) {
if (BIO_pending(SendBIO) == 0) {
if (NBytes == 0)
{
if (LSI->Luprex->get_channel_released(LSI->Luprex, ChannelID))
{
if (BIO_pending(SendBIO) == 0)
{
Close("");
return;
}
@@ -741,7 +761,8 @@ void FlxChannel::Advance()
}
SSLClearErrors();
switch (State) {
switch (State)
{
case CHAN_SSL_CONNECTING:
AdvanceConnecting();
break;
@@ -789,7 +810,6 @@ FlxListener::~FlxListener()
}
}
#pragma optimize("", off)
void FlxListener::AcceptConnection()
{
FSocket* csocket = Socket->Accept(TEXT("Incoming Connection"));
@@ -811,7 +831,8 @@ void FlxListener::AcceptConnection()
void FlxSocketsI::SetError(const std::string& s)
{
if (FatalError.empty()) {
if (FatalError.empty())
{
FatalError = s;
}
}
@@ -822,13 +843,6 @@ FlxSocketsI::FlxSocketsI(FlxLockedWrapper &w)
// We retain this pointer only so long as we have the wrapper lock.
TGuardValue<EngineWrapper*> GuardLuprex(Luprex, w.Get());
// This function is nonreentrant. It's not clear whether
// this is needed - it may be initialized elsewhere in unreal.
// It is also not clear that it's safe to do this in the
// blueprint thread (this constructor runs in the blueprint
// thread).
SSL_library_init();
ServerCTX = nullptr;
ClientSecureCTX = nullptr;
ClientInsecureCTX = nullptr;
@@ -839,11 +853,9 @@ FlxSocketsI::FlxSocketsI(FlxLockedWrapper &w)
SetError("Cannot obtain the socket subsystem");
}
TraceBIO = BIO_new(BIO_s_mem());
ServerCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_server_method(), TraceBIO);
ClientSecureCTX = SSLNewContext(SSL_VERIFY_PEER, TLS_client_method(), TraceBIO);
ClientInsecureCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_client_method(), TraceBIO);
ServerCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_server_method());
ClientSecureCTX = SSLNewContext(SSL_VERIFY_PEER, TLS_client_method());
ClientInsecureCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_client_method());
SetError(SSLLoadCertificateAuthorities(ClientSecureCTX));
SetError(SSLLoadDummyCert(ServerCTX));
@@ -865,7 +877,8 @@ void FlxSocketsI::ForceCloseEverything(FlxLockedWrapper& w)
TGuardValue<EngineWrapper*> GuardLuprex(Luprex, w.Get());
// Close all channels
for (FlxChannel& chan : Channels) {
for (FlxChannel& chan : Channels)
{
chan.Close("Force Close Everything");
}
@@ -883,30 +896,15 @@ FlxSocketsI::~FlxSocketsI()
if (ServerCTX != nullptr)
{
SSL_CTX_free(ServerCTX);
ServerCTX = nullptr;
}
if (ClientSecureCTX != nullptr)
{
SSL_CTX_free(ClientSecureCTX);
ClientSecureCTX = nullptr;
}
if (ClientInsecureCTX != nullptr)
{
SSL_CTX_free(ClientInsecureCTX);
ClientInsecureCTX = nullptr;
}
// TODO: Be more thorough.
}
void FlxSocketsI::LogTrace()
{
char* data;
int ndata = BIO_get_mem_data(TraceBIO, &data);
if (ndata == 0) return;
FString text(ndata, (const UTF8CHAR *)data);
UE_LOG(LogLuprexIntegration, Verbose, TEXT("SSL Trace: %s"), *text);
BIO_reset(TraceBIO);
}
bool FlxSocketsI::ListeningOnPort(int p)
@@ -922,7 +920,8 @@ void FlxSocketsI::HandleListenPorts()
{
uint32_t nports; const uint32_t* ports;
Luprex->get_listen_ports(Luprex, &nports, &ports);
for (uint32_t i = 0; i < nports; i++) {
for (uint32_t i = 0; i < nports; i++)
{
int port = ports[i];
if (!ListeningOnPort(port))
{
@@ -945,32 +944,38 @@ void FlxSocketsI::HandleNewOutgoingSockets()
{
uint32_t nchids; const uint32_t* chids;
Luprex->get_new_outgoing(Luprex, &nchids, &chids);
for (uint32_t i = 0; i < nchids; i++) {
for (uint32_t i = 0; i < nchids; i++)
{
uint32_t chid = chids[i];
std::string err, cert, host, port;
const char* target = Luprex->get_target(Luprex, chid);
drvutil::split_target(target, cert, host, port);
if (cert.empty() || host.empty() || port.empty()) {
if (cert.empty() || host.empty() || port.empty())
{
std::string message = "invalid target: ";
message += target;
Luprex->play_notify_close(Luprex, chid, message.size(), message.c_str());
continue;
}
SSL_CTX* ctx = nullptr;
if (cert == "cert") {
if (cert == "cert")
{
ctx = ClientSecureCTX;
}
else if (cert == "nocert") {
else if (cert == "nocert")
{
ctx = ClientInsecureCTX;
}
else {
else
{
std::string message = "invalid cert rule: ";
message += target;
Luprex->play_notify_close(Luprex, chid, message.size(), message.c_str());
continue;
}
FSocket *sock = OpenConnection(Subsys, host, port, err);
if (sock == nullptr) {
if (sock == nullptr)
{
Luprex->play_notify_close(Luprex, chid, err.size(), err.c_str());
continue;
}
@@ -986,10 +991,12 @@ void FlxSocketsI::RemoveInactiveChannels()
int n = Channels.Num();
while (true)
{
while ((n > 0) && (Channels[n - 1].State == CHAN_INACTIVE)) {
while ((n > 0) && (Channels[n - 1].State == CHAN_INACTIVE))
{
n -= 1;
}
while ((i < n) && (Channels[i].State != CHAN_INACTIVE)) {
while ((i < n) && (Channels[i].State != CHAN_INACTIVE))
{
i += 1;
}
if (i >= n) break;
@@ -1002,7 +1009,6 @@ void FlxSocketsI::RemoveInactiveChannels()
}
void FlxSocketsI::HandleSocketInputOutput()
{
for (FlxListener& listener : Listeners)
@@ -1020,6 +1026,7 @@ void FlxSocketsI::HandleSocketInputOutput()
RemoveInactiveChannels();
}
void FlxSocketsI::Update(FlxLockedWrapper &w)
{
// We retain this pointer only so long as we have the wrapper lock.
@@ -1029,6 +1036,9 @@ void FlxSocketsI::Update(FlxLockedWrapper &w)
HandleSocketInputOutput();
}
} // anonymous namespace
FlxSockets* FlxSockets::Create(FlxLockedWrapper &w)
{
return new FlxSocketsI(w);