diff --git a/Source/Integration/LuprexSockets.cpp b/Source/Integration/LuprexSockets.cpp index 779dee5c..6829711c 100644 --- a/Source/Integration/LuprexSockets.cpp +++ b/Source/Integration/LuprexSockets.cpp @@ -9,7 +9,6 @@ #include "SocketSubsystem.h" #include "AddressInfoTypes.h" - #define UI UI_ST THIRD_PARTY_INCLUDES_START #include @@ -36,7 +35,9 @@ THIRD_PARTY_INCLUDES_END #define MAX_BIO_BUFFER (128 * 1024) -static const char* dummy_cert = +namespace { + +const char* dummy_cert = "-----BEGIN CERTIFICATE-----\n" "MIIDezCCAmOgAwIBAgIUajKmxrLMr9zBMlphrTJU5qKG8FgwDQYJKoZIhvcNAQEL\n" "BQAwTDELMAkGA1UEBhMCVVMxFTATBgNVBAgMDFBlbm5zeWx2YW5pYTESMBAGA1UE\n" @@ -59,7 +60,7 @@ static const char* dummy_cert = "z+DnJGjHrV1J/jHPrnVvVLpigBlGno3C5O/sRw3gcQ==\n" "-----END CERTIFICATE-----\n"; -static const char* dummy_key = +const char* dummy_key = "-----BEGIN PRIVATE KEY-----\n" "MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDk5Yhoqphp7ic/\n" "G+7kQ/dWKVyMClhwxPj7bKl9CgHo8R6nmXCC4D3b2s2xMqeEawSJABnR5k8Rk3tW\n" @@ -116,7 +117,8 @@ public: // ///////////////////////////////////////////////////////////////// -enum EChanState { +enum EChanState +{ CHAN_INACTIVE, CHAN_SSL_CONNECTING, CHAN_SSL_ACCEPTING, @@ -204,20 +206,16 @@ public: TArray Listeners; // Pointer to the socket subsystem. - ISocketSubsystem* Subsys; + ISocketSubsystem* Subsys = nullptr; - BIO* TraceBIO; - - SSL_CTX* ServerCTX; - SSL_CTX* ClientSecureCTX; - SSL_CTX* ClientInsecureCTX; + SSL_CTX* ServerCTX = nullptr; + SSL_CTX* ClientSecureCTX = nullptr; + SSL_CTX* ClientInsecureCTX = nullptr; +public: FlxSocketsI(FlxLockedWrapper &w); virtual ~FlxSocketsI() override; - // Copy the trace to UE_LOG. - void LogTrace(); - // Error handling. void SetError(const std::string& s); virtual std::string GetError() override { return FatalError; } @@ -246,8 +244,7 @@ public: ///////////////////////////////////////////////////////////////// - -static FSocket* OpenConnection(ISocketSubsystem *subsys, const std::string& host, const std::string& port, std::string& err) +FSocket* OpenConnection(ISocketSubsystem *subsys, const std::string& host, const std::string& port, std::string& err) { std::string hostport = host + ":" + port; FString fshost(host.size(), (const UTF8CHAR*)host.c_str()); @@ -255,7 +252,8 @@ static FSocket* OpenConnection(ISocketSubsystem *subsys, const std::string& host FAddressInfoResult air = subsys->GetAddressInfo(*fshost, *fsport, EAddressInfoFlags::Default, NAME_None, ESocketType::SOCKTYPE_Streaming); - if (air.Results.Num() == 0) { + if (air.Results.Num() == 0) + { err = std::string("DNS Lookup failed for: ") + hostport; return nullptr; } @@ -336,12 +334,14 @@ FSocket* ListenOnPort(ISocketSubsystem* subsys, int port, std::string& err) // ///////////////////////////////////////////////////////////////// -static void SSLClearErrors() { +void SSLClearErrors() +{ ERR_clear_error(); errno = 0; } -static std::string SSLFullErrorString() { +std::string SSLFullErrorString() +{ BIO* b = BIO_new(BIO_s_mem()); ERR_print_errors(b); char* data; @@ -352,10 +352,12 @@ static std::string SSLFullErrorString() { return result; } -static std::string SSLErrorString() { +std::string SSLErrorString() +{ // Get the last code. int code = 0; - while (true) { + while (true) + { int icode = ERR_get_error(); if (icode == 0) break; code = icode; @@ -365,50 +367,55 @@ static std::string SSLErrorString() { int terrno = errno; errno = 0; - if (code != 0) { + if (code != 0) + { const char* rc = ERR_reason_error_string(code); - if (rc != nullptr) { + if (rc != nullptr) + { return rc; } - else { + else + { return std::system_category().message(ERR_GET_REASON(code)); } } - else if (terrno != 0) { + else if (terrno != 0) + { return std::system_category().message(terrno); } - else { + else + { return ""; } } -static SSL_CTX* SSLNewContext(int verify, const SSL_METHOD *method, BIO *tracebio) { +SSL_CTX* SSLNewContext(int verify, const SSL_METHOD *method) +{ check(method != nullptr); SSL_CTX* ctx = SSL_CTX_new(method); SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); SSL_CTX_set_verify(ctx, verify, nullptr); SSL_CTX_set_ecdh_auto(ctx, 1); - //if (tracebio != nullptr) - //{ - // SSL_CTX_set_msg_callback(ctx, SSL_trace); - // SSL_CTX_set_msg_callback_arg(ctx, tracebio); - //} return ctx; } #ifdef __linux__ -static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) { - if (SSL_CTX_set_default_verify_paths(ctx) != 1) { +std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) +{ + if (SSL_CTX_set_default_verify_paths(ctx) != 1) + { return "Could not load default certificate authority paths."; } return ""; } #else -static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) { +std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) +{ HCERTSTORE hStore = CertOpenSystemStoreW(0, L"ROOT"); - if (!hStore) { + if (!hStore) + { return "Could not open system cert store."; } @@ -422,7 +429,8 @@ static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) { if (pContext == nullptr) break; const unsigned char* encoded_cert = pContext->pbCertEncoded; x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) { + if (x509) + { X509_STORE_add_cert(store, x509); X509_free(x509); } @@ -433,7 +441,8 @@ static std::string SSLLoadCertificateAuthorities(SSL_CTX* ctx) { } #endif -static std::string SSLUseCertificateString(SSL_CTX* ctx, const char* str) { +std::string SSLUseCertificateString(SSL_CTX* ctx, const char* str) +{ SSLClearErrors(); BIO* bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); @@ -452,7 +461,8 @@ static std::string SSLUseCertificateString(SSL_CTX* ctx, const char* str) { return result; } -static std::string SSLUsePrivateKeyString(SSL_CTX* ctx, const char* str) { +std::string SSLUsePrivateKeyString(SSL_CTX* ctx, const char* str) +{ SSLClearErrors(); BIO* bio = BIO_new(BIO_s_mem()); BIO_puts(bio, str); @@ -471,7 +481,7 @@ static std::string SSLUsePrivateKeyString(SSL_CTX* ctx, const char* str) { return result; } -static std::string SSLLoadDummyCert(SSL_CTX* ctx) +std::string SSLLoadDummyCert(SSL_CTX* ctx) { std::string err1 = SSLUseCertificateString(ctx, dummy_cert); std::string err2 = SSLUsePrivateKeyString(ctx, dummy_key); @@ -484,7 +494,8 @@ static std::string SSLLoadDummyCert(SSL_CTX* ctx) // because MEM BIOs technically have unlimited capacity. We're // artificially limiting them to a certain size because there's no // reason to buffer huge amounts of data. -static int BIOSpace(BIO* bio) { +int BIOSpace(BIO* bio) +{ int space = (MAX_BIO_BUFFER)-BIO_pending(bio); if (space < 0) space = 0; return space; @@ -493,8 +504,10 @@ static int BIOSpace(BIO* bio) { // Discard the first nbytes in buffer. // This is a terribly inefficient way to discard data that has // already been processed. There has to be something better. -static void BIODiscard(BIO* b, int nbytes, char* chbuf) { - while (nbytes > 0) { +void BIODiscard(BIO* b, int nbytes, char* chbuf) +{ + while (nbytes > 0) + { int nread = nbytes; if (nread > DRV_SHORTSTRING_SIZE) nread = DRV_SHORTSTRING_SIZE; int ndropped = BIO_read(b, chbuf, nread); @@ -510,7 +523,6 @@ static void BIODiscard(BIO* b, int nbytes, char* chbuf) { // ///////////////////////////////////////////////////////////////// -#pragma optimize("", off) FlxChannel::FlxChannel(FlxSocketsI* lsi, FSocket* sock, int chid, SSL_CTX* ctx, EChanState st) { LSI = lsi; @@ -530,10 +542,12 @@ FlxChannel::FlxChannel(FlxSocketsI* lsi, FSocket* sock, int chid, SSL_CTX* ctx, State = st; } -void FlxChannel::Close(std::string_view err) { +void FlxChannel::Close(std::string_view err) +{ // Close and release the SSL channel. // This frees the BIO objects as well. - if (SSLState != nullptr) { + if (SSLState != nullptr) + { SSL_free(SSLState); SSLState = nullptr; } @@ -567,9 +581,10 @@ void FlxChannel::Close(std::string_view err) { State = CHAN_INACTIVE; } -#pragma optimize("", off) -void FlxChannel::TransferSocketToRecvBIO() { - if ((State == CHAN_INACTIVE) || RecvFail) { +void FlxChannel::TransferSocketToRecvBIO() +{ + if ((State == CHAN_INACTIVE) || RecvFail) + { return; } @@ -587,8 +602,10 @@ void FlxChannel::TransferSocketToRecvBIO() { } } -void FlxChannel::TransferSendBIOToSocket() { - if ((State == CHAN_INACTIVE) || SendFail) { +void FlxChannel::TransferSendBIOToSocket() +{ + if ((State == CHAN_INACTIVE) || SendFail) + { return; } @@ -611,7 +628,8 @@ void FlxChannel::TransferSendBIOToSocket() { } } -void FlxChannel::CloseChannelIfSSLErrorIsSerious(int retval) { +void FlxChannel::CloseChannelIfSSLErrorIsSerious(int retval) +{ int error = SSL_get_error(SSLState, retval); // Should never have write errors, because we're @@ -620,8 +638,10 @@ void FlxChannel::CloseChannelIfSSLErrorIsSerious(int retval) { // If we get a read error, make sure it's plausible: // if the recv bio is full, that makes no sense. - if (error == SSL_ERROR_WANT_READ) { - if (BIOSpace(RecvBIO) == 0) { + if (error == SSL_ERROR_WANT_READ) + { + if (BIOSpace(RecvBIO) == 0) + { Close("ssl waiting for data, but there's tons of data"); } return; @@ -646,7 +666,6 @@ void FlxChannel::AdvanceConnecting() } } -#pragma optimize("", off) void FlxChannel::AdvanceAccepting() { int retval = SSL_accept(SSLState); @@ -658,7 +677,6 @@ void FlxChannel::AdvanceAccepting() { CloseChannelIfSSLErrorIsSerious( retval); } - LSI->LogTrace(); } void FlxChannel::AdvanceReadWrite() @@ -718,7 +736,6 @@ void FlxChannel::AdvanceReadWrite() } } -#pragma optimize("", off) void FlxChannel::Advance() { check(State != CHAN_INACTIVE); @@ -731,9 +748,12 @@ void FlxChannel::Advance() // If all outgoing buffers are empty, and Luprex has released // the channel, close the channel. - if (NBytes == 0) { - if (LSI->Luprex->get_channel_released(LSI->Luprex, ChannelID)) { - if (BIO_pending(SendBIO) == 0) { + if (NBytes == 0) + { + if (LSI->Luprex->get_channel_released(LSI->Luprex, ChannelID)) + { + if (BIO_pending(SendBIO) == 0) + { Close(""); return; } @@ -741,7 +761,8 @@ void FlxChannel::Advance() } SSLClearErrors(); - switch (State) { + switch (State) + { case CHAN_SSL_CONNECTING: AdvanceConnecting(); break; @@ -789,7 +810,6 @@ FlxListener::~FlxListener() } } -#pragma optimize("", off) void FlxListener::AcceptConnection() { FSocket* csocket = Socket->Accept(TEXT("Incoming Connection")); @@ -811,7 +831,8 @@ void FlxListener::AcceptConnection() void FlxSocketsI::SetError(const std::string& s) { - if (FatalError.empty()) { + if (FatalError.empty()) + { FatalError = s; } } @@ -822,13 +843,6 @@ FlxSocketsI::FlxSocketsI(FlxLockedWrapper &w) // We retain this pointer only so long as we have the wrapper lock. TGuardValue GuardLuprex(Luprex, w.Get()); - // This function is nonreentrant. It's not clear whether - // this is needed - it may be initialized elsewhere in unreal. - // It is also not clear that it's safe to do this in the - // blueprint thread (this constructor runs in the blueprint - // thread). - SSL_library_init(); - ServerCTX = nullptr; ClientSecureCTX = nullptr; ClientInsecureCTX = nullptr; @@ -839,11 +853,9 @@ FlxSocketsI::FlxSocketsI(FlxLockedWrapper &w) SetError("Cannot obtain the socket subsystem"); } - TraceBIO = BIO_new(BIO_s_mem()); - - ServerCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_server_method(), TraceBIO); - ClientSecureCTX = SSLNewContext(SSL_VERIFY_PEER, TLS_client_method(), TraceBIO); - ClientInsecureCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_client_method(), TraceBIO); + ServerCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_server_method()); + ClientSecureCTX = SSLNewContext(SSL_VERIFY_PEER, TLS_client_method()); + ClientInsecureCTX = SSLNewContext(SSL_VERIFY_NONE, TLS_client_method()); SetError(SSLLoadCertificateAuthorities(ClientSecureCTX)); SetError(SSLLoadDummyCert(ServerCTX)); @@ -865,7 +877,8 @@ void FlxSocketsI::ForceCloseEverything(FlxLockedWrapper& w) TGuardValue GuardLuprex(Luprex, w.Get()); // Close all channels - for (FlxChannel& chan : Channels) { + for (FlxChannel& chan : Channels) + { chan.Close("Force Close Everything"); } @@ -883,30 +896,15 @@ FlxSocketsI::~FlxSocketsI() if (ServerCTX != nullptr) { SSL_CTX_free(ServerCTX); - ServerCTX = nullptr; } if (ClientSecureCTX != nullptr) { SSL_CTX_free(ClientSecureCTX); - ClientSecureCTX = nullptr; } if (ClientInsecureCTX != nullptr) { SSL_CTX_free(ClientInsecureCTX); - ClientInsecureCTX = nullptr; } - - // TODO: Be more thorough. -} - -void FlxSocketsI::LogTrace() -{ - char* data; - int ndata = BIO_get_mem_data(TraceBIO, &data); - if (ndata == 0) return; - FString text(ndata, (const UTF8CHAR *)data); - UE_LOG(LogLuprexIntegration, Verbose, TEXT("SSL Trace: %s"), *text); - BIO_reset(TraceBIO); } bool FlxSocketsI::ListeningOnPort(int p) @@ -922,7 +920,8 @@ void FlxSocketsI::HandleListenPorts() { uint32_t nports; const uint32_t* ports; Luprex->get_listen_ports(Luprex, &nports, &ports); - for (uint32_t i = 0; i < nports; i++) { + for (uint32_t i = 0; i < nports; i++) + { int port = ports[i]; if (!ListeningOnPort(port)) { @@ -945,32 +944,38 @@ void FlxSocketsI::HandleNewOutgoingSockets() { uint32_t nchids; const uint32_t* chids; Luprex->get_new_outgoing(Luprex, &nchids, &chids); - for (uint32_t i = 0; i < nchids; i++) { + for (uint32_t i = 0; i < nchids; i++) + { uint32_t chid = chids[i]; std::string err, cert, host, port; const char* target = Luprex->get_target(Luprex, chid); drvutil::split_target(target, cert, host, port); - if (cert.empty() || host.empty() || port.empty()) { + if (cert.empty() || host.empty() || port.empty()) + { std::string message = "invalid target: "; message += target; Luprex->play_notify_close(Luprex, chid, message.size(), message.c_str()); continue; } SSL_CTX* ctx = nullptr; - if (cert == "cert") { + if (cert == "cert") + { ctx = ClientSecureCTX; } - else if (cert == "nocert") { + else if (cert == "nocert") + { ctx = ClientInsecureCTX; } - else { + else + { std::string message = "invalid cert rule: "; message += target; Luprex->play_notify_close(Luprex, chid, message.size(), message.c_str()); continue; } FSocket *sock = OpenConnection(Subsys, host, port, err); - if (sock == nullptr) { + if (sock == nullptr) + { Luprex->play_notify_close(Luprex, chid, err.size(), err.c_str()); continue; } @@ -986,10 +991,12 @@ void FlxSocketsI::RemoveInactiveChannels() int n = Channels.Num(); while (true) { - while ((n > 0) && (Channels[n - 1].State == CHAN_INACTIVE)) { + while ((n > 0) && (Channels[n - 1].State == CHAN_INACTIVE)) + { n -= 1; } - while ((i < n) && (Channels[i].State != CHAN_INACTIVE)) { + while ((i < n) && (Channels[i].State != CHAN_INACTIVE)) + { i += 1; } if (i >= n) break; @@ -1002,7 +1009,6 @@ void FlxSocketsI::RemoveInactiveChannels() } - void FlxSocketsI::HandleSocketInputOutput() { for (FlxListener& listener : Listeners) @@ -1020,6 +1026,7 @@ void FlxSocketsI::HandleSocketInputOutput() RemoveInactiveChannels(); } + void FlxSocketsI::Update(FlxLockedWrapper &w) { // We retain this pointer only so long as we have the wrapper lock. @@ -1029,6 +1036,9 @@ void FlxSocketsI::Update(FlxLockedWrapper &w) HandleSocketInputOutput(); } + +} // anonymous namespace + FlxSockets* FlxSockets::Create(FlxLockedWrapper &w) { return new FlxSocketsI(w);