Files
integration/Plugins/UEWingman/Source/UEWingman/Private/WingServer.cpp

555 lines
15 KiB
C++

#include "WingServer.h"
#include "WingProperty.h"
#include "WingUtils.h"
#include "UObject/StrongObjectPtr.h"
#include "AssetRegistry/AssetRegistryModule.h"
#include "AssetRegistry/IAssetRegistry.h"
#include "Misc/OutputDeviceRedirector.h"
#include "Serialization/JsonReader.h"
#include "Serialization/JsonSerializer.h"
#include "SocketSubsystem.h"
#include "Sockets.h"
#include "Async/Async.h"
UWingServer* UWingServer::GWingServer = nullptr;
// ============================================================
// Initialization and Shutdown
// ============================================================
void UWingServer::Initialize(FSubsystemCollectionBase& Collection)
{
Super::Initialize(Collection);
GWingServer = this;
// Create TCP listen socket
ISocketSubsystem* SocketSub = ISocketSubsystem::Get(PLATFORM_SOCKETSUBSYSTEM);
ListenSocket = SocketSub->CreateSocket(NAME_Stream, TEXT("WingServer"), false);
if (!ListenSocket)
{
UE_LOG(LogTemp, Error, TEXT("UEWingman: Failed to create listen socket"));
return;
}
ListenSocket->SetReuseAddr(true);
ListenSocket->SetNonBlocking(true);
TSharedRef<FInternetAddr> Addr = SocketSub->CreateInternetAddr();
bool bIsValid = false;
Addr->SetIp(TEXT("127.0.0.1"), bIsValid);
Addr->SetPort(Port);
if (!ListenSocket->Bind(*Addr))
{
UE_LOG(LogTemp, Error, TEXT("UEWingman: Failed to bind to port %d"), Port);
SocketSub->DestroySocket(ListenSocket);
ListenSocket = nullptr;
return;
}
if (!ListenSocket->Listen(4))
{
UE_LOG(LogTemp, Error, TEXT("UEWingman: Failed to listen on port %d"), Port);
SocketSub->DestroySocket(ListenSocket);
ListenSocket = nullptr;
return;
}
BuildWingHandlerRegistry();
ModulesChangedHandle = FModuleManager::Get().OnModulesChanged().AddUObject(this, &UWingServer::OnModulesChanged);
LogCapture.bEnabled = false;
GLog->AddOutputDevice(&LogCapture);
bRunning = true;
UE_LOG(LogTemp, Display, TEXT("UEWingman: MCP server listening on tcp://localhost:%d"), Port);
}
void UWingServer::Deinitialize()
{
FModuleManager::Get().OnModulesChanged().Remove(ModulesChangedHandle);
if (!bRunning)
{
Super::Deinitialize();
return;
}
ISocketSubsystem* SocketSub = ISocketSubsystem::Get(PLATFORM_SOCKETSUBSYSTEM);
// Set shutdown flag and drain pending messages under lock
{
FScopeLock Lock(&Mutex);
bShuttingDown = true;
for (auto& Msg : PendingMessages)
{
Msg->Response.SetValue(FString());
}
PendingMessages.Empty();
}
// Close all client sockets (unblocks their blocking reads)
for (auto& Client : Clients)
{
if (Client->Socket)
{
Client->Socket->Close();
}
}
// Wait for client threads to exit
for (auto& Client : Clients)
{
Client->ThreadFuture.Wait();
if (Client->Socket)
{
SocketSub->DestroySocket(Client->Socket);
}
}
Clients.Empty();
// Close listen socket
if (ListenSocket)
{
ListenSocket->Close();
SocketSub->DestroySocket(ListenSocket);
ListenSocket = nullptr;
}
GLog->RemoveOutputDevice(&LogCapture);
bRunning = false;
bShuttingDown = false;
GWingServer = nullptr;
UE_LOG(LogTemp, Display, TEXT("UEWingman: Server stopped."));
Super::Deinitialize();
}
// ============================================================
// FTickableEditorObject interface
// ============================================================
void UWingServer::Tick(float DeltaTime)
{
if (!bRunning) return;
// Accept new connections (non-blocking)
AcceptNewConnections();
// Clean up finished client threads
CleanupFinishedClients();
// Dequeue one pending message
TSharedPtr<FPendingMessage> Request;
{
FScopeLock Lock(&Mutex);
if (PendingMessages.Num() > 0)
{
Request = PendingMessages[0];
PendingMessages.RemoveAt(0);
}
}
// If we have a request, process it.
if (Request.IsValid())
{
FString Response = HandleRequest(Request->Line);
Request->Response.SetValue(Response);
}
}
void UWingServer::TickServer(float DeltaTime)
{
if (GWingServer) GWingServer->Tick(DeltaTime);
}
TStatId UWingServer::GetStatId() const
{
RETURN_QUICK_DECLARE_CYCLE_STAT(UWingServer, STATGROUP_Tickables);
}
// ============================================================
// HandleRequest — Given a command, execute it.
// ============================================================
FString UWingServer::HandleRequest(const FString& Line)
{
// Parse the request as JSON before doing anything else.
TSharedPtr<FJsonValue> Value;
TSharedRef<TJsonReader<>> Reader = TJsonReaderFactory<>::Create(Line);
if (!FJsonSerializer::Deserialize(Reader, Value))
return PackageResponses({TEXT("Invalid Json")});
const TSharedPtr<FJsonObject>* RequestPtr = nullptr;
if (!Value->TryGetObject(RequestPtr))
return PackageResponses({TEXT("Json must be an object")});
TSharedPtr<FJsonObject> Request = *RequestPtr;
FString Command;
Request->TryGetStringField(TEXT("command"), Command);
if (Command == TEXT("Sequence"))
{
const TArray<TSharedPtr<FJsonValue>>* Subcommands = nullptr;
if (!Request->TryGetArrayField(TEXT("subcommands"), Subcommands))
return PackageResponses({TEXT("Sequence requires a 'subcommands' array.")});
TArray<FString> Responses;
Responses.Reserve(Subcommands->Num());
for (const TSharedPtr<FJsonValue>& Sub : *Subcommands)
{
const TSharedPtr<FJsonObject>* SubObjPtr = nullptr;
if (!Sub->TryGetObject(SubObjPtr))
Responses.Add(TEXT("Subcommand must be a JSON object."));
else
Responses.Add(HandleJsonRequest(*SubObjPtr));
}
return PackageResponses(Responses);
}
return PackageResponses({HandleJsonRequest(Request)});
}
FString UWingServer::PackageResponses(const TArray<FString>& Responses)
{
TArray<TSharedPtr<FJsonValue>> Blocks;
Blocks.Reserve(Responses.Num());
for (const FString& Response : Responses)
{
// Unreal's JSON writer terminates string serialization at the first
// embedded null byte rather than escaping it, which would silently
// truncate output. Sanitize null bytes to spaces.
FString Sanitized = Response;
for (int32 i = 0; i < Sanitized.Len(); ++i)
{
if (Sanitized[i] == TEXT('\0')) Sanitized[i] = TEXT(' ');
}
TSharedPtr<FJsonObject> Block = MakeShared<FJsonObject>();
Block->SetStringField(TEXT("type"), TEXT("text"));
Block->SetStringField(TEXT("text"), Sanitized);
Blocks.Add(MakeShared<FJsonValueObject>(Block));
}
FString OutJson;
TSharedRef<TJsonWriter<>> Writer = TJsonWriterFactory<>::Create(&OutJson);
FJsonSerializer::Serialize(Blocks, Writer);
return OutJson;
}
FString UWingServer::HandleJsonRequest(TSharedPtr<FJsonObject> Request)
{
LogCapture.CapturedErrors.Empty();
LogCapture.bEnabled = true;
WingOut::StdoutBuffer.Reset();
SuggestedManualSections.Empty();
bSuggestHandlerHelp = false;
LastHandler = nullptr;
TryCallHandler(Request);
Notifier.SendNotifications();
LogCapture.bEnabled = false;
for (const FString& Msg : LogCapture.CapturedErrors)
{
WingOut::Stdout.Printf(TEXT("UE_LOG: %s\n"), *Msg);
}
LogCapture.CapturedErrors.Empty();
if (bSuggestHandlerHelp || (!SuggestedManualSections.IsEmpty()))
{
if (LastHandler) WingManual::PrintHandlerHelp(*LastHandler);
if ((LastHandler == nullptr) || (LastHandler->Name != TEXT("Documentation_Manual")))
{
WingOut::Stdout.Print(TEXT("To see manual: command=Documentation_Manual\n"));
}
if (!SuggestedManualSections.IsEmpty())
{
WingManual::PrintSectionNames(TEXT("Suggested manual sections: "),
SuggestedManualSections, WingOut::Stdout);
}
}
FString Result = WingOut::StdoutBuffer.ToString();
WingOut::StdoutBuffer.Reset();
return Result;
}
void UWingServer::TryCallHandler(TSharedPtr<FJsonObject> Request)
{
// Extract the command from the request.
FString Command;
if (!Request->TryGetStringField(TEXT("command"), Command))
{
WingOut::Stdout.Printf(TEXT("Request does not contain 'command' parameter"));
WingOut::Stdout.Printf(TEXT("We recommend sending command='Documentation_Manual'."));
return;
}
Request->RemoveField(TEXT("command"));
// Find the handler for the specified command.
FWingHandlerConfig* Found = FindHandler(Command);
if (!Found)
{
WingOut::Stdout.Printf(TEXT("Unknown command: %s\n"), *Command);
UWingServer::SuggestManual(GET_FUNCTION_NAME_CHECKED(UWingManualSections, ImportantCommands));
return;
}
LastHandler = Found;
// Make an object of the handler class.
TStrongObjectPtr<UObject> HandlerObj(NewObject<UObject>(GetTransientPackage(), Found->HandlerClass.Get()));
UWingHandler* Handler = Cast<UWingHandler>(HandlerObj.Get());
Handler->Configuration = Found;
// Populate the handler object with the request parameters.
TArray<FWingProperty> Props = FWingProperty::GetVisible(Handler, true);
if (!FWingProperty::PopulateFromJson(Props, *Request, false, WingOut::Stdout))
{
UWingServer::SuggestHandlerHelp();
return;
}
// MCP handlers must not run inside an undo transaction.
check(GUndo == nullptr);
// Invoke the handler.
Handler->Handle();
}
// ============================================================
// Connection Maintenance
// ============================================================
void UWingServer::AcceptNewConnections()
{
if (!ListenSocket) return;
bool bHasPending = false;
if (!ListenSocket->HasPendingConnection(bHasPending) || !bHasPending) return;
FSocket* ClientSocket = ListenSocket->Accept(TEXT("MCPClient"));
if (!ClientSocket) return;
ClientSocket->SetNonBlocking(false); // client threads use blocking I/O
TSharedPtr<FClientConnection> Client = MakeShared<FClientConnection>();
Client->Socket = ClientSocket;
Client->ThreadFuture = Async(EAsyncExecution::Thread, [this, Client]() { ClientThreadFunc(this, Client); });
Clients.Add(Client);
}
void UWingServer::CleanupFinishedClients()
{
ISocketSubsystem* SocketSub = ISocketSubsystem::Get(PLATFORM_SOCKETSUBSYSTEM);
for (int32 i = Clients.Num() - 1; i >= 0; --i)
{
if (!Clients[i]->bDone) continue;
Clients[i]->ThreadFuture.Wait();
if (Clients[i]->Socket)
{
SocketSub->DestroySocket(Clients[i]->Socket);
}
Clients.RemoveAt(i);
}
}
// ============================================================
// Stuff Performed on the Client Thread
// ============================================================
void UWingServer::ClientThreadFunc(UWingServer* Server, TSharedPtr<FClientConnection> Client)
{
constexpr int32 MaxRecvBufBytes = 1024 * 1024;
constexpr int32 MinUnusedRecvSpace = 4096;
FSocket* Socket = Client->Socket;
TArray<uint8> RecvBuf;
RecvBuf.SetNumUninitialized(MinUnusedRecvSpace);
int32 RecvLen = 0;
WaitForAssetRegistry();
while (true)
{
FString Request;
if (ExtractRequestFromBuffer(RecvBuf, RecvLen, Request))
{
FString Response;
if (!ProcessRequestOnGameThread(Request, Response))
{
Client->bDone = true;
return;
}
// Write the response back, null-terminated (blocking)
FTCHARToUTF8 Utf8(*Response);
if (!SendAll(Socket, reinterpret_cast<const uint8*>(Utf8.Get()),
Utf8.Length() + 1))
{
Client->bDone = true;
return;
}
continue;
}
if (!ReceiveMoreBytesIntoBuffer(Socket, RecvBuf, RecvLen))
{
break;
}
}
Client->bDone = true;
}
bool UWingServer::ExtractRequestFromBuffer(
TArray<uint8>& RecvBuf, int32& RecvLen, FString& OutRequest)
{
const uint8* EndOfRequest = static_cast<const uint8*>(
memchr(RecvBuf.GetData(), '\0', RecvLen));
if (EndOfRequest == nullptr)
{
return false;
}
const int32 MessageLen =
static_cast<int32>(EndOfRequest - RecvBuf.GetData());
OutRequest = FString::ConstructFromPtrSize(
reinterpret_cast<const UTF8CHAR*>(RecvBuf.GetData()), MessageLen);
const int32 RemainingBytes = RecvLen - (MessageLen + 1);
if (RemainingBytes > 0)
{
FMemory::Memmove(
RecvBuf.GetData(),
RecvBuf.GetData() + MessageLen + 1,
RemainingBytes);
}
RecvLen = RemainingBytes;
return true;
}
bool UWingServer::ReceiveMoreBytesIntoBuffer(
FSocket* Socket, TArray<uint8>& RecvBuf, int32& RecvLen)
{
constexpr int32 MaxRecvBufBytes = 1024 * 1024;
constexpr int32 MinUnusedRecvSpace = 4096;
int32 UnusedSpace = RecvBuf.Num() - RecvLen;
if (UnusedSpace < MinUnusedRecvSpace)
{
if (RecvBuf.Num() >= MaxRecvBufBytes)
{
return false;
}
RecvBuf.SetNumUninitialized(RecvBuf.Num() * 2);
UnusedSpace = RecvBuf.Num() - RecvLen;
}
int32 BytesRead = 0;
if (!Socket->Recv(RecvBuf.GetData() + RecvLen, UnusedSpace, BytesRead))
{
return false;
}
if (BytesRead <= 0)
{
return false;
}
RecvLen += BytesRead;
return true;
}
bool UWingServer::SendAll(FSocket* Socket, const uint8* Data, int32 BytesToSend)
{
while (BytesToSend > 0)
{
int32 BytesSent = 0;
if (!Socket->Send(Data, BytesToSend, BytesSent) || (BytesSent <= 0))
{
return false;
}
Data += BytesSent;
BytesToSend -= BytesSent;
}
return true;
}
bool UWingServer::ProcessRequestOnGameThread(
const FString& Request, FString& Response)
{
// Enqueue the message for game-thread processing.
TSharedPtr<UWingServer::FPendingMessage> Msg =
MakeShared<UWingServer::FPendingMessage>();
Msg->Line = Request;
TFuture<FString> Future = Msg->Response.GetFuture();
{
FScopeLock Lock(&GWingServer->Mutex);
if (GWingServer->bShuttingDown)
{
return false;
}
GWingServer->PendingMessages.Add(Msg);
}
// Block until the game thread processes this message.
Response = Future.Get();
return true;
}
void UWingServer::WaitForAssetRegistry()
{
IAssetRegistry& AR =
FModuleManager::LoadModuleChecked<FAssetRegistryModule>(
"AssetRegistry").Get();
while (AR.IsLoadingAssets()) FPlatformProcess::Sleep(0.25f);
}
// ============================================================
// BuildWingHandlerRegistry
// ============================================================
void UWingServer::AddHandler(UObject* Obj, const FString& Documentation)
{
AddHandler(Obj, WingUtils::GetHandlerName(Obj->GetClass()), nullptr, EWingHandlerKind::Normal, nullptr, Documentation);
}
void UWingServer::AddHandler(UObject* Obj, const FString& Name, UObject* Config, EWingHandlerKind Kind, UClass* FactoryClass, const FString& Documentation)
{
FWingHandlerConfig H;
H.Name = Name;
H.Documentation = Documentation;
H.HandlerClass = TStrongObjectPtr<UClass>(Obj->GetClass());
H.Config = TStrongObjectPtr<UObject>(Config);
H.FactoryClass = TStrongObjectPtr<UClass>(FactoryClass);
H.Kind = Kind;
GWingServer->WingHandlerRegistry.Add(MoveTemp(H));
}
void UWingServer::BuildWingHandlerRegistry()
{
WingHandlerRegistry.Empty();
for (UClass* Class : WingUtils::CollectHandlerClasses())
{
UWingHandler* CDO = Cast<UWingHandler>(Class->GetDefaultObject());
CDO->Register();
}
WingHandlerRegistry.Sort([](const FWingHandlerConfig& A, const FWingHandlerConfig& B) { return A.Name < B.Name; });
}
void UWingServer::OnModulesChanged(FName ModuleName, EModuleChangeReason Reason)
{
BuildWingHandlerRegistry();
}
FWingHandlerConfig* UWingServer::FindHandler(const FString& Name)
{
int32 Index = Algo::LowerBoundBy(WingHandlerRegistry, Name, [](const FWingHandlerConfig& H) { return H.Name; });
if (Index < WingHandlerRegistry.Num() && WingHandlerRegistry[Index].Name == Name)
{
return &WingHandlerRegistry[Index];
}
return nullptr;
}
TStringBuilder<65536> WingOut::StdoutBuffer;
WingOut WingOut::Stdout(&WingOut::StdoutBuffer);
WingOut WingOut::None(nullptr);