1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-05-08 12:34:03 +02:00

Refactor RPC to use a single long running task

This commit is contained in:
chylex 2022-10-19 15:24:40 +02:00
parent 69f3fbcbf4
commit bcb53528b9
Signed by: chylex
GPG Key ID: 4DE42C8F19A80548
5 changed files with 85 additions and 45 deletions
Agent/Phantom.Agent.Rpc
Server/Phantom.Server.Rpc
Utils/Phantom.Utils.Rpc

View File

@ -5,6 +5,7 @@ using Phantom.Common.Messages;
using Phantom.Common.Messages.ToServer;
using Phantom.Utils.Rpc;
using Phantom.Utils.Runtime;
using Serilog;
using Serilog.Events;
namespace Phantom.Agent.Rpc;
@ -24,7 +25,7 @@ public sealed class RpcLauncher : RpcRuntime<ClientSocket> {
private readonly RpcConfiguration config;
private readonly Guid agentGuid;
private readonly Func<ClientSocket, IMessageToAgentListener> messageListenerFactory;
private readonly SemaphoreSlim disconnectSemaphore;
private readonly CancellationToken receiveCancellationToken;
@ -45,39 +46,46 @@ public sealed class RpcLauncher : RpcRuntime<ClientSocket> {
logger.Information("ZeroMQ client ready.");
}
protected override async Task Run(ClientSocket socket, TaskManager taskManager) {
protected override void Run(ClientSocket socket, TaskManager taskManager) {
var logger = config.Logger;
var listener = messageListenerFactory(socket);
ServerMessaging.SetCurrentSocket(socket);
var keepAliveLoop = new KeepAliveLoop(socket, taskManager);
try {
// TODO optimize msg
await foreach (var bytes in socket.ReceiveBytesAsyncEnumerable(receiveCancellationToken)) {
if (logger.IsEnabled(LogEventLevel.Verbose)) {
if (bytes.Length > 0 && MessageRegistries.ToAgent.TryGetType(bytes, out var type)) {
logger.Verbose("Received {MessageType} ({Bytes} B) from server.", type.Name, bytes.Length);
}
else {
logger.Verbose("Received {Bytes} B message from server.", bytes.Length);
}
}
if (bytes.Length > 0) {
MessageRegistries.ToAgent.Handle(bytes, listener, taskManager, receiveCancellationToken);
try {
while (!receiveCancellationToken.IsCancellationRequested) {
var data = socket.Receive(receiveCancellationToken);
LogMessageType(logger, data);
if (data.Length > 0) {
MessageRegistries.ToAgent.Handle(data, listener, taskManager, receiveCancellationToken);
}
}
} catch (OperationCanceledException) {
// Ignore.
} finally {
logger.Verbose("ZeroMQ client stopped receiving messages.");
await disconnectSemaphore.WaitAsync(CancellationToken.None);
disconnectSemaphore.Wait(CancellationToken.None);
keepAliveLoop.Cancel();
}
}
private static void LogMessageType(ILogger logger, ReadOnlyMemory<byte> data) {
if (!logger.IsEnabled(LogEventLevel.Verbose)) {
return;
}
if (data.Length > 0 && MessageRegistries.ToAgent.TryGetType(data, out var type)) {
logger.Verbose("Received {MessageType} ({Bytes} B) from server.", type.Name, data.Length);
}
else {
logger.Verbose("Received {Bytes} B message from server.", data.Length);
}
}
protected override async Task Disconnect(ClientSocket socket) {
var unregisterTimeoutTask = Task.Delay(TimeSpan.FromSeconds(5), CancellationToken.None);
var finishedTask = await Task.WhenAny(socket.SendMessage(new UnregisterAgentMessage(agentGuid)), unregisterTimeoutTask);

View File

@ -1,5 +1,4 @@
using NetMQ;
using NetMQ.Sockets;
using NetMQ.Sockets;
using Phantom.Common.Messages;
using Phantom.Common.Messages.ToServer;
using Phantom.Utils.Rpc;
@ -39,7 +38,7 @@ public sealed class RpcLauncher : RpcRuntime<ServerSocket> {
logger.Information("ZeroMQ server initialized, listening for agent connections on port {Port}.", config.Port);
}
protected override async Task Run(ServerSocket socket, TaskManager taskManager) {
protected override void Run(ServerSocket socket, TaskManager taskManager) {
var logger = config.Logger;
var clients = new Dictionary<ulong, Client>();
@ -48,15 +47,16 @@ public sealed class RpcLauncher : RpcRuntime<ServerSocket> {
logger.Verbose("Closed connection to {RoutingId}.", e.RoutingId);
}
// TODO optimize msg
await foreach (var (routingId, bytes) in socket.ReceiveBytesAsyncEnumerable(cancellationToken)) {
if (bytes.Length == 0) {
LogMessageType(logger, routingId, bytes);
while (!cancellationToken.IsCancellationRequested) {
var (routingId, data) = socket.Receive(cancellationToken);
if (data.Length == 0) {
LogMessageType(logger, routingId, data);
continue;
}
if (!clients.TryGetValue(routingId, out var client)) {
if (!CheckIsAgentRegistrationMessage(bytes, logger, routingId)) {
if (!CheckIsAgentRegistrationMessage(data, logger, routingId)) {
continue;
}
@ -67,8 +67,8 @@ public sealed class RpcLauncher : RpcRuntime<ServerSocket> {
clients[routingId] = client;
}
LogMessageType(logger, routingId, bytes);
MessageRegistries.ToServer.Handle(bytes, client.Listener, taskManager, cancellationToken);
LogMessageType(logger, routingId, data);
MessageRegistries.ToServer.Handle(data, client.Listener, taskManager, cancellationToken);
if (client.Listener.IsDisposed) {
client.Connection.Close();
@ -80,21 +80,21 @@ public sealed class RpcLauncher : RpcRuntime<ServerSocket> {
}
}
private static void LogMessageType(ILogger logger, uint routingId, byte[] bytes) {
private static void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data) {
if (!logger.IsEnabled(LogEventLevel.Verbose)) {
return;
}
if (bytes.Length > 0 && MessageRegistries.ToServer.TryGetType(bytes, out var type)) {
logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", type.Name, bytes.Length, routingId);
if (data.Length > 0 && MessageRegistries.ToServer.TryGetType(data, out var type)) {
logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", type.Name, data.Length, routingId);
}
else {
logger.Verbose("Received {Bytes} B message from {RoutingId}.", bytes.Length, routingId);
logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId);
}
}
private static bool CheckIsAgentRegistrationMessage(byte[] bytes, ILogger logger, uint routingId) {
if (MessageRegistries.ToServer.TryGetType(bytes, out var type) && type == typeof(RegisterAgentMessage)) {
private static bool CheckIsAgentRegistrationMessage(ReadOnlyMemory<byte> data, ILogger logger, uint routingId) {
if (MessageRegistries.ToServer.TryGetType(data, out var type) && type == typeof(RegisterAgentMessage)) {
return true;
}

View File

@ -24,11 +24,9 @@ public sealed class MessageRegistry<TListener, TMessageBase> where TMessageBase
codeToDeserializerMapping.Add(code, MessageSerializer.Deserialize<TMessage, TMessageBase, TListener>());
}
public bool TryGetType(byte[] bytes, [NotNullWhen(true)] out Type? type) {
var memory = new ReadOnlyMemory<byte>(bytes);
public bool TryGetType(ReadOnlyMemory<byte> data, [NotNullWhen(true)] out Type? type) {
try {
var code = MessageSerializer.ReadCode(ref memory);
var code = MessageSerializer.ReadCode(ref data);
return codeToTypeMapping.TryGetValue(code, out type);
} catch (Exception) {
type = null;
@ -59,12 +57,10 @@ public sealed class MessageRegistry<TListener, TMessageBase> where TMessageBase
}
}
public void Handle(byte[] bytes, TListener listener, TaskManager taskManager, CancellationToken cancellationToken) {
var memory = new ReadOnlyMemory<byte>(bytes);
public void Handle(ReadOnlyMemory<byte> data, TListener listener, TaskManager taskManager, CancellationToken cancellationToken) {
ushort code;
try {
code = MessageSerializer.ReadCode(ref memory);
code = MessageSerializer.ReadCode(ref data);
} catch (Exception e) {
logger.Error(e, "Failed to deserialize message code.");
return;
@ -77,7 +73,7 @@ public sealed class MessageRegistry<TListener, TMessageBase> where TMessageBase
TMessageBase message;
try {
message = deserialize(memory);
message = deserialize(data);
} catch (Exception e) {
logger.Error(e, "Failed to deserialize message with code {Code}.", code);
return;

View File

@ -0,0 +1,32 @@
using NetMQ;
using NetMQ.Sockets;
namespace Phantom.Utils.Rpc;
public static class RpcExtensions {
public static ReadOnlyMemory<byte> Receive(this ClientSocket socket, CancellationToken cancellationToken) {
var msg = new Msg();
msg.InitEmpty();
try {
socket.Receive(ref msg, cancellationToken);
return msg.SliceAsMemory();
} finally {
// Only releases references, so the returned ReadOnlyMemory is safe.
msg.Close();
}
}
public static (uint, ReadOnlyMemory<byte>) Receive(this ServerSocket socket, CancellationToken cancellationToken) {
var msg = new Msg();
msg.InitEmpty();
try {
socket.Receive(ref msg, cancellationToken);
return (msg.RoutingId, msg.SliceAsMemory());
} finally {
// Only releases references, so the returned ReadOnlyMemory is safe.
msg.Close();
}
}
}

View File

@ -38,9 +38,13 @@ public abstract class RpcRuntime<TSocket> where TSocket : ThreadSafeSocket, new(
protected async Task Launch() {
Connect(socket);
void RunTask() {
Run(socket, taskManager);
}
try {
await Run(socket, taskManager);
await Task.Factory.StartNew(RunTask, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default);
} catch (OperationCanceledException) {
// ignore
} finally {
@ -55,7 +59,7 @@ public abstract class RpcRuntime<TSocket> where TSocket : ThreadSafeSocket, new(
}
protected abstract void Connect(TSocket socket);
protected abstract Task Run(TSocket socket, TaskManager taskManager);
protected abstract void Run(TSocket socket, TaskManager taskManager);
protected virtual Task Disconnect(TSocket socket) {
return Task.CompletedTask;