diff --git a/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs b/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs index 462a5d8..a314a46 100644 --- a/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs +++ b/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs @@ -26,7 +26,7 @@ sealed class KeepAliveLoop { try { while (true) { await Task.Delay(KeepAliveInterval, cancellationToken); - await connection.Send(new AgentIsAliveMessage()); + await connection.Send(new AgentIsAliveMessage()).WaitAsync(cancellationToken); } } catch (OperationCanceledException) { // Ignore. diff --git a/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs b/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs index f56d48f..fe9141a 100644 --- a/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs +++ b/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs @@ -1,33 +1,31 @@ using NetMQ; using NetMQ.Sockets; -using Phantom.Common.Data.Agent; using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent.BiDirectional; using Phantom.Common.Messages.Agent.ToController; using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Sockets; -using Phantom.Utils.Tasks; using Serilog; namespace Phantom.Agent.Rpc; public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> { - public static Task Launch(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, AgentInfo agentInfo, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) { + public static Task Launch(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) { return new RpcClientRuntime(socket, messageListener, disconnectSemaphore, receiveCancellationToken).Launch(); } private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) {} - protected override void RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToControllerListener> connection, ILogger logger, TaskManager taskManager) { + protected override async Task RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToControllerListener> connection) { var keepAliveLoop = new KeepAliveLoop(connection); try { - base.RunWithConnection(socket, connection, logger, taskManager); + await base.RunWithConnection(socket, connection); } finally { keepAliveLoop.Cancel(); } } - protected override async Task Disconnect(ClientSocket socket, ILogger logger) { + protected override async Task SendDisconnectMessage(ClientSocket socket, ILogger logger) { var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage()).ToArray(); try { await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None); diff --git a/Agent/Phantom.Agent.Services/Instances/Instance.cs b/Agent/Phantom.Agent.Services/Instances/Instance.cs index 3771d8d..8977b1e 100644 --- a/Agent/Phantom.Agent.Services/Instances/Instance.cs +++ b/Agent/Phantom.Agent.Services/Instances/Instance.cs @@ -21,8 +21,7 @@ sealed class Instance : IAsyncDisposable { private readonly ILogger logger; private IInstanceStatus currentStatus; - private int statusUpdateCounter; - + private IInstanceState currentState; public bool IsRunning => currentState is not InstanceNotRunningState; @@ -38,40 +37,23 @@ sealed class Instance : IAsyncDisposable { this.Configuration = configuration; this.Launcher = launcher; - this.currentState = new InstanceNotRunningState(); this.currentStatus = InstanceStatus.NotRunning; + this.currentState = new InstanceNotRunningState(); this.procedureManager = new InstanceProcedureManager(this, new Context(this), services.TaskManager); } - private void TryUpdateStatus(string taskName, Func<Task> getUpdateTask) { - int myStatusUpdateCounter = Interlocked.Increment(ref statusUpdateCounter); - - Services.TaskManager.Run(taskName, async () => { - if (myStatusUpdateCounter == statusUpdateCounter) { - await getUpdateTask(); - } - }); - } - public void ReportLastStatus() { - TryUpdateStatus("Report last status of instance " + shortName, async () => { - await Services.ControllerConnection.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, currentStatus)); - }); + Services.ControllerConnection.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, currentStatus)); } private void ReportAndSetStatus(IInstanceStatus status) { - TryUpdateStatus("Report status of instance " + shortName + " as " + status.GetType().Name, async () => { - if (status != currentStatus) { - currentStatus = status; - await Services.ControllerConnection.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, status)); - } - }); + currentStatus = status; + Services.ControllerConnection.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, status)); } private void ReportEvent(IInstanceEvent instanceEvent) { - var message = new ReportInstanceEventMessage(Guid.NewGuid(), DateTime.UtcNow, Configuration.InstanceGuid, instanceEvent); - Services.TaskManager.Run("Report event for instance " + shortName, async () => await Services.ControllerConnection.Send(message)); + Services.ControllerConnection.Send(new ReportInstanceEventMessage(Guid.NewGuid(), DateTime.UtcNow, Configuration.InstanceGuid, instanceEvent)); } internal void TransitionState(IInstanceState newState) { diff --git a/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs b/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs index 7e4bad8..3702416 100644 --- a/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs +++ b/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs @@ -36,14 +36,14 @@ sealed class InstanceLogSender : CancellableBackgroundTask { try { while (await lineReader.WaitToReadAsync(CancellationToken)) { await Task.Delay(SendDelay, CancellationToken); - await SendOutputToServer(ReadLinesFromChannel(lineReader, lineBuilder)); + SendOutputToServer(ReadLinesFromChannel(lineReader, lineBuilder)); } } catch (OperationCanceledException) { // Ignore. } // Flush remaining lines. - await SendOutputToServer(ReadLinesFromChannel(lineReader, lineBuilder)); + SendOutputToServer(ReadLinesFromChannel(lineReader, lineBuilder)); } private ImmutableArray<string> ReadLinesFromChannel(ChannelReader<string> reader, ImmutableArray<string>.Builder builder) { @@ -61,9 +61,9 @@ sealed class InstanceLogSender : CancellableBackgroundTask { return builder.ToImmutable(); } - private async Task SendOutputToServer(ImmutableArray<string> lines) { + private void SendOutputToServer(ImmutableArray<string> lines) { if (!lines.IsEmpty) { - await controllerConnection.Send(new InstanceOutputMessage(instanceGuid, lines)); + controllerConnection.Send(new InstanceOutputMessage(instanceGuid, lines)); } } diff --git a/Agent/Phantom.Agent/Program.cs b/Agent/Phantom.Agent/Program.cs index f7f2cba..8fa74ab 100644 --- a/Agent/Phantom.Agent/Program.cs +++ b/Agent/Phantom.Agent/Program.cs @@ -11,7 +11,6 @@ using Phantom.Utils.Logging; using Phantom.Utils.Rpc; using Phantom.Utils.Rpc.Sockets; using Phantom.Utils.Runtime; -using Phantom.Utils.Tasks; const int ProtocolVersion = 1; @@ -52,7 +51,7 @@ try { PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent..."); - var rpcConfiguration = new RpcConfiguration(PhantomLogger.Create("Rpc"), PhantomLogger.Create<TaskManager>("Rpc"), controllerHost, controllerPort, controllerCertificate); + var rpcConfiguration = new RpcConfiguration("Rpc", controllerHost, controllerPort, controllerCertificate); var rpcSocket = RpcClientSocket.Connect(rpcConfiguration, AgentMessageRegistries.Definitions, new RegisterAgentMessage(agentToken, agentInfo)); var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcSocket.Connection)); @@ -60,7 +59,7 @@ try { var rpcDisconnectSemaphore = new SemaphoreSlim(0, 1); var rpcMessageListener = new MessageListener(rpcSocket.Connection, agentServices, shutdownCancellationTokenSource); - var rpcTask = RpcClientRuntime.Launch(rpcSocket, agentInfo, rpcMessageListener, rpcDisconnectSemaphore, shutdownCancellationToken); + var rpcTask = RpcClientRuntime.Launch(rpcSocket, rpcMessageListener, rpcDisconnectSemaphore, shutdownCancellationToken); try { await rpcTask.WaitAsync(shutdownCancellationToken); } finally { diff --git a/Common/Phantom.Common.Messages.Agent/BiDirectional/ReplyMessage.cs b/Common/Phantom.Common.Messages.Agent/BiDirectional/ReplyMessage.cs index 99e782b..b24c9c6 100644 --- a/Common/Phantom.Common.Messages.Agent/BiDirectional/ReplyMessage.cs +++ b/Common/Phantom.Common.Messages.Agent/BiDirectional/ReplyMessage.cs @@ -8,6 +8,11 @@ public sealed partial record ReplyMessage( [property: MemoryPackOrder(0)] uint SequenceId, [property: MemoryPackOrder(1)] byte[] SerializedReply ) : IMessageToController, IMessageToAgent, IReply { + private static readonly MessageQueueKey MessageQueueKey = new ("Reply"); + + [MemoryPackIgnore] + public MessageQueueKey QueueKey => MessageQueueKey; + public Task<NoReply> Accept(IMessageToControllerListener listener) { return listener.HandleReply(this); } diff --git a/Common/Phantom.Common.Messages.Agent/IMessageToAgent.cs b/Common/Phantom.Common.Messages.Agent/IMessageToAgent.cs index b38e13e..07ef707 100644 --- a/Common/Phantom.Common.Messages.Agent/IMessageToAgent.cs +++ b/Common/Phantom.Common.Messages.Agent/IMessageToAgent.cs @@ -2,6 +2,11 @@ namespace Phantom.Common.Messages.Agent; -public interface IMessageToAgent<TReply> : IMessage<IMessageToAgentListener, TReply> {} +public interface IMessageToAgent<TReply> : IMessage<IMessageToAgentListener, TReply> { + MessageQueueKey IMessage<IMessageToAgentListener, TReply>.QueueKey => IMessageToAgent.DefaultQueueKey; +} -public interface IMessageToAgent : IMessageToAgent<NoReply> {} +public interface IMessageToAgent : IMessageToAgent<NoReply> { + internal static readonly MessageQueueKey DefaultQueueKey = new ("Agent.Default"); + MessageQueueKey IMessage<IMessageToAgentListener, NoReply>.QueueKey => DefaultQueueKey; +} diff --git a/Common/Phantom.Common.Messages.Agent/IMessageToController.cs b/Common/Phantom.Common.Messages.Agent/IMessageToController.cs index 75889d1..a9bff1d 100644 --- a/Common/Phantom.Common.Messages.Agent/IMessageToController.cs +++ b/Common/Phantom.Common.Messages.Agent/IMessageToController.cs @@ -2,6 +2,11 @@ namespace Phantom.Common.Messages.Agent; -public interface IMessageToController<TReply> : IMessage<IMessageToControllerListener, TReply> {} +public interface IMessageToController<TReply> : IMessage<IMessageToControllerListener, TReply> { + MessageQueueKey IMessage<IMessageToControllerListener, TReply>.QueueKey => IMessageToController.DefaultQueueKey; +} -public interface IMessageToController : IMessageToController<NoReply> {} +public interface IMessageToController : IMessageToController<NoReply> { + internal static readonly MessageQueueKey DefaultQueueKey = new ("Agent.Default"); + MessageQueueKey IMessage<IMessageToControllerListener, NoReply>.QueueKey => DefaultQueueKey; +} diff --git a/Common/Phantom.Common.Messages.Agent/ToController/InstanceOutputMessage.cs b/Common/Phantom.Common.Messages.Agent/ToController/InstanceOutputMessage.cs index 81be51f..424288b 100644 --- a/Common/Phantom.Common.Messages.Agent/ToController/InstanceOutputMessage.cs +++ b/Common/Phantom.Common.Messages.Agent/ToController/InstanceOutputMessage.cs @@ -9,6 +9,11 @@ public sealed partial record InstanceOutputMessage( [property: MemoryPackOrder(0)] Guid InstanceGuid, [property: MemoryPackOrder(1)] ImmutableArray<string> Lines ) : IMessageToController { + private static readonly MessageQueueKey MessageQueueKey = new ("Agent.InstanceOutput"); + + [MemoryPackIgnore] + public MessageQueueKey QueueKey => MessageQueueKey; + public Task<NoReply> Accept(IMessageToControllerListener listener) { return listener.HandleInstanceOutput(this); } diff --git a/Common/Phantom.Common.Messages.Web/IMessageToController.cs b/Common/Phantom.Common.Messages.Web/IMessageToController.cs index cbefc1a..73ef744 100644 --- a/Common/Phantom.Common.Messages.Web/IMessageToController.cs +++ b/Common/Phantom.Common.Messages.Web/IMessageToController.cs @@ -2,6 +2,11 @@ namespace Phantom.Common.Messages.Web; -public interface IMessageToController<TReply> : IMessage<IMessageToControllerListener, TReply> {} +public interface IMessageToController<TReply> : IMessage<IMessageToControllerListener, TReply> { + MessageQueueKey IMessage<IMessageToControllerListener, TReply>.QueueKey => IMessageToController.DefaultQueueKey; +} -public interface IMessageToController : IMessageToController<NoReply> {} +public interface IMessageToController : IMessageToController<NoReply> { + internal static readonly MessageQueueKey DefaultQueueKey = new ("Web.Default"); + MessageQueueKey IMessage<IMessageToControllerListener, NoReply>.QueueKey => DefaultQueueKey; +} diff --git a/Common/Phantom.Common.Messages.Web/IMessageToWeb.cs b/Common/Phantom.Common.Messages.Web/IMessageToWeb.cs index d26640b..e7b95ce 100644 --- a/Common/Phantom.Common.Messages.Web/IMessageToWeb.cs +++ b/Common/Phantom.Common.Messages.Web/IMessageToWeb.cs @@ -2,6 +2,11 @@ namespace Phantom.Common.Messages.Web; -public interface IMessageToWeb<TReply> : IMessage<IMessageToWebListener, TReply> {} +public interface IMessageToWeb<TReply> : IMessage<IMessageToWebListener, TReply> { + MessageQueueKey IMessage<IMessageToWebListener, TReply>.QueueKey => IMessageToWeb.DefaultQueueKey; +} -public interface IMessageToWeb : IMessageToWeb<NoReply> {} +public interface IMessageToWeb : IMessageToWeb<NoReply> { + internal static readonly MessageQueueKey DefaultQueueKey = new ("Web.Default"); + MessageQueueKey IMessage<IMessageToWebListener, NoReply>.QueueKey => DefaultQueueKey; +} diff --git a/Common/Phantom.Common.Messages.Web/ToWeb/InstanceOutputMessage.cs b/Common/Phantom.Common.Messages.Web/ToWeb/InstanceOutputMessage.cs index 2e7cd19..5a32d72 100644 --- a/Common/Phantom.Common.Messages.Web/ToWeb/InstanceOutputMessage.cs +++ b/Common/Phantom.Common.Messages.Web/ToWeb/InstanceOutputMessage.cs @@ -9,6 +9,11 @@ public sealed partial record InstanceOutputMessage( [property: MemoryPackOrder(0)] Guid InstanceGuid, [property: MemoryPackOrder(1)] ImmutableArray<string> Lines ) : IMessageToWeb { + private static readonly MessageQueueKey MessageQueueKey = new ("Web.InstanceOutput"); + + [MemoryPackIgnore] + public MessageQueueKey QueueKey => MessageQueueKey; + public Task<NoReply> Accept(IMessageToWebListener listener) { return listener.HandleInstanceOutput(this); } diff --git a/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs b/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs index c198313..aabf20e 100644 --- a/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs +++ b/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs @@ -22,7 +22,7 @@ sealed class AgentConnection { return connection.Send(message); } - public Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToAgent<TReply> where TReply : class { + public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToAgent<TReply> where TReply : class { return connection.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken); } } diff --git a/Controller/Phantom.Controller.Services/Agents/AgentManager.cs b/Controller/Phantom.Controller.Services/Agents/AgentManager.cs index 7adb181..f373902 100644 --- a/Controller/Phantom.Controller.Services/Agents/AgentManager.cs +++ b/Controller/Phantom.Controller.Services/Agents/AgentManager.cs @@ -126,12 +126,17 @@ sealed class AgentManager { internal async Task<TReply?> SendMessage<TMessage, TReply>(Guid guid, TMessage message, TimeSpan waitForReplyTime) where TMessage : IMessageToAgent<TReply> where TReply : class { var connection = agents.ByGuid.TryGetValue(guid, out var agent) ? agent.Connection : null; - if (connection == null) { + if (connection == null || agent == null) { // TODO handle missing agent? return null; } - return await connection.Send<TMessage, TReply>(message, waitForReplyTime, cancellationToken); + try { + return await connection.Send<TMessage, TReply>(message, waitForReplyTime, cancellationToken); + } catch (Exception e) { + Logger.Error(e, "Could not send message to agent \"{Name}\" (GUID {Guid}).", agent.Name, agent.Guid); + return null; + } } private sealed class ObservableAgents : ObservableState<ImmutableArray<Agent>> { diff --git a/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs b/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs index 23d639e..43c14d3 100644 --- a/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs +++ b/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs @@ -36,10 +36,11 @@ public sealed class AgentMessageListener : IMessageToControllerListener { public async Task<NoReply> HandleRegisterAgent(RegisterAgentMessage message) { if (agentGuidWaiter.Task.IsCompleted && agentGuidWaiter.Task.Result != message.AgentInfo.Guid) { + connection.SetAuthorizationResult(false); await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.ConnectionAlreadyHasAnAgent)); } else if (await agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, instanceManager, connection)) { - connection.IsAuthorized = true; + connection.SetAuthorizationResult(true); agentGuidWaiter.SetResult(message.AgentInfo.Guid); } diff --git a/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs b/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs index c7d13d5..d2b3ad9 100644 --- a/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs +++ b/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs @@ -108,11 +108,12 @@ public sealed class WebMessageListener : IMessageToControllerListener { public async Task<NoReply> HandleRegisterWeb(RegisterWebMessage message) { if (authToken.FixedTimeEquals(message.AuthToken)) { Logger.Information("Web authorized successfully."); - connection.IsAuthorized = true; + connection.SetAuthorizationResult(true); await connection.Send(new RegisterWebResultMessage(true)); } else { Logger.Warning("Web failed to authorize, invalid token."); + connection.SetAuthorizationResult(false); await connection.Send(new RegisterWebResultMessage(false)); } diff --git a/Controller/Phantom.Controller/Program.cs b/Controller/Phantom.Controller/Program.cs index 850ae98..429b8d8 100644 --- a/Controller/Phantom.Controller/Program.cs +++ b/Controller/Phantom.Controller/Program.cs @@ -59,15 +59,17 @@ try { await controllerServices.Initialize(); static RpcConfiguration ConfigureRpc(string serviceName, string host, ushort port, ConnectionKeyData connectionKey) { - return new RpcConfiguration(PhantomLogger.Create("Rpc", serviceName), PhantomLogger.Create<TaskManager>("Rpc", serviceName), host, port, connectionKey.Certificate); + return new RpcConfiguration("Rpc:" + serviceName, host, port, connectionKey.Certificate); } + var rpcTaskManager = new TaskManager(PhantomLogger.Create<TaskManager>("Rpc")); try { await Task.WhenAll( RpcServerRuntime.Launch(ConfigureRpc("Agent", agentRpcServerHost, agentRpcServerPort, agentKeyData), AgentMessageRegistries.Definitions, controllerServices.CreateAgentMessageListener, shutdownCancellationToken), RpcServerRuntime.Launch(ConfigureRpc("Web", webRpcServerHost, webRpcServerPort, webKeyData), WebMessageRegistries.Definitions, controllerServices.CreateWebMessageListener, shutdownCancellationToken) ); } finally { + await rpcTaskManager.Stop(); NetMQConfig.Cleanup(); } diff --git a/Utils/Phantom.Utils.Rpc/Message/IMessage.cs b/Utils/Phantom.Utils.Rpc/Message/IMessage.cs index e8d25e0..8d6532e 100644 --- a/Utils/Phantom.Utils.Rpc/Message/IMessage.cs +++ b/Utils/Phantom.Utils.Rpc/Message/IMessage.cs @@ -1,5 +1,6 @@ namespace Phantom.Utils.Rpc.Message; public interface IMessage<TListener, TReply> { + MessageQueueKey QueueKey { get; } Task<TReply> Accept(TListener listener); } diff --git a/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs b/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs index 92e8df7..303ee71 100644 --- a/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs +++ b/Utils/Phantom.Utils.Rpc/Message/MessageHandler.cs @@ -1,38 +1,41 @@ -using Phantom.Utils.Tasks; +using Phantom.Utils.Logging; using Serilog; namespace Phantom.Utils.Rpc.Message; abstract class MessageHandler<TListener> { + protected ILogger Logger { get; } + private readonly TListener listener; - private readonly ILogger logger; - private readonly TaskManager taskManager; - private readonly CancellationToken cancellationToken; + private readonly MessageQueues messageQueues; - protected MessageHandler(TListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) { + protected MessageHandler(string loggerName, TListener listener) { + this.Logger = PhantomLogger.Create("MessageHandler", loggerName); this.listener = listener; - this.logger = logger; - this.taskManager = taskManager; - this.cancellationToken = cancellationToken; + this.messageQueues = new MessageQueues(loggerName + ":Receive"); } internal void Enqueue<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : IMessage<TListener, TReply> { - cancellationToken.ThrowIfCancellationRequested(); - taskManager.Run("Handle message " + message.GetType().Name, async () => { - try { - await Handle<TMessage, TReply>(sequenceId, message); - } catch (Exception e) { - logger.Error(e, "Failed to handle message {Type}.", message.GetType().Name); - } - }); + messageQueues.Enqueue(message.QueueKey, () => TryHandle<TMessage, TReply>(sequenceId, message)); } - private async Task Handle<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : IMessage<TListener, TReply> { - TReply reply = await message.Accept(listener); + private async Task TryHandle<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : IMessage<TListener, TReply> { + TReply reply; + try { + reply = await message.Accept(listener); + } catch (Exception e) { + Logger.Error(e, "Failed to handle message {Type}.", message.GetType().Name); + return; + } + if (reply is not NoReply) { await SendReply(sequenceId, MessageSerializer.Serialize(reply)); } } protected abstract Task SendReply(uint sequenceId, byte[] serializedReply); + + internal Task StopReceiving() { + return messageQueues.Stop(); + } } diff --git a/Utils/Phantom.Utils.Rpc/Message/MessageQueueKey.cs b/Utils/Phantom.Utils.Rpc/Message/MessageQueueKey.cs new file mode 100644 index 0000000..397e5c2 --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/Message/MessageQueueKey.cs @@ -0,0 +1,9 @@ +namespace Phantom.Utils.Rpc.Message; + +public sealed class MessageQueueKey { + public string Name { get; } + + public MessageQueueKey(string name) { + Name = name; + } +} diff --git a/Utils/Phantom.Utils.Rpc/Message/MessageQueues.cs b/Utils/Phantom.Utils.Rpc/Message/MessageQueues.cs new file mode 100644 index 0000000..c43e1fd --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/Message/MessageQueues.cs @@ -0,0 +1,53 @@ +using Phantom.Utils.Logging; +using Phantom.Utils.Tasks; +using Serilog; + +namespace Phantom.Utils.Rpc.Message; + +sealed class MessageQueues { + private readonly ILogger logger; + private readonly TaskManager taskManager; + private readonly Dictionary<MessageQueueKey, RpcQueue> queues = new (); + + private Task? stopTask; + + public MessageQueues(string loggerName) { + this.logger = PhantomLogger.Create<MessageQueues>(loggerName); + this.taskManager = new TaskManager(PhantomLogger.Create<TaskManager>(loggerName)); + } + + private RpcQueue GetOrCreateQueue(MessageQueueKey key) { + if (!queues.TryGetValue(key, out var queue)) { + queues[key] = queue = new RpcQueue(taskManager, "Message queue for " + key.Name); + } + + return queue; + } + + public Task Enqueue(MessageQueueKey key, Func<Task> task) { + lock (this) { + return stopTask == null ? GetOrCreateQueue(key).Enqueue(task) : Task.FromException(new OperationCanceledException()); + } + } + + public Task<T> Enqueue<T>(MessageQueueKey key, Func<Task<T>> task) { + lock (this) { + return stopTask == null ? GetOrCreateQueue(key).Enqueue(task) : Task.FromException<T>(new OperationCanceledException()); + } + } + + internal Task Stop() { + lock (this) { + if (stopTask == null) { + logger.Debug("Stopping " + queues.Count + " message queue(s)..."); + + stopTask = Task.WhenAll(queues.Values.Select(static queue => queue.Stop())) + .ContinueWith(_ => logger.Debug("All queues stopped.")); + + queues.Clear(); + } + + return stopTask; + } + } +} diff --git a/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs b/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs index 47e6d6e..bddd6d8 100644 --- a/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs +++ b/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using Phantom.Utils.Logging; using Phantom.Utils.Tasks; using Serilog; @@ -10,8 +11,8 @@ sealed class MessageReplyTracker { private uint lastSequenceId; - internal MessageReplyTracker(ILogger logger) { - this.logger = logger; + internal MessageReplyTracker(string loggerName) { + this.logger = PhantomLogger.Create<MessageReplyTracker>(loggerName); } public uint RegisterReply() { @@ -42,14 +43,6 @@ sealed class MessageReplyTracker { ForgetReply(sequenceId); } } - - public async Task<TReply?> TryWaitForReply<TReply>(uint sequenceId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TReply : class { - try { - return await WaitForReply<TReply>(sequenceId, waitForReplyTime, cancellationToken); - } catch (Exception) { - return null; - } - } public void ForgetReply(uint sequenceId) { if (replyTasks.TryRemove(sequenceId, out var task)) { diff --git a/Utils/Phantom.Utils.Rpc/Phantom.Utils.Rpc.csproj b/Utils/Phantom.Utils.Rpc/Phantom.Utils.Rpc.csproj index 523ab4e..5e70e64 100644 --- a/Utils/Phantom.Utils.Rpc/Phantom.Utils.Rpc.csproj +++ b/Utils/Phantom.Utils.Rpc/Phantom.Utils.Rpc.csproj @@ -13,6 +13,7 @@ <ItemGroup> <ProjectReference Include="..\Phantom.Utils\Phantom.Utils.csproj" /> + <ProjectReference Include="..\Phantom.Utils.Logging\Phantom.Utils.Logging.csproj" /> </ItemGroup> </Project> diff --git a/Utils/Phantom.Utils.Rpc/RpcConfiguration.cs b/Utils/Phantom.Utils.Rpc/RpcConfiguration.cs index 5ecda16..f730671 100644 --- a/Utils/Phantom.Utils.Rpc/RpcConfiguration.cs +++ b/Utils/Phantom.Utils.Rpc/RpcConfiguration.cs @@ -1,8 +1,7 @@ using NetMQ; -using Serilog; namespace Phantom.Utils.Rpc; -public sealed record RpcConfiguration(ILogger RuntimeLogger, ILogger TaskManagerLogger, string Host, ushort Port, NetMQCertificate ServerCertificate) { +public sealed record RpcConfiguration(string LoggerName, string Host, ushort Port, NetMQCertificate ServerCertificate) { public string TcpUrl => "tcp://" + Host + ":" + Port; } diff --git a/Utils/Phantom.Utils.Rpc/RpcExtensions.cs b/Utils/Phantom.Utils.Rpc/RpcExtensions.cs index 76eb393..dc8e6e1 100644 --- a/Utils/Phantom.Utils.Rpc/RpcExtensions.cs +++ b/Utils/Phantom.Utils.Rpc/RpcExtensions.cs @@ -3,7 +3,7 @@ using NetMQ.Sockets; namespace Phantom.Utils.Rpc; -public static class RpcExtensions { +static class RpcExtensions { public static ReadOnlyMemory<byte> Receive(this ClientSocket socket, CancellationToken cancellationToken) { var msg = new Msg(); msg.InitEmpty(); diff --git a/Utils/Phantom.Utils.Rpc/RpcQueue.cs b/Utils/Phantom.Utils.Rpc/RpcQueue.cs new file mode 100644 index 0000000..d27c7e7 --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/RpcQueue.cs @@ -0,0 +1,60 @@ +using System.Threading.Channels; +using Phantom.Utils.Tasks; + +namespace Phantom.Utils.Rpc; + +sealed class RpcQueue { + private readonly Channel<Func<Task>> channel = Channel.CreateUnbounded<Func<Task>>(new UnboundedChannelOptions { + SingleReader = true, + SingleWriter = false, + AllowSynchronousContinuations = false + }); + + private readonly Task processingTask; + + public RpcQueue(TaskManager taskManager, string taskName) { + this.processingTask = taskManager.Run(taskName, Process); + } + + public Task Enqueue(Action action) { + return Enqueue(() => { + action(); + return Task.CompletedTask; + }); + } + + public Task Enqueue(Func<Task> task) { + var completionSource = AsyncTasks.CreateCompletionSource(); + + if (!channel.Writer.TryWrite(() => task().ContinueWith(t => completionSource.SetResultFrom(t)))) { + completionSource.SetCanceled(); + } + + return completionSource.Task; + } + + public Task<T> Enqueue<T>(Func<Task<T>> task) { + var completionSource = AsyncTasks.CreateCompletionSource<T>(); + + if (!channel.Writer.TryWrite(() => task().ContinueWith(t => completionSource.SetResultFrom(t)))) { + completionSource.SetCanceled(); + } + + return completionSource.Task; + } + + private async Task Process() { + try { + await foreach (var task in channel.Reader.ReadAllAsync()) { + await task(); + } + } catch (OperationCanceledException) { + // Ignore. + } + } + + public Task Stop() { + channel.Writer.Complete(); + return processingTask; + } +} diff --git a/Utils/Phantom.Utils.Rpc/Runtime/RpcClientRuntime.cs b/Utils/Phantom.Utils.Rpc/Runtime/RpcClientRuntime.cs index 872851e..3de2b11 100644 --- a/Utils/Phantom.Utils.Rpc/Runtime/RpcClientRuntime.cs +++ b/Utils/Phantom.Utils.Rpc/Runtime/RpcClientRuntime.cs @@ -1,7 +1,6 @@ using NetMQ.Sockets; using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Sockets; -using Phantom.Utils.Tasks; using Serilog; using Serilog.Events; @@ -23,18 +22,18 @@ public abstract class RpcClientRuntime<TClientListener, TServerListener, TReplyM this.receiveCancellationToken = receiveCancellationToken; } - private protected sealed override void Run(ClientSocket socket, ILogger logger, MessageReplyTracker replyTracker, TaskManager taskManager) { - RunWithConnection(socket, connection, logger, taskManager); + private protected sealed override Task Run(ClientSocket socket) { + return RunWithConnection(socket, connection); } - protected virtual void RunWithConnection(ClientSocket socket, RpcConnectionToServer<TServerListener> connection, ILogger logger, TaskManager taskManager) { - var handler = new Handler(connection, messageDefinitions, messageListener, logger, taskManager, receiveCancellationToken); + protected virtual async Task RunWithConnection(ClientSocket socket, RpcConnectionToServer<TServerListener> connection) { + var handler = new Handler(LoggerName, connection, messageDefinitions, messageListener); try { while (!receiveCancellationToken.IsCancellationRequested) { var data = socket.Receive(receiveCancellationToken); - LogMessageType(logger, data); + LogMessageType(RuntimeLogger, data); if (data.Length > 0) { messageDefinitions.ToClient.Handle(data, handler); @@ -43,11 +42,25 @@ public abstract class RpcClientRuntime<TClientListener, TServerListener, TReplyM } catch (OperationCanceledException) { // Ignore. } finally { - logger.Debug("ZeroMQ client stopped receiving messages."); - disconnectSemaphore.Wait(CancellationToken.None); + await handler.StopReceiving(); + RuntimeLogger.Debug("ZeroMQ client stopped receiving messages."); + + await disconnectSemaphore.WaitAsync(CancellationToken.None); } } + private protected sealed override async Task Disconnect(ClientSocket socket) { + try { + await connection.StopSending().WaitAsync(TimeSpan.FromSeconds(10), CancellationToken.None); + } catch (TimeoutException) { + RuntimeLogger.Error("Timed out waiting for message sending queue."); + } + + await SendDisconnectMessage(socket, RuntimeLogger); + } + + protected abstract Task SendDisconnectMessage(ClientSocket socket, ILogger logger); + private void LogMessageType(ILogger logger, ReadOnlyMemory<byte> data) { if (!logger.IsEnabled(LogEventLevel.Verbose)) { return; @@ -65,7 +78,7 @@ public abstract class RpcClientRuntime<TClientListener, TServerListener, TReplyM private readonly RpcConnectionToServer<TServerListener> connection; private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; - public Handler(RpcConnectionToServer<TServerListener> connection, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TClientListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) { + public Handler(string loggerName, RpcConnectionToServer<TServerListener> connection, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TClientListener listener) : base(loggerName, listener) { this.connection = connection; this.messageDefinitions = messageDefinitions; } diff --git a/Utils/Phantom.Utils.Rpc/Runtime/RpcConnection.cs b/Utils/Phantom.Utils.Rpc/Runtime/RpcConnection.cs new file mode 100644 index 0000000..b4a0fd5 --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/Runtime/RpcConnection.cs @@ -0,0 +1,53 @@ +using Phantom.Utils.Rpc.Message; + +namespace Phantom.Utils.Rpc.Runtime; + +public abstract class RpcConnection<TListener> { + private readonly MessageRegistry<TListener> messageRegistry; + private readonly MessageQueues sendingQueues; + private readonly MessageReplyTracker replyTracker; + + internal RpcConnection(string loggerName, MessageRegistry<TListener> messageRegistry, MessageReplyTracker replyTracker) { + this.messageRegistry = messageRegistry; + this.sendingQueues = new MessageQueues(loggerName + ":Send"); + this.replyTracker = replyTracker; + } + + private protected abstract ValueTask Send(byte[] bytes); + + public Task Send<TMessage>(TMessage message) where TMessage : IMessage<TListener, NoReply> { + return sendingQueues.Enqueue(message.QueueKey, () => SendTask(message)); + } + + public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> { + return sendingQueues.Enqueue(message.QueueKey, () => SendTask<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)); + } + + private async Task SendTask<TMessage>(TMessage message) where TMessage : IMessage<TListener, NoReply> { + var bytes = messageRegistry.Write(message).ToArray(); + if (bytes.Length > 0) { + await Send(bytes); + } + } + + private async Task<TReply> SendTask<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> { + var sequenceId = replyTracker.RegisterReply(); + + var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray(); + if (bytes.Length == 0) { + replyTracker.ForgetReply(sequenceId); + throw new ArgumentException("Could not write message.", nameof(message)); + } + + await Send(bytes); + return await replyTracker.WaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); + } + + public void Receive(IReply message) { + replyTracker.ReceiveReply(message.SequenceId, message.SerializedReply); + } + + internal Task StopSending() { + return sendingQueues.Stop(); + } +} diff --git a/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToClient.cs b/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToClient.cs index 2f9f2f4..ffa6123 100644 --- a/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToClient.cs +++ b/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToClient.cs @@ -4,28 +4,26 @@ using Phantom.Utils.Rpc.Message; namespace Phantom.Utils.Rpc.Runtime; -public sealed class RpcConnectionToClient<TListener> { +public sealed class RpcConnectionToClient<TListener> : RpcConnection<TListener> { private readonly ServerSocket socket; private readonly uint routingId; - private readonly MessageRegistry<TListener> messageRegistry; - private readonly MessageReplyTracker messageReplyTracker; - - private volatile bool isAuthorized; - - public bool IsAuthorized { - get => isAuthorized; - set => isAuthorized = value; - } + private readonly TaskCompletionSource<bool> authorizationCompletionSource = new (); internal event EventHandler<RpcClientConnectionClosedEventArgs>? Closed; public bool IsClosed { get; private set; } - internal RpcConnectionToClient(ServerSocket socket, uint routingId, MessageRegistry<TListener> messageRegistry, MessageReplyTracker messageReplyTracker) { + internal RpcConnectionToClient(string loggerName, ServerSocket socket, uint routingId, MessageRegistry<TListener> messageRegistry, MessageReplyTracker replyTracker) : base(loggerName, messageRegistry, replyTracker) { this.socket = socket; this.routingId = routingId; - this.messageRegistry = messageRegistry; - this.messageReplyTracker = messageReplyTracker; + } + + internal Task<bool> GetAuthorization() { + return authorizationCompletionSource.Task; + } + + public void SetAuthorizationResult(bool isAuthorized) { + authorizationCompletionSource.SetResult(isAuthorized); } public bool IsSame(RpcConnectionToClient<TListener> other) { @@ -47,35 +45,7 @@ public sealed class RpcConnectionToClient<TListener> { } } - public async Task Send<TMessage>(TMessage message) where TMessage : IMessage<TListener, NoReply> { - if (IsClosed) { - return; - } - - var bytes = messageRegistry.Write(message).ToArray(); - if (bytes.Length > 0) { - await socket.SendAsync(routingId, bytes); - } - } - - public async Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> where TReply : class { - if (IsClosed) { - return null; - } - - var sequenceId = messageReplyTracker.RegisterReply(); - - var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray(); - if (bytes.Length == 0) { - messageReplyTracker.ForgetReply(sequenceId); - return null; - } - - await socket.SendAsync(routingId, bytes); - return await messageReplyTracker.TryWaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); - } - - public void Receive(IReply message) { - messageReplyTracker.ReceiveReply(message.SequenceId, message.SerializedReply); + private protected override ValueTask Send(byte[] bytes) { + return socket.SendAsync(routingId, bytes); } } diff --git a/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToServer.cs b/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToServer.cs index 9421047..8a771cf 100644 --- a/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToServer.cs +++ b/Utils/Phantom.Utils.Rpc/Runtime/RpcConnectionToServer.cs @@ -4,51 +4,14 @@ using Phantom.Utils.Rpc.Message; namespace Phantom.Utils.Rpc.Runtime; -public sealed class RpcConnectionToServer<TListener> { +public sealed class RpcConnectionToServer<TListener> : RpcConnection<TListener> { private readonly ClientSocket socket; - private readonly MessageRegistry<TListener> messageRegistry; - private readonly MessageReplyTracker replyTracker; - internal RpcConnectionToServer(ClientSocket socket, MessageRegistry<TListener> messageRegistry, MessageReplyTracker replyTracker) { + internal RpcConnectionToServer(string loggerName, ClientSocket socket, MessageRegistry<TListener> messageRegistry, MessageReplyTracker replyTracker) : base(loggerName, messageRegistry, replyTracker) { this.socket = socket; - this.messageRegistry = messageRegistry; - this.replyTracker = replyTracker; } - public async Task Send<TMessage>(TMessage message) where TMessage : IMessage<TListener, NoReply> { - var bytes = messageRegistry.Write(message).ToArray(); - if (bytes.Length > 0) { - await socket.SendAsync(bytes); - } - } - - public async Task<TReply?> TrySend<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> where TReply : class { - var sequenceId = replyTracker.RegisterReply(); - - var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray(); - if (bytes.Length == 0) { - replyTracker.ForgetReply(sequenceId); - return null; - } - - await socket.SendAsync(bytes); - return await replyTracker.TryWaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); - } - - public async Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> { - var sequenceId = replyTracker.RegisterReply(); - - var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray(); - if (bytes.Length == 0) { - replyTracker.ForgetReply(sequenceId); - throw new ArgumentException("Could not write message.", nameof(message)); - } - - await socket.SendAsync(bytes); - return await replyTracker.WaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); - } - - public void Receive(IReply message) { - replyTracker.ReceiveReply(message.SequenceId, message.SerializedReply); + private protected override ValueTask Send(byte[] bytes) { + return socket.SendAsync(bytes); } } diff --git a/Utils/Phantom.Utils.Rpc/Runtime/RpcRuntime.cs b/Utils/Phantom.Utils.Rpc/Runtime/RpcRuntime.cs index 49e73eb..319a215 100644 --- a/Utils/Phantom.Utils.Rpc/Runtime/RpcRuntime.cs +++ b/Utils/Phantom.Utils.Rpc/Runtime/RpcRuntime.cs @@ -1,49 +1,50 @@ -using NetMQ; +using System.Diagnostics.CodeAnalysis; +using NetMQ; +using Phantom.Utils.Logging; using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Sockets; -using Phantom.Utils.Tasks; using Serilog; namespace Phantom.Utils.Rpc.Runtime; public abstract class RpcRuntime<TSocket> where TSocket : ThreadSafeSocket { private readonly TSocket socket; - private readonly ILogger runtimeLogger; - private readonly MessageReplyTracker replyTracker; - private readonly TaskManager taskManager; + private protected string LoggerName { get; } + private protected ILogger RuntimeLogger { get; } + private protected MessageReplyTracker ReplyTracker { get; } + protected RpcRuntime(RpcSocket<TSocket> socket) { this.socket = socket.Socket; - this.runtimeLogger = socket.Config.RuntimeLogger; - this.replyTracker = socket.ReplyTracker; - this.taskManager = new TaskManager(socket.Config.TaskManagerLogger); + + this.LoggerName = socket.Config.LoggerName; + this.RuntimeLogger = PhantomLogger.Create(LoggerName); + this.ReplyTracker = socket.ReplyTracker; } protected async Task Launch() { - void RunTask() { + [SuppressMessage("ReSharper", "AccessToDisposedClosure")] + async Task RunTask() { try { - Run(socket, runtimeLogger, replyTracker, taskManager); + await Run(socket); } catch (Exception e) { - runtimeLogger.Error(e, "Caught exception in RPC thread."); + RuntimeLogger.Error(e, "Caught exception in RPC thread."); } } try { - await Task.Factory.StartNew(RunTask, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); + await Task.Factory.StartNew(RunTask, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).Unwrap(); } catch (OperationCanceledException) { // Ignore. } finally { - await taskManager.Stop(); - await Disconnect(socket, runtimeLogger); + await Disconnect(socket); socket.Dispose(); - runtimeLogger.Information("ZeroMQ runtime stopped."); + RuntimeLogger.Information("ZeroMQ runtime stopped."); } } - private protected abstract void Run(TSocket socket, ILogger logger, MessageReplyTracker replyTracker, TaskManager taskManager); + private protected abstract Task Run(TSocket socket); - protected virtual Task Disconnect(TSocket socket, ILogger logger) { - return Task.CompletedTask; - } + private protected abstract Task Disconnect(TSocket socket); } diff --git a/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs b/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs index e7b4bef..ce5f643 100644 --- a/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs +++ b/Utils/Phantom.Utils.Rpc/Runtime/RpcServerRuntime.cs @@ -1,8 +1,9 @@ -using NetMQ.Sockets; +using System.Collections.Concurrent; +using NetMQ.Sockets; +using Phantom.Utils.Logging; using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Sockets; using Phantom.Utils.Tasks; -using Serilog; using Serilog.Events; namespace Phantom.Utils.Rpc.Runtime; @@ -21,91 +22,124 @@ internal sealed class RpcServerRuntime<TClientListener, TServerListener, TReplyM private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; private readonly Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory; + private readonly TaskManager taskManager; private readonly CancellationToken cancellationToken; private RpcServerRuntime(RpcServerSocket socket, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory, CancellationToken cancellationToken) : base(socket) { this.messageDefinitions = messageDefinitions; this.listenerFactory = listenerFactory; + this.taskManager = new TaskManager(PhantomLogger.Create<TaskManager>(socket.Config.LoggerName + ":Runtime")); this.cancellationToken = cancellationToken; } - private protected override void Run(ServerSocket socket, ILogger logger, MessageReplyTracker replyTracker, TaskManager taskManager) { - var clients = new Dictionary<ulong, Client>(); + private protected override Task Run(ServerSocket socket) { + var clients = new ConcurrentDictionary<ulong, Client>(); void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) { - clients.Remove(e.RoutingId); - logger.Debug("Closed connection to {RoutingId}.", e.RoutingId); + if (!clients.Remove(e.RoutingId, out var client)) { + return; + } + + RuntimeLogger.Debug("Closing connection to {RoutingId}.", e.RoutingId); + client.Connection.Closed -= OnConnectionClosed; + + taskManager.Run("Closing connection to " + e.RoutingId, async () => { + await client.StopReceiving(); + await client.StopProcessing(); + await client.Connection.StopSending(); + RuntimeLogger.Debug("Closed connection to {RoutingId}.", e.RoutingId); + }); } while (!cancellationToken.IsCancellationRequested) { var (routingId, data) = socket.Receive(cancellationToken); if (data.Length == 0) { - LogMessageType(logger, routingId, data, messageType: null); + LogMessageType(routingId, data, messageType: null); continue; } Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null; if (!clients.TryGetValue(routingId, out var client)) { - if (!CheckIsRegistrationMessage(messageType, logger, routingId)) { - continue; - } - - var connection = new RpcConnectionToClient<TClientListener>(socket, routingId, messageDefinitions.ToClient, replyTracker); + var clientLoggerName = LoggerName + ":" + routingId; + var processingQueue = new RpcQueue(taskManager, "Process messages from " + routingId); + var connection = new RpcConnectionToClient<TClientListener>(clientLoggerName, socket, routingId, messageDefinitions.ToClient, ReplyTracker); + connection.Closed += OnConnectionClosed; - client = new Client(connection, messageDefinitions, listenerFactory(connection), logger, taskManager, cancellationToken); + client = new Client(clientLoggerName, connection, processingQueue, messageDefinitions, listenerFactory(connection)); clients[routingId] = client; } - if (!client.Connection.IsAuthorized && !CheckIsRegistrationMessage(messageType, logger, routingId)) { - continue; - } - - LogMessageType(logger, routingId, data, messageType); - messageDefinitions.ToServer.Handle(data, client); + LogMessageType(routingId, data, messageType); + client.Enqueue(messageType, data); } foreach (var client in clients.Values) { - client.Connection.Closed -= OnConnectionClosed; + client.Connection.Close(); } + + return Task.CompletedTask; } - private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data, Type? messageType) { - if (!logger.IsEnabled(LogEventLevel.Verbose)) { + private protected override Task Disconnect(ServerSocket socket) { + return Task.CompletedTask; + } + + private void LogMessageType(uint routingId, ReadOnlyMemory<byte> data, Type? messageType) { + if (!RuntimeLogger.IsEnabled(LogEventLevel.Verbose)) { return; } if (data.Length > 0 && messageType != null) { - logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", messageType.Name, data.Length, routingId); + RuntimeLogger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", messageType.Name, data.Length, routingId); } else { - logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId); + RuntimeLogger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId); } } - private bool CheckIsRegistrationMessage(Type? messageType, ILogger logger, uint routingId) { - if (messageType != null && messageDefinitions.IsRegistrationMessage(messageType)) { - return true; - } - - logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", messageType?.Name ?? "unknown message", routingId); - return false; - } - private sealed class Client : MessageHandler<TServerListener> { public RpcConnectionToClient<TClientListener> Connection { get; } + private readonly RpcQueue processingQueue; private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; - - public Client(RpcConnectionToClient<TClientListener> connection, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TServerListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) { + + public Client(string loggerName, RpcConnectionToClient<TClientListener> connection, RpcQueue processingQueue, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TServerListener listener) : base(loggerName, listener) { this.Connection = connection; + this.processingQueue = processingQueue; this.messageDefinitions = messageDefinitions; } + internal void Enqueue(Type? messageType, ReadOnlyMemory<byte> data) { + if (!Connection.GetAuthorization().IsCompleted && messageType != null && messageDefinitions.IsRegistrationMessage(messageType)) { + processingQueue.Enqueue(() => Handle(data)); + } + else { + processingQueue.Enqueue(() => WaitForAuthorizationAndHandle(data)); + } + } + + private void Handle(ReadOnlyMemory<byte> data) { + messageDefinitions.ToServer.Handle(data, this); + } + + private async Task WaitForAuthorizationAndHandle(ReadOnlyMemory<byte> data) { + if (await Connection.GetAuthorization()) { + Handle(data); + } + else { + Logger.Warning("Dropped message after failed registration."); + } + } + protected override Task SendReply(uint sequenceId, byte[] serializedReply) { return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply)); } + + internal Task StopProcessing() { + return processingQueue.Stop(); + } } } diff --git a/Utils/Phantom.Utils.Rpc/Sockets/RpcClientSocket.cs b/Utils/Phantom.Utils.Rpc/Sockets/RpcClientSocket.cs index 13e9279..bf25889 100644 --- a/Utils/Phantom.Utils.Rpc/Sockets/RpcClientSocket.cs +++ b/Utils/Phantom.Utils.Rpc/Sockets/RpcClientSocket.cs @@ -1,5 +1,6 @@ using NetMQ; using NetMQ.Sockets; +using Phantom.Utils.Logging; using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Runtime; @@ -22,20 +23,20 @@ public sealed class RpcClientSocket<TClientListener, TServerListener, TReplyMess RpcSocket.SetDefaultSocketOptions(options); var url = config.TcpUrl; - var logger = config.RuntimeLogger; - + var logger = PhantomLogger.Create(config.LoggerName); + logger.Information("Starting ZeroMQ client and connecting to {Url}...", url); socket.Connect(url); logger.Information("ZeroMQ client ready."); - + return new RpcClientSocket<TClientListener, TServerListener, TReplyMessage>(socket, config, messageDefinitions); } public RpcConnectionToServer<TServerListener> Connection { get; } internal IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> MessageDefinitions { get; } - + private RpcClientSocket(ClientSocket socket, RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions) : base(socket, config) { MessageDefinitions = messageDefinitions; - Connection = new RpcConnectionToServer<TServerListener>(socket, messageDefinitions.ToServer, ReplyTracker); + Connection = new RpcConnectionToServer<TServerListener>(config.LoggerName, socket, messageDefinitions.ToServer, ReplyTracker); } } diff --git a/Utils/Phantom.Utils.Rpc/Sockets/RpcServerSocket.cs b/Utils/Phantom.Utils.Rpc/Sockets/RpcServerSocket.cs index 6be05db..c4b9ced 100644 --- a/Utils/Phantom.Utils.Rpc/Sockets/RpcServerSocket.cs +++ b/Utils/Phantom.Utils.Rpc/Sockets/RpcServerSocket.cs @@ -1,4 +1,5 @@ using NetMQ.Sockets; +using Phantom.Utils.Logging; namespace Phantom.Utils.Rpc.Sockets; @@ -12,7 +13,7 @@ public sealed class RpcServerSocket : RpcSocket<ServerSocket> { RpcSocket.SetDefaultSocketOptions(options); var url = config.TcpUrl; - var logger = config.RuntimeLogger; + var logger = PhantomLogger.Create(config.LoggerName); logger.Information("Starting ZeroMQ server on {Url}...", url); socket.Bind(url); diff --git a/Utils/Phantom.Utils.Rpc/Sockets/RpcSocket.cs b/Utils/Phantom.Utils.Rpc/Sockets/RpcSocket.cs index 134c42d..1009424 100644 --- a/Utils/Phantom.Utils.Rpc/Sockets/RpcSocket.cs +++ b/Utils/Phantom.Utils.Rpc/Sockets/RpcSocket.cs @@ -20,6 +20,6 @@ public abstract class RpcSocket<TSocket> where TSocket : ThreadSafeSocket { protected RpcSocket(TSocket socket, RpcConfiguration config) { Socket = socket; Config = config; - ReplyTracker = new MessageReplyTracker(config.RuntimeLogger); + ReplyTracker = new MessageReplyTracker(config.LoggerName); } } diff --git a/Utils/Phantom.Utils/Tasks/AsyncTasks.cs b/Utils/Phantom.Utils/Tasks/AsyncTasks.cs index 0414c45..98fb26a 100644 --- a/Utils/Phantom.Utils/Tasks/AsyncTasks.cs +++ b/Utils/Phantom.Utils/Tasks/AsyncTasks.cs @@ -8,4 +8,28 @@ public static class AsyncTasks { public static TaskCompletionSource<T> CreateCompletionSource<T>() { return new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously); } + + public static void SetResultFrom(this TaskCompletionSource completionSource, Task task) { + if (task.IsFaulted) { + completionSource.SetException(task.Exception.InnerExceptions); + } + else if (task.IsCanceled) { + completionSource.SetCanceled(); + } + else { + completionSource.SetResult(); + } + } + + public static void SetResultFrom<T>(this TaskCompletionSource<T> completionSource, Task<T> task) { + if (task.IsFaulted) { + completionSource.SetException(task.Exception.InnerExceptions); + } + else if (task.IsCanceled) { + completionSource.SetCanceled(); + } + else { + completionSource.SetResult(task.Result); + } + } } diff --git a/Web/Phantom.Web.Services/Rpc/RpcClientRuntime.cs b/Web/Phantom.Web.Services/Rpc/RpcClientRuntime.cs index c35bfd7..87de6a6 100644 --- a/Web/Phantom.Web.Services/Rpc/RpcClientRuntime.cs +++ b/Web/Phantom.Web.Services/Rpc/RpcClientRuntime.cs @@ -15,8 +15,8 @@ public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToWebListener, I } private RpcClientRuntime(RpcClientSocket<IMessageToWebListener, IMessageToControllerListener, ReplyMessage> socket, IMessageToWebListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) {} - - protected override async Task Disconnect(ClientSocket socket, ILogger logger) { + + protected override async Task SendDisconnectMessage(ClientSocket socket, ILogger logger) { var unregisterMessageBytes = WebMessageRegistries.ToController.Write(new UnregisterWebMessage()).ToArray(); try { await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None); diff --git a/Web/Phantom.Web/Program.cs b/Web/Phantom.Web/Program.cs index 9ab73f6..6662931 100644 --- a/Web/Phantom.Web/Program.cs +++ b/Web/Phantom.Web/Program.cs @@ -48,7 +48,7 @@ try { var (controllerCertificate, webToken) = webKey.Value; - var rpcConfiguration = new RpcConfiguration(PhantomLogger.Create("Rpc"), PhantomLogger.Create<TaskManager>("Rpc"), controllerHost, controllerPort, controllerCertificate); + var rpcConfiguration = new RpcConfiguration("Rpc", controllerHost, controllerPort, controllerCertificate); var rpcSocket = RpcClientSocket.Connect(rpcConfiguration, WebMessageRegistries.Definitions, new RegisterWebMessage(webToken)); var configuration = new Configuration(PhantomLogger.Create("Web"), webServerHost, webServerPort, webBasePath, dataProtectionKeysPath, shutdownCancellationToken);