diff --git a/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs b/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs index ce5f643..002b58e 100644 --- a/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs +++ b/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs @@ -36,88 +36,90 @@ internal sealed class RpcServerRuntime<TClientListener, TServerListener, TReplyM var clients = new ConcurrentDictionary<ulong, Client>(); void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) { - if (!clients.Remove(e.RoutingId, out var client)) { - return; + if (clients.Remove(e.RoutingId, out var client)) { + client.Connection.Closed -= OnConnectionClosed; } - - RuntimeLogger.Debug("Closing connection to {RoutingId}.", e.RoutingId); - client.Connection.Closed -= OnConnectionClosed; - - taskManager.Run("Closing connection to " + e.RoutingId, async () => { - await client.StopReceiving(); - await client.StopProcessing(); - await client.Connection.StopSending(); - RuntimeLogger.Debug("Closed connection to {RoutingId}.", e.RoutingId); - }); } while (!cancellationToken.IsCancellationRequested) { var (routingId, data) = socket.Receive(cancellationToken); if (data.Length == 0) { - LogMessageType(routingId, data, messageType: null); + LogUnknownMessage(routingId, data); continue; } Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null; + if (messageType == null) { + LogUnknownMessage(routingId, data); + continue; + } if (!clients.TryGetValue(routingId, out var client)) { + if (!messageDefinitions.IsRegistrationMessage(messageType)) { + RuntimeLogger.Warning("Received {MessageType} ({Bytes} B) from unregistered client {RoutingId}.", messageType.Name, data.Length, routingId); + continue; + } + var clientLoggerName = LoggerName + ":" + routingId; var processingQueue = new RpcQueue(taskManager, "Process messages from " + routingId); var connection = new RpcConnectionToClient<TClientListener>(clientLoggerName, socket, routingId, messageDefinitions.ToClient, ReplyTracker); connection.Closed += OnConnectionClosed; - client = new Client(clientLoggerName, connection, processingQueue, messageDefinitions, listenerFactory(connection)); + client = new Client(clientLoggerName, connection, processingQueue, messageDefinitions, listenerFactory(connection), taskManager); clients[routingId] = client; + client.EnqueueRegistrationMessage(messageType, data); + } + else { + client.Enqueue(messageType, data); } - - LogMessageType(routingId, data, messageType); - client.Enqueue(messageType, data); } foreach (var client in clients.Values) { client.Connection.Close(); } - return Task.CompletedTask; + return taskManager.Stop(); + } + + private void LogUnknownMessage(uint routingId, ReadOnlyMemory<byte> data) { + RuntimeLogger.Warning("Received unknown message ({Bytes} B) from {RoutingId}.", data.Length, routingId); } private protected override Task Disconnect(ServerSocket socket) { return Task.CompletedTask; } - private void LogMessageType(uint routingId, ReadOnlyMemory<byte> data, Type? messageType) { - if (!RuntimeLogger.IsEnabled(LogEventLevel.Verbose)) { - return; - } - - if (data.Length > 0 && messageType != null) { - RuntimeLogger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", messageType.Name, data.Length, routingId); - } - else { - RuntimeLogger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId); - } - } - private sealed class Client : MessageHandler<TServerListener> { public RpcConnectionToClient<TClientListener> Connection { get; } private readonly RpcQueue processingQueue; private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; - - public Client(string loggerName, RpcConnectionToClient<TClientListener> connection, RpcQueue processingQueue, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TServerListener listener) : base(loggerName, listener) { + private readonly TaskManager taskManager; + + public Client(string loggerName, RpcConnectionToClient<TClientListener> connection, RpcQueue processingQueue, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TServerListener listener, TaskManager taskManager) : base(loggerName, listener) { this.Connection = connection; + this.Connection.Closed += OnConnectionClosed; + this.processingQueue = processingQueue; this.messageDefinitions = messageDefinitions; + this.taskManager = taskManager; } - internal void Enqueue(Type? messageType, ReadOnlyMemory<byte> data) { - if (!Connection.GetAuthorization().IsCompleted && messageType != null && messageDefinitions.IsRegistrationMessage(messageType)) { - processingQueue.Enqueue(() => Handle(data)); - } - else { - processingQueue.Enqueue(() => WaitForAuthorizationAndHandle(data)); + internal void EnqueueRegistrationMessage(Type messageType, ReadOnlyMemory<byte> data) { + LogMessageType(messageType, data); + processingQueue.Enqueue(() => Handle(data)); + } + + internal void Enqueue(Type messageType, ReadOnlyMemory<byte> data) { + LogMessageType(messageType, data); + processingQueue.Enqueue(() => WaitForAuthorizationAndHandle(data)); + } + + private void LogMessageType(Type messageType, ReadOnlyMemory<byte> data) { + if (Logger.IsEnabled(LogEventLevel.Verbose)) { + Logger.Verbose("Received {MessageType} ({Bytes} B).", messageType.Name, data.Length); } } @@ -138,8 +140,17 @@ internal sealed class RpcServerRuntime<TClientListener, TServerListener, TReplyM return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply)); } - internal Task StopProcessing() { - return processingQueue.Stop(); + private void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) { + Connection.Closed -= OnConnectionClosed; + + Logger.Debug("Closing connection..."); + + taskManager.Run("Closing connection to " + e.RoutingId, async () => { + await StopReceiving(); + await processingQueue.Stop(); + await Connection.StopSending(); + Logger.Debug("Connection closed."); + }); } } }