diff --git a/Agent/Phantom.Agent.Rpc/ControllerConnection.cs b/Agent/Phantom.Agent.Rpc/ControllerConnection.cs index e0c453c..7943896 100644 --- a/Agent/Phantom.Agent.Rpc/ControllerConnection.cs +++ b/Agent/Phantom.Agent.Rpc/ControllerConnection.cs @@ -18,8 +18,4 @@ public sealed class ControllerConnection { public Task Send<TMessage>(TMessage message) where TMessage : IMessageToController { return connection.Send(message); } - - public Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessageToController<TReply> where TReply : class { - return connection.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken); - } } diff --git a/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs b/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs index a56d167..8be4f1a 100644 --- a/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs +++ b/Controller/Phantom.Controller.Rpc/RpcConnectionToClient.cs @@ -66,7 +66,7 @@ public sealed class RpcConnectionToClient<TListener> { } await socket.SendAsync(routingId, bytes); - return await messageReplyTracker.WaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); + return await messageReplyTracker.TryWaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); } public void Receive(IReply message) { diff --git a/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs b/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs index b302425..852d1aa 100644 --- a/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs +++ b/Utils/Phantom.Utils.Rpc/Message/MessageReplyTracker.cs @@ -20,26 +20,36 @@ public sealed class MessageReplyTracker { return sequenceId; } - public async Task<TReply?> WaitForReply<TReply>(uint sequenceId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TReply : class { + public async Task<TReply> WaitForReply<TReply>(uint sequenceId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) { if (!replyTasks.TryGetValue(sequenceId, out var completionSource)) { logger.Warning("No reply callback for id {SequenceId}.", sequenceId); - return null; + throw new ArgumentException("No reply callback for id: " + sequenceId, nameof(sequenceId)); } try { byte[] replyBytes = await completionSource.Task.WaitAsync(waitForReplyTime, cancellationToken); return MessageSerializer.Deserialize<TReply>(replyBytes); } catch (TimeoutException) { - return null; + logger.Debug("Timed out waiting for reply with id {SequenceId}.", sequenceId); + throw; } catch (OperationCanceledException) { - return null; + logger.Debug("Cancelled waiting for reply with id {SequenceId}.", sequenceId); + throw; } catch (Exception e) { logger.Warning(e, "Error processing reply with id {SequenceId}.", sequenceId); - return null; + throw; } finally { ForgetReply(sequenceId); } } + + public async Task<TReply?> TryWaitForReply<TReply>(uint sequenceId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TReply : class { + try { + return await WaitForReply<TReply>(sequenceId, waitForReplyTime, cancellationToken); + } catch (Exception) { + return null; + } + } public void ForgetReply(uint sequenceId) { if (replyTasks.TryRemove(sequenceId, out var task)) { diff --git a/Utils/Phantom.Utils.Rpc/RpcConnectionToServer.cs b/Utils/Phantom.Utils.Rpc/RpcConnectionToServer.cs index 9ac135c..7cc0331 100644 --- a/Utils/Phantom.Utils.Rpc/RpcConnectionToServer.cs +++ b/Utils/Phantom.Utils.Rpc/RpcConnectionToServer.cs @@ -22,7 +22,7 @@ public sealed class RpcConnectionToServer<TListener> { } } - public async Task<TReply?> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> where TReply : class { + public async Task<TReply?> TrySend<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> where TReply : class { var sequenceId = replyTracker.RegisterReply(); var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray(); @@ -31,6 +31,19 @@ public sealed class RpcConnectionToServer<TListener> { return null; } + await socket.SendAsync(bytes); + return await replyTracker.TryWaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); + } + + public async Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : IMessage<TListener, TReply> { + var sequenceId = replyTracker.RegisterReply(); + + var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray(); + if (bytes.Length == 0) { + replyTracker.ForgetReply(sequenceId); + throw new ArgumentException("Could not write message.", nameof(message)); + } + await socket.SendAsync(bytes); return await replyTracker.WaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken); }