diff --git a/Agent/Phantom.Agent.Rpc/ControllerConnection.cs b/Agent/Phantom.Agent.Rpc/ControllerConnection.cs new file mode 100644 index 0000000..e0c453c --- /dev/null +++ b/Agent/Phantom.Agent.Rpc/ControllerConnection.cs @@ -0,0 +1,25 @@ +using Phantom.Common.Logging; +using Phantom.Common.Messages.Agent; +using Phantom.Utils.Rpc; +using Serilog; + +namespace Phantom.Agent.Rpc; + +public sealed class ControllerConnection { + private static readonly ILogger Logger = PhantomLogger.Create(nameof(ControllerConnection)); + + private readonly RpcConnectionToServer<IMessageToControllerListener> connection; + + public ControllerConnection(RpcConnectionToServer<IMessageToControllerListener> connection) { + this.connection = connection; + Logger.Information("Connection ready."); + } + + public Task Send<TMessage>(TMessage message) where TMessage : IMessageToController { + return connection.Send(message); + } + + public Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToController<TReply> where TReply : class { + return connection.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken); + } +} diff --git a/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs b/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs index 350a05a..bcb2c56 100644 --- a/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs +++ b/Agent/Phantom.Agent.Rpc/KeepAliveLoop.cs @@ -1,5 +1,7 @@ using Phantom.Common.Logging; +using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent.ToController; +using Phantom.Utils.Rpc; using Serilog; namespace Phantom.Agent.Rpc; @@ -9,10 +11,10 @@ sealed class KeepAliveLoop { private static readonly TimeSpan KeepAliveInterval = TimeSpan.FromSeconds(10); - private readonly RpcServerConnection connection; + private readonly RpcConnectionToServer<IMessageToControllerListener> connection; private readonly CancellationTokenSource cancellationTokenSource = new (); - public KeepAliveLoop(RpcServerConnection connection) { + public KeepAliveLoop(RpcConnectionToServer<IMessageToControllerListener> connection) { this.connection = connection; Task.Run(Run); } diff --git a/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs b/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs new file mode 100644 index 0000000..6b06123 --- /dev/null +++ b/Agent/Phantom.Agent.Rpc/RpcClientRuntime.cs @@ -0,0 +1,42 @@ +using NetMQ; +using NetMQ.Sockets; +using Phantom.Common.Data.Agent; +using Phantom.Common.Messages.Agent; +using Phantom.Common.Messages.Agent.BiDirectional; +using Phantom.Common.Messages.Agent.ToController; +using Phantom.Utils.Rpc; +using Phantom.Utils.Rpc.Sockets; +using Phantom.Utils.Tasks; +using Serilog; + +namespace Phantom.Agent.Rpc; + +public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> { + public static Task Launch(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, AgentInfo agentInfo, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) { + return new RpcClientRuntime(socket, agentInfo.Guid, messageListener, disconnectSemaphore, receiveCancellationToken).Launch(); + } + + private readonly Guid agentGuid; + + private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, Guid agentGuid, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) { + this.agentGuid = agentGuid; + } + + protected override void RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToControllerListener> connection, ILogger logger, TaskManager taskManager) { + var keepAliveLoop = new KeepAliveLoop(connection); + try { + base.RunWithConnection(socket, connection, logger, taskManager); + } finally { + keepAliveLoop.Cancel(); + } + } + + protected override async Task Disconnect(ClientSocket socket, ILogger logger) { + var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage(agentGuid)).ToArray(); + try { + await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None); + } catch (TimeoutException) { + logger.Error("Timed out communicating agent shutdown with the controller."); + } + } +} diff --git a/Agent/Phantom.Agent.Rpc/RpcLauncher.cs b/Agent/Phantom.Agent.Rpc/RpcLauncher.cs deleted file mode 100644 index 4e17a67..0000000 --- a/Agent/Phantom.Agent.Rpc/RpcLauncher.cs +++ /dev/null @@ -1,107 +0,0 @@ -using NetMQ; -using NetMQ.Sockets; -using Phantom.Common.Data.Agent; -using Phantom.Common.Messages.Agent; -using Phantom.Common.Messages.Agent.BiDirectional; -using Phantom.Common.Messages.Agent.ToController; -using Phantom.Utils.Rpc; -using Phantom.Utils.Rpc.Message; -using Phantom.Utils.Tasks; -using Serilog; -using Serilog.Events; - -namespace Phantom.Agent.Rpc; - -public sealed class RpcLauncher : RpcRuntime<ClientSocket> { - public static Task Launch(RpcConfiguration config, AuthToken authToken, AgentInfo agentInfo, Func<RpcServerConnection, IMessageToAgentListener> listenerFactory, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) { - var socket = new ClientSocket(); - var options = socket.Options; - - options.CurveServerCertificate = config.ServerCertificate; - options.CurveCertificate = new NetMQCertificate(); - options.HelloMessage = AgentMessageRegistries.ToController.Write(new RegisterAgentMessage(authToken, agentInfo)).ToArray(); - - return new RpcLauncher(config, socket, agentInfo.Guid, listenerFactory, disconnectSemaphore, receiveCancellationToken).Launch(); - } - - private readonly RpcConfiguration config; - private readonly Guid agentGuid; - private readonly Func<RpcServerConnection, IMessageToAgentListener> messageListenerFactory; - - private readonly SemaphoreSlim disconnectSemaphore; - private readonly CancellationToken receiveCancellationToken; - - private RpcLauncher(RpcConfiguration config, ClientSocket socket, Guid agentGuid, Func<RpcServerConnection, IMessageToAgentListener> messageListenerFactory, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(config, socket) { - this.config = config; - this.agentGuid = agentGuid; - this.messageListenerFactory = messageListenerFactory; - this.disconnectSemaphore = disconnectSemaphore; - this.receiveCancellationToken = receiveCancellationToken; - } - - protected override void Connect(ClientSocket socket) { - var logger = config.RuntimeLogger; - var url = config.TcpUrl; - - logger.Information("Starting ZeroMQ client and connecting to {Url}...", url); - socket.Connect(url); - logger.Information("ZeroMQ client ready."); - } - - protected override void Run(ClientSocket socket, MessageReplyTracker replyTracker, TaskManager taskManager) { - var connection = new RpcServerConnection(socket, replyTracker); - ServerMessaging.SetCurrentConnection(connection); - - var logger = config.RuntimeLogger; - var handler = new MessageToAgentHandler(messageListenerFactory(connection), logger, taskManager, receiveCancellationToken); - var keepAliveLoop = new KeepAliveLoop(connection); - - try { - while (!receiveCancellationToken.IsCancellationRequested) { - var data = socket.Receive(receiveCancellationToken); - - LogMessageType(logger, data); - - if (data.Length > 0) { - AgentMessageRegistries.ToAgent.Handle(data, handler); - } - } - } catch (OperationCanceledException) { - // Ignore. - } finally { - logger.Debug("ZeroMQ client stopped receiving messages."); - - disconnectSemaphore.Wait(CancellationToken.None); - keepAliveLoop.Cancel(); - } - } - - private static void LogMessageType(ILogger logger, ReadOnlyMemory<byte> data) { - if (!logger.IsEnabled(LogEventLevel.Verbose)) { - return; - } - - if (data.Length > 0 && AgentMessageRegistries.ToAgent.TryGetType(data, out var type)) { - logger.Verbose("Received {MessageType} ({Bytes} B) from controller.", type.Name, data.Length); - } - else { - logger.Verbose("Received {Bytes} B message from controller.", data.Length); - } - } - - protected override async Task Disconnect() { - var unregisterTimeoutTask = Task.Delay(TimeSpan.FromSeconds(5), CancellationToken.None); - var finishedTask = await Task.WhenAny(ServerMessaging.Send(new UnregisterAgentMessage(agentGuid)), unregisterTimeoutTask); - if (finishedTask == unregisterTimeoutTask) { - config.RuntimeLogger.Error("Timed out communicating agent shutdown with the controller."); - } - } - - private sealed class MessageToAgentHandler : MessageHandler<IMessageToAgentListener> { - public MessageToAgentHandler(IMessageToAgentListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) {} - - protected override Task SendReply(uint sequenceId, byte[] serializedReply) { - return ServerMessaging.Send(new ReplyMessage(sequenceId, serializedReply)); - } - } -} diff --git a/Agent/Phantom.Agent.Rpc/RpcServerConnection.cs b/Agent/Phantom.Agent.Rpc/RpcServerConnection.cs deleted file mode 100644 index 62e6e45..0000000 --- a/Agent/Phantom.Agent.Rpc/RpcServerConnection.cs +++ /dev/null @@ -1,41 +0,0 @@ -using NetMQ; -using NetMQ.Sockets; -using Phantom.Common.Messages.Agent; -using Phantom.Common.Messages.Agent.BiDirectional; -using Phantom.Utils.Rpc.Message; - -namespace Phantom.Agent.Rpc; - -public sealed class RpcServerConnection { - private readonly ClientSocket socket; - private readonly MessageReplyTracker replyTracker; - - internal RpcServerConnection(ClientSocket socket, MessageReplyTracker replyTracker) { - this.socket = socket; - this.replyTracker = replyTracker; - } - - internal async Task Send<TMessage>(TMessage message) where TMessage : IMessageToController { - var bytes = AgentMessageRegistries.ToController.Write(message).ToArray(); - if (bytes.Length > 0) { - await socket.SendAsync(bytes); - } - } - - internal async Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToController<TReply> where TReply : class { - var sequenceId = replyTracker.RegisterReply(); - - var bytes = AgentMessageRegistries.ToController.Write<TMessage, TReply>(sequenceId, message).ToArray(); - if (bytes.Length == 0) { - replyTracker.ForgetReply(sequenceId); - return null; - } - - await socket.SendAsync(bytes); - return await replyTracker.WaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); - } - - public void Receive(ReplyMessage message) { - replyTracker.ReceiveReply(message.SequenceId, message.SerializedReply); - } -} diff --git a/Agent/Phantom.Agent.Rpc/ServerMessaging.cs b/Agent/Phantom.Agent.Rpc/ServerMessaging.cs deleted file mode 100644 index 7466e39..0000000 --- a/Agent/Phantom.Agent.Rpc/ServerMessaging.cs +++ /dev/null @@ -1,34 +0,0 @@ -using Phantom.Common.Logging; -using Phantom.Common.Messages.Agent; -using Serilog; - -namespace Phantom.Agent.Rpc; - -public static class ServerMessaging { - private static readonly ILogger Logger = PhantomLogger.Create(nameof(ServerMessaging)); - - private static RpcServerConnection? CurrentConnection { get; set; } - private static RpcServerConnection CurrentConnectionOrThrow => CurrentConnection ?? throw new InvalidOperationException("Server connection not ready."); - - private static readonly object SetCurrentConnectionLock = new (); - - internal static void SetCurrentConnection(RpcServerConnection connection) { - lock (SetCurrentConnectionLock) { - if (CurrentConnection != null) { - throw new InvalidOperationException("Server connection can only be set once."); - } - - CurrentConnection = connection; - } - - Logger.Information("Server connection ready."); - } - - public static Task Send<TMessage>(TMessage message) where TMessage : IMessageToController { - return CurrentConnectionOrThrow.Send(message); - } - - public static Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToController<TReply> where TReply : class { - return CurrentConnectionOrThrow.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken); - } -} diff --git a/Agent/Phantom.Agent.Services/AgentServices.cs b/Agent/Phantom.Agent.Services/AgentServices.cs index c3ba7b3..3025436 100644 --- a/Agent/Phantom.Agent.Services/AgentServices.cs +++ b/Agent/Phantom.Agent.Services/AgentServices.cs @@ -1,4 +1,5 @@ using Phantom.Agent.Minecraft.Java; +using Phantom.Agent.Rpc; using Phantom.Agent.Services.Backups; using Phantom.Agent.Services.Instances; using Phantom.Common.Data.Agent; @@ -18,12 +19,12 @@ public sealed class AgentServices { internal JavaRuntimeRepository JavaRuntimeRepository { get; } internal InstanceSessionManager InstanceSessionManager { get; } - public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration) { + public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) { this.AgentFolders = agentFolders; this.TaskManager = new TaskManager(PhantomLogger.Create<TaskManager, AgentServices>()); this.BackupManager = new BackupManager(agentFolders, serviceConfiguration.MaxConcurrentCompressionTasks); this.JavaRuntimeRepository = new JavaRuntimeRepository(); - this.InstanceSessionManager = new InstanceSessionManager(agentInfo, agentFolders, JavaRuntimeRepository, TaskManager, BackupManager); + this.InstanceSessionManager = new InstanceSessionManager(controllerConnection, agentInfo, agentFolders, JavaRuntimeRepository, TaskManager, BackupManager); } public async Task Initialize() { diff --git a/Agent/Phantom.Agent.Services/Instances/Instance.cs b/Agent/Phantom.Agent.Services/Instances/Instance.cs index 5aa2b78..863a904 100644 --- a/Agent/Phantom.Agent.Services/Instances/Instance.cs +++ b/Agent/Phantom.Agent.Services/Instances/Instance.cs @@ -1,5 +1,4 @@ using Phantom.Agent.Minecraft.Launcher; -using Phantom.Agent.Rpc; using Phantom.Agent.Services.Instances.Procedures; using Phantom.Agent.Services.Instances.States; using Phantom.Common.Data.Instance; @@ -57,7 +56,7 @@ sealed class Instance : IAsyncDisposable { public void ReportLastStatus() { TryUpdateStatus("Report last status of instance " + shortName, async () => { - await ServerMessaging.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, currentStatus)); + await Services.ControllerConnection.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, currentStatus)); }); } @@ -65,14 +64,14 @@ sealed class Instance : IAsyncDisposable { TryUpdateStatus("Report status of instance " + shortName + " as " + status.GetType().Name, async () => { if (status != currentStatus) { currentStatus = status; - await ServerMessaging.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, status)); + await Services.ControllerConnection.Send(new ReportInstanceStatusMessage(Configuration.InstanceGuid, status)); } }); } private void ReportEvent(IInstanceEvent instanceEvent) { var message = new ReportInstanceEventMessage(Guid.NewGuid(), DateTime.UtcNow, Configuration.InstanceGuid, instanceEvent); - Services.TaskManager.Run("Report event for instance " + shortName, async () => await ServerMessaging.Send(message)); + Services.TaskManager.Run("Report event for instance " + shortName, async () => await Services.ControllerConnection.Send(message)); } internal void TransitionState(IInstanceState newState) { diff --git a/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs b/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs index 2d7b781..cace3e3 100644 --- a/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs +++ b/Agent/Phantom.Agent.Services/Instances/InstanceLogSender.cs @@ -16,12 +16,14 @@ sealed class InstanceLogSender : CancellableBackgroundTask { private static readonly TimeSpan SendDelay = TimeSpan.FromMilliseconds(200); + private readonly ControllerConnection controllerConnection; private readonly Guid instanceGuid; private readonly Channel<string> outputChannel; private int droppedLinesSinceLastSend; - public InstanceLogSender(TaskManager taskManager, Guid instanceGuid, string loggerName) : base(PhantomLogger.Create<InstanceLogSender>(loggerName), taskManager, "Instance log sender for " + loggerName) { + public InstanceLogSender(ControllerConnection controllerConnection, TaskManager taskManager, Guid instanceGuid, string loggerName) : base(PhantomLogger.Create<InstanceLogSender>(loggerName), taskManager, "Instance log sender for " + loggerName) { + this.controllerConnection = controllerConnection; this.instanceGuid = instanceGuid; this.outputChannel = Channel.CreateBounded<string>(BufferOptions, OnLineDropped); Start(); @@ -61,7 +63,7 @@ sealed class InstanceLogSender : CancellableBackgroundTask { private async Task SendOutputToServer(ImmutableArray<string> lines) { if (!lines.IsEmpty) { - await ServerMessaging.Send(new InstanceOutputMessage(instanceGuid, lines)); + await controllerConnection.Send(new InstanceOutputMessage(instanceGuid, lines)); } } diff --git a/Agent/Phantom.Agent.Services/Instances/InstanceServices.cs b/Agent/Phantom.Agent.Services/Instances/InstanceServices.cs index e063d38..c2d3f5c 100644 --- a/Agent/Phantom.Agent.Services/Instances/InstanceServices.cs +++ b/Agent/Phantom.Agent.Services/Instances/InstanceServices.cs @@ -1,7 +1,8 @@ using Phantom.Agent.Minecraft.Launcher; +using Phantom.Agent.Rpc; using Phantom.Agent.Services.Backups; using Phantom.Utils.Tasks; namespace Phantom.Agent.Services.Instances; -sealed record InstanceServices(TaskManager TaskManager, PortManager PortManager, BackupManager BackupManager, LaunchServices LaunchServices); +sealed record InstanceServices(ControllerConnection ControllerConnection, TaskManager TaskManager, PortManager PortManager, BackupManager BackupManager, LaunchServices LaunchServices); diff --git a/Agent/Phantom.Agent.Services/Instances/InstanceSessionManager.cs b/Agent/Phantom.Agent.Services/Instances/InstanceSessionManager.cs index ad0a507..a5460c7 100644 --- a/Agent/Phantom.Agent.Services/Instances/InstanceSessionManager.cs +++ b/Agent/Phantom.Agent.Services/Instances/InstanceSessionManager.cs @@ -24,6 +24,7 @@ namespace Phantom.Agent.Services.Instances; sealed class InstanceSessionManager : IAsyncDisposable { private static readonly ILogger Logger = PhantomLogger.Create<InstanceSessionManager>(); + private readonly ControllerConnection controllerConnection; private readonly AgentInfo agentInfo; private readonly string basePath; @@ -36,7 +37,8 @@ sealed class InstanceSessionManager : IAsyncDisposable { private uint instanceLoggerSequenceId = 0; - public InstanceSessionManager(AgentInfo agentInfo, AgentFolders agentFolders, JavaRuntimeRepository javaRuntimeRepository, TaskManager taskManager, BackupManager backupManager) { + public InstanceSessionManager(ControllerConnection controllerConnection, AgentInfo agentInfo, AgentFolders agentFolders, JavaRuntimeRepository javaRuntimeRepository, TaskManager taskManager, BackupManager backupManager) { + this.controllerConnection = controllerConnection; this.agentInfo = agentInfo; this.basePath = agentFolders.InstancesFolderPath; this.shutdownCancellationToken = shutdownCancellationTokenSource.Token; @@ -45,7 +47,7 @@ sealed class InstanceSessionManager : IAsyncDisposable { var launchServices = new LaunchServices(minecraftServerExecutables, javaRuntimeRepository); var portManager = new PortManager(agentInfo.AllowedServerPorts, agentInfo.AllowedRconPorts); - this.instanceServices = new InstanceServices(taskManager, portManager, backupManager, launchServices); + this.instanceServices = new InstanceServices(controllerConnection, taskManager, portManager, backupManager, launchServices); } private async Task<InstanceActionResult<T>> AcquireSemaphoreAndRun<T>(Func<Task<InstanceActionResult<T>>> func) { @@ -146,7 +148,7 @@ sealed class InstanceSessionManager : IAsyncDisposable { var runningInstances = GetRunningInstancesInternal(); var runningInstanceCount = runningInstances.Length; var runningInstanceMemory = runningInstances.Aggregate(RamAllocationUnits.Zero, static (total, instance) => total + instance.Configuration.MemoryAllocation); - await ServerMessaging.Send(new ReportAgentStatusMessage(runningInstanceCount, runningInstanceMemory)); + await controllerConnection.Send(new ReportAgentStatusMessage(runningInstanceCount, runningInstanceMemory)); } finally { semaphore.Release(); } diff --git a/Agent/Phantom.Agent.Services/Instances/States/InstanceRunningState.cs b/Agent/Phantom.Agent.Services/Instances/States/InstanceRunningState.cs index 5fbf02b..92d36ee 100644 --- a/Agent/Phantom.Agent.Services/Instances/States/InstanceRunningState.cs +++ b/Agent/Phantom.Agent.Services/Instances/States/InstanceRunningState.cs @@ -27,7 +27,7 @@ sealed class InstanceRunningState : IInstanceState, IDisposable { this.context = context; this.Process = process; - this.logSender = new InstanceLogSender(context.Services.TaskManager, configuration.InstanceGuid, context.ShortName); + this.logSender = new InstanceLogSender(context.Services.ControllerConnection, context.Services.TaskManager, configuration.InstanceGuid, context.ShortName); this.backupScheduler = new BackupScheduler(context.Services.TaskManager, context.Services.BackupManager, process, context, configuration.ServerPort); this.backupScheduler.BackupCompleted += OnScheduledBackupCompleted; diff --git a/Agent/Phantom.Agent.Services/Rpc/MessageListener.cs b/Agent/Phantom.Agent.Services/Rpc/MessageListener.cs index a83ae77..45d6810 100644 --- a/Agent/Phantom.Agent.Services/Rpc/MessageListener.cs +++ b/Agent/Phantom.Agent.Services/Rpc/MessageListener.cs @@ -1,11 +1,11 @@ -using Phantom.Agent.Rpc; -using Phantom.Common.Data.Instance; +using Phantom.Common.Data.Instance; using Phantom.Common.Data.Replies; using Phantom.Common.Logging; using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent.BiDirectional; using Phantom.Common.Messages.Agent.ToAgent; using Phantom.Common.Messages.Agent.ToController; +using Phantom.Utils.Rpc; using Phantom.Utils.Rpc.Message; using Serilog; @@ -14,11 +14,11 @@ namespace Phantom.Agent.Services.Rpc; public sealed class MessageListener : IMessageToAgentListener { private static ILogger Logger { get; } = PhantomLogger.Create<MessageListener>(); - private readonly RpcServerConnection connection; + private readonly RpcConnectionToServer<IMessageToControllerListener> connection; private readonly AgentServices agent; private readonly CancellationTokenSource shutdownTokenSource; - public MessageListener(RpcServerConnection connection, AgentServices agent, CancellationTokenSource shutdownTokenSource) { + public MessageListener(RpcConnectionToServer<IMessageToControllerListener> connection, AgentServices agent, CancellationTokenSource shutdownTokenSource) { this.connection = connection; this.agent = agent; this.shutdownTokenSource = shutdownTokenSource; @@ -40,7 +40,7 @@ public sealed class MessageListener : IMessageToAgentListener { } } - await ServerMessaging.Send(new AdvertiseJavaRuntimesMessage(agent.JavaRuntimeRepository.All)); + await connection.Send(new AdvertiseJavaRuntimesMessage(agent.JavaRuntimeRepository.All)); await agent.InstanceSessionManager.RefreshAgentStatus(); return NoReply.Instance; diff --git a/Agent/Phantom.Agent/Program.cs b/Agent/Phantom.Agent/Program.cs index 988ea7e..7805279 100644 --- a/Agent/Phantom.Agent/Program.cs +++ b/Agent/Phantom.Agent/Program.cs @@ -1,11 +1,15 @@ using System.Reflection; +using NetMQ; using Phantom.Agent; using Phantom.Agent.Rpc; using Phantom.Agent.Services; using Phantom.Agent.Services.Rpc; using Phantom.Common.Data.Agent; using Phantom.Common.Logging; +using Phantom.Common.Messages.Agent; +using Phantom.Common.Messages.Agent.ToController; using Phantom.Utils.Rpc; +using Phantom.Utils.Rpc.Sockets; using Phantom.Utils.Runtime; using Phantom.Utils.Tasks; @@ -45,19 +49,18 @@ try { var (controllerCertificate, agentToken) = agentKey.Value; var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts); - var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks)); - - MessageListener MessageListenerFactory(RpcServerConnection connection) { - return new MessageListener(connection, agentServices, shutdownCancellationTokenSource); - } - + PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent..."); + + var rpcConfiguration = new RpcConfiguration(PhantomLogger.Create("Rpc"), PhantomLogger.Create<TaskManager>("Rpc"), controllerHost, controllerPort, controllerCertificate); + var rpcSocket = RpcClientSocket.Connect(rpcConfiguration, AgentMessageRegistries.Definitions, new RegisterAgentMessage(agentToken, agentInfo)); + var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcSocket.Connection)); await agentServices.Initialize(); var rpcDisconnectSemaphore = new SemaphoreSlim(0, 1); - var rpcConfiguration = new RpcConfiguration(PhantomLogger.Create("Rpc"), PhantomLogger.Create<TaskManager>("Rpc"), controllerHost, controllerPort, controllerCertificate); - var rpcTask = RpcLauncher.Launch(rpcConfiguration, agentToken, agentInfo, MessageListenerFactory, rpcDisconnectSemaphore, shutdownCancellationToken); + var rpcMessageListener = new MessageListener(rpcSocket.Connection, agentServices, shutdownCancellationTokenSource); + var rpcTask = RpcClientRuntime.Launch(rpcSocket, agentInfo, rpcMessageListener, rpcDisconnectSemaphore, shutdownCancellationToken); try { await rpcTask.WaitAsync(shutdownCancellationToken); } finally { @@ -67,6 +70,8 @@ try { rpcDisconnectSemaphore.Release(); await rpcTask; rpcDisconnectSemaphore.Dispose(); + + NetMQConfig.Cleanup(); } return 0; diff --git a/Common/Phantom.Common.Messages.Agent/AgentMessageRegistries.cs b/Common/Phantom.Common.Messages.Agent/AgentMessageRegistries.cs index 3023d71..341b46a 100644 --- a/Common/Phantom.Common.Messages.Agent/AgentMessageRegistries.cs +++ b/Common/Phantom.Common.Messages.Agent/AgentMessageRegistries.cs @@ -34,14 +34,14 @@ public static class AgentMessageRegistries { } private sealed class MessageDefinitions : IMessageDefinitions<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> { - public MessageRegistry<IMessageToAgentListener> Outgoing => ToAgent; - public MessageRegistry<IMessageToControllerListener> Incoming => ToController; + public MessageRegistry<IMessageToAgentListener> ToClient => ToAgent; + public MessageRegistry<IMessageToControllerListener> ToServer => ToController; public bool IsRegistrationMessage(Type messageType) { return messageType == typeof(RegisterAgentMessage); } - public ReplyMessage CreateReplyMessage( uint sequenceId, byte[] serializedReply) { + public ReplyMessage CreateReplyMessage(uint sequenceId, byte[] serializedReply) { return new ReplyMessage(sequenceId, serializedReply); } } diff --git a/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs b/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs index b94d236..a3b82d5 100644 --- a/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs +++ b/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs @@ -17,14 +17,14 @@ public static class WebMessageRegistries { } private sealed class MessageDefinitions : IMessageDefinitions<IMessageToWebListener, IMessageToControllerListener, ReplyMessage> { - public MessageRegistry<IMessageToWebListener> Outgoing => ToWeb; - public MessageRegistry<IMessageToControllerListener> Incoming => ToController; + public MessageRegistry<IMessageToWebListener> ToClient => ToWeb; + public MessageRegistry<IMessageToControllerListener> ToServer => ToController; public bool IsRegistrationMessage(Type messageType) { return false; } - public ReplyMessage CreateReplyMessage( uint sequenceId, byte[] serializedReply) { + public ReplyMessage CreateReplyMessage(uint sequenceId, byte[] serializedReply) { return new ReplyMessage(sequenceId, serializedReply); } } diff --git a/Controller/Phantom.Controller.Rpc/RpcClientConnection.cs b/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs similarity index 87% rename from Controller/Phantom.Controller.Rpc/RpcClientConnection.cs rename to Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs index 1307481..30060c9 100644 --- a/Controller/Phantom.Controller.Rpc/RpcClientConnection.cs +++ b/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs @@ -4,7 +4,7 @@ using Phantom.Utils.Rpc.Message; namespace Phantom.Controller.Rpc; -public sealed class RpcClientConnection<TListener> { +public sealed class RpcConnectionToClient<TListener> { private readonly ServerSocket socket; private readonly uint routingId; @@ -14,14 +14,14 @@ public sealed class RpcClientConnection<TListener> { internal event EventHandler<RpcClientConnectionClosedEventArgs>? Closed; private bool isClosed; - internal RpcClientConnection(ServerSocket socket, uint routingId, MessageRegistry<TListener> messageRegistry, MessageReplyTracker messageReplyTracker) { + internal RpcConnectionToClient(ServerSocket socket, uint routingId, MessageRegistry<TListener> messageRegistry, MessageReplyTracker messageReplyTracker) { this.socket = socket; this.routingId = routingId; this.messageRegistry = messageRegistry; this.messageReplyTracker = messageReplyTracker; } - public bool IsSame(RpcClientConnection<TListener> other) { + public bool IsSame(RpcConnectionToClient<TListener> other) { return this.routingId == other.routingId && this.socket == other.socket; } diff --git a/Controller/Phantom.Controller.Rpc/RpcRuntime.cs b/Controller/Phantom.Controller.Rpc/RpcRuntime.cs index 1638bd8..f2ebe4a 100644 --- a/Controller/Phantom.Controller.Rpc/RpcRuntime.cs +++ b/Controller/Phantom.Controller.Rpc/RpcRuntime.cs @@ -1,6 +1,7 @@ using NetMQ.Sockets; using Phantom.Utils.Rpc; using Phantom.Utils.Rpc.Message; +using Phantom.Utils.Rpc.Sockets; using Phantom.Utils.Tasks; using Serilog; using Serilog.Events; @@ -8,49 +9,28 @@ using Serilog.Events; namespace Phantom.Controller.Rpc; public static class RpcRuntime { - public static Task Launch<TOutgoingListener, TIncomingListener, TReplyMessage>(RpcConfiguration config, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory, CancellationToken cancellationToken) where TReplyMessage : IMessage<TOutgoingListener, NoReply>, IMessage<TIncomingListener, NoReply> { - return RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMessage>.Launch(config, messageDefinitions, listenerFactory, cancellationToken); + public static Task Launch<TClientListener, TServerListener, TReplyMessage>(RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory, CancellationToken cancellationToken) where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { + return RpcRuntime<TClientListener, TServerListener, TReplyMessage>.Launch(config, messageDefinitions, listenerFactory, cancellationToken); } } -internal sealed class RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMessage> : RpcRuntime<ServerSocket> where TReplyMessage : IMessage<TOutgoingListener, NoReply>, IMessage<TIncomingListener, NoReply> { - internal static Task Launch(RpcConfiguration config, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory, CancellationToken cancellationToken) { - return new RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMessage>(config, messageDefinitions, listenerFactory, cancellationToken).Launch(); +internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage> : RpcRuntime<ServerSocket> where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { + internal static Task Launch(RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory, CancellationToken cancellationToken) { + var socket = RpcServerSocket.Connect(config); + return new RpcRuntime<TClientListener, TServerListener, TReplyMessage>(socket, messageDefinitions, listenerFactory, cancellationToken).Launch(); } - private static ServerSocket CreateSocket(RpcConfiguration config) { - var socket = new ServerSocket(); - var options = socket.Options; - - options.CurveServer = true; - options.CurveCertificate = config.ServerCertificate; - - return socket; - } - - private readonly RpcConfiguration config; - private readonly IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions; - private readonly Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory; + private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; + private readonly Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory; private readonly CancellationToken cancellationToken; - private RpcRuntime(RpcConfiguration config, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory, CancellationToken cancellationToken) : base(config, CreateSocket(config)) { - this.config = config; + private RpcRuntime(RpcServerSocket socket, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, Func<RpcConnectionToClient<TClientListener>, TServerListener> listenerFactory, CancellationToken cancellationToken) : base(socket) { this.messageDefinitions = messageDefinitions; this.listenerFactory = listenerFactory; this.cancellationToken = cancellationToken; } - protected override void Connect(ServerSocket socket) { - var logger = config.RuntimeLogger; - var url = config.TcpUrl; - - logger.Information("Starting ZeroMQ server on {Url}...", url); - socket.Bind(url); - logger.Information("ZeroMQ server initialized, listening for connections on port {Port}.", config.Port); - } - - protected override void Run(ServerSocket socket, MessageReplyTracker replyTracker, TaskManager taskManager) { - var logger = config.RuntimeLogger; + protected override void Run(ServerSocket socket, ILogger logger, MessageReplyTracker replyTracker, TaskManager taskManager) { var clients = new Dictionary<ulong, Client>(); void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) { @@ -71,7 +51,7 @@ internal sealed class RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMes continue; } - var connection = new RpcClientConnection<TOutgoingListener>(socket, routingId, messageDefinitions.Outgoing, replyTracker); + var connection = new RpcConnectionToClient<TClientListener>(socket, routingId, messageDefinitions.ToClient, replyTracker); connection.Closed += OnConnectionClosed; client = new Client(connection, messageDefinitions, listenerFactory(connection), logger, taskManager, cancellationToken); @@ -79,7 +59,7 @@ internal sealed class RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMes } LogMessageType(logger, routingId, data); - messageDefinitions.Incoming.Handle(data, client); + messageDefinitions.ToServer.Handle(data, client); } foreach (var client in clients.Values) { @@ -92,7 +72,7 @@ internal sealed class RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMes return; } - if (data.Length > 0 && messageDefinitions.Incoming.TryGetType(data, out var type)) { + if (data.Length > 0 && messageDefinitions.ToServer.TryGetType(data, out var type)) { logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", type.Name, data.Length, routingId); } else { @@ -101,7 +81,7 @@ internal sealed class RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMes } private bool CheckIsRegistrationMessage(ReadOnlyMemory<byte> data, ILogger logger, uint routingId) { - if (messageDefinitions.Incoming.TryGetType(data, out var type) && messageDefinitions.IsRegistrationMessage(type)) { + if (messageDefinitions.ToServer.TryGetType(data, out var type) && messageDefinitions.IsRegistrationMessage(type)) { return true; } @@ -109,12 +89,12 @@ internal sealed class RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMes return false; } - private sealed class Client : MessageHandler<TIncomingListener> { - public RpcClientConnection<TOutgoingListener> Connection { get; } + private sealed class Client : MessageHandler<TServerListener> { + public RpcConnectionToClient<TClientListener> Connection { get; } - private readonly IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions; + private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; - public Client(RpcClientConnection<TOutgoingListener> connection, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, TIncomingListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) { + public Client(RpcConnectionToClient<TClientListener> connection, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TServerListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) { this.Connection = connection; this.messageDefinitions = messageDefinitions; } diff --git a/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs b/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs index 30d04ce..5e5f460 100644 --- a/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs +++ b/Controller/Phantom.Controller.Services/Agents/AgentConnection.cs @@ -4,13 +4,13 @@ using Phantom.Controller.Rpc; namespace Phantom.Controller.Services.Agents; sealed class AgentConnection { - private readonly RpcClientConnection<IMessageToAgentListener> connection; + private readonly RpcConnectionToClient<IMessageToAgentListener> connection; - internal AgentConnection(RpcClientConnection<IMessageToAgentListener> connection) { + internal AgentConnection(RpcConnectionToClient<IMessageToAgentListener> connection) { this.connection = connection; } - public bool IsSame(RpcClientConnection<IMessageToAgentListener> connection) { + public bool IsSame(RpcConnectionToClient<IMessageToAgentListener> connection) { return this.connection.IsSame(connection); } diff --git a/Controller/Phantom.Controller.Services/Agents/AgentManager.cs b/Controller/Phantom.Controller.Services/Agents/AgentManager.cs index e4bf9bc..111c877 100644 --- a/Controller/Phantom.Controller.Services/Agents/AgentManager.cs +++ b/Controller/Phantom.Controller.Services/Agents/AgentManager.cs @@ -52,7 +52,7 @@ public sealed class AgentManager { return agents.ByGuid.ToImmutable(); } - internal async Task<bool> RegisterAgent(AuthToken authToken, AgentInfo agentInfo, InstanceManager instanceManager, RpcClientConnection<IMessageToAgentListener> connection) { + internal async Task<bool> RegisterAgent(AuthToken authToken, AgentInfo agentInfo, InstanceManager instanceManager, RpcConnectionToClient<IMessageToAgentListener> connection) { if (!this.authToken.FixedTimeEquals(authToken)) { await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.InvalidToken)); return false; @@ -88,7 +88,7 @@ public sealed class AgentManager { return true; } - internal bool UnregisterAgent(Guid agentGuid, RpcClientConnection<IMessageToAgentListener> connection) { + internal bool UnregisterAgent(Guid agentGuid, RpcConnectionToClient<IMessageToAgentListener> connection) { if (agents.ByGuid.TryReplaceIf(agentGuid, static oldAgent => oldAgent.AsOffline(), oldAgent => oldAgent.Connection?.IsSame(connection) == true)) { Logger.Information("Unregistered agent with GUID {Guid}.", agentGuid); return true; diff --git a/Controller/Phantom.Controller.Services/ControllerServices.cs b/Controller/Phantom.Controller.Services/ControllerServices.cs index e3a4ebe..20e1d62 100644 --- a/Controller/Phantom.Controller.Services/ControllerServices.cs +++ b/Controller/Phantom.Controller.Services/ControllerServices.cs @@ -53,11 +53,11 @@ public sealed class ControllerServices { this.cancellationToken = shutdownCancellationToken; } - public AgentMessageListener CreateAgentMessageListener(RpcClientConnection<IMessageToAgentListener> connection) { + public AgentMessageListener CreateAgentMessageListener(RpcConnectionToClient<IMessageToAgentListener> connection) { return new AgentMessageListener(connection, AgentManager, AgentJavaRuntimesManager, InstanceManager, InstanceLogManager, EventLog, cancellationToken); } - public WebMessageListener CreateWebMessageListener(RpcClientConnection<IMessageToWebListener> connection) { + public WebMessageListener CreateWebMessageListener(RpcConnectionToClient<IMessageToWebListener> connection) { return new WebMessageListener(connection); } diff --git a/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs b/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs index 3d8347a..083cfad 100644 --- a/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs +++ b/Controller/Phantom.Controller.Services/Rpc/AgentMessageListener.cs @@ -14,7 +14,7 @@ using Phantom.Utils.Tasks; namespace Phantom.Controller.Services.Rpc; public sealed class AgentMessageListener : IMessageToControllerListener { - private readonly RpcClientConnection<IMessageToAgentListener> connection; + private readonly RpcConnectionToClient<IMessageToAgentListener> connection; private readonly AgentManager agentManager; private readonly AgentJavaRuntimesManager agentJavaRuntimesManager; private readonly InstanceManager instanceManager; @@ -24,7 +24,7 @@ public sealed class AgentMessageListener : IMessageToControllerListener { private readonly TaskCompletionSource<Guid> agentGuidWaiter = AsyncTasks.CreateCompletionSource<Guid>(); - internal AgentMessageListener(RpcClientConnection<IMessageToAgentListener> connection, AgentManager agentManager, AgentJavaRuntimesManager agentJavaRuntimesManager, InstanceManager instanceManager, InstanceLogManager instanceLogManager, EventLog eventLog, CancellationToken cancellationToken) { + internal AgentMessageListener(RpcConnectionToClient<IMessageToAgentListener> connection, AgentManager agentManager, AgentJavaRuntimesManager agentJavaRuntimesManager, InstanceManager instanceManager, InstanceLogManager instanceLogManager, EventLog eventLog, CancellationToken cancellationToken) { this.connection = connection; this.agentManager = agentManager; this.agentJavaRuntimesManager = agentJavaRuntimesManager; diff --git a/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs b/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs index 6b1a5eb..0b109cf 100644 --- a/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs +++ b/Controller/Phantom.Controller.Services/Rpc/WebMessageListener.cs @@ -6,9 +6,9 @@ using Phantom.Utils.Rpc.Message; namespace Phantom.Controller.Services.Rpc; public sealed class WebMessageListener : IMessageToControllerListener { - private readonly RpcClientConnection<IMessageToWebListener> connection; + private readonly RpcConnectionToClient<IMessageToWebListener> connection; - internal WebMessageListener(RpcClientConnection<IMessageToWebListener> connection) { + internal WebMessageListener(RpcConnectionToClient<IMessageToWebListener> connection) { this.connection = connection; } diff --git a/Controller/Phantom.Controller/Program.cs b/Controller/Phantom.Controller/Program.cs index 1cc2086..c09f980 100644 --- a/Controller/Phantom.Controller/Program.cs +++ b/Controller/Phantom.Controller/Program.cs @@ -1,4 +1,5 @@ using System.Reflection; +using NetMQ; using Phantom.Common.Logging; using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Web; @@ -61,10 +62,14 @@ try { return new RpcConfiguration(PhantomLogger.Create("Rpc", serviceName), PhantomLogger.Create<TaskManager>("Rpc", serviceName), host, port, connectionKey.Certificate); } - await Task.WhenAll( - RpcRuntime.Launch(ConfigureRpc("Agent", agentRpcServerHost, agentRpcServerPort, agentKeyData), AgentMessageRegistries.Definitions, controllerServices.CreateAgentMessageListener, shutdownCancellationToken), - RpcRuntime.Launch(ConfigureRpc("Web", webRpcServerHost, webRpcServerPort, webKeyData), WebMessageRegistries.Definitions, controllerServices.CreateWebMessageListener, shutdownCancellationToken) - ); + try { + await Task.WhenAll( + RpcRuntime.Launch(ConfigureRpc("Agent", agentRpcServerHost, agentRpcServerPort, agentKeyData), AgentMessageRegistries.Definitions, controllerServices.CreateAgentMessageListener, shutdownCancellationToken), + RpcRuntime.Launch(ConfigureRpc("Web", webRpcServerHost, webRpcServerPort, webKeyData), WebMessageRegistries.Definitions, controllerServices.CreateWebMessageListener, shutdownCancellationToken) + ); + } finally { + NetMQConfig.Cleanup(); + } return 0; } catch (OperationCanceledException) { diff --git a/Utils/Phantom.Utils.Rpc/Message/IMessageDefinitions.cs b/Utils/Phantom.Utils.Rpc/Message/IMessageDefinitions.cs index e006f4e..e215b6e 100644 --- a/Utils/Phantom.Utils.Rpc/Message/IMessageDefinitions.cs +++ b/Utils/Phantom.Utils.Rpc/Message/IMessageDefinitions.cs @@ -1,8 +1,8 @@ namespace Phantom.Utils.Rpc.Message; -public interface IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> where TReplyMessage : IMessage<TOutgoingListener, NoReply>, IMessage<TIncomingListener, NoReply> { - MessageRegistry<TOutgoingListener> Outgoing { get; } - MessageRegistry<TIncomingListener> Incoming { get; } +public interface IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { + MessageRegistry<TClientListener> ToClient { get; } + MessageRegistry<TServerListener> ToServer { get; } bool IsRegistrationMessage(Type messageType); TReplyMessage CreateReplyMessage(uint sequenceId, byte[] serializedReply); diff --git a/Utils/Phantom.Utils.Rpc/RpcClientRuntime.cs b/Utils/Phantom.Utils.Rpc/RpcClientRuntime.cs new file mode 100644 index 0000000..f234759 --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/RpcClientRuntime.cs @@ -0,0 +1,77 @@ +using NetMQ.Sockets; +using Phantom.Utils.Rpc.Message; +using Phantom.Utils.Rpc.Sockets; +using Phantom.Utils.Tasks; +using Serilog; +using Serilog.Events; + +namespace Phantom.Utils.Rpc; + +public abstract class RpcClientRuntime<TClientListener, TServerListener, TReplyMessage> : RpcRuntime<ClientSocket> where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { + private readonly RpcConnectionToServer<TServerListener> connection; + private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; + private readonly TClientListener messageListener; + + private readonly SemaphoreSlim disconnectSemaphore; + private readonly CancellationToken receiveCancellationToken; + + protected RpcClientRuntime(RpcClientSocket<TClientListener, TServerListener, TReplyMessage> socket, TClientListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket) { + this.connection = socket.Connection; + this.messageDefinitions = socket.MessageDefinitions; + this.messageListener = messageListener; + this.disconnectSemaphore = disconnectSemaphore; + this.receiveCancellationToken = receiveCancellationToken; + } + + protected sealed override void Run(ClientSocket socket, ILogger logger, MessageReplyTracker replyTracker, TaskManager taskManager) { + RunWithConnection(socket, connection, logger, taskManager); + } + + protected virtual void RunWithConnection(ClientSocket socket, RpcConnectionToServer<TServerListener> connection, ILogger logger, TaskManager taskManager) { + var handler = new Handler(connection, messageDefinitions, messageListener, logger, taskManager, receiveCancellationToken); + + try { + while (!receiveCancellationToken.IsCancellationRequested) { + var data = socket.Receive(receiveCancellationToken); + + LogMessageType(logger, data); + + if (data.Length > 0) { + messageDefinitions.ToClient.Handle(data, handler); + } + } + } catch (OperationCanceledException) { + // Ignore. + } finally { + logger.Debug("ZeroMQ client stopped receiving messages."); + disconnectSemaphore.Wait(CancellationToken.None); + } + } + + private void LogMessageType(ILogger logger, ReadOnlyMemory<byte> data) { + if (!logger.IsEnabled(LogEventLevel.Verbose)) { + return; + } + + if (data.Length > 0 && messageDefinitions.ToClient.TryGetType(data, out var type)) { + logger.Verbose("Received {MessageType} ({Bytes} B).", type.Name, data.Length); + } + else { + logger.Verbose("Received {Bytes} B message.", data.Length); + } + } + + private sealed class Handler : MessageHandler<TClientListener> { + private readonly RpcConnectionToServer<TServerListener> connection; + private readonly IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions; + + public Handler(RpcConnectionToServer<TServerListener> connection, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, TClientListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) { + this.connection = connection; + this.messageDefinitions = messageDefinitions; + } + + protected override Task SendReply(uint sequenceId, byte[] serializedReply) { + return connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply)); + } + } +} diff --git a/Utils/Phantom.Utils.Rpc/RpcConnectionToServer.cs b/Utils/Phantom.Utils.Rpc/RpcConnectionToServer.cs new file mode 100644 index 0000000..9ac135c --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/RpcConnectionToServer.cs @@ -0,0 +1,41 @@ +using NetMQ; +using NetMQ.Sockets; +using Phantom.Utils.Rpc.Message; + +namespace Phantom.Utils.Rpc; + +public sealed class RpcConnectionToServer<TListener> { + private readonly ClientSocket socket; + private readonly MessageRegistry<TListener> messageRegistry; + private readonly MessageReplyTracker replyTracker; + + internal RpcConnectionToServer(ClientSocket socket, MessageRegistry<TListener> messageRegistry, MessageReplyTracker replyTracker) { + this.socket = socket; + this.messageRegistry = messageRegistry; + this.replyTracker = replyTracker; + } + + public async Task Send<TMessage>(TMessage message) where TMessage : IMessage<TListener, NoReply> { + var bytes = messageRegistry.Write(message).ToArray(); + if (bytes.Length > 0) { + await socket.SendAsync(bytes); + } + } + + public async Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> where TReply : class { + var sequenceId = replyTracker.RegisterReply(); + + var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray(); + if (bytes.Length == 0) { + replyTracker.ForgetReply(sequenceId); + return null; + } + + await socket.SendAsync(bytes); + return await replyTracker.WaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); + } + + public void Receive(IReply message) { + replyTracker.ReceiveReply(message.SequenceId, message.SerializedReply); + } +} diff --git a/Utils/Phantom.Utils.Rpc/RpcRuntime.cs b/Utils/Phantom.Utils.Rpc/RpcRuntime.cs index 3f252a5..0474f97 100644 --- a/Utils/Phantom.Utils.Rpc/RpcRuntime.cs +++ b/Utils/Phantom.Utils.Rpc/RpcRuntime.cs @@ -1,39 +1,28 @@ using NetMQ; using Phantom.Utils.Rpc.Message; +using Phantom.Utils.Rpc.Sockets; using Phantom.Utils.Tasks; using Serilog; namespace Phantom.Utils.Rpc; -static class RpcRuntime { - internal static void SetDefaultSocketOptions(ThreadSafeSocketOptions options) { - // TODO test behavior when either agent or server are offline for a very long time - options.DelayAttachOnConnect = true; - options.ReceiveHighWatermark = 10_000; - options.SendHighWatermark = 10_000; - } -} - public abstract class RpcRuntime<TSocket> where TSocket : ThreadSafeSocket { private readonly TSocket socket; private readonly ILogger runtimeLogger; private readonly MessageReplyTracker replyTracker; private readonly TaskManager taskManager; - protected RpcRuntime(RpcConfiguration configuration, TSocket socket) { - RpcRuntime.SetDefaultSocketOptions(socket.Options); - this.socket = socket; - this.runtimeLogger = configuration.RuntimeLogger; - this.replyTracker = new MessageReplyTracker(runtimeLogger); - this.taskManager = new TaskManager(configuration.TaskManagerLogger); + protected RpcRuntime(RpcSocket<TSocket> socket) { + this.socket = socket.Socket; + this.runtimeLogger = socket.Config.RuntimeLogger; + this.replyTracker = socket.ReplyTracker; + this.taskManager = new TaskManager(socket.Config.TaskManagerLogger); } protected async Task Launch() { - Connect(socket); - void RunTask() { try { - Run(socket, replyTracker, taskManager); + Run(socket, runtimeLogger, replyTracker, taskManager); } catch (Exception e) { runtimeLogger.Error(e, "Caught exception in RPC thread."); } @@ -42,21 +31,19 @@ public abstract class RpcRuntime<TSocket> where TSocket : ThreadSafeSocket { try { await Task.Factory.StartNew(RunTask, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default); } catch (OperationCanceledException) { - // ignore + // Ignore. } finally { await taskManager.Stop(); - await Disconnect(); + await Disconnect(socket, runtimeLogger); socket.Dispose(); - NetMQConfig.Cleanup(); - runtimeLogger.Information("ZeroMQ client stopped."); + runtimeLogger.Information("ZeroMQ runtime stopped."); } } - protected abstract void Connect(TSocket socket); - protected abstract void Run(TSocket socket, MessageReplyTracker replyTracker, TaskManager taskManager); + protected abstract void Run(TSocket socket, ILogger logger, MessageReplyTracker replyTracker, TaskManager taskManager); - protected virtual Task Disconnect() { + protected virtual Task Disconnect(TSocket socket, ILogger logger) { return Task.CompletedTask; } } diff --git a/Utils/Phantom.Utils.Rpc/Sockets/RpcClientSocket.cs b/Utils/Phantom.Utils.Rpc/Sockets/RpcClientSocket.cs new file mode 100644 index 0000000..f26bf95 --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/Sockets/RpcClientSocket.cs @@ -0,0 +1,40 @@ +using NetMQ; +using NetMQ.Sockets; +using Phantom.Utils.Rpc.Message; + +namespace Phantom.Utils.Rpc.Sockets; + +public static class RpcClientSocket { + public static RpcClientSocket<TClientListener, TServerListener, TReplyMessage> Connect<TClientListener, TServerListener, TReplyMessage, THelloMessage>(RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, THelloMessage helloMessage) where THelloMessage : IMessage<TServerListener, NoReply> where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { + return RpcClientSocket<TClientListener, TServerListener, TReplyMessage>.Connect(config, messageDefinitions, helloMessage); + } +} + +public sealed class RpcClientSocket<TClientListener, TServerListener, TReplyMessage> : RpcSocket<ClientSocket> where TReplyMessage : IMessage<TClientListener, NoReply>, IMessage<TServerListener, NoReply> { + internal static RpcClientSocket<TClientListener, TServerListener, TReplyMessage> Connect<THelloMessage>(RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions, THelloMessage helloMessage) where THelloMessage : IMessage<TServerListener, NoReply> { + var socket = new ClientSocket(); + var options = socket.Options; + + options.CurveServerCertificate = config.ServerCertificate; + options.CurveCertificate = new NetMQCertificate(); + options.HelloMessage = messageDefinitions.ToServer.Write(helloMessage).ToArray(); + RpcSocket.SetDefaultSocketOptions(options); + + var url = config.TcpUrl; + var logger = config.RuntimeLogger; + + logger.Information("Starting ZeroMQ client and connecting to {Url}...", url); + socket.Connect(url); + logger.Information("ZeroMQ client ready."); + + return new RpcClientSocket<TClientListener, TServerListener, TReplyMessage>(socket, config, messageDefinitions); + } + + public RpcConnectionToServer<TServerListener> Connection { get; } + internal IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> MessageDefinitions { get; } + + private RpcClientSocket(ClientSocket socket, RpcConfiguration config, IMessageDefinitions<TClientListener, TServerListener, TReplyMessage> messageDefinitions) : base(socket, config) { + MessageDefinitions = messageDefinitions; + Connection = new RpcConnectionToServer<TServerListener>(socket, messageDefinitions.ToServer, ReplyTracker); + } +} diff --git a/Utils/Phantom.Utils.Rpc/Sockets/RpcServerSocket.cs b/Utils/Phantom.Utils.Rpc/Sockets/RpcServerSocket.cs new file mode 100644 index 0000000..6be05db --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/Sockets/RpcServerSocket.cs @@ -0,0 +1,25 @@ +using NetMQ.Sockets; + +namespace Phantom.Utils.Rpc.Sockets; + +public sealed class RpcServerSocket : RpcSocket<ServerSocket> { + public static RpcServerSocket Connect(RpcConfiguration config) { + var socket = new ServerSocket(); + var options = socket.Options; + + options.CurveServer = true; + options.CurveCertificate = config.ServerCertificate; + RpcSocket.SetDefaultSocketOptions(options); + + var url = config.TcpUrl; + var logger = config.RuntimeLogger; + + logger.Information("Starting ZeroMQ server on {Url}...", url); + socket.Bind(url); + logger.Information("ZeroMQ server initialized, listening for connections on port {Port}.", config.Port); + + return new RpcServerSocket(socket, config); + } + + private RpcServerSocket(ServerSocket socket, RpcConfiguration config) : base(socket, config) {} +} diff --git a/Utils/Phantom.Utils.Rpc/Sockets/RpcSocket.cs b/Utils/Phantom.Utils.Rpc/Sockets/RpcSocket.cs new file mode 100644 index 0000000..134c42d --- /dev/null +++ b/Utils/Phantom.Utils.Rpc/Sockets/RpcSocket.cs @@ -0,0 +1,25 @@ +using NetMQ; +using Phantom.Utils.Rpc.Message; + +namespace Phantom.Utils.Rpc.Sockets; + +static class RpcSocket { + internal static void SetDefaultSocketOptions(ThreadSafeSocketOptions options) { + // TODO test behavior when either agent or server are offline for a very long time + options.DelayAttachOnConnect = true; + options.ReceiveHighWatermark = 10_000; + options.SendHighWatermark = 10_000; + } +} + +public abstract class RpcSocket<TSocket> where TSocket : ThreadSafeSocket { + internal TSocket Socket { get; } + internal RpcConfiguration Config { get; } + internal MessageReplyTracker ReplyTracker { get; } + + protected RpcSocket(TSocket socket, RpcConfiguration config) { + Socket = socket; + Config = config; + ReplyTracker = new MessageReplyTracker(config.RuntimeLogger); + } +}