1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-10-17 00:39:36 +02:00

4 Commits

41 changed files with 719 additions and 400 deletions

View File

@@ -1,19 +1,19 @@
using Phantom.Common.Messages.Agent;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Agent.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> sendChannel) {
public sealed class ControllerConnection(MessageSender<IMessageToController> sender) {
internal bool TrySend<TMessage>(TMessage message) where TMessage : IMessageToController {
return sendChannel.TrySendMessage(message);
return sender.TrySend(message);
}
internal ValueTask Send<TMessage>(TMessage message, CancellationToken cancellationToken) where TMessage : IMessageToController {
return sendChannel.SendMessage(message, cancellationToken);
return sender.Send(message, cancellationToken);
}
internal Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : IMessageToController, ICanReply<TReply> {
return sendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, cancellationToken);
return sender.Send<TMessage, TReply>(message, waitForReplyTime, cancellationToken);
}
}

View File

@@ -10,6 +10,7 @@ using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime;
using Phantom.Utils.Threading;
@@ -60,8 +61,12 @@ try {
CertificateThumbprint: agentKey.Value.CertificateThumbprint,
AuthToken: agentKey.Value.AuthToken,
Handshake: controllerHandshake,
SendQueueCapacity: 500,
PingInterval: TimeSpan.FromSeconds(10)
CommonParameters: new RpcCommonConnectionParameters(
MessageQueueCapacity: 250,
FrameQueueCapacity: 500,
MaxConcurrentlyHandledMessages: 50,
PingInterval: TimeSpan.FromSeconds(10)
)
);
using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, AgentMessageRegistries.Definitions, shutdownCancellationToken);
@@ -74,7 +79,7 @@ try {
try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent...");
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.SendChannel), javaRuntimeRepository);
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.MessageSender), javaRuntimeRepository);
var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices);
var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
@@ -91,7 +96,7 @@ try {
PhantomLogger.Root.Information("Unregistering agent...");
try {
using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
await rpcClient.SendChannel.SendMessage(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token);
await rpcClient.MessageSender.Send(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister agent after shutdown.");
} catch (Exception e) {

View File

@@ -41,7 +41,7 @@ sealed class AgentConnection(Guid agentGuid, string agentName) {
public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToAgent {
lock (this) {
if (connection != null) {
return connection.SendChannel.SendMessage(message);
return connection.MessageSender.Send(message);
}
}
@@ -52,7 +52,7 @@ sealed class AgentConnection(Guid agentGuid, string agentName) {
public Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToAgent, ICanReply<TReply> where TReply : class {
lock (this) {
if (connection != null) {
return connection.SendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!;
return connection.MessageSender.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!;
}
}

View File

@@ -5,24 +5,24 @@ using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> {
public readonly record struct Init(RpcSendChannel<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager);
public readonly record struct Init(MessageSender<IMessageToWeb> MessageSender, ControllerState ControllerState, InstanceLogManager InstanceLogManager);
public static Props<ICommand> Factory(Init init) {
return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
}
private readonly RpcSendChannel<IMessageToWeb> connection;
private readonly MessageSender<IMessageToWeb> messageSender;
private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager;
private readonly ActorRef<ICommand> selfCached;
private WebMessageDataUpdateSenderActor(Init init) {
this.connection = init.Connection;
this.messageSender = init.MessageSender;
this.controllerState = init.ControllerState;
this.instanceLogManager = init.InstanceLogManager;
this.selfCached = SelfTyped;
@@ -70,18 +70,18 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate
private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand;
private Task RefreshAgents(RefreshAgentsCommand command) {
return connection.SendMessage(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())).AsTask();
return messageSender.Send(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())).AsTask();
}
private Task RefreshInstances(RefreshInstancesCommand command) {
return connection.SendMessage(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())).AsTask();
return messageSender.Send(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())).AsTask();
}
private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) {
return connection.SendMessage(new InstanceOutputMessage(command.InstanceGuid, command.Lines)).AsTask();
return messageSender.Send(new InstanceOutputMessage(command.InstanceGuid, command.Lines)).AsTask();
}
private Task RefreshUserSession(RefreshUserSessionCommand command) {
return connection.SendMessage(new RefreshUserSessionMessage(command.UserGuid)).AsTask();
return messageSender.Send(new RefreshUserSessionMessage(command.UserGuid)).AsTask();
}
}

View File

@@ -63,7 +63,7 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
this.minecraftVersions = init.MinecraftVersions;
this.eventLogManager = init.EventLogManager;
var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection.SendChannel, controllerState, init.InstanceLogManager);
var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection.MessageSender, controllerState, init.InstanceLogManager);
Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender");
ReceiveAsync<UnregisterWebMessage>(HandleUnregisterWeb);

View File

@@ -6,6 +6,7 @@ using Phantom.Controller.Database.Postgres;
using Phantom.Controller.Services;
using Phantom.Utils.IO;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Server;
using Phantom.Utils.Runtime;
using Phantom.Utils.Tasks;
@@ -64,16 +65,24 @@ try {
EndPoint: agentRpcServerHost,
Certificate: agentKeyData.Certificate,
AuthToken: agentKeyData.AuthToken,
SendQueueCapacity: 100,
PingInterval: TimeSpan.FromSeconds(10)
CommonParameters: new RpcCommonConnectionParameters(
MessageQueueCapacity: 50,
FrameQueueCapacity: 100,
MaxConcurrentlyHandledMessages: 20,
PingInterval: TimeSpan.FromSeconds(10)
)
);
var webConnectionParameters = new RpcServerConnectionParameters(
EndPoint: webRpcServerHost,
Certificate: webKeyData.Certificate,
AuthToken: webKeyData.AuthToken,
SendQueueCapacity: 500,
PingInterval: TimeSpan.FromMinutes(1)
CommonParameters: new RpcCommonConnectionParameters(
MessageQueueCapacity: 250,
FrameQueueCapacity: 500,
MaxConcurrentlyHandledMessages: 100,
PingInterval: TimeSpan.FromMinutes(1)
)
);
LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([

View File

@@ -6,12 +6,14 @@ interface IFrame {
private const byte TypePingId = 0;
private const byte TypePongId = 1;
private const byte TypeMessageId = 2;
private const byte TypeReplyId = 3;
private const byte TypeErrorId = 4;
private const byte TypeAcknowledgmentId = 3;
private const byte TypeReplyId = 4;
private const byte TypeErrorId = 5;
static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]);
static readonly ReadOnlyMemory<byte> TypePong = new ([TypePongId]);
static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]);
static readonly ReadOnlyMemory<byte> TypeAcknowledgment = new ([TypeAcknowledgmentId]);
static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]);
static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]);
@@ -37,6 +39,11 @@ interface IFrame {
await reader.OnMessageFrame(messageFrame, cancellationToken);
break;
case TypeAcknowledgmentId:
var acknowledgmentFrame = await AcknowledgmentFrame.Read(stream, cancellationToken);
reader.OnAcknowledgmentFrame(acknowledgmentFrame);
break;
case TypeReplyId:
var replyFrame = await ReplyFrame.Read(stream, cancellationToken);
reader.OnReplyFrame(replyFrame);

View File

@@ -6,6 +6,7 @@ interface IFrameReader {
ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken);
void OnPongFrame(PongFrame frame);
Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken);
void OnAcknowledgmentFrame(AcknowledgmentFrame frame);
void OnReplyFrame(ReplyFrame frame);
void OnErrorFrame(ErrorFrame frame);
void OnUnknownFrameId(byte frameId);

View File

@@ -0,0 +1,18 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record AcknowledgmentFrame(uint FirstMessageId, uint LastMessageId) : IFrame {
public ReadOnlyMemory<byte> FrameType => IFrame.TypeAcknowledgment;
public async Task Write(Stream stream, CancellationToken cancellationToken = default) {
await RpcSerialization.WriteUnsignedInt(FirstMessageId, stream, cancellationToken);
await RpcSerialization.WriteUnsignedInt(LastMessageId, stream, cancellationToken);
}
public static async Task<AcknowledgmentFrame> Read(Stream stream, CancellationToken cancellationToken) {
var firstMessageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
var lastMessageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
return new AcknowledgmentFrame(firstMessageId, lastMessageId);
}
}

View File

@@ -8,30 +8,26 @@ sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<b
public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
int serializedMessageLength = SerializedMessage.Length;
uint serializedMessageLength = (uint) SerializedMessage.Length;
CheckMessageLength(serializedMessageLength);
await RpcSerialization.WriteUnsignedInt(MessageId, stream, cancellationToken);
await RpcSerialization.WriteUnsignedShort(RegistryCode, stream, cancellationToken);
await RpcSerialization.WriteSignedInt(serializedMessageLength, stream, cancellationToken);
await RpcSerialization.WriteUnsignedInt(serializedMessageLength, stream, cancellationToken);
await stream.WriteAsync(SerializedMessage, cancellationToken);
}
public static async Task<MessageFrame> Read(Stream stream, CancellationToken cancellationToken) {
var messageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
var registryCode = await RpcSerialization.ReadUnsignedShort(stream, cancellationToken);
var serializedMessageLength = await RpcSerialization.ReadSignedInt(stream, cancellationToken);
var serializedMessageLength = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
CheckMessageLength(serializedMessageLength);
var serializedMessage = await RpcSerialization.ReadBytes(serializedMessageLength, stream, cancellationToken);
return new MessageFrame(messageId, registryCode, serializedMessage);
}
private static void CheckMessageLength(int messageLength) {
if (messageLength < 0) {
throw new RpcErrorException("Message length is negative.", RpcError.InvalidData);
}
private static void CheckMessageLength(uint messageLength) {
if (messageLength > MaxMessageBytes) {
throw new RpcErrorException("Message is too large: " + messageLength + " > " + MaxMessageBytes + " bytes", RpcError.MessageTooLarge);
}

View File

@@ -8,30 +8,26 @@ sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> Serializ
public ReadOnlyMemory<byte> FrameType => IFrame.TypeReply;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
int replyLength = SerializedReply.Length;
CheckReplyLength(replyLength);
uint serializedReplyLength = (uint) SerializedReply.Length;
CheckReplyLength(serializedReplyLength);
await RpcSerialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken);
await RpcSerialization.WriteSignedInt(replyLength, stream, cancellationToken);
await RpcSerialization.WriteUnsignedInt(serializedReplyLength, stream, cancellationToken);
await stream.WriteAsync(SerializedReply, cancellationToken);
}
public static async Task<ReplyFrame> Read(Stream stream, CancellationToken cancellationToken) {
var replyingToMessageId = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
var replyLength = await RpcSerialization.ReadSignedInt(stream, cancellationToken);
CheckReplyLength(replyLength);
var reply = await RpcSerialization.ReadBytes(replyLength, stream, cancellationToken);
var serializedReplyLength = await RpcSerialization.ReadUnsignedInt(stream, cancellationToken);
CheckReplyLength(serializedReplyLength);
var serializedReply = await RpcSerialization.ReadBytes(serializedReplyLength, stream, cancellationToken);
return new ReplyFrame(replyingToMessageId, reply);
return new ReplyFrame(replyingToMessageId, serializedReply);
}
private static void CheckReplyLength(int replyLength) {
if (replyLength < 0) {
throw new RpcErrorException("Reply length is negative.", RpcError.InvalidData);
}
private static void CheckReplyLength(uint replyLength) {
if (replyLength > MaxReplyBytes) {
throw new RpcErrorException("Reply is too large: " + replyLength + " > " + MaxReplyBytes + " bytes", RpcError.MessageTooLarge);
throw new RpcErrorException("Reply is too large: " + replyLength + " > " + MaxReplyBytes + " bytes", RpcError.ReplyTooLarge);
}
}
}

View File

@@ -1,6 +0,0 @@
namespace Phantom.Utils.Rpc.Handshake;
public interface IRpcHandshake<T> {
Task Send(Stream stream, CancellationToken cancellationToken);
Task<T> Receive(Stream stream, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,7 @@
namespace Phantom.Utils.Rpc.Handshake;
public enum RpcFinalHandshakeResult : byte {
Error = 0,
NewSession = 1,
ReusedSession = 2,
}

View File

@@ -1,6 +1,9 @@
namespace Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime;
interface IRpcReplySender {
namespace Phantom.Utils.Rpc.Message;
interface IMessageReplySender {
ValueTask SendEmptyReply(uint replyingToMessageId, CancellationToken cancellationToken);
ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken);
ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,39 @@
using Phantom.Utils.Collections;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Threading;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageAcknowledgmentQueue {
private readonly Lock @lock = new ();
private readonly RangeSet<uint> pendingMessageIdRanges = new ();
private readonly ManualResetEventSlim pendingEvent = new ();
public void Enqueue(uint messageId) {
lock (@lock) {
pendingMessageIdRanges.Add(messageId);
}
pendingEvent.Set();
}
public Task Wait(CancellationToken cancellationToken) {
return pendingEvent.WaitHandle.WaitOneAsync(cancellationToken);
}
public List<AcknowledgmentFrame> Drain() {
pendingEvent.Reset();
List<AcknowledgmentFrame> frames = [];
lock (@lock) {
foreach (var range in pendingMessageIdRanges) {
frames.Add(new AcknowledgmentFrame(range.Min, range.Max));
}
pendingMessageIdRanges.Clear();
}
return frames;
}
}

View File

@@ -0,0 +1,23 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageHandler<TMessageBase>(IMessageReceiver<TMessageBase> messageReceiver, IMessageReplySender replySender) {
public IMessageReceiver<TMessageBase> Receiver => messageReceiver;
public void OnPing() {
messageReceiver.OnPing();
}
public ValueTask SendEmptyReply(uint messageId, CancellationToken cancellationToken) {
return replySender.SendEmptyReply(messageId, cancellationToken);
}
public ValueTask SendReply<TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) {
return replySender.SendReply(messageId, reply, cancellationToken);
}
public ValueTask SendError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return replySender.SendError(messageId, error, cancellationToken);
}
}

View File

@@ -6,7 +6,6 @@ sealed class MessageReceiveTracker {
private readonly RangeSet<uint> receivedMessageIds = new ();
public bool ReceiveMessage(uint messageId) {
// TODO reset on session change and invalidate replies
lock (receivedMessageIds) {
return receivedMessageIds.Add(messageId);
}

View File

@@ -9,7 +9,7 @@ namespace Phantom.Utils.Rpc.Message;
public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
private readonly Dictionary<Type, ushort> typeToCodeMapping = new ();
private readonly Dictionary<ushort, Type> codeToTypeMapping = new ();
private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, RpcMessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new ();
private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, MessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new ();
public void Add<TMessage>(ushort code) where TMessage : TMessageBase {
if (HasReplyType(typeof(TMessage))) {
@@ -50,7 +50,7 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
}
}
internal async Task Handle(MessageFrame frame, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
internal async Task Handle(MessageFrame frame, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
uint messageId = frame.MessageId;
if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) {
@@ -62,26 +62,35 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
}
}
private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
TMessage message;
try {
message = RpcSerialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
await OnMessageDeserializationError<TMessage>(messageId, e, handler, cancellationToken);
return;
}
handler.Receiver.OnMessage(message);
try {
handler.Receiver.OnMessage(message);
} catch (Exception e) {
await OnMessageHandlingError<TMessage>(messageId, e, handler, cancellationToken);
return;
}
try {
await handler.SendEmptyReply(messageId, cancellationToken);
} catch (Exception e) {
await OnMessageReplyingError<TMessage>(messageId, e, handler, cancellationToken);
}
}
private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
TMessage message;
try {
message = RpcSerialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
await OnMessageDeserializationError<TMessage>(messageId, e, handler, cancellationToken);
return;
}
@@ -89,16 +98,29 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
try {
reply = await handler.Receiver.OnMessage<TMessage, TReply>(message, cancellationToken);
} catch (Exception e) {
logger.Error(e, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
await OnMessageHandlingError<TMessage>(messageId, e, handler, cancellationToken);
return;
}
try {
await handler.SendReply(messageId, reply, cancellationToken);
} catch (Exception e) {
logger.Error(e, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
await OnMessageReplyingError<TMessage>(messageId, e, handler, cancellationToken);
}
}
private async Task OnMessageDeserializationError<TMessage>(uint messageId, Exception exception, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
logger.Error(exception, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
}
private async Task OnMessageHandlingError<TMessage>(uint messageId, Exception exception, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
logger.Error(exception, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
}
private async Task OnMessageReplyingError<TMessage>(uint messageId, Exception exception, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
logger.Error(exception, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageReplyingError, cancellationToken);
}
}

View File

@@ -20,30 +20,30 @@ sealed class MessageReplyTracker {
public async Task<TReply> WaitForReply<TReply>(uint messageId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) {
if (!replyTasks.TryGetValue(messageId, out var completionSource)) {
logger.Warning("No reply callback for id {MessageId}.", messageId);
throw new ArgumentException("No reply callback for id: " + messageId, nameof(messageId));
logger.Warning("No reply callback for message {MessageId}.", messageId);
throw new ArgumentException("No reply callback for message: " + messageId, nameof(messageId));
}
try {
ReadOnlyMemory<byte> serializedReply = await completionSource.Task.WaitAsync(waitForReplyTime, cancellationToken);
return RpcSerialization.Deserialize<TReply>(serializedReply);
} catch (TimeoutException) {
logger.Debug("Timed out waiting for reply with id {MessageId}.", messageId);
logger.Debug("Timed out waiting for reply with message {MessageId}.", messageId);
throw;
} catch (OperationCanceledException) {
logger.Debug("Cancelled waiting for reply with id {MessageId}.", messageId);
logger.Debug("Cancelled waiting for reply with message {MessageId}.", messageId);
throw;
} catch (Exception e) {
logger.Warning(e, "Error processing reply with id {MessageId}.", messageId);
logger.Warning(e, "Error processing reply with message {MessageId}.", messageId);
throw;
} finally {
ForgetReply(messageId);
}
}
public void ForgetReply(uint messageId) {
public void ReceiveReply(uint messageId, ReadOnlyMemory<byte> serializedReply) {
if (replyTasks.TryRemove(messageId, out var task)) {
task.SetCanceled();
task.SetResult(serializedReply);
}
}
@@ -53,12 +53,9 @@ sealed class MessageReplyTracker {
}
}
public void ReceiveReply(uint messageId, ReadOnlyMemory<byte> serializedReply) {
public void ForgetReply(uint messageId) {
if (replyTasks.TryRemove(messageId, out var task)) {
task.SetResult(serializedReply);
}
else {
logger.Warning("Received a reply with id {MessageId} but no registered callback.", messageId);
task.SetCanceled();
}
}
}

View File

@@ -0,0 +1,117 @@
using System.Threading.Channels;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Utils.Rpc.Message;
public sealed class MessageSender<TMessageBase> {
private readonly ILogger logger;
private readonly MessageRegistry<TMessageBase> messageRegistry;
private readonly MessageReplyTracker messageReplyTracker;
private uint nextMessageId;
private readonly Channel<PreparedMessage> messageQueue;
private readonly Task messageQueueTask;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
internal MessageSender(string loggerName, MessageRegistry<TMessageBase> messageRegistry, RpcCommonConnectionParameters connectionParameters) {
this.logger = PhantomLogger.Create<MessageSender<TMessageBase>>(loggerName);
this.messageRegistry = messageRegistry;
this.messageReplyTracker = new MessageReplyTracker(loggerName);
this.messageQueue = Channel.CreateBounded<PreparedMessage>(new BoundedChannelOptions(connectionParameters.MessageQueueCapacity) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
this.messageQueueTask = ProcessQueue();
}
public bool TrySend<TMessage>(TMessage message) where TMessage : TMessageBase {
return messageQueue.Writer.TryWrite(PrepareMessage(message));
}
public async ValueTask Send<TMessage>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase {
await messageQueue.Writer.WriteAsync(PrepareMessage(message), cancellationToken);
}
public async Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
var preparedMessage = PrepareMessage(message);
var messageId = preparedMessage.MessageId;
messageReplyTracker.RegisterReply(messageId);
try {
await messageQueue.Writer.WriteAsync(preparedMessage, cancellationToken);
} catch (Exception) {
messageReplyTracker.ForgetReply(messageId);
throw;
}
return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken);
}
private PreparedMessage PrepareMessage(TMessageBase message) {
uint messageId = Interlocked.Increment(ref nextMessageId);
return new PreparedMessage(messageId, message);
}
private readonly record struct PreparedMessage(uint MessageId, TMessageBase Message);
private async Task ProcessQueue() {
CancellationToken cancellationToken = shutdownCancellationTokenSource.Token;
Queue<PreparedMessage> messagesInTransit = new (capacity: 10);
while (await messageQueue.Reader.WaitToReadAsync(cancellationToken)) {
do {
while (messagesInTransit.Count < messagesInTransit.Capacity) {
if (messageQueue.Reader.TryRead(out var nextMessage)) {
messagesInTransit.Enqueue(nextMessage);
}
else {
break;
}
}
RpcFrameSender<TMessageBase> frameSender;
foreach ((uint messageId, TMessageBase message) in messagesInTransit) {
await frameSender.SendFrame(messageRegistry.CreateFrame(messageId, message), cancellationToken);
}
} while (messagesInTransit.Count > 0);
}
}
internal void ReceiveAcknowledgment(AcknowledgmentFrame frame) {
// TODO wait for reply instead?
}
internal void ReceiveReply(ReplyFrame frame) {
messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply);
}
internal void ReceiveError(ErrorFrame frame) {
messageReplyTracker.FailReply(frame.ReplyingToMessageId, RpcErrorException.From(frame.Error));
}
internal async Task Close() {
messageQueue.Writer.TryComplete();
try {
await messageQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing message queue before timeout, forcibly shutting it down.");
await shutdownCancellationTokenSource.CancelAsync();
} catch (Exception) {
// Ignore.
}
messageQueueTask.Dispose();
shutdownCancellationTokenSource.Dispose();
}
}

View File

@@ -11,29 +11,21 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection);
}
private readonly string loggerName;
private readonly ILogger logger;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly RpcClientToServerConnection connection;
private readonly RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage> connection;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
public RpcSendChannel<TClientToServerMessage> SendChannel { get; }
public MessageSender<TClientToServerMessage> MessageSender { get; }
private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.messageRegistry = messageDefinitions.ToClient;
this.connection = new RpcClientToServerConnection(loggerName, connector, connection);
this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters.Common, this.connection, messageDefinitions.ToServer);
this.connection = new RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage>(loggerName, messageDefinitions.ToClient, connectionParameters.Common, connector, connection);
this.MessageSender = new MessageSender<TClientToServerMessage>(loggerName, messageDefinitions.ToServer, connectionParameters.Common);
}
public async Task Listen(IMessageReceiver<TServerToClientMessage> receiver) {
var messageHandler = new RpcMessageHandler<TServerToClientMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
public async Task Listen(IMessageReceiver<TServerToClientMessage> messageReceiver) {
try {
await connection.ReadConnection(frameReader, shutdownCancellationTokenSource.Token);
await connection.ReadConnection(MessageSender, messageReceiver, shutdownCancellationTokenSource.Token);
} catch (OperationCanceledException) {
// Ignore.
}
@@ -43,9 +35,9 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
logger.Information("Shutting down client...");
try {
await SendChannel.Close();
await MessageSender.Close();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing send channel.");
logger.Error(e, "Caught exception while closing message sender.");
}
try {
@@ -61,6 +53,5 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
public void Dispose() {
connection.Dispose();
SendChannel.Dispose();
}
}

View File

@@ -3,15 +3,12 @@ using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime.Client;
public readonly record struct RpcClientConnectionParameters(
public sealed record RpcClientConnectionParameters(
string Host,
ushort Port,
string DistinguishedName,
RpcCertificateThumbprint CertificateThumbprint,
AuthToken AuthToken,
IRpcClientHandshake Handshake,
ushort SendQueueCapacity,
TimeSpan PingInterval
) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}
RpcCommonConnectionParameters CommonParameters
);

View File

@@ -1,12 +1,19 @@
using System.Net.Sockets;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client;
sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection initialConnection) : IRpcConnectionProvider, IDisposable {
private readonly ILogger logger = PhantomLogger.Create<RpcClientToServerConnection>(loggerName);
sealed class RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage>(
string loggerName,
MessageRegistry<TServerToClientMessage> messageRegistry,
RpcCommonConnectionParameters connectionParameters,
RpcClientToServerConnector connector,
RpcClientToServerConnector.Connection initialConnection
) : IRpcConnectionProvider, IDisposable {
private readonly ILogger logger = PhantomLogger.Create<RpcClientToServerConnection<TClientToServerMessage, TServerToClientMessage>>(loggerName);
private readonly SemaphoreSlim semaphore = new (1);
private RpcClientToServerConnector.Connection currentConnection = initialConnection;
@@ -30,9 +37,11 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
}
}
public async Task ReadConnection(IFrameReader frameReader, CancellationToken cancellationToken) {
public async Task ReadConnection(MessageSender<TClientToServerMessage> messageSender, IMessageReceiver<TServerToClientMessage> messageReceiver, CancellationToken cancellationToken) {
RpcClientToServerConnector.Connection? connection = null;
var sessionState = NewSessionState(messageSender, messageReceiver);
try {
while (true) {
connection?.Dispose();
@@ -47,8 +56,13 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
continue;
}
if (connection.RestartSession) {
await sessionState.FrameSender.ShutdownNow();
sessionState = NewSessionState(messageSender, messageReceiver);
}
try {
await IFrame.ReadFrom(connection.Stream, frameReader, cancellationToken);
await IFrame.ReadFrom(connection.Stream, sessionState.FrameReader, cancellationToken);
} catch (OperationCanceledException) {
throw;
} catch (EndOfStreamException) {
@@ -66,6 +80,12 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
}
}
} finally {
try {
await sessionState.FrameSender.Shutdown();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing frame sender.");
}
if (connection != null) {
try {
await connection.Disconnect();
@@ -76,6 +96,15 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
}
}
private SessionState NewSessionState(MessageSender<TClientToServerMessage> messageSender, IMessageReceiver<TServerToClientMessage> messageReceiver) {
var frameSender = new RpcFrameSender<TClientToServerMessage>(loggerName, connectionParameters, this);
var messageHandler = new MessageHandler<TServerToClientMessage>(messageReceiver, frameSender);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageRegistry, messageHandler, messageSender, frameSender);
return new SessionState(frameSender, frameReader);
}
private readonly record struct SessionState(RpcFrameSender<TClientToServerMessage> FrameSender, RpcFrameReader<TClientToServerMessage, TServerToClientMessage> FrameReader);
public void StopReconnecting() {
newConnectionCancellationTokenSource.Cancel();
}

View File

@@ -4,12 +4,13 @@ using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Collections;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Handshake;
using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client;
internal sealed class RpcClientToServerConnector {
sealed class RpcClientToServerConnector {
private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100);
private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30);
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10);
@@ -112,9 +113,10 @@ internal sealed class RpcClientToServerConnector {
try {
stream = new SslStream(new NetworkStream(clientSocket, ownsSocket: false), leaveInnerStreamOpen: false);
if (await FinalizeConnection(stream, cancellationToken)) {
var handshakeResult = await FinalizeConnection(stream, cancellationToken);
if (handshakeResult != RpcFinalHandshakeResult.Error) {
logger.Information("Connected to {Host}:{Port}.", parameters.Host, parameters.Port);
return new Connection(clientSocket, stream);
return new Connection(clientSocket, stream, RestartSession: handshakeResult == RpcFinalHandshakeResult.NewSession);
}
} catch (Exception e) {
logger.Error(e, "Caught unhandled exception.");
@@ -130,7 +132,7 @@ internal sealed class RpcClientToServerConnector {
return null;
}
private async Task<bool> FinalizeConnection(SslStream stream, CancellationToken cancellationToken) {
private async Task<RpcFinalHandshakeResult> FinalizeConnection(SslStream stream, CancellationToken cancellationToken) {
try {
loggedCertificateValidationError = false;
await stream.AuthenticateAsClientAsync(sslOptions, cancellationToken);
@@ -139,7 +141,7 @@ internal sealed class RpcClientToServerConnector {
logger.Error(e, "Could not establish a secure connection.");
}
return false;
return RpcFinalHandshakeResult.Error;
}
logger.Information("Established a secure connection.");
@@ -148,30 +150,31 @@ internal sealed class RpcClientToServerConnector {
return await PerformApplicationHandshake(stream, cancellationToken);
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
return false;
return RpcFinalHandshakeResult.Error;
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
return false;
return RpcFinalHandshakeResult.Error;
}
}
private async Task<bool> PerformApplicationHandshake(Stream stream, CancellationToken cancellationToken) {
private async Task<RpcFinalHandshakeResult> PerformApplicationHandshake(Stream stream, CancellationToken cancellationToken) {
await RpcSerialization.WriteAuthToken(parameters.AuthToken, stream, cancellationToken);
if (await RpcSerialization.ReadByte(stream, cancellationToken) != 1) {
logger.Error("Server rejected authorization token.");
return false;
return RpcFinalHandshakeResult.Error;
}
await RpcSerialization.WriteGuid(sessionId, stream, cancellationToken);
await parameters.Handshake.Perform(stream, cancellationToken);
if (await RpcSerialization.ReadByte(stream, cancellationToken) != 1) {
RpcFinalHandshakeResult finalHandshakeResult = (RpcFinalHandshakeResult) await RpcSerialization.ReadByte(stream, cancellationToken);
if (finalHandshakeResult == RpcFinalHandshakeResult.Error) {
logger.Error("Server rejected client due to unknown error.");
return false;
}
return true;
return finalHandshakeResult;
}
private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) {
@@ -207,7 +210,7 @@ internal sealed class RpcClientToServerConnector {
await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token);
}
internal sealed record Connection(Socket Socket, Stream Stream) : IDisposable {
internal sealed record Connection(Socket Socket, Stream Stream, bool RestartSession) : IDisposable {
public async Task Disconnect() {
await DisconnectSocket(Socket, Stream);
}

View File

@@ -1,6 +1,8 @@
namespace Phantom.Utils.Rpc.Runtime;
readonly record struct RpcCommonConnectionParameters(
ushort SendQueueCapacity,
public sealed record RpcCommonConnectionParameters(
ushort MessageQueueCapacity,
ushort FrameQueueCapacity,
ushort MaxConcurrentlyHandledMessages,
TimeSpan PingInterval
);

View File

@@ -6,5 +6,6 @@ enum RpcError : byte {
MessageTooLarge = 2,
MessageDeserializationError = 3,
MessageHandlingError = 4,
MessageAlreadyHandled = 5,
MessageReplyingError = 5,
ReplyTooLarge = 6,
}

View File

@@ -3,13 +3,14 @@
sealed class RpcErrorException : Exception {
internal static RpcErrorException From(RpcError error) {
return error switch {
RpcError.InvalidData => new RpcErrorException("Invalid data", error),
RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error),
RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error),
RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error),
RpcError.MessageAlreadyHandled => new RpcErrorException("Message already handled", error),
_ => new RpcErrorException("Unknown error", error),
RpcError.InvalidData => new RpcErrorException("Invalid data.", error),
RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code.", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large.", error),
RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error.", error),
RpcError.MessageHandlingError => new RpcErrorException("Message handling error.", error),
RpcError.MessageReplyingError => new RpcErrorException("Message replying error.", error),
RpcError.ReplyTooLarge => new RpcErrorException("Reply is too large.", error),
_ => new RpcErrorException("Unknown error.", error),
};
}

View File

@@ -8,43 +8,71 @@ namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
string loggerName,
RpcCommonConnectionParameters connectionParameters,
MessageRegistry<TReceivedMessage> messageRegistry,
MessageReceiveTracker messageReceiveTracker,
RpcMessageHandler<TReceivedMessage> messageHandler,
RpcSendChannel<TSentMessage> sendChannel
MessageHandler<TReceivedMessage> messageHandler,
MessageSender<TSentMessage> messageSender,
RpcFrameSender<TSentMessage> frameSender
) : IFrameReader {
private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName);
private readonly ushort maxConcurrentlyHandledMessages = connectionParameters.MaxConcurrentlyHandledMessages;
private readonly SemaphoreSlim messageHandlingSemaphore = new (connectionParameters.MaxConcurrentlyHandledMessages);
public ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken) {
messageHandler.OnPing();
return sendChannel.SendPong(pingTime, cancellationToken);
return frameSender.SendPong(pingTime, cancellationToken);
}
public void OnPongFrame(PongFrame frame) {
sendChannel.ReceivePong(frame);
frameSender.ReceivePong(frame);
}
public Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) {
if (!messageReceiveTracker.ReceiveMessage(frame.MessageId)) {
public async Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) {
if (!frameSender.ReceiveMessage(frame)) {
logger.Warning("Received duplicate message {MessageId}.", frame.MessageId);
return messageHandler.SendError(frame.MessageId, RpcError.MessageAlreadyHandled, cancellationToken).AsTask();
return;
}
if (messageRegistry.TryGetType(frame, out var messageType)) {
logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length);
}
return messageRegistry.Handle(frame, messageHandler, cancellationToken);
Task acquireSemaphore = messageHandlingSemaphore.WaitAsync(cancellationToken);
try {
if (!acquireSemaphore.IsCompleted) {
logger.Warning("Reached limit for concurrently handled messages ({Limit}).", maxConcurrentlyHandledMessages);
}
await acquireSemaphore;
_ = HandleMessage(frame, cancellationToken);
} catch (Exception) {
messageHandlingSemaphore.Release();
throw;
}
}
private async Task HandleMessage(MessageFrame frame, CancellationToken cancellationToken) {
try {
await messageRegistry.Handle(frame, messageHandler, cancellationToken);
} finally {
messageHandlingSemaphore.Release();
}
}
public void OnAcknowledgmentFrame(AcknowledgmentFrame frame) {
logger.Verbose("Received acknowledgment of messages {FirstMessageId}-{LastMessageId}.", frame.FirstMessageId, frame.LastMessageId);
messageSender.ReceiveAcknowledgment(frame);
}
public void OnReplyFrame(ReplyFrame frame) {
logger.Verbose("Received reply to message {MesageId} ({Bytes} B).", frame.ReplyingToMessageId, frame.SerializedReply.Length);
sendChannel.ReceiveReply(frame);
messageSender.ReceiveReply(frame);
}
public void OnErrorFrame(ErrorFrame frame) {
logger.Warning("Received error response to message {MesageId}: {Error}", frame.ReplyingToMessageId, frame.Error);
sendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error);
messageSender.ReceiveError(frame);
}
public void OnUnknownFrameId(byte frameId) {

View File

@@ -0,0 +1,176 @@
using System.Diagnostics.CodeAnalysis;
using System.Threading.Channels;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameSender<TMessageBase> : IMessageReplySender {
private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly MessageAcknowledgmentQueue acknowledgmentQueue = new ();
private readonly Channel<IFrame> frameQueue;
private readonly Task frameQueueTask;
private readonly Task pingTask;
private readonly Task acknowledgementTask;
private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private readonly CancellationTokenSource acknowledgementCancellationTokenSource = new ();
private TaskCompletionSource<DateTimeOffset>? pongTask;
internal RpcFrameSender(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider) {
this.logger = PhantomLogger.Create<RpcFrameSender<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider;
this.frameQueue = Channel.CreateBounded<IFrame>(new BoundedChannelOptions(connectionParameters.FrameQueueCapacity) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
this.frameQueueTask = ProcessQueue();
this.pingTask = PingSchedule(connectionParameters.PingInterval);
this.acknowledgementTask = AcknowledgementSchedule();
}
public async ValueTask SendPong(DateTimeOffset pingTime, CancellationToken cancellationToken) {
await SendFrame(new PongFrame(pingTime), cancellationToken);
}
async ValueTask IMessageReplySender.SendEmptyReply(uint replyingToMessageId, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, ReadOnlyMemory<byte>.Empty), cancellationToken);
}
async ValueTask IMessageReplySender.SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, RpcSerialization.Serialize(reply)), cancellationToken);
}
async ValueTask IMessageReplySender.SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) {
await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken);
}
public async ValueTask SendFrame(IFrame frame, CancellationToken cancellationToken) {
if (!frameQueue.Writer.TryWrite(frame)) {
logger.Warning("Send queue is full!");
await frameQueue.Writer.WriteAsync(frame, cancellationToken);
}
}
private async Task ProcessQueue() {
CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
await foreach (IFrame frame in frameQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) {
try {
Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
break;
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Retry.
}
}
}
}
[SuppressMessage("ReSharper", "FunctionNeverReturns")]
private async Task PingSchedule(TimeSpan interval) {
CancellationToken cancellationToken = pingCancellationTokenSource.Token;
while (true) {
await Task.Delay(interval, cancellationToken);
pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!frameQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue.");
continue;
}
DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
}
}
[SuppressMessage("ReSharper", "FunctionNeverReturns")]
private async Task AcknowledgementSchedule() {
CancellationToken cancellationToken = acknowledgementCancellationTokenSource.Token;
TimeSpan interval = TimeSpan.FromSeconds(1);
while (true) {
await acknowledgmentQueue.Wait(cancellationToken);
await Task.Delay(interval, cancellationToken);
foreach (var acknowledgmentFrame in acknowledgmentQueue.Drain()) {
await SendFrame(acknowledgmentFrame, cancellationToken);
}
}
}
public bool ReceiveMessage(MessageFrame frame) {
acknowledgmentQueue.Enqueue(frame.MessageId);
return messageReceiveTracker.ReceiveMessage(frame.MessageId);
}
public void ReceivePong(PongFrame frame) {
pongTask?.TrySetResult(frame.PingTime);
}
public async Task Shutdown() {
await pingCancellationTokenSource.CancelAsync();
await acknowledgementCancellationTokenSource.CancelAsync();
frameQueue.Writer.TryComplete();
await pingTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await acknowledgementTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
try {
await frameQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing frame queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) {
// Ignore.
}
Dispose();
}
public async Task ShutdownNow() {
await pingCancellationTokenSource.CancelAsync();
await acknowledgementCancellationTokenSource.CancelAsync();
await sendQueueCancellationTokenSource.CancelAsync();
frameQueue.Writer.TryComplete();
await pingTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await acknowledgementTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await frameQueueTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
Dispose();
}
private void Dispose() {
frameQueueTask.Dispose();
pingTask.Dispose();
acknowledgementTask.Dispose();
sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose();
acknowledgementCancellationTokenSource.Dispose();
}
}

View File

@@ -1,19 +0,0 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcMessageHandler<TMessageBase>(IMessageReceiver<TMessageBase> receiver, IRpcReplySender replySender) {
public IMessageReceiver<TMessageBase> Receiver => receiver;
public void OnPing() {
receiver.OnPing();
}
public ValueTask SendReply<TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) {
return replySender.SendReply(messageId, reply, cancellationToken);
}
public ValueTask SendError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return replySender.SendError(messageId, error, cancellationToken);
}
}

View File

@@ -1,168 +0,0 @@
using System.Diagnostics.CodeAnalysis;
using System.Threading.Channels;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable {
private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageRegistry<TMessageBase> messageRegistry;
private readonly MessageReplyTracker messageReplyTracker;
private readonly Channel<IFrame> sendQueue;
private readonly Task sendQueueTask;
private readonly Task pingTask;
private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private uint nextMessageId;
private TaskCompletionSource<DateTimeOffset>? pongTask;
internal RpcSendChannel(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry) {
this.logger = PhantomLogger.Create<RpcSendChannel<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider;
this.messageRegistry = messageRegistry;
this.messageReplyTracker = new MessageReplyTracker(loggerName);
this.sendQueue = Channel.CreateBounded<IFrame>(new BoundedChannelOptions(connectionParameters.SendQueueCapacity) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
this.sendQueueTask = ProcessSendQueue();
this.pingTask = Ping(connectionParameters.PingInterval);
}
internal async ValueTask SendPong(DateTimeOffset pingTime, CancellationToken cancellationToken) {
await SendFrame(new PongFrame(pingTime), cancellationToken);
}
public bool TrySendMessage<TMessage>(TMessage message) where TMessage : TMessageBase {
return sendQueue.Writer.TryWrite(NextMessageFrame(message));
}
public async ValueTask SendMessage<TMessage>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase {
await SendFrame(NextMessageFrame(message), cancellationToken);
}
public async Task<TReply> SendMessage<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
MessageFrame frame = NextMessageFrame(message);
uint messageId = frame.MessageId;
messageReplyTracker.RegisterReply(messageId);
try {
await SendFrame(frame, cancellationToken);
} catch (Exception) {
messageReplyTracker.ForgetReply(messageId);
throw;
}
return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken);
}
async ValueTask IRpcReplySender.SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, RpcSerialization.Serialize(reply)), cancellationToken);
}
async ValueTask IRpcReplySender.SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) {
await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken);
}
private async ValueTask SendFrame(IFrame frame, CancellationToken cancellationToken) {
if (!sendQueue.Writer.TryWrite(frame)) {
logger.Warning("Send queue is full!");
await sendQueue.Writer.WriteAsync(frame, cancellationToken);
}
}
private MessageFrame NextMessageFrame<TMessage>(TMessage message) where TMessage : TMessageBase {
uint messageId = Interlocked.Increment(ref nextMessageId);
return messageRegistry.CreateFrame(messageId, message);
}
private async Task ProcessSendQueue() {
CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) {
try {
Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
break;
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Retry.
}
}
}
}
[SuppressMessage("ReSharper", "FunctionNeverReturns")]
private async Task Ping(TimeSpan interval) {
CancellationToken cancellationToken = pingCancellationTokenSource.Token;
while (true) {
await Task.Delay(interval, cancellationToken);
pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue.");
continue;
}
DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
}
}
internal void ReceivePong(PongFrame frame) {
pongTask?.TrySetResult(frame.PingTime);
}
internal void ReceiveReply(ReplyFrame frame) {
messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply);
}
internal void ReceiveError(uint messageId, RpcError error) {
messageReplyTracker.FailReply(messageId, RpcErrorException.From(error));
}
internal async Task Close() {
await pingCancellationTokenSource.CancelAsync();
sendQueue.Writer.TryComplete();
await pingTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
try {
await sendQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing send queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) {
// Ignore.
}
}
public void Dispose() {
sendQueueTask.Dispose();
sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose();
}
}

View File

@@ -91,6 +91,12 @@ public static class RpcSerialization {
return buffer;
}
public static async ValueTask<ReadOnlyMemory<byte>> ReadBytes(uint length, Stream stream, CancellationToken cancellationToken) {
Memory<byte> buffer = new byte[length];
await stream.ReadExactlyAsync(buffer, cancellationToken);
return buffer;
}
public static ReadOnlyMemory<byte> Serialize<T>(T value) {
var buffer = new ArrayBufferWriter<byte>();
MemoryPackSerializer.Serialize(buffer, value, SerializerOptions);

View File

@@ -25,7 +25,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
public async Task<bool> Run(CancellationToken shutdownToken) {
EndPoint endPoint = connectionParameters.EndPoint;
SslServerAuthenticationOptions sslOptions = new () {
var sslOptions = new SslServerAuthenticationOptions {
AllowRenegotiation = false,
AllowTlsResume = true,
CertificateRevocationCheckMode = X509RevocationMode.NoCheck,
@@ -35,6 +35,15 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
ServerCertificate = connectionParameters.Certificate.Certificate,
};
var serverData = new SharedData(
connectionParameters.Common,
connectionParameters.AuthToken,
messageDefinitions.ToServer,
clientHandshake,
clientRegistrar,
clientSessions
);
try {
using var serverSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
@@ -51,7 +60,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
while (true) {
Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken);
clients.Add(new Client(loggerName, messageDefinitions, clientHandshake, clientRegistrar, clientSessions, clientSocket, sslOptions, connectionParameters.AuthToken, shutdownToken));
clients.Add(new Client(loggerName, serverData, clientSocket, sslOptions, shutdownToken));
clients.RemoveAll(static client => client.Task.IsCompleted);
}
} catch (OperationCanceledException) {
@@ -83,6 +92,15 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
logger.Information("Server stopped.");
}
private readonly record struct SharedData(
RpcCommonConnectionParameters ConnectionParameters,
AuthToken AuthToken,
MessageRegistry<TClientToServerMessage> MessageRegistry,
IRpcServerClientHandshake<THandshakeResult> ClientHandshake,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> ClientRegistrar,
RpcServerClientSessions<TServerToClientMessage> ClientSessions
);
private sealed class Client {
private static TimeSpan DisconnectTimeout => TimeSpan.FromSeconds(10);
@@ -99,34 +117,22 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
public Task Task { get; }
private ILogger logger;
private readonly IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions;
private readonly IRpcServerClientHandshake<THandshakeResult> clientHandshake;
private readonly IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> clientRegistrar;
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions;
private readonly SharedData sharedData;
private readonly Socket socket;
private readonly SslServerAuthenticationOptions sslOptions;
private readonly AuthToken authToken;
private readonly CancellationToken shutdownToken;
public Client(
string serverLoggerName,
IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions,
IRpcServerClientHandshake<THandshakeResult> clientHandshake,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> clientRegistrar,
RpcServerClientSessions<TServerToClientMessage> clientSessions,
SharedData sharedData,
Socket socket,
SslServerAuthenticationOptions sslOptions,
AuthToken authToken,
CancellationToken shutdownToken
) {
this.logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>, Client>(PhantomLogger.ConcatNames(serverLoggerName, GetAddressDescriptor(socket)));
this.messageDefinitions = messageDefinitions;
this.clientHandshake = clientHandshake;
this.clientRegistrar = clientRegistrar;
this.clientSessions = clientSessions;
this.sharedData = sharedData;
this.socket = socket;
this.sslOptions = sslOptions;
this.authToken = authToken;
this.shutdownToken = shutdownToken;
this.Task = Run();
@@ -207,7 +213,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
try {
var suppliedAuthToken = await RpcSerialization.ReadAuthToken(stream, cancellationToken);
if (!authToken.FixedTimeEquals(suppliedAuthToken)) {
if (!sharedData.AuthToken.FixedTimeEquals(suppliedAuthToken)) {
logger.Warning("Rejected client, invalid authorization token.");
await RpcSerialization.WriteByte(value: 0, stream, cancellationToken);
return null;
@@ -217,17 +223,17 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
}
var sessionId = await RpcSerialization.ReadGuid(stream, cancellationToken);
var session = clientSessions.GetOrCreateSession(sessionId);
var session = sharedData.ClientSessions.GetOrCreateSession(sessionId);
logger.Information("Client connected with session {SessionId}, new logger name: {LoggerName}", sessionId, session.LoggerName);
logger.Information("Client connected with session {SessionId}, new logger name: {LoggerName}", session.SessionId, session.LoggerName);
logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>, Client>(session.LoggerName);
EstablishedConnection? establishedConnection;
switch (await clientHandshake.Perform(stream, cancellationToken)) {
switch (await sharedData.ClientHandshake.Perform(stream, cancellationToken)) {
case Left<THandshakeResult, Exception>(var handshakeResult):
try {
var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(clientSessions, sessionId, messageDefinitions.ToServer, stream, session);
var messageReceiver = clientRegistrar.Register(connection, handshakeResult);
var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(sharedData.ConnectionParameters, sharedData.MessageRegistry, sharedData.ClientSessions, session, stream);
var messageReceiver = sharedData.ClientRegistrar.Register(connection, handshakeResult);
establishedConnection = new EstablishedConnection(session, connection, messageReceiver);
} catch (Exception e) {
@@ -247,14 +253,17 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
break;
}
RpcFinalHandshakeResult finalHandshakeResult;
if (establishedConnection == null) {
await RpcSerialization.WriteByte(value: 0, stream, cancellationToken);
return null;
finalHandshakeResult = RpcFinalHandshakeResult.Error;
}
else {
await RpcSerialization.WriteByte(value: 1, stream, cancellationToken);
return establishedConnection;
bool isNewSession = session.MarkFirstTimeUse();
finalHandshakeResult = isNewSession ? RpcFinalHandshakeResult.NewSession : RpcFinalHandshakeResult.ReusedSession;
}
await RpcSerialization.WriteByte((byte) finalHandshakeResult, stream, cancellationToken);
return establishedConnection;
} catch (OperationCanceledException) {
throw;
} catch (EndOfStreamException) {

View File

@@ -1,18 +1,31 @@
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProvider {
public string LoggerName { get; }
public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
public MessageReceiveTracker MessageReceiveTracker { get; }
private readonly ILogger logger;
public string LoggerName { get; }
public Guid SessionId { get; }
public MessageSender<TServerToClientMessage> MessageSender { get; }
public RpcFrameSender<TServerToClientMessage> FrameSender { get; }
private bool isNew = true;
private TaskCompletionSource<Stream> nextStream = new ();
public RpcServerClientSession(string loggerName, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
public RpcServerClientSession(string loggerName, Guid sessionId, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.logger = PhantomLogger.Create<RpcServerClientSession<TServerToClientMessage>>(loggerName);
this.LoggerName = loggerName;
this.SendChannel = new RpcSendChannel<TServerToClientMessage>(loggerName, connectionParameters, this, messageRegistry);
this.MessageReceiveTracker = new MessageReceiveTracker();
this.SessionId = sessionId;
this.MessageSender = new MessageSender<TServerToClientMessage>(loggerName, messageRegistry, connectionParameters);
this.FrameSender = new RpcFrameSender<TServerToClientMessage>(loggerName, connectionParameters, this);
}
/// <returns>Whether this was a new session. If it was a new session, it will be marked as used.</returns>
public bool MarkFirstTimeUse() {
return Interlocked.CompareExchange(ref isNew, value: true, comparand: false);
}
public void OnConnected(Stream stream) {
@@ -39,7 +52,7 @@ sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProv
}
}
public Task Close() {
public async Task Close() {
lock (this) {
if (!nextStream.TrySetCanceled()) {
nextStream = new TaskCompletionSource<Stream>();
@@ -47,6 +60,16 @@ sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProv
}
}
return SendChannel.Close();
try {
await MessageSender.Close();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing message sender.");
}
try {
await FrameSender.Shutdown();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing send channel.");
}
}
}

View File

@@ -29,7 +29,7 @@ sealed class RpcServerClientSessions<TServerToClientMessage> {
}
private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId, string loggerName) {
return new RpcServerClientSession<TServerToClientMessage>(loggerName, connectionParameters, messageRegistry);
return new RpcServerClientSession<TServerToClientMessage>(loggerName, sessionId, connectionParameters, messageRegistry);
}
public Task CloseSession(Guid sessionId) {

View File

@@ -4,12 +4,9 @@ using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime.Server;
public readonly record struct RpcServerConnectionParameters(
public sealed record RpcServerConnectionParameters(
EndPoint EndPoint,
RpcServerCertificate Certificate,
AuthToken AuthToken,
ushort SendQueueCapacity,
TimeSpan PingInterval
) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}
RpcCommonConnectionParameters CommonParameters
);

View File

@@ -6,32 +6,29 @@ using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> {
private readonly string loggerName;
private readonly ILogger logger;
private readonly RpcServerClientSessions<TServerToClientMessage> sessions;
private readonly RpcCommonConnectionParameters connectionParameters;
private readonly MessageRegistry<TClientToServerMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker;
private readonly RpcServerClientSessions<TServerToClientMessage> sessions;
private readonly RpcServerClientSession<TServerToClientMessage> session;
private readonly Stream stream;
private readonly CancellationTokenSource closeCancellationTokenSource = new ();
public Guid SessionId { get; }
public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
public Guid SessionId => session.SessionId;
public MessageSender<TServerToClientMessage> MessageSender => session.MessageSender;
internal RpcServerToClientConnection(RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId, MessageRegistry<TClientToServerMessage> messageRegistry, Stream stream, RpcServerClientSession<TServerToClientMessage> session) {
this.loggerName = session.LoggerName;
this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.sessions = sessions;
internal RpcServerToClientConnection(RpcCommonConnectionParameters connectionParameters, MessageRegistry<TClientToServerMessage> messageRegistry, RpcServerClientSessions<TServerToClientMessage> sessions, RpcServerClientSession<TServerToClientMessage> session, Stream stream) {
this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(session.LoggerName);
this.connectionParameters = connectionParameters;
this.messageRegistry = messageRegistry;
this.messageReceiveTracker = session.MessageReceiveTracker;
this.sessions = sessions;
this.session = session;
this.stream = stream;
this.SessionId = sessionId;
this.SendChannel = session.SendChannel;
}
internal async Task Listen(IMessageReceiver<TClientToServerMessage> receiver) {
var messageHandler = new RpcMessageHandler<TClientToServerMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
internal async Task Listen(IMessageReceiver<TClientToServerMessage> messageReceiver) {
var messageHandler = new MessageHandler<TClientToServerMessage>(messageReceiver, session.FrameSender);
var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(session.LoggerName, connectionParameters, messageRegistry, messageHandler, MessageSender, session.FrameSender);
try {
await IFrame.ReadFrom(stream, frameReader, closeCancellationTokenSource.Token);
} catch (OperationCanceledException) {

View File

@@ -38,7 +38,15 @@ public sealed class RangeSet<T> : IEnumerable<RangeSet<T>.Range> where T : IBina
return true;
}
public IEnumerator<Range> GetEnumerator() {
public void Clear() {
ranges.Clear();
}
public List<Range>.Enumerator GetEnumerator() {
return ranges.GetEnumerator();
}
IEnumerator<Range> IEnumerable<Range>.GetEnumerator() {
return ranges.GetEnumerator();
}

View File

@@ -1,19 +1,19 @@
using Phantom.Common.Messages.Web;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Web.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> connection) {
public sealed class ControllerConnection(MessageSender<IMessageToController> sender) {
public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController {
return connection.SendMessage(message);
return sender.Send(message);
}
public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken = default) where TMessage : IMessageToController, ICanReply<TReply> {
return connection.SendMessage<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken);
return sender.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken);
}
public Task<TReply> Send<TMessage, TReply>(TMessage message, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToController, ICanReply<TReply> {
return connection.SendMessage<TMessage, TReply>(message, Timeout.InfiniteTimeSpan, waitForReplyCancellationToken);
return sender.Send<TMessage, TReply>(message, Timeout.InfiniteTimeSpan, waitForReplyCancellationToken);
}
}

View File

@@ -6,6 +6,7 @@ using Phantom.Utils.Cryptography;
using Phantom.Utils.IO;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime;
using Phantom.Utils.Threading;
@@ -59,8 +60,12 @@ try {
CertificateThumbprint: webKey.Value.CertificateThumbprint,
AuthToken: webKey.Value.AuthToken,
Handshake: new IRpcClientHandshake.NoOp(),
SendQueueCapacity: 500,
PingInterval: TimeSpan.FromSeconds(10)
CommonParameters: new RpcCommonConnectionParameters(
MessageQueueCapacity: 250,
FrameQueueCapacity: 500,
MaxConcurrentlyHandledMessages: 100,
PingInterval: TimeSpan.FromMinutes(1)
)
);
using var rpcClient = await RpcClient<IMessageToController, IMessageToWeb>.Connect("Controller", rpcClientConnectionParameters, WebMessageRegistries.Definitions, shutdownCancellationToken);
@@ -70,7 +75,7 @@ try {
}
var webConfiguration = new WebLauncher.Configuration(PhantomLogger.Create("Web"), webServerHost, webServerPort, webBasePath, dataProtectionKeysPath, shutdownCancellationToken);
var webApplication = WebLauncher.CreateApplication(webConfiguration, applicationProperties, rpcClient.SendChannel);
var webApplication = WebLauncher.CreateApplication(webConfiguration, applicationProperties, rpcClient.MessageSender);
using var actorSystem = ActorSystemFactory.Create("Web");
@@ -98,7 +103,7 @@ try {
PhantomLogger.Root.Information("Unregistering web...");
try {
using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage(), unregisterCancellationTokenSource.Token);
await rpcClient.MessageSender.Send(new UnregisterWebMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister web after shutdown.");
} catch (Exception e) {

View File

@@ -1,6 +1,6 @@
using Microsoft.AspNetCore.DataProtection;
using Phantom.Common.Messages.Web;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Message;
using Phantom.Web.Services;
using Serilog;
using ILogger = Serilog.ILogger;
@@ -12,7 +12,7 @@ static class WebLauncher {
public string HttpUrl => "http://" + Host + ":" + Port;
}
internal static WebApplication CreateApplication(Configuration config, ApplicationProperties applicationProperties, RpcSendChannel<IMessageToController> sendChannel) {
internal static WebApplication CreateApplication(Configuration config, ApplicationProperties applicationProperties, MessageSender<IMessageToController> sendChannel) {
var assembly = typeof(WebLauncher).Assembly;
var builder = WebApplication.CreateBuilder(new WebApplicationOptions {
ApplicationName = assembly.GetName().Name,