using NetMQ.Sockets;
using Phantom.Utils.Rpc;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Tasks;
using Serilog;
using Serilog.Events;

namespace Phantom.Controller.Rpc;

public static class RpcRuntime {
	public static Task Launch<TOutgoingListener, TIncomingListener, TReplyMessage>(RpcConfiguration config, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory, CancellationToken cancellationToken) where TReplyMessage : IMessage<TOutgoingListener, NoReply>, IMessage<TIncomingListener, NoReply> {
		return RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMessage>.Launch(config, messageDefinitions, listenerFactory, cancellationToken);
	}
}

internal sealed class RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMessage> : RpcRuntime<ServerSocket> where TReplyMessage : IMessage<TOutgoingListener, NoReply>, IMessage<TIncomingListener, NoReply> {
	internal static Task Launch(RpcConfiguration config, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory, CancellationToken cancellationToken) {
		return new RpcRuntime<TOutgoingListener, TIncomingListener, TReplyMessage>(config, messageDefinitions, listenerFactory, cancellationToken).Launch();
	}

	private static ServerSocket CreateSocket(RpcConfiguration config) {
		var socket = new ServerSocket();
		var options = socket.Options;

		options.CurveServer = true;
		options.CurveCertificate = config.ServerCertificate;
		
		return socket;
	}
	
	private readonly RpcConfiguration config;
	private readonly IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions;
	private readonly Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory;
	private readonly CancellationToken cancellationToken;

	private RpcRuntime(RpcConfiguration config, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, Func<RpcClientConnection<TOutgoingListener>, TIncomingListener> listenerFactory, CancellationToken cancellationToken) : base(config, CreateSocket(config)) {
		this.config = config;
		this.messageDefinitions = messageDefinitions;
		this.listenerFactory = listenerFactory;
		this.cancellationToken = cancellationToken;
	}

	protected override void Connect(ServerSocket socket) {
		var logger = config.RuntimeLogger;
		var url = config.TcpUrl;

		logger.Information("Starting ZeroMQ server on {Url}...", url);
		socket.Bind(url);
		logger.Information("ZeroMQ server initialized, listening for connections on port {Port}.", config.Port);
	}

	protected override void Run(ServerSocket socket, MessageReplyTracker replyTracker, TaskManager taskManager) {
		var logger = config.RuntimeLogger;
		var clients = new Dictionary<ulong, Client>();

		void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) {
			clients.Remove(e.RoutingId);
			logger.Debug("Closed connection to {RoutingId}.", e.RoutingId);
		}

		while (!cancellationToken.IsCancellationRequested) {
			var (routingId, data) = socket.Receive(cancellationToken);

			if (data.Length == 0) {
				LogMessageType(logger, routingId, data);
				continue;
			}

			if (!clients.TryGetValue(routingId, out var client)) {
				if (!CheckIsRegistrationMessage(data, logger, routingId)) {
					continue;
				}

				var connection = new RpcClientConnection<TOutgoingListener>(socket, routingId, messageDefinitions.Outgoing, replyTracker);
				connection.Closed += OnConnectionClosed;

				client = new Client(connection, messageDefinitions, listenerFactory(connection), logger, taskManager, cancellationToken);
				clients[routingId] = client;
			}

			LogMessageType(logger, routingId, data);
			messageDefinitions.Incoming.Handle(data, client);
		}

		foreach (var client in clients.Values) {
			client.Connection.Closed -= OnConnectionClosed;
		}
	}

	private void LogMessageType(ILogger logger, uint routingId, ReadOnlyMemory<byte> data) {
		if (!logger.IsEnabled(LogEventLevel.Verbose)) {
			return;
		}

		if (data.Length > 0 && messageDefinitions.Incoming.TryGetType(data, out var type)) {
			logger.Verbose("Received {MessageType} ({Bytes} B) from {RoutingId}.", type.Name, data.Length, routingId);
		}
		else {
			logger.Verbose("Received {Bytes} B message from {RoutingId}.", data.Length, routingId);
		}
	}

	private bool CheckIsRegistrationMessage(ReadOnlyMemory<byte> data, ILogger logger, uint routingId) {
		if (messageDefinitions.Incoming.TryGetType(data, out var type) && messageDefinitions.IsRegistrationMessage(type)) {
			return true;
		}

		logger.Warning("Received {MessageType} from {RoutingId} who is not registered.", type?.Name ?? "unknown message", routingId);
		return false;
	}
	
	private sealed class Client : MessageHandler<TIncomingListener> {
		public RpcClientConnection<TOutgoingListener> Connection { get; }
		
		private readonly IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions;
		
		public Client(RpcClientConnection<TOutgoingListener> connection, IMessageDefinitions<TOutgoingListener, TIncomingListener, TReplyMessage> messageDefinitions, TIncomingListener listener, ILogger logger, TaskManager taskManager, CancellationToken cancellationToken) : base(listener, logger, taskManager, cancellationToken) {
			this.Connection = connection;
			this.messageDefinitions = messageDefinitions;
		}
	
		protected override Task SendReply(uint sequenceId, byte[] serializedReply) {
			return Connection.Send(messageDefinitions.CreateReplyMessage(sequenceId, serializedReply));
		}
	}
}