1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-05-02 03:34:06 +02:00

Fix Controller handling messages from unauthorized Agents

This commit is contained in:
chylex 2023-10-23 13:19:28 +02:00
parent 2a9bb9e6ac
commit cd332a6571
Signed by: chylex
GPG Key ID: 4DE42C8F19A80548
6 changed files with 38 additions and 30 deletions
Agent/Phantom.Agent.Rpc
Common/Phantom.Common.Messages.Agent/ToController
Controller
Phantom.Controller.Rpc
Phantom.Controller.Services/Rpc
Utils/Phantom.Utils.Rpc/Message

View File

@ -13,14 +13,10 @@ namespace Phantom.Agent.Rpc;
public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> { public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> {
public static Task Launch(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, AgentInfo agentInfo, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) { public static Task Launch(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, AgentInfo agentInfo, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) {
return new RpcClientRuntime(socket, agentInfo.Guid, messageListener, disconnectSemaphore, receiveCancellationToken).Launch(); return new RpcClientRuntime(socket, messageListener, disconnectSemaphore, receiveCancellationToken).Launch();
} }
private readonly Guid agentGuid; private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) {}
private RpcClientRuntime(RpcClientSocket<IMessageToAgentListener, IMessageToControllerListener, ReplyMessage> socket, Guid agentGuid, IMessageToAgentListener messageListener, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, messageListener, disconnectSemaphore, receiveCancellationToken) {
this.agentGuid = agentGuid;
}
protected override void RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToControllerListener> connection, ILogger logger, TaskManager taskManager) { protected override void RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToControllerListener> connection, ILogger logger, TaskManager taskManager) {
var keepAliveLoop = new KeepAliveLoop(connection); var keepAliveLoop = new KeepAliveLoop(connection);
@ -32,7 +28,7 @@ public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgentListener,
} }
protected override async Task Disconnect(ClientSocket socket, ILogger logger) { protected override async Task Disconnect(ClientSocket socket, ILogger logger) {
var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage(agentGuid)).ToArray(); var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage()).ToArray();
try { try {
await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None); await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None);
} catch (TimeoutException) { } catch (TimeoutException) {

View File

@ -4,9 +4,7 @@ using Phantom.Utils.Rpc.Message;
namespace Phantom.Common.Messages.Agent.ToController; namespace Phantom.Common.Messages.Agent.ToController;
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record UnregisterAgentMessage( public sealed partial record UnregisterAgentMessage : IMessageToController {
[property: MemoryPackOrder(0)] Guid AgentGuid
) : IMessageToController {
public Task<NoReply> Accept(IMessageToControllerListener listener) { public Task<NoReply> Accept(IMessageToControllerListener listener) {
return listener.HandleUnregisterAgent(this); return listener.HandleUnregisterAgent(this);
} }

View File

@ -11,6 +11,13 @@ public sealed class RpcConnectionToClient<TListener> {
private readonly MessageRegistry<TListener> messageRegistry; private readonly MessageRegistry<TListener> messageRegistry;
private readonly MessageReplyTracker messageReplyTracker; private readonly MessageReplyTracker messageReplyTracker;
private volatile bool isAuthorized;
public bool IsAuthorized {
get => isAuthorized;
set => isAuthorized = value;
}
internal event EventHandler<RpcClientConnectionClosedEventArgs>? Closed; internal event EventHandler<RpcClientConnectionClosedEventArgs>? Closed;
private bool isClosed; private bool isClosed;

View File

@ -42,12 +42,14 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
var (routingId, data) = socket.Receive(cancellationToken); var (routingId, data) = socket.Receive(cancellationToken);
if (data.Length == 0) { if (data.Length == 0) {
LogMessageType(logger, routingId, data); LogMessageType(logger, routingId, data, messageType: null);
continue; continue;
} }
Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null;
if (!clients.TryGetValue(routingId, out var client)) { if (!clients.TryGetValue(routingId, out var client)) {
if (!CheckIsRegistrationMessage(data, logger, routingId)) { if (!CheckIsRegistrationMessage(messageType, logger, routingId)) {
continue; continue;
} }
@ -58,7 +60,11 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
clients[routingId] = client; clients[routingId] = client;
} }
LogMessageType(logger, routingId, data); if (!client.Connection.IsAuthorized && !CheckIsRegistrationMessage(messageType, logger, routingId)) {
continue;
}
LogMessageType(logger, routingId, data, messageType);
messageDefinitions.ToServer.Handle(data, client); messageDefinitions.ToServer.Handle(data, client);
} }
@ -67,25 +73,25 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
} }
} }
private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data) { private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data, Type? messageType) {
if (!logger.IsEnabled(LogEventLevel.Verbose)) { if (!logger.IsEnabled(LogEventLevel.Verbose)) {
return; return;
} }
if (data.Length > 0 && messageDefinitions.ToServer.TryGetType(data, out var type)) { if (data.Length > 0 && messageType != null) {
logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", type.Name, data.Length, routingId); logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", messageType.Name, data.Length, routingId);
} }
else { else {
logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId); logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId);
} }
} }
private bool CheckIsRegistrationMessage(ReadOnlyMemory<byte> data, ILogger logger, uint routingId) { private bool CheckIsRegistrationMessage(Type? messageType, ILogger logger, uint routingId) {
if (messageDefinitions.ToServer.TryGetType(data, out var type) && messageDefinitions.IsRegistrationMessage(type)) { if (messageType != null && messageDefinitions.IsRegistrationMessage(messageType)) {
return true; return true;
} }
logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", type?.Name ?? "unknown message", routingId); logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", messageType?.Name ?? "unknown message", routingId);
return false; return false;
} }
@ -98,7 +104,7 @@ internal sealed class RpcRuntime<TClientListener, TServerListener, TReplyMessage
this.Connection = connection; this.Connection = connection;
this.messageDefinitions = messageDefinitions; this.messageDefinitions = messageDefinitions;
} }
protected override Task SendReply(uint sequenceId, byte[] serializedReply) { protected override Task SendReply(uint sequenceId, byte[] serializedReply) {
return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply)); return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply));
} }

View File

@ -39,8 +39,8 @@ public sealed class AgentMessageListener : IMessageToControllerListener {
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.ConnectionAlreadyHasAnAgent)); await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.ConnectionAlreadyHasAnAgent));
} }
else if (await agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, instanceManager, connection)) { else if (await agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, instanceManager, connection)) {
var guid = message.AgentInfo.Guid; connection.IsAuthorized = true;
agentGuidWaiter.SetResult(guid); agentGuidWaiter.SetResult(message.AgentInfo.Guid);
} }
return NoReply.Instance; return NoReply.Instance;
@ -51,8 +51,11 @@ public sealed class AgentMessageListener : IMessageToControllerListener {
} }
public Task<NoReply> HandleUnregisterAgent(UnregisterAgentMessage message) { public Task<NoReply> HandleUnregisterAgent(UnregisterAgentMessage message) {
if (agentManager.UnregisterAgent(message.AgentGuid, connection)) { if (agentGuidWaiter.Task.IsCompleted) {
instanceManager.SetInstanceStatesForAgent(message.AgentGuid, InstanceStatus.Offline); var agentGuid = agentGuidWaiter.Task.Result;
if (agentManager.UnregisterAgent(agentGuid, connection)) {
instanceManager.SetInstanceStatesForAgent(agentGuid, InstanceStatus.Offline);
}
} }
connection.Close(); connection.Close();

View File

@ -4,14 +4,13 @@ using Serilog;
namespace Phantom.Utils.Rpc.Message; namespace Phantom.Utils.Rpc.Message;
public abstract class MessageHandler<TListener> { public abstract class MessageHandler<TListener> {
protected TListener Listener { get; } private readonly TListener listener;
private readonly ILogger logger; private readonly ILogger logger;
private readonly TaskManager taskManager; private readonly TaskManager taskManager;
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
protected MessageHandler(TListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) { protected MessageHandler(TListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) {
this.Listener = listener; this.listener = listener;
this.logger = logger; this.logger = logger;
this.taskManager = taskManager; this.taskManager = taskManager;
this.cancellationToken = cancellationToken; this.cancellationToken = cancellationToken;
@ -29,12 +28,11 @@ public abstract class MessageHandler<TListener> {
} }
private async Task Handle<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : IMessage<TListener, TReply> { private async Task Handle<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : IMessage<TListener, TReply> {
TReply reply = await message.Accept(Listener); TReply reply = await message.Accept(listener);
if (reply is not NoReply) { if (reply is not NoReply) {
await SendReply(sequenceId, MessageSerializer.Serialize(reply)); await SendReply(sequenceId, MessageSerializer.Serialize(reply));
} }
} }
protected abstract Task SendReply(uint sequenceId, byte[] serializedReply); protected abstract Task SendReply(uint sequenceId, byte[] serializedReply);
} }