using System.Collections.Immutable;
using Phantom.Common.Data;
using Phantom.Common.Data.Agent;
using Phantom.Common.Data.Replies;
using Phantom.Common.Logging;
using Phantom.Common.Messages;
using Phantom.Common.Messages.ToAgent;
using Phantom.Server.Database;
using Phantom.Server.Rpc;
using Phantom.Server.Services.Instances;
using Phantom.Utils.Collections;
using Phantom.Utils.Events;
using Phantom.Utils.Tasks;
using ILogger = Serilog.ILogger;

namespace Phantom.Server.Services.Agents;

public sealed class AgentManager {
	private static readonly ILogger Logger = PhantomLogger.Create<AgentManager>();

	private static readonly TimeSpan DisconnectionRecheckInterval = TimeSpan.FromSeconds(5);
	private static readonly TimeSpan DisconnectionThreshold = TimeSpan.FromSeconds(12);

	private readonly ObservableAgents agents = new (PhantomLogger.Create<AgentManager, ObservableAgents>());

	public EventSubscribers<ImmutableArray<Agent>> AgentsChanged => agents.Subs;

	private readonly CancellationToken cancellationToken;
	private readonly AgentAuthToken authToken;
	private readonly DatabaseProvider databaseProvider;

	public AgentManager(ServiceConfiguration configuration, AgentAuthToken authToken, DatabaseProvider databaseProvider, TaskManager taskManager) {
		this.cancellationToken = configuration.CancellationToken;
		this.authToken = authToken;
		this.databaseProvider = databaseProvider;
		taskManager.Run("Refresh agent status loop", RefreshAgentStatus);
	}

	public async Task Initialize() {
		using var scope = databaseProvider.CreateScope();

		await foreach (var entity in scope.Ctx.Agents.AsAsyncEnumerable().WithCancellation(cancellationToken)) {
			var agent = new Agent(entity.AgentGuid, entity.Name, entity.ProtocolVersion, entity.BuildVersion, entity.MaxInstances, entity.MaxMemory);
			if (!agents.ByGuid.AddOrReplaceIf(agent.Guid, agent, static oldAgent => oldAgent.IsOffline)) {
				// TODO
				throw new InvalidOperationException("Unable to register agent from database: " + agent.Guid);
			}
		}
	}

	public ImmutableDictionary<Guid, Agent> GetAgents() {
		return agents.ByGuid.ToImmutable();
	}

	internal async Task<bool> RegisterAgent(AgentAuthToken authToken, AgentInfo agentInfo, InstanceManager instanceManager, RpcClientConnection connection) {
		if (!this.authToken.FixedTimeEquals(authToken)) {
			await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.InvalidToken));
			return false;
		}

		var agent = new Agent(agentInfo) {
			LastPing = DateTimeOffset.Now,
			IsOnline = true,
			Connection = new AgentConnection(connection)
		};

		if (agents.ByGuid.AddOrReplace(agent.Guid, agent, out var oldAgent)) {
			oldAgent.Connection?.Close();
		}

		using (var scope = databaseProvider.CreateScope()) {
			var entity = scope.Ctx.AgentUpsert.Fetch(agent.Guid);

			entity.Name = agent.Name;
			entity.ProtocolVersion = agent.ProtocolVersion;
			entity.BuildVersion = agent.BuildVersion;
			entity.MaxInstances = agent.MaxInstances;
			entity.MaxMemory = agent.MaxMemory;

			await scope.Ctx.SaveChangesAsync(cancellationToken);
		}

		Logger.Information("Registered agent \"{Name}\" (GUID {Guid}).", agent.Name, agent.Guid);

		var instanceConfigurations = await instanceManager.GetInstanceConfigurationsForAgent(agent.Guid);
		await connection.Send(new RegisterAgentSuccessMessage(instanceConfigurations));
		
		return true;
	}

	internal bool UnregisterAgent(Guid agentGuid, RpcClientConnection connection) {
		if (agents.ByGuid.TryReplaceIf(agentGuid, static oldAgent => oldAgent.AsOffline(), oldAgent => oldAgent.Connection?.IsSame(connection) == true)) {
			Logger.Information("Unregistered agent with GUID {Guid}.", agentGuid);
			return true;
		}
		else {
			return false;
		}
	}
	
	internal Agent? GetAgent(Guid guid) {
		return agents.ByGuid.TryGetValue(guid, out var agent) ? agent : null;
	}

	internal void NotifyAgentIsAlive(Guid agentGuid) {
		agents.ByGuid.TryReplace(agentGuid, static agent => agent.AsOnline(DateTimeOffset.Now));
	}

	internal void SetAgentStats(Guid agentGuid, int runningInstanceCount, RamAllocationUnits runningInstanceMemory) {
		agents.ByGuid.TryReplace(agentGuid, agent => agent with { Stats = new AgentStats(runningInstanceCount, runningInstanceMemory) });
	}

	private async Task RefreshAgentStatus() {
		static Agent MarkAgentAsOffline(Agent agent) {
			Logger.Warning("Lost connection to agent \"{Name}\" (GUID {Guid}).", agent.Name, agent.Guid);
			return agent.AsDisconnected();
		}

		while (!cancellationToken.IsCancellationRequested) {
			await Task.Delay(DisconnectionRecheckInterval, cancellationToken);

			var now = DateTimeOffset.Now;
			agents.ByGuid.ReplaceAllIf(MarkAgentAsOffline, agent => agent.IsOnline && agent.LastPing is {} lastPing && now - lastPing >= DisconnectionThreshold);
		}
	}

	internal async Task<TReply?> SendMessage<TMessage, TReply>(Guid guid, TMessage message, TimeSpan waitForReplyTime) where TMessage : IMessageToAgent<TReply> where TReply : class {
		var connection = agents.ByGuid.TryGetValue(guid, out var agent) ? agent.Connection : null;
		if (connection == null) {
			// TODO handle missing agent?
			return null;
		}

		return await connection.Send<TMessage, TReply>(message, waitForReplyTime, cancellationToken);
	}

	private sealed class ObservableAgents : ObservableState<ImmutableArray<Agent>> {
		public RwLockedObservableDictionary<Guid, Agent> ByGuid { get; } = new (LockRecursionPolicy.NoRecursion);

		public ObservableAgents(ILogger logger) : base(logger) {
			ByGuid.CollectionChanged += Update;
		}

		protected override ImmutableArray<Agent> GetData() {
			return ByGuid.ValuesCopy;
		}
	}
}