diff --git a/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs b/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs index 6b06123..f558f3a 100644 --- a/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs +++ b/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs @@ -13,14 +13,10 @@ namespace Phantom.Agent.Rpc; public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> { public static Task Launch(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, AgentInfo agentInfo, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) { - return new RpcClientRuntime(socket, agentInfo.Guid, messageListener, disconnectSemaphore, receiveCancellationToken).Launch(); + return new RpcClientRuntime(socket, messageListener, disconnectSemaphore, receiveCancellationToken).Launch(); } - private readonly Guid agentGuid; - - private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, Guid agentGuid, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) { - this.agentGuid = agentGuid; - } + private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) {} protected override void RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToControllerListener> connection, ILogger logger, TaskManager taskManager) { var keepAliveLoop = new KeepAliveLoop(connection); @@ -32,7 +28,7 @@ public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener, } protected override async Task Disconnect(ClientSocket socket, ILogger logger) { - var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage(agentGuid)).ToArray(); + var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage()).ToArray(); try { await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None); } catch (TimeoutException) { diff --git a/Common/Phantom.Common.Messages.Agent/ToController/UnregisterAgentMessage.cs b/Common/Phantom.Common.Messages.Agent/ToController/UnregisterAgentMessage.cs index 2afc720..bdb8b98 100644 --- a/Common/Phantom.Common.Messages.Agent/ToController/UnregisterAgentMessage.cs +++ b/Common/Phantom.Common.Messages.Agent/ToController/UnregisterAgentMessage.cs @@ -4,9 +4,7 @@ using Phantom.Utils.Rpc.Message; namespace Phantom.Common.Messages.Agent.ToController; [MemoryPackable(GenerateType.VersionTolerant)] -public sealed partial record UnregisterAgentMessage( - [property: MemoryPackOrder(0)] Guid AgentGuid -) : IMessageToController { +public sealed partial record UnregisterAgentMessage : IMessageToController { public Task<NoReply> Accept(IMessageToControllerListener listener) { return listener.HandleUnregisterAgent(this); } diff --git a/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs b/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs index 30060c9..a56d167 100644 --- a/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs +++ b/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs @@ -11,6 +11,13 @@ public sealed class RpcConnectionToClient<TListener> { private readonly MessageRegistry<TListener> messageRegistry; private readonly MessageReplyTracker messageReplyTracker; + private volatile bool isAuthorized; + + public bool IsAuthorized { + get => isAuthorized; + set => isAuthorized = value; + } + internal event EventHandler<RpcClientConnectionClosedEventArgs>? Closed; private bool isClosed; diff --git a/Controller/Phantom.Controller.Rpc/RpcRuntime.cs b/Controller/Phantom.Controller.Rpc/RpcRuntime.cs index f2ebe4a..47e45b0 100644 --- a/Controller/Phantom.Controller.Rpc/RpcRuntime.cs +++ b/Controller/Phantom.Controller.Rpc/RpcRuntime.cs @@ -42,12 +42,14 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage var (routingId, data) = socket.Receive(cancellationToken); if (data.Length == 0) { - LogMessageType(logger, routingId, data); + LogMessageType(logger, routingId, data, messageType: null); continue; } + Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null; + if (!clients.TryGetValue(routingId, out var client)) { - if (!CheckIsRegistrationMessage(data, logger, routingId)) { + if (!CheckIsRegistrationMessage(messageType, logger, routingId)) { continue; } @@ -58,7 +60,11 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage clients[routingId] = client; } - LogMessageType(logger, routingId, data); + if (!client.Connection.IsAuthorized && !CheckIsRegistrationMessage(messageType, logger, routingId)) { + continue; + } + + LogMessageType(logger, routingId, data, messageType); messageDefinitions.ToServer.Handle(data, client); } @@ -67,25 +73,25 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage } } - private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data) { + private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data, Type? messageType) { if (!logger.IsEnabled(LogEventLevel.Verbose)) { return; } - if (data.Length > 0 && messageDefinitions.ToServer.TryGetType(data, out var type)) { - logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", type.Name, data.Length, routingId); + if (data.Length > 0 && messageType != null) { + logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", messageType.Name, data.Length, routingId); } else { logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId); } } - private bool CheckIsRegistrationMessage(ReadOnlyMemory<byte> data, ILogger logger, uint routingId) { - if (messageDefinitions.ToServer.TryGetType(data, out var type) && messageDefinitions.IsRegistrationMessage(type)) { + private bool CheckIsRegistrationMessage(Type? messageType, ILogger logger, uint routingId) { + if (messageType != null && messageDefinitions.IsRegistrationMessage(messageType)) { return true; } - logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", type?.Name ?? "unknown message", routingId); + logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", messageType?.Name ?? "unknown message", routingId); return false; } @@ -98,7 +104,7 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage this.Connection = connection; this.messageDefinitions = messageDefinitions; } - + protected override Task SendReply(uint sequenceId, byte[] serializedReply) { return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply)); } diff --git a/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs b/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs index 083cfad..d0b5ec6 100644 --- a/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs +++ b/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs @@ -39,8 +39,8 @@ public sealed class AgentMessageListener : IMessageToControllerListener { await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.ConnectionAlreadyHasAnAgent)); } else if (await agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, instanceManager, connection)) { - var guid = message.AgentInfo.Guid; - agentGuidWaiter.SetResult(guid); + connection.IsAuthorized = true; + agentGuidWaiter.SetResult(message.AgentInfo.Guid); } return NoReply.Instance; @@ -51,8 +51,11 @@ public sealed class AgentMessageListener : IMessageToControllerListener { } public Task<NoReply> HandleUnregisterAgent(UnregisterAgentMessage message) { - if (agentManager.UnregisterAgent(message.AgentGuid, connection)) { - instanceManager.SetInstanceStatesForAgent(message.AgentGuid, InstanceStatus.Offline); + if (agentGuidWaiter.Task.IsCompleted) { + var agentGuid = agentGuidWaiter.Task.Result; + if (agentManager.UnregisterAgent(agentGuid, connection)) { + instanceManager.SetInstanceStatesForAgent(agentGuid, InstanceStatus.Offline); + } } connection.Close(); diff --git a/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs b/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs index ba77853..cc770e3 100644 --- a/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs +++ b/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs @@ -4,14 +4,13 @@ using Serilog; namespace Phantom.Utils.Rpc.Message; public abstract class MessageHandler<TListener> { - protected TListener Listener { get; } - + private readonly TListener listener; private readonly ILogger logger; private readonly TaskManager taskManager; private readonly CancellationToken cancellationToken; protected MessageHandler(TListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) { - this.Listener = listener; + this.listener = listener; this.logger = logger; this.taskManager = taskManager; this.cancellationToken = cancellationToken; @@ -29,12 +28,11 @@ public abstract class MessageHandler<TListener> { } private async Task Handle<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : IMessage<TListener, TReply> { - TReply reply = await message.Accept(Listener); - + TReply reply = await message.Accept(listener); if (reply is not NoReply) { await SendReply(sequenceId, MessageSerializer.Serialize(reply)); } } - + protected abstract Task SendReply(uint sequenceId, byte[] serializedReply); }