mirror of
https://github.com/chylex/Minecraft-Phantom-Panel.git
synced 2025-04-23 13:15:46 +02:00
Fix Controller handling messages from unauthorized Agents
This commit is contained in:
parent
2a9bb9e6ac
commit
cd332a6571
Agent/Phantom.Agent.Rpc
Common/Phantom.Common.Messages.Agent/ToController
Controller
Phantom.Controller.Rpc
Phantom.Controller.Services/Rpc
Utils/Phantom.Utils.Rpc/Message
@ -13,14 +13,10 @@ namespace Phantom.Agent.Rpc;
|
||||
|
||||
public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> {
|
||||
public static Task Launch(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, AgentInfo agentInfo, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) {
|
||||
return new RpcClientRuntime(socket, agentInfo.Guid, messageListener, disconnectSemaphore, receiveCancellationToken).Launch();
|
||||
return new RpcClientRuntime(socket, messageListener, disconnectSemaphore, receiveCancellationToken).Launch();
|
||||
}
|
||||
|
||||
private readonly Guid agentGuid;
|
||||
|
||||
private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, Guid agentGuid, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) {
|
||||
this.agentGuid = agentGuid;
|
||||
}
|
||||
private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) {}
|
||||
|
||||
protected override void RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToControllerListener> connection, ILogger logger, TaskManager taskManager) {
|
||||
var keepAliveLoop = new KeepAliveLoop(connection);
|
||||
@ -32,7 +28,7 @@ public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener,
|
||||
}
|
||||
|
||||
protected override async Task Disconnect(ClientSocket socket, ILogger logger) {
|
||||
var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage(agentGuid)).ToArray();
|
||||
var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage()).ToArray();
|
||||
try {
|
||||
await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None);
|
||||
} catch (TimeoutException) {
|
||||
|
@ -4,9 +4,7 @@ using Phantom.Utils.Rpc.Message;
|
||||
namespace Phantom.Common.Messages.Agent.ToController;
|
||||
|
||||
[MemoryPackable(GenerateType.VersionTolerant)]
|
||||
public sealed partial record UnregisterAgentMessage(
|
||||
[property: MemoryPackOrder(0)] Guid AgentGuid
|
||||
) : IMessageToController {
|
||||
public sealed partial record UnregisterAgentMessage : IMessageToController {
|
||||
public Task<NoReply> Accept(IMessageToControllerListener listener) {
|
||||
return listener.HandleUnregisterAgent(this);
|
||||
}
|
||||
|
@ -11,6 +11,13 @@ public sealed class RpcConnectionToClient<TListener> {
|
||||
private readonly MessageRegistry<TListener> messageRegistry;
|
||||
private readonly MessageReplyTracker messageReplyTracker;
|
||||
|
||||
private volatile bool isAuthorized;
|
||||
|
||||
public bool IsAuthorized {
|
||||
get => isAuthorized;
|
||||
set => isAuthorized = value;
|
||||
}
|
||||
|
||||
internal event EventHandler<RpcClientConnectionClosedEventArgs>? Closed;
|
||||
private bool isClosed;
|
||||
|
||||
|
@ -42,12 +42,14 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
|
||||
var (routingId, data) = socket.Receive(cancellationToken);
|
||||
|
||||
if (data.Length == 0) {
|
||||
LogMessageType(logger, routingId, data);
|
||||
LogMessageType(logger, routingId, data, messageType: null);
|
||||
continue;
|
||||
}
|
||||
|
||||
Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null;
|
||||
|
||||
if (!clients.TryGetValue(routingId, out var client)) {
|
||||
if (!CheckIsRegistrationMessage(data, logger, routingId)) {
|
||||
if (!CheckIsRegistrationMessage(messageType, logger, routingId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -58,7 +60,11 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
|
||||
clients[routingId] = client;
|
||||
}
|
||||
|
||||
LogMessageType(logger, routingId, data);
|
||||
if (!client.Connection.IsAuthorized && !CheckIsRegistrationMessage(messageType, logger, routingId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LogMessageType(logger, routingId, data, messageType);
|
||||
messageDefinitions.ToServer.Handle(data, client);
|
||||
}
|
||||
|
||||
@ -67,25 +73,25 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
|
||||
}
|
||||
}
|
||||
|
||||
private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data) {
|
||||
private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data, Type? messageType) {
|
||||
if (!logger.IsEnabled(LogEventLevel.Verbose)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.Length > 0 && messageDefinitions.ToServer.TryGetType(data, out var type)) {
|
||||
logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", type.Name, data.Length, routingId);
|
||||
if (data.Length > 0 && messageType != null) {
|
||||
logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", messageType.Name, data.Length, routingId);
|
||||
}
|
||||
else {
|
||||
logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId);
|
||||
}
|
||||
}
|
||||
|
||||
private bool CheckIsRegistrationMessage(ReadOnlyMemory<byte> data, ILogger logger, uint routingId) {
|
||||
if (messageDefinitions.ToServer.TryGetType(data, out var type) && messageDefinitions.IsRegistrationMessage(type)) {
|
||||
private bool CheckIsRegistrationMessage(Type? messageType, ILogger logger, uint routingId) {
|
||||
if (messageType != null && messageDefinitions.IsRegistrationMessage(messageType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", type?.Name ?? "unknown message", routingId);
|
||||
logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", messageType?.Name ?? "unknown message", routingId);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -98,7 +104,7 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
|
||||
this.Connection = connection;
|
||||
this.messageDefinitions = messageDefinitions;
|
||||
}
|
||||
|
||||
|
||||
protected override Task SendReply(uint sequenceId, byte[] serializedReply) {
|
||||
return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply));
|
||||
}
|
||||
|
@ -39,8 +39,8 @@ public sealed class AgentMessageListener : IMessageToControllerListener {
|
||||
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.ConnectionAlreadyHasAnAgent));
|
||||
}
|
||||
else if (await agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, instanceManager, connection)) {
|
||||
var guid = message.AgentInfo.Guid;
|
||||
agentGuidWaiter.SetResult(guid);
|
||||
connection.IsAuthorized = true;
|
||||
agentGuidWaiter.SetResult(message.AgentInfo.Guid);
|
||||
}
|
||||
|
||||
return NoReply.Instance;
|
||||
@ -51,8 +51,11 @@ public sealed class AgentMessageListener : IMessageToControllerListener {
|
||||
}
|
||||
|
||||
public Task<NoReply> HandleUnregisterAgent(UnregisterAgentMessage message) {
|
||||
if (agentManager.UnregisterAgent(message.AgentGuid, connection)) {
|
||||
instanceManager.SetInstanceStatesForAgent(message.AgentGuid, InstanceStatus.Offline);
|
||||
if (agentGuidWaiter.Task.IsCompleted) {
|
||||
var agentGuid = agentGuidWaiter.Task.Result;
|
||||
if (agentManager.UnregisterAgent(agentGuid, connection)) {
|
||||
instanceManager.SetInstanceStatesForAgent(agentGuid, InstanceStatus.Offline);
|
||||
}
|
||||
}
|
||||
|
||||
connection.Close();
|
||||
|
@ -4,14 +4,13 @@ using Serilog;
|
||||
namespace Phantom.Utils.Rpc.Message;
|
||||
|
||||
public abstract class MessageHandler<TListener> {
|
||||
protected TListener Listener { get; }
|
||||
|
||||
private readonly TListener listener;
|
||||
private readonly ILogger logger;
|
||||
private readonly TaskManager taskManager;
|
||||
private readonly CancellationToken cancellationToken;
|
||||
|
||||
protected MessageHandler(TListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) {
|
||||
this.Listener = listener;
|
||||
this.listener = listener;
|
||||
this.logger = logger;
|
||||
this.taskManager = taskManager;
|
||||
this.cancellationToken = cancellationToken;
|
||||
@ -29,12 +28,11 @@ public abstract class MessageHandler<TListener> {
|
||||
}
|
||||
|
||||
private async Task Handle<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : IMessage<TListener, TReply> {
|
||||
TReply reply = await message.Accept(Listener);
|
||||
|
||||
TReply reply = await message.Accept(listener);
|
||||
if (reply is not NoReply) {
|
||||
await SendReply(sequenceId, MessageSerializer.Serialize(reply));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
protected abstract Task SendReply(uint sequenceId, byte[] serializedReply);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user