1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-10-17 18:39:35 +02:00

4 Commits

41 changed files with 719 additions and 400 deletions

View File

@@ -1,19 +1,19 @@
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Message;
namespace Phantom.Agent.Services.Rpc; namespace Phantom.Agent.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> sendChannel) { public sealed class ControllerConnection(MessageSender<IMessageToController> sender) {
internal bool TrySend<TMessage>(TMessage message) where TMessage : IMessageToController { internal bool TrySend<TMessage>(TMessage message) where TMessage : IMessageToController {
return sendChannel.TrySendMessage(message); return sender.TrySend(message);
} }
internal ValueTask Send<TMessage>(TMessage message, CancellationToken cancellationToken) where TMessage : IMessageToController { internal ValueTask Send<TMessage>(TMessage message, CancellationToken cancellationToken) where TMessage : IMessageToController {
return sendChannel.SendMessage(message, cancellationToken); return sender.Send(message, cancellationToken);
} }
internal Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : IMessageToController, ICanReply<TReply> { internal Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : IMessageToController, ICanReply<TReply> {
return sendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, cancellationToken); return sender.Send<TMessage, TReply>(message, waitForReplyTime, cancellationToken);
} }
} }

View File

@@ -10,6 +10,7 @@ using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Client; using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Threading; using Phantom.Utils.Threading;
@@ -60,8 +61,12 @@ try {
CertificateThumbprint: agentKey.Value.CertificateThumbprint, CertificateThumbprint: agentKey.Value.CertificateThumbprint,
AuthToken: agentKey.Value.AuthToken, AuthToken: agentKey.Value.AuthToken,
Handshake: controllerHandshake, Handshake: controllerHandshake,
SendQueueCapacity: 500, CommonParameters: new RpcCommonConnectionParameters(
PingInterval: TimeSpan.FromSeconds(10) MessageQueueCapacity: 250,
FrameQueueCapacity: 500,
MaxConcurrentlyHandledMessages: 50,
PingInterval: TimeSpan.FromSeconds(10)
)
); );
using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, AgentMessageRegistries.Definitions, shutdownCancellationToken); using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, AgentMessageRegistries.Definitions, shutdownCancellationToken);
@@ -74,7 +79,7 @@ try {
try { try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent..."); PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent...");
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.SendChannel), javaRuntimeRepository); var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.MessageSender), javaRuntimeRepository);
var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices); var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices);
var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler"); var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
@@ -91,7 +96,7 @@ try {
PhantomLogger.Root.Information("Unregistering agent..."); PhantomLogger.Root.Information("Unregistering agent...");
try { try {
using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10)); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
await rpcClient.SendChannel.SendMessage(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token); await rpcClient.MessageSender.Send(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister agent after shutdown."); PhantomLogger.Root.Warning("Could not unregister agent after shutdown.");
} catch (Exception e) { } catch (Exception e) {

View File

@@ -41,7 +41,7 @@ sealed class AgentConnection(Guid agentGuid, string agentName) {
public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToAgent { public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToAgent {
lock (this) { lock (this) {
if (connection != null) { if (connection != null) {
return connection.SendChannel.SendMessage(message); return connection.MessageSender.Send(message);
} }
} }
@@ -52,7 +52,7 @@ sealed class AgentConnection(Guid agentGuid, string agentName) {
public Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToAgent, ICanReply<TReply> where TReply : class { public Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToAgent, ICanReply<TReply> where TReply : class {
lock (this) { lock (this) {
if (connection != null) { if (connection != null) {
return connection.SendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!; return connection.MessageSender.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!;
} }
} }

View File

@@ -5,24 +5,24 @@ using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToWeb; using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Controller.Services.Instances; using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Message;
namespace Phantom.Controller.Services.Rpc; namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> { sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> {
public readonly record struct Init(RpcSendChannel<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager); public readonly record struct Init(MessageSender<IMessageToWeb> MessageSender, ControllerState ControllerState, InstanceLogManager InstanceLogManager);
public static Props<ICommand> Factory(Init init) { public static Props<ICommand> Factory(Init init) {
return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly RpcSendChannel<IMessageToWeb> connection; private readonly MessageSender<IMessageToWeb> messageSender;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager; private readonly InstanceLogManager instanceLogManager;
private readonly ActorRef<ICommand> selfCached; private readonly ActorRef<ICommand> selfCached;
private WebMessageDataUpdateSenderActor(Init init) { private WebMessageDataUpdateSenderActor(Init init) {
this.connection = init.Connection; this.messageSender = init.MessageSender;
this.controllerState = init.ControllerState; this.controllerState = init.ControllerState;
this.instanceLogManager = init.InstanceLogManager; this.instanceLogManager = init.InstanceLogManager;
this.selfCached = SelfTyped; this.selfCached = SelfTyped;
@@ -70,18 +70,18 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate
private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand; private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand;
private Task RefreshAgents(RefreshAgentsCommand command) { private Task RefreshAgents(RefreshAgentsCommand command) {
return connection.SendMessage(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())).AsTask(); return messageSender.Send(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())).AsTask();
} }
private Task RefreshInstances(RefreshInstancesCommand command) { private Task RefreshInstances(RefreshInstancesCommand command) {
return connection.SendMessage(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())).AsTask(); return messageSender.Send(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())).AsTask();
} }
private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) { private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) {
return connection.SendMessage(new InstanceOutputMessage(command.InstanceGuid, command.Lines)).AsTask(); return messageSender.Send(new InstanceOutputMessage(command.InstanceGuid, command.Lines)).AsTask();
} }
private Task RefreshUserSession(RefreshUserSessionCommand command) { private Task RefreshUserSession(RefreshUserSessionCommand command) {
return connection.SendMessage(new RefreshUserSessionMessage(command.UserGuid)).AsTask(); return messageSender.Send(new RefreshUserSessionMessage(command.UserGuid)).AsTask();
} }
} }

View File

@@ -63,7 +63,7 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
this.minecraftVersions = init.MinecraftVersions; this.minecraftVersions = init.MinecraftVersions;
this.eventLogManager = init.EventLogManager; this.eventLogManager = init.EventLogManager;
var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection.SendChannel, controllerState, init.InstanceLogManager); var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection.MessageSender, controllerState, init.InstanceLogManager);
Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender"); Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender");
ReceiveAsync<UnregisterWebMessage>(HandleUnregisterWeb); ReceiveAsync<UnregisterWebMessage>(HandleUnregisterWeb);

View File

@@ -6,6 +6,7 @@ using Phantom.Controller.Database.Postgres;
using Phantom.Controller.Services; using Phantom.Controller.Services;
using Phantom.Utils.IO; using Phantom.Utils.IO;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Server; using Phantom.Utils.Rpc.Runtime.Server;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Tasks; using Phantom.Utils.Tasks;
@@ -64,16 +65,24 @@ try {
EndPoint: agentRpcServerHost, EndPoint: agentRpcServerHost,
Certificate: agentKeyData.Certificate, Certificate: agentKeyData.Certificate,
AuthToken: agentKeyData.AuthToken, AuthToken: agentKeyData.AuthToken,
SendQueueCapacity: 100, CommonParameters: new RpcCommonConnectionParameters(
PingInterval: TimeSpan.FromSeconds(10) MessageQueueCapacity: 50,
FrameQueueCapacity: 100,
MaxConcurrentlyHandledMessages: 20,
PingInterval: TimeSpan.FromSeconds(10)
)
); );
var webConnectionParameters = new RpcServerConnectionParameters( var webConnectionParameters = new RpcServerConnectionParameters(
EndPoint: webRpcServerHost, EndPoint: webRpcServerHost,
Certificate: webKeyData.Certificate, Certificate: webKeyData.Certificate,
AuthToken: webKeyData.AuthToken, AuthToken: webKeyData.AuthToken,
SendQueueCapacity: 500, CommonParameters: new RpcCommonConnectionParameters(
PingInterval: TimeSpan.FromMinutes(1) MessageQueueCapacity: 250,
FrameQueueCapacity: 500,
MaxConcurrentlyHandledMessages: 100,
PingInterval: TimeSpan.FromMinutes(1)
)
); );
LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([ LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([

View File

@@ -6,12 +6,14 @@ interface IFrame {
private const byte TypePingId = 0; private const byte TypePingId = 0;
private const byte TypePongId = 1; private const byte TypePongId = 1;
private const byte TypeMessageId = 2; private const byte TypeMessageId = 2;
private const byte TypeReplyId = 3; private const byte TypeAcknowledgmentId = 3;
private const byte TypeErrorId = 4; private const byte TypeReplyId = 4;
private const byte TypeErrorId = 5;
static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]); static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]);
static readonly ReadOnlyMemory<byte> TypePong = new ([TypePongId]); static readonly ReadOnlyMemory<byte> TypePong = new ([TypePongId]);
static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]); static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]);
static readonly ReadOnlyMemory<byte> TypeAcknowledgment = new ([TypeAcknowledgmentId]);
static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]); static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]);
static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]); static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]);
@@ -37,6 +39,11 @@ interface IFrame {
await reader.OnMessageFrame(messageFrame, cancellationToken); await reader.OnMessageFrame(messageFrame, cancellationToken);
break; break;
case TypeAcknowledgmentId:
var acknowledgmentFrame = await AcknowledgmentFrame.Read(stream, cancellationToken);
reader.OnAcknowledgmentFrame(acknowledgmentFrame);
break;
case TypeReplyId: case TypeReplyId:
var replyFrame = await ReplyFrame.Read(stream, cancellationToken); var replyFrame = await ReplyFrame.Read(stream, cancellationToken);
reader.OnReplyFrame(replyFrame); reader.OnReplyFrame(replyFrame);

View File

@@ -6,6 +6,7 @@ interface IFrameReader {
ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken); ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken);
void OnPongFrame(PongFrame frame); void OnPongFrame(PongFrame frame);
Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken); Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken);
void OnAcknowledgmentFrame(AcknowledgmentFrame frame);
void OnReplyFrame(ReplyFrame frame); void OnReplyFrame(ReplyFrame frame);
void OnErrorFrame(ErrorFrame frame); void OnErrorFrame(ErrorFrame frame);
void OnUnknownFrameId(byte frameId); void OnUnknownFrameId(byte frameId);

View File

@@ -0,0 +1,18 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record AcknowledgmentFrame(uint FirstMessageId, uint LastMessageId) : IFrame {
public ReadOnlyMemory<byte> FrameType => IFrame.TypeAcknowledgment;
public async Task Write(Stream stream, CancellationToken cancellationToken = default) {
await RpcSerialization.WriteUnsignedInt(FirstMessageId, stream, cancellationToken);
await RpcSerialization.WriteUnsignedInt(LastMessageId, stream, cancellationToken);
}
public static async Task<AcknowledgmentFrame> Read(Stream stream, CancellationToken cancellationToken) {
var firstMessageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
var lastMessageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
return new AcknowledgmentFrame(firstMessageId, lastMessageId);
}
}

View File

@@ -8,30 +8,26 @@ sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<b
public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage; public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
int serializedMessageLength = SerializedMessage.Length; uint serializedMessageLength = (uint) SerializedMessage.Length;
CheckMessageLength(serializedMessageLength); CheckMessageLength(serializedMessageLength);
await RpcSerialization.WriteUnsignedInt(MessageId, stream, cancellationToken); await RpcSerialization.WriteUnsignedInt(MessageId, stream, cancellationToken);
await RpcSerialization.WriteUnsignedShort(RegistryCode, stream, cancellationToken); await RpcSerialization.WriteUnsignedShort(RegistryCode, stream, cancellationToken);
await RpcSerialization.WriteSignedInt(serializedMessageLength, stream, cancellationToken); await RpcSerialization.WriteUnsignedInt(serializedMessageLength, stream, cancellationToken);
await stream.WriteAsync(SerializedMessage, cancellationToken); await stream.WriteAsync(SerializedMessage, cancellationToken);
} }
public static async Task<MessageFrame> Read(Stream stream, CancellationToken cancellationToken) { public static async Task<MessageFrame> Read(Stream stream, CancellationToken cancellationToken) {
var messageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken); var messageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
var registryCode = await RpcSerialization.ReadUnsignedShort(stream, cancellationToken); var registryCode = await RpcSerialization.ReadUnsignedShort(stream, cancellationToken);
var serializedMessageLength = await RpcSerialization.ReadSignedInt(stream, cancellationToken); var serializedMessageLength = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
CheckMessageLength(serializedMessageLength); CheckMessageLength(serializedMessageLength);
var serializedMessage = await RpcSerialization.ReadBytes(serializedMessageLength, stream, cancellationToken); var serializedMessage = await RpcSerialization.ReadBytes(serializedMessageLength, stream, cancellationToken);
return new MessageFrame(messageId, registryCode, serializedMessage); return new MessageFrame(messageId, registryCode, serializedMessage);
} }
private static void CheckMessageLength(int messageLength) { private static void CheckMessageLength(uint messageLength) {
if (messageLength < 0) {
throw new RpcErrorException("Message length is negative.", RpcError.InvalidData);
}
if (messageLength > MaxMessageBytes) { if (messageLength > MaxMessageBytes) {
throw new RpcErrorException("Message is too large: " + messageLength + " > " + MaxMessageBytes + " bytes", RpcError.MessageTooLarge); throw new RpcErrorException("Message is too large: " + messageLength + " > " + MaxMessageBytes + " bytes", RpcError.MessageTooLarge);
} }

View File

@@ -8,30 +8,26 @@ sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> Serializ
public ReadOnlyMemory<byte> FrameType => IFrame.TypeReply; public ReadOnlyMemory<byte> FrameType => IFrame.TypeReply;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
int replyLength = SerializedReply.Length; uint serializedReplyLength = (uint) SerializedReply.Length;
CheckReplyLength(replyLength); CheckReplyLength(serializedReplyLength);
await RpcSerialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken); await RpcSerialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken);
await RpcSerialization.WriteSignedInt(replyLength, stream, cancellationToken); await RpcSerialization.WriteUnsignedInt(serializedReplyLength, stream, cancellationToken);
await stream.WriteAsync(SerializedReply, cancellationToken); await stream.WriteAsync(SerializedReply, cancellationToken);
} }
public static async Task<ReplyFrame> Read(Stream stream, CancellationToken cancellationToken) { public static async Task<ReplyFrame> Read(Stream stream, CancellationToken cancellationToken) {
var replyingToMessageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken); var replyingToMessageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
var replyLength = await RpcSerialization.ReadSignedInt(stream, cancellationToken); var serializedReplyLength = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
CheckReplyLength(replyLength); CheckReplyLength(serializedReplyLength);
var reply = await RpcSerialization.ReadBytes(replyLength, stream, cancellationToken); var serializedReply = await RpcSerialization.ReadBytes(serializedReplyLength, stream, cancellationToken);
return new ReplyFrame(replyingToMessageId, reply); return new ReplyFrame(replyingToMessageId, serializedReply);
} }
private static void CheckReplyLength(int replyLength) { private static void CheckReplyLength(uint replyLength) {
if (replyLength < 0) {
throw new RpcErrorException("Reply length is negative.", RpcError.InvalidData);
}
if (replyLength > MaxReplyBytes) { if (replyLength > MaxReplyBytes) {
throw new RpcErrorException("Reply is too large: " + replyLength + " > " + MaxReplyBytes + " bytes", RpcError.MessageTooLarge); throw new RpcErrorException("Reply is too large: " + replyLength + " > " + MaxReplyBytes + " bytes", RpcError.ReplyTooLarge);
} }
} }
} }

View File

@@ -1,6 +0,0 @@
namespace Phantom.Utils.Rpc.Handshake;
public interface IRpcHandshake<T> {
Task Send(Stream stream, CancellationToken cancellationToken);
Task<T> Receive(Stream stream, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,7 @@
namespace Phantom.Utils.Rpc.Handshake;
public enum RpcFinalHandshakeResult : byte {
Error = 0,
NewSession = 1,
ReusedSession = 2,
}

View File

@@ -1,6 +1,9 @@
namespace Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Runtime;
interface IRpcReplySender { namespace Phantom.Utils.Rpc.Message;
interface IMessageReplySender {
ValueTask SendEmptyReply(uint replyingToMessageId, CancellationToken cancellationToken);
ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken); ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken);
ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken); ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken);
} }

View File

@@ -0,0 +1,39 @@
using Phantom.Utils.Collections;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Threading;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageAcknowledgmentQueue {
private readonly Lock @lock = new ();
private readonly RangeSet<uint> pendingMessageIdRanges = new ();
private readonly ManualResetEventSlim pendingEvent = new ();
public void Enqueue(uint messageId) {
lock (@lock) {
pendingMessageIdRanges.Add(messageId);
}
pendingEvent.Set();
}
public Task Wait(CancellationToken cancellationToken) {
return pendingEvent.WaitHandle.WaitOneAsync(cancellationToken);
}
public List<AcknowledgmentFrame> Drain() {
pendingEvent.Reset();
List<AcknowledgmentFrame> frames = [];
lock (@lock) {
foreach (var range in pendingMessageIdRanges) {
frames.Add(new AcknowledgmentFrame(range.Min, range.Max));
}
pendingMessageIdRanges.Clear();
}
return frames;
}
}

View File

@@ -0,0 +1,23 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageHandler<TMessageBase>(IMessageReceiver<TMessageBase> messageReceiver, IMessageReplySender replySender) {
public IMessageReceiver<TMessageBase> Receiver => messageReceiver;
public void OnPing() {
messageReceiver.OnPing();
}
public ValueTask SendEmptyReply(uint messageId, CancellationToken cancellationToken) {
return replySender.SendEmptyReply(messageId, cancellationToken);
}
public ValueTask SendReply<TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) {
return replySender.SendReply(messageId, reply, cancellationToken);
}
public ValueTask SendError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return replySender.SendError(messageId, error, cancellationToken);
}
}

View File

@@ -6,7 +6,6 @@ sealed class MessageReceiveTracker {
private readonly RangeSet<uint> receivedMessageIds = new (); private readonly RangeSet<uint> receivedMessageIds = new ();
public bool ReceiveMessage(uint messageId) { public bool ReceiveMessage(uint messageId) {
// TODO reset on session change and invalidate replies
lock (receivedMessageIds) { lock (receivedMessageIds) {
return receivedMessageIds.Add(messageId); return receivedMessageIds.Add(messageId);
} }

View File

@@ -9,7 +9,7 @@ namespace Phantom.Utils.Rpc.Message;
public sealed class MessageRegistry<TMessageBase>(ILogger logger) { public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
private readonly Dictionary<Type, ushort> typeToCodeMapping = new (); private readonly Dictionary<Type, ushort> typeToCodeMapping = new ();
private readonly Dictionary<ushort, Type> codeToTypeMapping = new (); private readonly Dictionary<ushort, Type> codeToTypeMapping = new ();
private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, RpcMessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new (); private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, MessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new ();
public void Add<TMessage>(ushort code) where TMessage : TMessageBase { public void Add<TMessage>(ushort code) where TMessage : TMessageBase {
if (HasReplyType(typeof(TMessage))) { if (HasReplyType(typeof(TMessage))) {
@@ -50,7 +50,7 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
} }
} }
internal async Task Handle(MessageFrame frame, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) { internal async Task Handle(MessageFrame frame, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
uint messageId = frame.MessageId; uint messageId = frame.MessageId;
if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) { if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) {
@@ -62,26 +62,35 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
} }
} }
private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase { private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
TMessage message; TMessage message;
try { try {
message = RpcSerialization.Deserialize<TMessage>(serializedMessage); message = RpcSerialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); await OnMessageDeserializationError<TMessage>(messageId, e, handler, cancellationToken);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return; return;
} }
handler.Receiver.OnMessage(message); try {
handler.Receiver.OnMessage(message);
} catch (Exception e) {
await OnMessageHandlingError<TMessage>(messageId, e, handler, cancellationToken);
return;
}
try {
await handler.SendEmptyReply(messageId, cancellationToken);
} catch (Exception e) {
await OnMessageReplyingError<TMessage>(messageId, e, handler, cancellationToken);
}
} }
private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> { private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
TMessage message; TMessage message;
try { try {
message = RpcSerialization.Deserialize<TMessage>(serializedMessage); message = RpcSerialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); await OnMessageDeserializationError<TMessage>(messageId, e, handler, cancellationToken);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return; return;
} }
@@ -89,16 +98,29 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
try { try {
reply = await handler.Receiver.OnMessage<TMessage, TReply>(message, cancellationToken); reply = await handler.Receiver.OnMessage<TMessage, TReply>(message, cancellationToken);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); await OnMessageHandlingError<TMessage>(messageId, e, handler, cancellationToken);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
return; return;
} }
try { try {
await handler.SendReply(messageId, reply, cancellationToken); await handler.SendReply(messageId, reply, cancellationToken);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); await OnMessageReplyingError<TMessage>(messageId, e, handler, cancellationToken);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
} }
} }
private async Task OnMessageDeserializationError<TMessage>(uint messageId, Exception exception, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
logger.Error(exception, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
}
private async Task OnMessageHandlingError<TMessage>(uint messageId, Exception exception, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
logger.Error(exception, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
}
private async Task OnMessageReplyingError<TMessage>(uint messageId, Exception exception, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
logger.Error(exception, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageReplyingError, cancellationToken);
}
} }

View File

@@ -20,30 +20,30 @@ sealed class MessageReplyTracker {
public async Task<TReply> WaitForReply<TReply>(uint messageId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) { public async Task<TReply> WaitForReply<TReply>(uint messageId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) {
if (!replyTasks.TryGetValue(messageId, out var completionSource)) { if (!replyTasks.TryGetValue(messageId, out var completionSource)) {
logger.Warning("No reply callback for id {MessageId}.", messageId); logger.Warning("No reply callback for message {MessageId}.", messageId);
throw new ArgumentException("No reply callback for id: " + messageId, nameof(messageId)); throw new ArgumentException("No reply callback for message: " + messageId, nameof(messageId));
} }
try { try {
ReadOnlyMemory<byte> serializedReply = await completionSource.Task.WaitAsync(waitForReplyTime, cancellationToken); ReadOnlyMemory<byte> serializedReply = await completionSource.Task.WaitAsync(waitForReplyTime, cancellationToken);
return RpcSerialization.Deserialize<TReply>(serializedReply); return RpcSerialization.Deserialize<TReply>(serializedReply);
} catch (TimeoutException) { } catch (TimeoutException) {
logger.Debug("Timed out waiting for reply with id {MessageId}.", messageId); logger.Debug("Timed out waiting for reply with message {MessageId}.", messageId);
throw; throw;
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
logger.Debug("Cancelled waiting for reply with id {MessageId}.", messageId); logger.Debug("Cancelled waiting for reply with message {MessageId}.", messageId);
throw; throw;
} catch (Exception e) { } catch (Exception e) {
logger.Warning(e, "Error processing reply with id {MessageId}.", messageId); logger.Warning(e, "Error processing reply with message {MessageId}.", messageId);
throw; throw;
} finally { } finally {
ForgetReply(messageId); ForgetReply(messageId);
} }
} }
public void ForgetReply(uint messageId) { public void ReceiveReply(uint messageId, ReadOnlyMemory<byte> serializedReply) {
if (replyTasks.TryRemove(messageId, out var task)) { if (replyTasks.TryRemove(messageId, out var task)) {
task.SetCanceled(); task.SetResult(serializedReply);
} }
} }
@@ -53,12 +53,9 @@ sealed class MessageReplyTracker {
} }
} }
public void ReceiveReply(uint messageId, ReadOnlyMemory<byte> serializedReply) { public void ForgetReply(uint messageId) {
if (replyTasks.TryRemove(messageId, out var task)) { if (replyTasks.TryRemove(messageId, out var task)) {
task.SetResult(serializedReply); task.SetCanceled();
}
else {
logger.Warning("Received a reply with id {MessageId} but no registered callback.", messageId);
} }
} }
} }

View File

@@ -0,0 +1,117 @@
using System.Threading.Channels;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Utils.Rpc.Message;
public sealed class MessageSender<TMessageBase> {
private readonly ILogger logger;
private readonly MessageRegistry<TMessageBase> messageRegistry;
private readonly MessageReplyTracker messageReplyTracker;
private uint nextMessageId;
private readonly Channel<PreparedMessage> messageQueue;
private readonly Task messageQueueTask;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
internal MessageSender(string loggerName, MessageRegistry<TMessageBase> messageRegistry, RpcCommonConnectionParameters connectionParameters) {
this.logger = PhantomLogger.Create<MessageSender<TMessageBase>>(loggerName);
this.messageRegistry = messageRegistry;
this.messageReplyTracker = new MessageReplyTracker(loggerName);
this.messageQueue = Channel.CreateBounded<PreparedMessage>(new BoundedChannelOptions(connectionParameters.MessageQueueCapacity) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
this.messageQueueTask = ProcessQueue();
}
public bool TrySend<TMessage>(TMessage message) where TMessage : TMessageBase {
return messageQueue.Writer.TryWrite(PrepareMessage(message));
}
public async ValueTask Send<TMessage>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase {
await messageQueue.Writer.WriteAsync(PrepareMessage(message), cancellationToken);
}
public async Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
var preparedMessage = PrepareMessage(message);
var messageId = preparedMessage.MessageId;
messageReplyTracker.RegisterReply(messageId);
try {
await messageQueue.Writer.WriteAsync(preparedMessage, cancellationToken);
} catch (Exception) {
messageReplyTracker.ForgetReply(messageId);
throw;
}
return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken);
}
private PreparedMessage PrepareMessage(TMessageBase message) {
uint messageId = Interlocked.Increment(ref nextMessageId);
return new PreparedMessage(messageId, message);
}
private readonly record struct PreparedMessage(uint MessageId, TMessageBase Message);
private async Task ProcessQueue() {
CancellationToken cancellationToken = shutdownCancellationTokenSource.Token;
Queue<PreparedMessage> messagesInTransit = new (capacity: 10);
while (await messageQueue.Reader.WaitToReadAsync(cancellationToken)) {
do {
while (messagesInTransit.Count < messagesInTransit.Capacity) {
if (messageQueue.Reader.TryRead(out var nextMessage)) {
messagesInTransit.Enqueue(nextMessage);
}
else {
break;
}
}
RpcFrameSender<TMessageBase> frameSender;
foreach ((uint messageId, TMessageBase message) in messagesInTransit) {
await frameSender.SendFrame(messageRegistry.CreateFrame(messageId, message), cancellationToken);
}
} while (messagesInTransit.Count > 0);
}
}
internal void ReceiveAcknowledgment(AcknowledgmentFrame frame) {
// TODO wait for reply instead?
}
internal void ReceiveReply(ReplyFrame frame) {
messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply);
}
internal void ReceiveError(ErrorFrame frame) {
messageReplyTracker.FailReply(frame.ReplyingToMessageId, RpcErrorException.From(frame.Error));
}
internal async Task Close() {
messageQueue.Writer.TryComplete();
try {
await messageQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing message queue before timeout, forcibly shutting it down.");
await shutdownCancellationTokenSource.CancelAsync();
} catch (Exception) {
// Ignore.
}
messageQueueTask.Dispose();
shutdownCancellationTokenSource.Dispose();
}
}

View File

@@ -11,29 +11,21 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection); return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection);
} }
private readonly string loggerName;
private readonly ILogger logger; private readonly ILogger logger;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry; private readonly RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage> connection;
private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly RpcClientToServerConnection connection;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new (); private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
public RpcSendChannel<TClientToServerMessage> SendChannel { get; } public MessageSender<TClientToServerMessage> MessageSender { get; }
private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) { private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName); this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.messageRegistry = messageDefinitions.ToClient; this.connection = new RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage>(loggerName, messageDefinitions.ToClient, connectionParameters.Common, connector, connection);
this.MessageSender = new MessageSender<TClientToServerMessage>(loggerName, messageDefinitions.ToServer, connectionParameters.Common);
this.connection = new RpcClientToServerConnection(loggerName, connector, connection);
this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters.Common, this.connection, messageDefinitions.ToServer);
} }
public async Task Listen(IMessageReceiver<TServerToClientMessage> receiver) { public async Task Listen(IMessageReceiver<TServerToClientMessage> messageReceiver) {
var messageHandler = new RpcMessageHandler<TServerToClientMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try { try {
await connection.ReadConnection(frameReader, shutdownCancellationTokenSource.Token); await connection.ReadConnection(MessageSender, messageReceiver, shutdownCancellationTokenSource.Token);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
// Ignore. // Ignore.
} }
@@ -43,9 +35,9 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
logger.Information("Shutting down client..."); logger.Information("Shutting down client...");
try { try {
await SendChannel.Close(); await MessageSender.Close();
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Caught exception while closing send channel."); logger.Error(e, "Caught exception while closing message sender.");
} }
try { try {
@@ -61,6 +53,5 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
public void Dispose() { public void Dispose() {
connection.Dispose(); connection.Dispose();
SendChannel.Dispose();
} }
} }

View File

@@ -3,15 +3,12 @@ using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime.Client; namespace Phantom.Utils.Rpc.Runtime.Client;
public readonly record struct RpcClientConnectionParameters( public sealed record RpcClientConnectionParameters(
string Host, string Host,
ushort Port, ushort Port,
string DistinguishedName, string DistinguishedName,
RpcCertificateThumbprint CertificateThumbprint, RpcCertificateThumbprint CertificateThumbprint,
AuthToken AuthToken, AuthToken AuthToken,
IRpcClientHandshake Handshake, IRpcClientHandshake Handshake,
ushort SendQueueCapacity, RpcCommonConnectionParameters CommonParameters
TimeSpan PingInterval );
) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}

View File

@@ -1,12 +1,19 @@
using System.Net.Sockets; using System.Net.Sockets;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame; using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Message;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client; namespace Phantom.Utils.Rpc.Runtime.Client;
sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection initialConnection) : IRpcConnectionProvider, IDisposable { sealed class RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage>(
private readonly ILogger logger = PhantomLogger.Create<RpcClientToServerConnection>(loggerName); string loggerName,
MessageRegistry<TServerToClientMessage> messageRegistry,
RpcCommonConnectionParameters connectionParameters,
RpcClientToServerConnector connector,
RpcClientToServerConnector.Connection initialConnection
) : IRpcConnectionProvider, IDisposable {
private readonly ILogger logger = PhantomLogger.Create<RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage>>(loggerName);
private readonly SemaphoreSlim semaphore = new (1); private readonly SemaphoreSlim semaphore = new (1);
private RpcClientToServerConnector.Connection currentConnection = initialConnection; private RpcClientToServerConnector.Connection currentConnection = initialConnection;
@@ -30,9 +37,11 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
} }
} }
public async Task ReadConnection(IFrameReader frameReader, CancellationToken cancellationToken) { public async Task ReadConnection(MessageSender<TClientToServerMessage> messageSender, IMessageReceiver<TServerToClientMessage> messageReceiver, CancellationToken cancellationToken) {
RpcClientToServerConnector.Connection? connection = null; RpcClientToServerConnector.Connection? connection = null;
var sessionState = NewSessionState(messageSender, messageReceiver);
try { try {
while (true) { while (true) {
connection?.Dispose(); connection?.Dispose();
@@ -47,8 +56,13 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
continue; continue;
} }
if (connection.RestartSession) {
await sessionState.FrameSender.ShutdownNow();
sessionState = NewSessionState(messageSender, messageReceiver);
}
try { try {
await IFrame.ReadFrom(connection.Stream, frameReader, cancellationToken); await IFrame.ReadFrom(connection.Stream, sessionState.FrameReader, cancellationToken);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
throw; throw;
} catch (EndOfStreamException) { } catch (EndOfStreamException) {
@@ -66,6 +80,12 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
} }
} }
} finally { } finally {
try {
await sessionState.FrameSender.Shutdown();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing frame sender.");
}
if (connection != null) { if (connection != null) {
try { try {
await connection.Disconnect(); await connection.Disconnect();
@@ -76,6 +96,15 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
} }
} }
private SessionState NewSessionState(MessageSender<TClientToServerMessage> messageSender, IMessageReceiver<TServerToClientMessage> messageReceiver) {
var frameSender = new RpcFrameSender<TClientToServerMessage>(loggerName, connectionParameters, this);
var messageHandler = new MessageHandler<TServerToClientMessage>(messageReceiver, frameSender);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageRegistry, messageHandler, messageSender, frameSender);
return new SessionState(frameSender, frameReader);
}
private readonly record struct SessionState(RpcFrameSender<TClientToServerMessage> FrameSender, RpcFrameReader<TClientToServerMessage, TServerToClientMessage> FrameReader);
public void StopReconnecting() { public void StopReconnecting() {
newConnectionCancellationTokenSource.Cancel(); newConnectionCancellationTokenSource.Cancel();
} }

View File

@@ -4,12 +4,13 @@ using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Collections; using Phantom.Utils.Collections;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Handshake;
using Phantom.Utils.Rpc.Runtime.Tls; using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client; namespace Phantom.Utils.Rpc.Runtime.Client;
internal sealed class RpcClientToServerConnector { sealed class RpcClientToServerConnector {
private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100); private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100);
private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30); private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30);
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10); private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10);
@@ -112,9 +113,10 @@ internal sealed class RpcClientToServerConnector {
try { try {
stream = new SslStream(new NetworkStream(clientSocket, ownsSocket: false), leaveInnerStreamOpen: false); stream = new SslStream(new NetworkStream(clientSocket, ownsSocket: false), leaveInnerStreamOpen: false);
if (await FinalizeConnection(stream, cancellationToken)) { var handshakeResult = await FinalizeConnection(stream, cancellationToken);
if (handshakeResult != RpcFinalHandshakeResult.Error) {
logger.Information("Connected to {Host}:{Port}.", parameters.Host, parameters.Port); logger.Information("Connected to {Host}:{Port}.", parameters.Host, parameters.Port);
return new Connection(clientSocket, stream); return new Connection(clientSocket, stream, RestartSession: handshakeResult == RpcFinalHandshakeResult.NewSession);
} }
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Caught unhandled exception."); logger.Error(e, "Caught unhandled exception.");
@@ -130,7 +132,7 @@ internal sealed class RpcClientToServerConnector {
return null; return null;
} }
private async Task<bool> FinalizeConnection(SslStream stream, CancellationToken cancellationToken) { private async Task<RpcFinalHandshakeResult> FinalizeConnection(SslStream stream, CancellationToken cancellationToken) {
try { try {
loggedCertificateValidationError = false; loggedCertificateValidationError = false;
await stream.AuthenticateAsClientAsync(sslOptions, cancellationToken); await stream.AuthenticateAsClientAsync(sslOptions, cancellationToken);
@@ -139,7 +141,7 @@ internal sealed class RpcClientToServerConnector {
logger.Error(e, "Could not establish a secure connection."); logger.Error(e, "Could not establish a secure connection.");
} }
return false; return RpcFinalHandshakeResult.Error;
} }
logger.Information("Established a secure connection."); logger.Information("Established a secure connection.");
@@ -148,30 +150,31 @@ internal sealed class RpcClientToServerConnector {
return await PerformApplicationHandshake(stream, cancellationToken); return await PerformApplicationHandshake(stream, cancellationToken);
} catch (EndOfStreamException) { } catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost."); logger.Warning("Could not perform application handshake, connection lost.");
return false; return RpcFinalHandshakeResult.Error;
} catch (Exception e) { } catch (Exception e) {
logger.Warning(e, "Could not perform application handshake."); logger.Warning(e, "Could not perform application handshake.");
return false; return RpcFinalHandshakeResult.Error;
} }
} }
private async Task<bool> PerformApplicationHandshake(Stream stream, CancellationToken cancellationToken) { private async Task<RpcFinalHandshakeResult> PerformApplicationHandshake(Stream stream, CancellationToken cancellationToken) {
await RpcSerialization.WriteAuthToken(parameters.AuthToken, stream, cancellationToken); await RpcSerialization.WriteAuthToken(parameters.AuthToken, stream, cancellationToken);
if (await RpcSerialization.ReadByte(stream, cancellationToken) != 1) { if (await RpcSerialization.ReadByte(stream, cancellationToken) != 1) {
logger.Error("Server rejected authorization token."); logger.Error("Server rejected authorization token.");
return false; return RpcFinalHandshakeResult.Error;
} }
await RpcSerialization.WriteGuid(sessionId, stream, cancellationToken); await RpcSerialization.WriteGuid(sessionId, stream, cancellationToken);
await parameters.Handshake.Perform(stream, cancellationToken); await parameters.Handshake.Perform(stream, cancellationToken);
if (await RpcSerialization.ReadByte(stream, cancellationToken) != 1) { RpcFinalHandshakeResult finalHandshakeResult = (RpcFinalHandshakeResult) await RpcSerialization.ReadByte(stream, cancellationToken);
if (finalHandshakeResult == RpcFinalHandshakeResult.Error) {
logger.Error("Server rejected client due to unknown error."); logger.Error("Server rejected client due to unknown error.");
return false;
} }
return true; return finalHandshakeResult;
} }
private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) { private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) {
@@ -207,7 +210,7 @@ internal sealed class RpcClientToServerConnector {
await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token); await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token);
} }
internal sealed record Connection(Socket Socket, Stream Stream) : IDisposable { internal sealed record Connection(Socket Socket, Stream Stream, bool RestartSession) : IDisposable {
public async Task Disconnect() { public async Task Disconnect() {
await DisconnectSocket(Socket, Stream); await DisconnectSocket(Socket, Stream);
} }

View File

@@ -1,6 +1,8 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
readonly record struct RpcCommonConnectionParameters( public sealed record RpcCommonConnectionParameters(
ushort SendQueueCapacity, ushort MessageQueueCapacity,
ushort FrameQueueCapacity,
ushort MaxConcurrentlyHandledMessages,
TimeSpan PingInterval TimeSpan PingInterval
); );

View File

@@ -6,5 +6,6 @@ enum RpcError : byte {
MessageTooLarge = 2, MessageTooLarge = 2,
MessageDeserializationError = 3, MessageDeserializationError = 3,
MessageHandlingError = 4, MessageHandlingError = 4,
MessageAlreadyHandled = 5, MessageReplyingError = 5,
ReplyTooLarge = 6,
} }

View File

@@ -3,13 +3,14 @@
sealed class RpcErrorException : Exception { sealed class RpcErrorException : Exception {
internal static RpcErrorException From(RpcError error) { internal static RpcErrorException From(RpcError error) {
return error switch { return error switch {
RpcError.InvalidData => new RpcErrorException("Invalid data", error), RpcError.InvalidData => new RpcErrorException("Invalid data.", error),
RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error), RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code.", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error), RpcError.MessageTooLarge => new RpcErrorException("Message is too large.", error),
RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error), RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error.", error),
RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error), RpcError.MessageHandlingError => new RpcErrorException("Message handling error.", error),
RpcError.MessageAlreadyHandled => new RpcErrorException("Message already handled", error), RpcError.MessageReplyingError => new RpcErrorException("Message replying error.", error),
_ => new RpcErrorException("Unknown error", error), RpcError.ReplyTooLarge => new RpcErrorException("Reply is too large.", error),
_ => new RpcErrorException("Unknown error.", error),
}; };
} }

View File

@@ -8,43 +8,71 @@ namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameReader<TSentMessage, TReceivedMessage>( sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
string loggerName, string loggerName,
RpcCommonConnectionParameters connectionParameters,
MessageRegistry<TReceivedMessage> messageRegistry, MessageRegistry<TReceivedMessage> messageRegistry,
MessageReceiveTracker messageReceiveTracker, MessageHandler<TReceivedMessage> messageHandler,
RpcMessageHandler<TReceivedMessage> messageHandler, MessageSender<TSentMessage> messageSender,
RpcSendChannel<TSentMessage> sendChannel RpcFrameSender<TSentMessage> frameSender
) : IFrameReader { ) : IFrameReader {
private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName); private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName);
private readonly ushort maxConcurrentlyHandledMessages = connectionParameters.MaxConcurrentlyHandledMessages;
private readonly SemaphoreSlim messageHandlingSemaphore = new (connectionParameters.MaxConcurrentlyHandledMessages);
public ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken) { public ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken) {
messageHandler.OnPing(); messageHandler.OnPing();
return sendChannel.SendPong(pingTime, cancellationToken); return frameSender.SendPong(pingTime, cancellationToken);
} }
public void OnPongFrame(PongFrame frame) { public void OnPongFrame(PongFrame frame) {
sendChannel.ReceivePong(frame); frameSender.ReceivePong(frame);
} }
public Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) { public async Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) {
if (!messageReceiveTracker.ReceiveMessage(frame.MessageId)) { if (!frameSender.ReceiveMessage(frame)) {
logger.Warning("Received duplicate message {MessageId}.", frame.MessageId); logger.Warning("Received duplicate message {MessageId}.", frame.MessageId);
return messageHandler.SendError(frame.MessageId, RpcError.MessageAlreadyHandled, cancellationToken).AsTask(); return;
} }
if (messageRegistry.TryGetType(frame, out var messageType)) { if (messageRegistry.TryGetType(frame, out var messageType)) {
logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length); logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length);
} }
return messageRegistry.Handle(frame, messageHandler, cancellationToken); Task acquireSemaphore = messageHandlingSemaphore.WaitAsync(cancellationToken);
try {
if (!acquireSemaphore.IsCompleted) {
logger.Warning("Reached limit for concurrently handled messages ({Limit}).", maxConcurrentlyHandledMessages);
}
await acquireSemaphore;
_ = HandleMessage(frame, cancellationToken);
} catch (Exception) {
messageHandlingSemaphore.Release();
throw;
}
}
private async Task HandleMessage(MessageFrame frame, CancellationToken cancellationToken) {
try {
await messageRegistry.Handle(frame, messageHandler, cancellationToken);
} finally {
messageHandlingSemaphore.Release();
}
}
public void OnAcknowledgmentFrame(AcknowledgmentFrame frame) {
logger.Verbose("Received acknowledgment of messages {FirstMessageId}-{LastMessageId}.", frame.FirstMessageId, frame.LastMessageId);
messageSender.ReceiveAcknowledgment(frame);
} }
public void OnReplyFrame(ReplyFrame frame) { public void OnReplyFrame(ReplyFrame frame) {
logger.Verbose("Received reply to message {MesageId} ({Bytes} B).", frame.ReplyingToMessageId, frame.SerializedReply.Length); logger.Verbose("Received reply to message {MesageId} ({Bytes} B).", frame.ReplyingToMessageId, frame.SerializedReply.Length);
sendChannel.ReceiveReply(frame); messageSender.ReceiveReply(frame);
} }
public void OnErrorFrame(ErrorFrame frame) { public void OnErrorFrame(ErrorFrame frame) {
logger.Warning("Received error response to message {MesageId}: {Error}", frame.ReplyingToMessageId, frame.Error); logger.Warning("Received error response to message {MesageId}: {Error}", frame.ReplyingToMessageId, frame.Error);
sendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error); messageSender.ReceiveError(frame);
} }
public void OnUnknownFrameId(byte frameId) { public void OnUnknownFrameId(byte frameId) {

View File

@@ -0,0 +1,176 @@
using System.Diagnostics.CodeAnalysis;
using System.Threading.Channels;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameSender<TMessageBase> : IMessageReplySender {
private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly MessageAcknowledgmentQueue acknowledgmentQueue = new ();
private readonly Channel<IFrame> frameQueue;
private readonly Task frameQueueTask;
private readonly Task pingTask;
private readonly Task acknowledgementTask;
private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private readonly CancellationTokenSource acknowledgementCancellationTokenSource = new ();
private TaskCompletionSource<DateTimeOffset>? pongTask;
internal RpcFrameSender(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider) {
this.logger = PhantomLogger.Create<RpcFrameSender<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider;
this.frameQueue = Channel.CreateBounded<IFrame>(new BoundedChannelOptions(connectionParameters.FrameQueueCapacity) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
this.frameQueueTask = ProcessQueue();
this.pingTask = PingSchedule(connectionParameters.PingInterval);
this.acknowledgementTask = AcknowledgementSchedule();
}
public async ValueTask SendPong(DateTimeOffset pingTime, CancellationToken cancellationToken) {
await SendFrame(new PongFrame(pingTime), cancellationToken);
}
async ValueTask IMessageReplySender.SendEmptyReply(uint replyingToMessageId, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, ReadOnlyMemory<byte>.Empty), cancellationToken);
}
async ValueTask IMessageReplySender.SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, RpcSerialization.Serialize(reply)), cancellationToken);
}
async ValueTask IMessageReplySender.SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) {
await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken);
}
public async ValueTask SendFrame(IFrame frame, CancellationToken cancellationToken) {
if (!frameQueue.Writer.TryWrite(frame)) {
logger.Warning("Send queue is full!");
await frameQueue.Writer.WriteAsync(frame, cancellationToken);
}
}
private async Task ProcessQueue() {
CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
await foreach (IFrame frame in frameQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) {
try {
Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
break;
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Retry.
}
}
}
}
[SuppressMessage("ReSharper", "FunctionNeverReturns")]
private async Task PingSchedule(TimeSpan interval) {
CancellationToken cancellationToken = pingCancellationTokenSource.Token;
while (true) {
await Task.Delay(interval, cancellationToken);
pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!frameQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue.");
continue;
}
DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
}
}
[SuppressMessage("ReSharper", "FunctionNeverReturns")]
private async Task AcknowledgementSchedule() {
CancellationToken cancellationToken = acknowledgementCancellationTokenSource.Token;
TimeSpan interval = TimeSpan.FromSeconds(1);
while (true) {
await acknowledgmentQueue.Wait(cancellationToken);
await Task.Delay(interval, cancellationToken);
foreach (var acknowledgmentFrame in acknowledgmentQueue.Drain()) {
await SendFrame(acknowledgmentFrame, cancellationToken);
}
}
}
public bool ReceiveMessage(MessageFrame frame) {
acknowledgmentQueue.Enqueue(frame.MessageId);
return messageReceiveTracker.ReceiveMessage(frame.MessageId);
}
public void ReceivePong(PongFrame frame) {
pongTask?.TrySetResult(frame.PingTime);
}
public async Task Shutdown() {
await pingCancellationTokenSource.CancelAsync();
await acknowledgementCancellationTokenSource.CancelAsync();
frameQueue.Writer.TryComplete();
await pingTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await acknowledgementTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
try {
await frameQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing frame queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) {
// Ignore.
}
Dispose();
}
public async Task ShutdownNow() {
await pingCancellationTokenSource.CancelAsync();
await acknowledgementCancellationTokenSource.CancelAsync();
await sendQueueCancellationTokenSource.CancelAsync();
frameQueue.Writer.TryComplete();
await pingTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await acknowledgementTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await frameQueueTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
Dispose();
}
private void Dispose() {
frameQueueTask.Dispose();
pingTask.Dispose();
acknowledgementTask.Dispose();
sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose();
acknowledgementCancellationTokenSource.Dispose();
}
}

View File

@@ -1,19 +0,0 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcMessageHandler<TMessageBase>(IMessageReceiver<TMessageBase> receiver, IRpcReplySender replySender) {
public IMessageReceiver<TMessageBase> Receiver => receiver;
public void OnPing() {
receiver.OnPing();
}
public ValueTask SendReply<TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) {
return replySender.SendReply(messageId, reply, cancellationToken);
}
public ValueTask SendError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return replySender.SendError(messageId, error, cancellationToken);
}
}

View File

@@ -1,168 +0,0 @@
using System.Diagnostics.CodeAnalysis;
using System.Threading.Channels;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable {
private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageRegistry<TMessageBase> messageRegistry;
private readonly MessageReplyTracker messageReplyTracker;
private readonly Channel<IFrame> sendQueue;
private readonly Task sendQueueTask;
private readonly Task pingTask;
private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private uint nextMessageId;
private TaskCompletionSource<DateTimeOffset>? pongTask;
internal RpcSendChannel(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry) {
this.logger = PhantomLogger.Create<RpcSendChannel<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider;
this.messageRegistry = messageRegistry;
this.messageReplyTracker = new MessageReplyTracker(loggerName);
this.sendQueue = Channel.CreateBounded<IFrame>(new BoundedChannelOptions(connectionParameters.SendQueueCapacity) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
this.sendQueueTask = ProcessSendQueue();
this.pingTask = Ping(connectionParameters.PingInterval);
}
internal async ValueTask SendPong(DateTimeOffset pingTime, CancellationToken cancellationToken) {
await SendFrame(new PongFrame(pingTime), cancellationToken);
}
public bool TrySendMessage<TMessage>(TMessage message) where TMessage : TMessageBase {
return sendQueue.Writer.TryWrite(NextMessageFrame(message));
}
public async ValueTask SendMessage<TMessage>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase {
await SendFrame(NextMessageFrame(message), cancellationToken);
}
public async Task<TReply> SendMessage<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
MessageFrame frame = NextMessageFrame(message);
uint messageId = frame.MessageId;
messageReplyTracker.RegisterReply(messageId);
try {
await SendFrame(frame, cancellationToken);
} catch (Exception) {
messageReplyTracker.ForgetReply(messageId);
throw;
}
return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken);
}
async ValueTask IRpcReplySender.SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, RpcSerialization.Serialize(reply)), cancellationToken);
}
async ValueTask IRpcReplySender.SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) {
await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken);
}
private async ValueTask SendFrame(IFrame frame, CancellationToken cancellationToken) {
if (!sendQueue.Writer.TryWrite(frame)) {
logger.Warning("Send queue is full!");
await sendQueue.Writer.WriteAsync(frame, cancellationToken);
}
}
private MessageFrame NextMessageFrame<TMessage>(TMessage message) where TMessage : TMessageBase {
uint messageId = Interlocked.Increment(ref nextMessageId);
return messageRegistry.CreateFrame(messageId, message);
}
private async Task ProcessSendQueue() {
CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) {
try {
Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
break;
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Retry.
}
}
}
}
[SuppressMessage("ReSharper", "FunctionNeverReturns")]
private async Task Ping(TimeSpan interval) {
CancellationToken cancellationToken = pingCancellationTokenSource.Token;
while (true) {
await Task.Delay(interval, cancellationToken);
pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue.");
continue;
}
DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
}
}
internal void ReceivePong(PongFrame frame) {
pongTask?.TrySetResult(frame.PingTime);
}
internal void ReceiveReply(ReplyFrame frame) {
messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply);
}
internal void ReceiveError(uint messageId, RpcError error) {
messageReplyTracker.FailReply(messageId, RpcErrorException.From(error));
}
internal async Task Close() {
await pingCancellationTokenSource.CancelAsync();
sendQueue.Writer.TryComplete();
await pingTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
try {
await sendQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing send queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) {
// Ignore.
}
}
public void Dispose() {
sendQueueTask.Dispose();
sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose();
}
}

View File

@@ -91,6 +91,12 @@ public static class RpcSerialization {
return buffer; return buffer;
} }
public static async ValueTask<ReadOnlyMemory<byte>> ReadBytes(uint length, Stream stream, CancellationToken cancellationToken) {
Memory<byte> buffer = new byte[length];
await stream.ReadExactlyAsync(buffer, cancellationToken);
return buffer;
}
public static ReadOnlyMemory<byte> Serialize<T>(T value) { public static ReadOnlyMemory<byte> Serialize<T>(T value) {
var buffer = new ArrayBufferWriter<byte>(); var buffer = new ArrayBufferWriter<byte>();
MemoryPackSerializer.Serialize(buffer, value, SerializerOptions); MemoryPackSerializer.Serialize(buffer, value, SerializerOptions);

View File

@@ -25,7 +25,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
public async Task<bool> Run(CancellationToken shutdownToken) { public async Task<bool> Run(CancellationToken shutdownToken) {
EndPoint endPoint = connectionParameters.EndPoint; EndPoint endPoint = connectionParameters.EndPoint;
SslServerAuthenticationOptions sslOptions = new () { var sslOptions = new SslServerAuthenticationOptions {
AllowRenegotiation = false, AllowRenegotiation = false,
AllowTlsResume = true, AllowTlsResume = true,
CertificateRevocationCheckMode = X509RevocationMode.NoCheck, CertificateRevocationCheckMode = X509RevocationMode.NoCheck,
@@ -35,6 +35,15 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
ServerCertificate = connectionParameters.Certificate.Certificate, ServerCertificate = connectionParameters.Certificate.Certificate,
}; };
var serverData = new SharedData(
connectionParameters.Common,
connectionParameters.AuthToken,
messageDefinitions.ToServer,
clientHandshake,
clientRegistrar,
clientSessions
);
try { try {
using var serverSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); using var serverSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
@@ -51,7 +60,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
while (true) { while (true) {
Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken); Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken);
clients.Add(new Client(loggerName, messageDefinitions, clientHandshake, clientRegistrar, clientSessions, clientSocket, sslOptions, connectionParameters.AuthToken, shutdownToken)); clients.Add(new Client(loggerName, serverData, clientSocket, sslOptions, shutdownToken));
clients.RemoveAll(static client => client.Task.IsCompleted); clients.RemoveAll(static client => client.Task.IsCompleted);
} }
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
@@ -83,6 +92,15 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
logger.Information("Server stopped."); logger.Information("Server stopped.");
} }
private readonly record struct SharedData(
RpcCommonConnectionParameters ConnectionParameters,
AuthToken AuthToken,
MessageRegistry<TClientToServerMessage> MessageRegistry,
IRpcServerClientHandshake<THandshakeResult> ClientHandshake,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> ClientRegistrar,
RpcServerClientSessions<TServerToClientMessage> ClientSessions
);
private sealed class Client { private sealed class Client {
private static TimeSpan DisconnectTimeout => TimeSpan.FromSeconds(10); private static TimeSpan DisconnectTimeout => TimeSpan.FromSeconds(10);
@@ -99,34 +117,22 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
public Task Task { get; } public Task Task { get; }
private ILogger logger; private ILogger logger;
private readonly IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions; private readonly SharedData sharedData;
private readonly IRpcServerClientHandshake<THandshakeResult> clientHandshake;
private readonly IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> clientRegistrar;
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions;
private readonly Socket socket; private readonly Socket socket;
private readonly SslServerAuthenticationOptions sslOptions; private readonly SslServerAuthenticationOptions sslOptions;
private readonly AuthToken authToken;
private readonly CancellationToken shutdownToken; private readonly CancellationToken shutdownToken;
public Client( public Client(
string serverLoggerName, string serverLoggerName,
IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, SharedData sharedData,
IRpcServerClientHandshake<THandshakeResult> clientHandshake,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> clientRegistrar,
RpcServerClientSessions<TServerToClientMessage> clientSessions,
Socket socket, Socket socket,
SslServerAuthenticationOptions sslOptions, SslServerAuthenticationOptions sslOptions,
AuthToken authToken,
CancellationToken shutdownToken CancellationToken shutdownToken
) { ) {
this.logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>, Client>(PhantomLogger.ConcatNames(serverLoggerName, GetAddressDescriptor(socket))); this.logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>, Client>(PhantomLogger.ConcatNames(serverLoggerName, GetAddressDescriptor(socket)));
this.messageDefinitions = messageDefinitions; this.sharedData = sharedData;
this.clientHandshake = clientHandshake;
this.clientRegistrar = clientRegistrar;
this.clientSessions = clientSessions;
this.socket = socket; this.socket = socket;
this.sslOptions = sslOptions; this.sslOptions = sslOptions;
this.authToken = authToken;
this.shutdownToken = shutdownToken; this.shutdownToken = shutdownToken;
this.Task = Run(); this.Task = Run();
@@ -207,7 +213,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
try { try {
var suppliedAuthToken = await RpcSerialization.ReadAuthToken(stream, cancellationToken); var suppliedAuthToken = await RpcSerialization.ReadAuthToken(stream, cancellationToken);
if (!authToken.FixedTimeEquals(suppliedAuthToken)) { if (!sharedData.AuthToken.FixedTimeEquals(suppliedAuthToken)) {
logger.Warning("Rejected client, invalid authorization token."); logger.Warning("Rejected client, invalid authorization token.");
await RpcSerialization.WriteByte(value: 0, stream, cancellationToken); await RpcSerialization.WriteByte(value: 0, stream, cancellationToken);
return null; return null;
@@ -217,17 +223,17 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
} }
var sessionId = await RpcSerialization.ReadGuid(stream, cancellationToken); var sessionId = await RpcSerialization.ReadGuid(stream, cancellationToken);
var session = clientSessions.GetOrCreateSession(sessionId); var session = sharedData.ClientSessions.GetOrCreateSession(sessionId);
logger.Information("Client connected with session {SessionId}, new logger name: {LoggerName}", sessionId, session.LoggerName); logger.Information("Client connected with session {SessionId}, new logger name: {LoggerName}", session.SessionId, session.LoggerName);
logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>, Client>(session.LoggerName); logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>, Client>(session.LoggerName);
EstablishedConnection? establishedConnection; EstablishedConnection? establishedConnection;
switch (await clientHandshake.Perform(stream, cancellationToken)) { switch (await sharedData.ClientHandshake.Perform(stream, cancellationToken)) {
case Left<THandshakeResult, Exception>(var handshakeResult): case Left<THandshakeResult, Exception>(var handshakeResult):
try { try {
var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(clientSessions, sessionId, messageDefinitions.ToServer, stream, session); var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(sharedData.ConnectionParameters, sharedData.MessageRegistry, sharedData.ClientSessions, session, stream);
var messageReceiver = clientRegistrar.Register(connection, handshakeResult); var messageReceiver = sharedData.ClientRegistrar.Register(connection, handshakeResult);
establishedConnection = new EstablishedConnection(session, connection, messageReceiver); establishedConnection = new EstablishedConnection(session, connection, messageReceiver);
} catch (Exception e) { } catch (Exception e) {
@@ -247,14 +253,17 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
break; break;
} }
RpcFinalHandshakeResult finalHandshakeResult;
if (establishedConnection == null) { if (establishedConnection == null) {
await RpcSerialization.WriteByte(value: 0, stream, cancellationToken); finalHandshakeResult = RpcFinalHandshakeResult.Error;
return null;
} }
else { else {
await RpcSerialization.WriteByte(value: 1, stream, cancellationToken); bool isNewSession = session.MarkFirstTimeUse();
return establishedConnection; finalHandshakeResult = isNewSession ? RpcFinalHandshakeResult.NewSession : RpcFinalHandshakeResult.ReusedSession;
} }
await RpcSerialization.WriteByte((byte) finalHandshakeResult, stream, cancellationToken);
return establishedConnection;
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
throw; throw;
} catch (EndOfStreamException) { } catch (EndOfStreamException) {

View File

@@ -1,18 +1,31 @@
using Phantom.Utils.Rpc.Message; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server; namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProvider { sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProvider {
public string LoggerName { get; } private readonly ILogger logger;
public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
public MessageReceiveTracker MessageReceiveTracker { get; }
public string LoggerName { get; }
public Guid SessionId { get; }
public MessageSender<TServerToClientMessage> MessageSender { get; }
public RpcFrameSender<TServerToClientMessage> FrameSender { get; }
private bool isNew = true;
private TaskCompletionSource<Stream> nextStream = new (); private TaskCompletionSource<Stream> nextStream = new ();
public RpcServerClientSession(string loggerName, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) { public RpcServerClientSession(string loggerName, Guid sessionId, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.logger = PhantomLogger.Create<RpcServerClientSession<TServerToClientMessage>>(loggerName);
this.LoggerName = loggerName; this.LoggerName = loggerName;
this.SendChannel = new RpcSendChannel<TServerToClientMessage>(loggerName, connectionParameters, this, messageRegistry); this.SessionId = sessionId;
this.MessageReceiveTracker = new MessageReceiveTracker(); this.MessageSender = new MessageSender<TServerToClientMessage>(loggerName, messageRegistry, connectionParameters);
this.FrameSender = new RpcFrameSender<TServerToClientMessage>(loggerName, connectionParameters, this);
}
/// <returns>Whether this was a new session. If it was a new session, it will be marked as used.</returns>
public bool MarkFirstTimeUse() {
return Interlocked.CompareExchange(ref isNew, value: true, comparand: false);
} }
public void OnConnected(Stream stream) { public void OnConnected(Stream stream) {
@@ -39,7 +52,7 @@ sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProv
} }
} }
public Task Close() { public async Task Close() {
lock (this) { lock (this) {
if (!nextStream.TrySetCanceled()) { if (!nextStream.TrySetCanceled()) {
nextStream = new TaskCompletionSource<Stream>(); nextStream = new TaskCompletionSource<Stream>();
@@ -47,6 +60,16 @@ sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProv
} }
} }
return SendChannel.Close(); try {
await MessageSender.Close();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing message sender.");
}
try {
await FrameSender.Shutdown();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing send channel.");
}
} }
} }

View File

@@ -29,7 +29,7 @@ sealed class RpcServerClientSessions<TServerToClientMessage> {
} }
private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId, string loggerName) { private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId, string loggerName) {
return new RpcServerClientSession<TServerToClientMessage>(loggerName, connectionParameters, messageRegistry); return new RpcServerClientSession<TServerToClientMessage>(loggerName, sessionId, connectionParameters, messageRegistry);
} }
public Task CloseSession(Guid sessionId) { public Task CloseSession(Guid sessionId) {

View File

@@ -4,12 +4,9 @@ using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime.Server; namespace Phantom.Utils.Rpc.Runtime.Server;
public readonly record struct RpcServerConnectionParameters( public sealed record RpcServerConnectionParameters(
EndPoint EndPoint, EndPoint EndPoint,
RpcServerCertificate Certificate, RpcServerCertificate Certificate,
AuthToken AuthToken, AuthToken AuthToken,
ushort SendQueueCapacity, RpcCommonConnectionParameters CommonParameters
TimeSpan PingInterval );
) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}

View File

@@ -6,32 +6,29 @@ using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server; namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> { public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> {
private readonly string loggerName;
private readonly ILogger logger; private readonly ILogger logger;
private readonly RpcServerClientSessions<TServerToClientMessage> sessions; private readonly RpcCommonConnectionParameters connectionParameters;
private readonly MessageRegistry<TClientToServerMessage> messageRegistry; private readonly MessageRegistry<TClientToServerMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker; private readonly RpcServerClientSessions<TServerToClientMessage> sessions;
private readonly RpcServerClientSession<TServerToClientMessage> session;
private readonly Stream stream; private readonly Stream stream;
private readonly CancellationTokenSource closeCancellationTokenSource = new (); private readonly CancellationTokenSource closeCancellationTokenSource = new ();
public Guid SessionId { get; } public Guid SessionId => session.SessionId;
public RpcSendChannel<TServerToClientMessage> SendChannel { get; } public MessageSender<TServerToClientMessage> MessageSender => session.MessageSender;
internal RpcServerToClientConnection(RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId, MessageRegistry<TClientToServerMessage> messageRegistry, Stream stream, RpcServerClientSession<TServerToClientMessage> session) { internal RpcServerToClientConnection(RpcCommonConnectionParameters connectionParameters, MessageRegistry<TClientToServerMessage> messageRegistry, RpcServerClientSessions<TServerToClientMessage> sessions, RpcServerClientSession<TServerToClientMessage> session, Stream stream) {
this.loggerName = session.LoggerName; this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(session.LoggerName);
this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(loggerName); this.connectionParameters = connectionParameters;
this.sessions = sessions;
this.messageRegistry = messageRegistry; this.messageRegistry = messageRegistry;
this.messageReceiveTracker = session.MessageReceiveTracker; this.sessions = sessions;
this.session = session;
this.stream = stream; this.stream = stream;
this.SessionId = sessionId;
this.SendChannel = session.SendChannel;
} }
internal async Task Listen(IMessageReceiver<TClientToServerMessage> receiver) { internal async Task Listen(IMessageReceiver<TClientToServerMessage> messageReceiver) {
var messageHandler = new RpcMessageHandler<TClientToServerMessage>(receiver, SendChannel); var messageHandler = new MessageHandler<TClientToServerMessage>(messageReceiver, session.FrameSender);
var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel); var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(session.LoggerName, connectionParameters, messageRegistry, messageHandler, MessageSender, session.FrameSender);
try { try {
await IFrame.ReadFrom(stream, frameReader, closeCancellationTokenSource.Token); await IFrame.ReadFrom(stream, frameReader, closeCancellationTokenSource.Token);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {

View File

@@ -38,7 +38,15 @@ public sealed class RangeSet<T> : IEnumerable<RangeSet<T>.Range> where T : IBina
return true; return true;
} }
public IEnumerator<Range> GetEnumerator() { public void Clear() {
ranges.Clear();
}
public List<Range>.Enumerator GetEnumerator() {
return ranges.GetEnumerator();
}
IEnumerator<Range> IEnumerable<Range>.GetEnumerator() {
return ranges.GetEnumerator(); return ranges.GetEnumerator();
} }

View File

@@ -1,19 +1,19 @@
using Phantom.Common.Messages.Web; using Phantom.Common.Messages.Web;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Message;
namespace Phantom.Web.Services.Rpc; namespace Phantom.Web.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> connection) { public sealed class ControllerConnection(MessageSender<IMessageToController> sender) {
public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController { public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController {
return connection.SendMessage(message); return sender.Send(message);
} }
public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken = default) where TMessage : IMessageToController, ICanReply<TReply> { public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken = default) where TMessage : IMessageToController, ICanReply<TReply> {
return connection.SendMessage<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken); return sender.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken);
} }
public Task<TReply> Send<TMessage, TReply>(TMessage message, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToController, ICanReply<TReply> { public Task<TReply> Send<TMessage, TReply>(TMessage message, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToController, ICanReply<TReply> {
return connection.SendMessage<TMessage, TReply>(message, Timeout.InfiniteTimeSpan, waitForReplyCancellationToken); return sender.Send<TMessage, TReply>(message, Timeout.InfiniteTimeSpan, waitForReplyCancellationToken);
} }
} }

View File

@@ -6,6 +6,7 @@ using Phantom.Utils.Cryptography;
using Phantom.Utils.IO; using Phantom.Utils.IO;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Client; using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Threading; using Phantom.Utils.Threading;
@@ -59,8 +60,12 @@ try {
CertificateThumbprint: webKey.Value.CertificateThumbprint, CertificateThumbprint: webKey.Value.CertificateThumbprint,
AuthToken: webKey.Value.AuthToken, AuthToken: webKey.Value.AuthToken,
Handshake: new IRpcClientHandshake.NoOp(), Handshake: new IRpcClientHandshake.NoOp(),
SendQueueCapacity: 500, CommonParameters: new RpcCommonConnectionParameters(
PingInterval: TimeSpan.FromSeconds(10) MessageQueueCapacity: 250,
FrameQueueCapacity: 500,
MaxConcurrentlyHandledMessages: 100,
PingInterval: TimeSpan.FromMinutes(1)
)
); );
using var rpcClient = await RpcClient<IMessageToController, IMessageToWeb>.Connect("Controller", rpcClientConnectionParameters, WebMessageRegistries.Definitions, shutdownCancellationToken); using var rpcClient = await RpcClient<IMessageToController, IMessageToWeb>.Connect("Controller", rpcClientConnectionParameters, WebMessageRegistries.Definitions, shutdownCancellationToken);
@@ -70,7 +75,7 @@ try {
} }
var webConfiguration = new WebLauncher.Configuration(PhantomLogger.Create("Web"), webServerHost, webServerPort, webBasePath, dataProtectionKeysPath, shutdownCancellationToken); var webConfiguration = new WebLauncher.Configuration(PhantomLogger.Create("Web"), webServerHost, webServerPort, webBasePath, dataProtectionKeysPath, shutdownCancellationToken);
var webApplication = WebLauncher.CreateApplication(webConfiguration, applicationProperties, rpcClient.SendChannel); var webApplication = WebLauncher.CreateApplication(webConfiguration, applicationProperties, rpcClient.MessageSender);
using var actorSystem = ActorSystemFactory.Create("Web"); using var actorSystem = ActorSystemFactory.Create("Web");
@@ -98,7 +103,7 @@ try {
PhantomLogger.Root.Information("Unregistering web..."); PhantomLogger.Root.Information("Unregistering web...");
try { try {
using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10)); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage(), unregisterCancellationTokenSource.Token); await rpcClient.MessageSender.Send(new UnregisterWebMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister web after shutdown."); PhantomLogger.Root.Warning("Could not unregister web after shutdown.");
} catch (Exception e) { } catch (Exception e) {

View File

@@ -1,6 +1,6 @@
using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.DataProtection;
using Phantom.Common.Messages.Web; using Phantom.Common.Messages.Web;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Message;
using Phantom.Web.Services; using Phantom.Web.Services;
using Serilog; using Serilog;
using ILogger = Serilog.ILogger; using ILogger = Serilog.ILogger;
@@ -12,7 +12,7 @@ static class WebLauncher {
public string HttpUrl => "http://" + Host + ":" + Port; public string HttpUrl => "http://" + Host + ":" + Port;
} }
internal static WebApplication CreateApplication(Configuration config, ApplicationProperties applicationProperties, RpcSendChannel<IMessageToController> sendChannel) { internal static WebApplication CreateApplication(Configuration config, ApplicationProperties applicationProperties, MessageSender<IMessageToController> sendChannel) {
var assembly = typeof(WebLauncher).Assembly; var assembly = typeof(WebLauncher).Assembly;
var builder = WebApplication.CreateBuilder(new WebApplicationOptions { var builder = WebApplication.CreateBuilder(new WebApplicationOptions {
ApplicationName = assembly.GetName().Name, ApplicationName = assembly.GetName().Name,