using NetMQ.Sockets; using Phantom.Utils.Rpc; using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Sockets; using Phantom.Utils.Tasks; using Serilog; using Serilog.Events; namespace Phantom.Controller.Rpc; public static class RpcRuntime { public static Task Launch<TClientListener, TServerListener, TReplyMessage>(RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory, CancellationToken cancellationToken) where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { return RpcRuntime<TClientListener, TServerListener, TReplyMessage>.Launch(config, messageDefinitions, listenerFactory, cancellationToken); } } internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage> : RpcRuntime<ServerSocket> where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { internal static Task Launch(RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory, CancellationToken cancellationToken) { var socket = RpcServerSocket.Connect(config); return new RpcRuntime<TClientListener, TServerListener, TReplyMessage>(socket, messageDefinitions, listenerFactory, cancellationToken).Launch(); } private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; private readonly Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory; private readonly CancellationToken cancellationToken; private RpcRuntime(RpcServerSocket socket, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory, CancellationToken cancellationToken) : base(socket) { this.messageDefinitions = messageDefinitions; this.listenerFactory = listenerFactory; this.cancellationToken = cancellationToken; } protected override void Run(ServerSocket socket, ILogger logger, MessageReplyTracker replyTracker, TaskManager taskManager) { var clients = new Dictionary<ulong, Client>(); void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) { clients.Remove(e.RoutingId); logger.Debug("Closed connection to {RoutingId}.", e.RoutingId); } while (!cancellationToken.IsCancellationRequested) { var (routingId, data) = socket.Receive(cancellationToken); if (data.Length == 0) { LogMessageType(logger, routingId, data, messageType: null); continue; } Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null; if (!clients.TryGetValue(routingId, out var client)) { if (!CheckIsRegistrationMessage(messageType, logger, routingId)) { continue; } var connection = new RpcConnectionToClient<TClientListener>(socket, routingId, messageDefinitions.ToClient, replyTracker); connection.Closed += OnConnectionClosed; client = new Client(connection, messageDefinitions, listenerFactory(connection), logger, taskManager, cancellationToken); clients[routingId] = client; } if (!client.Connection.IsAuthorized && !CheckIsRegistrationMessage(messageType, logger, routingId)) { continue; } LogMessageType(logger, routingId, data, messageType); messageDefinitions.ToServer.Handle(data, client); } foreach (var client in clients.Values) { client.Connection.Closed -= OnConnectionClosed; } } private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data, Type? messageType) { if (!logger.IsEnabled(LogEventLevel.Verbose)) { return; } if (data.Length > 0 && messageType != null) { logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", messageType.Name, data.Length, routingId); } else { logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId); } } private bool CheckIsRegistrationMessage(Type? messageType, ILogger logger, uint routingId) { if (messageType != null && messageDefinitions.IsRegistrationMessage(messageType)) { return true; } logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", messageType?.Name ?? "unknown message", routingId); return false; } private sealed class Client : MessageHandler<TServerListener> { public RpcConnectionToClient<TClientListener> Connection { get; } private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; public Client(RpcConnectionToClient<TClientListener> connection, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TServerListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) { this.Connection = connection; this.messageDefinitions = messageDefinitions; } protected override Task SendReply(uint sequenceId, byte[] serializedReply) { return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply)); } } }