diff --git a/Common/Phantom.Common.Messages.Web/ToWeb/RefreshUserSessionMessage.cs b/Common/Phantom.Common.Messages.Web/ToWeb/RefreshUserSessionMessage.cs new file mode 100644 index 0000000..fa84996 --- /dev/null +++ b/Common/Phantom.Common.Messages.Web/ToWeb/RefreshUserSessionMessage.cs @@ -0,0 +1,8 @@ +using MemoryPack; + +namespace Phantom.Common.Messages.Web.ToWeb; + +[MemoryPackable(GenerateType.VersionTolerant)] +public sealed partial record RefreshUserSessionMessage( + [property: MemoryPackOrder(0)] Guid UserGuid +) : IMessageToWeb; diff --git a/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs b/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs index b154994..e146230 100644 --- a/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs +++ b/Common/Phantom.Common.Messages.Web/WebMessageRegistries.cs @@ -48,6 +48,7 @@ public static class WebMessageRegistries { ToWeb.Add<RefreshAgentsMessage>(1); ToWeb.Add<RefreshInstancesMessage>(2); ToWeb.Add<InstanceOutputMessage>(3); + ToWeb.Add<RefreshUserSessionMessage>(4); ToWeb.Add<ReplyMessage>(127); } diff --git a/Controller/Phantom.Controller.Services/ControllerServices.cs b/Controller/Phantom.Controller.Services/ControllerServices.cs index bc21293..91dfbaa 100644 --- a/Controller/Phantom.Controller.Services/ControllerServices.cs +++ b/Controller/Phantom.Controller.Services/ControllerServices.cs @@ -54,9 +54,9 @@ public sealed class ControllerServices : IDisposable { this.MinecraftVersions = new MinecraftVersions(); this.AuthenticatedUserCache = new AuthenticatedUserCache(); - this.UserManager = new UserManager(AuthenticatedUserCache, dbProvider); + this.UserManager = new UserManager(AuthenticatedUserCache, ControllerState, dbProvider); this.RoleManager = new RoleManager(dbProvider); - this.UserRoleManager = new UserRoleManager(AuthenticatedUserCache, dbProvider); + this.UserRoleManager = new UserRoleManager(AuthenticatedUserCache, ControllerState, dbProvider); this.UserLoginManager = new UserLoginManager(AuthenticatedUserCache, UserManager, dbProvider); this.PermissionManager = new PermissionManager(dbProvider); diff --git a/Controller/Phantom.Controller.Services/ControllerState.cs b/Controller/Phantom.Controller.Services/ControllerState.cs index 795da1a..855702a 100644 --- a/Controller/Phantom.Controller.Services/ControllerState.cs +++ b/Controller/Phantom.Controller.Services/ControllerState.cs @@ -19,6 +19,8 @@ sealed class ControllerState { public ObservableState<ImmutableDictionary<Guid, ImmutableArray<TaggedJavaRuntime>>>.Receiver AgentJavaRuntimesByGuidReceiver => agentJavaRuntimesByGuid.ReceiverSide; public ObservableState<ImmutableDictionary<Guid, Instance>>.Receiver InstancesByGuidReceiver => instancesByGuid.ReceiverSide; + public event EventHandler<Guid>? UserUpdatedOrDeleted; + public void UpdateAgent(Agent agent) { agentsByGuid.PublisherSide.Publish(static (agentsByGuid, agent) => agentsByGuid.SetItem(agent.AgentGuid, agent), agent); } @@ -30,4 +32,8 @@ sealed class ControllerState { public void UpdateInstance(Instance instance) { instancesByGuid.PublisherSide.Publish(static (instancesByGuid, instance) => instancesByGuid.SetItem(instance.InstanceGuid, instance), instance); } + + public void UpdateOrDeleteUser(Guid userGuid) { + UserUpdatedOrDeleted?.Invoke(null, userGuid); + } } diff --git a/Controller/Phantom.Controller.Services/Rpc/WebMessageDataUpdateSenderActor.cs b/Controller/Phantom.Controller.Services/Rpc/WebMessageDataUpdateSenderActor.cs index 9787a5d..df1e4c1 100644 --- a/Controller/Phantom.Controller.Services/Rpc/WebMessageDataUpdateSenderActor.cs +++ b/Controller/Phantom.Controller.Services/Rpc/WebMessageDataUpdateSenderActor.cs @@ -30,22 +30,31 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate ReceiveAsync<RefreshAgentsCommand>(RefreshAgents); ReceiveAsync<RefreshInstancesCommand>(RefreshInstances); ReceiveAsync<ReceiveInstanceLogsCommand>(ReceiveInstanceLogs); + ReceiveAsync<RefreshUserSessionCommand>(RefreshUserSession); } protected override void PreStart() { controllerState.AgentsByGuidReceiver.Register(SelfTyped, static state => new RefreshAgentsCommand(state)); controllerState.InstancesByGuidReceiver.Register(SelfTyped, static state => new RefreshInstancesCommand(state)); - + + controllerState.UserUpdatedOrDeleted += OnUserUpdatedOrDeleted; + instanceLogManager.LogsReceived += OnInstanceLogsReceived; } protected override void PostStop() { instanceLogManager.LogsReceived -= OnInstanceLogsReceived; + + controllerState.UserUpdatedOrDeleted -= OnUserUpdatedOrDeleted; controllerState.AgentsByGuidReceiver.Unregister(SelfTyped); controllerState.InstancesByGuidReceiver.Unregister(SelfTyped); } + private void OnUserUpdatedOrDeleted(object? sender, Guid userGuid) { + selfCached.Tell(new RefreshUserSessionCommand(userGuid)); + } + private void OnInstanceLogsReceived(object? sender, InstanceLogManager.Event e) { selfCached.Tell(new ReceiveInstanceLogsCommand(e.InstanceGuid, e.Lines)); } @@ -57,6 +66,8 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate private sealed record RefreshInstancesCommand(ImmutableDictionary<Guid, Instance> Instances) : ICommand; private sealed record ReceiveInstanceLogsCommand(Guid InstanceGuid, ImmutableArray<string> Lines) : ICommand; + + private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand; private Task RefreshAgents(RefreshAgentsCommand command) { return connection.Send(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())); @@ -69,4 +80,8 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) { return connection.Send(new InstanceOutputMessage(command.InstanceGuid, command.Lines)); } + + private Task RefreshUserSession(RefreshUserSessionCommand command) { + return connection.Send(new RefreshUserSessionMessage(command.UserGuid)); + } } diff --git a/Controller/Phantom.Controller.Services/Users/UserManager.cs b/Controller/Phantom.Controller.Services/Users/UserManager.cs index 5045f59..4f54afe 100644 --- a/Controller/Phantom.Controller.Services/Users/UserManager.cs +++ b/Controller/Phantom.Controller.Services/Users/UserManager.cs @@ -14,10 +14,12 @@ sealed class UserManager { private static readonly ILogger Logger = PhantomLogger.Create<UserManager>(); private readonly AuthenticatedUserCache authenticatedUserCache; + private readonly ControllerState controllerState; private readonly IDbContextProvider dbProvider; - public UserManager(AuthenticatedUserCache authenticatedUserCache, IDbContextProvider dbProvider) { + public UserManager(AuthenticatedUserCache authenticatedUserCache, ControllerState controllerState, IDbContextProvider dbProvider) { this.authenticatedUserCache = authenticatedUserCache; + this.controllerState = controllerState; this.dbProvider = dbProvider; } @@ -140,6 +142,7 @@ sealed class UserManager { // In case the user logged in during deletion. authenticatedUserCache.Remove(userGuid); + controllerState.UpdateOrDeleteUser(userGuid); Logger.Information("Deleted user \"{Username}\" (GUID {Guid}).", user.Name, user.UserGuid); return DeleteUserResult.Deleted; diff --git a/Controller/Phantom.Controller.Services/Users/UserRoleManager.cs b/Controller/Phantom.Controller.Services/Users/UserRoleManager.cs index b71e870..3fa39cc 100644 --- a/Controller/Phantom.Controller.Services/Users/UserRoleManager.cs +++ b/Controller/Phantom.Controller.Services/Users/UserRoleManager.cs @@ -13,10 +13,12 @@ sealed class UserRoleManager { private static readonly ILogger Logger = PhantomLogger.Create<UserRoleManager>(); private readonly AuthenticatedUserCache authenticatedUserCache; + private readonly ControllerState controllerState; private readonly IDbContextProvider dbProvider; - public UserRoleManager(AuthenticatedUserCache authenticatedUserCache, IDbContextProvider dbProvider) { + public UserRoleManager(AuthenticatedUserCache authenticatedUserCache, ControllerState controllerState, IDbContextProvider dbProvider) { this.authenticatedUserCache = authenticatedUserCache; + this.controllerState = controllerState; this.dbProvider = dbProvider; } @@ -49,7 +51,7 @@ sealed class UserRoleManager { var removedFromRoleGuids = ImmutableHashSet.CreateBuilder<Guid>(); var removedFromRoleNames = new List<string>(); - + try { foreach (var roleGuid in addToRoleGuids) { if (rolesByGuid.TryGetValue(roleGuid, out var role)) { @@ -71,6 +73,7 @@ sealed class UserRoleManager { await db.Ctx.SaveChangesAsync(); await authenticatedUserCache.Update(user, db); + controllerState.UpdateOrDeleteUser(user.UserGuid); Logger.Information("Changed roles for user \"{Username}\" (GUID {Guid}).", user.Name, user.UserGuid); return new ChangeUserRolesResult(addedToRoleGuids.ToImmutable(), removedFromRoleGuids.ToImmutable()); diff --git a/Web/Phantom.Web.Services/Authentication/CustomAuthenticationStateProvider.cs b/Web/Phantom.Web.Services/Authentication/CustomAuthenticationStateProvider.cs index f4fdfb0..c1d9a59 100644 --- a/Web/Phantom.Web.Services/Authentication/CustomAuthenticationStateProvider.cs +++ b/Web/Phantom.Web.Services/Authentication/CustomAuthenticationStateProvider.cs @@ -4,42 +4,118 @@ using Microsoft.AspNetCore.Components.Server; using Phantom.Common.Data; using Phantom.Common.Data.Web.Users; using Phantom.Common.Messages.Web.ToController; +using Phantom.Utils.Logging; using Phantom.Web.Services.Rpc; +using ILogger = Serilog.ILogger; namespace Phantom.Web.Services.Authentication; -public sealed class CustomAuthenticationStateProvider : ServerAuthenticationStateProvider { +public sealed class CustomAuthenticationStateProvider : ServerAuthenticationStateProvider, IAsyncDisposable { + private static readonly ILogger Logger = PhantomLogger.Create<CustomAuthenticationStateProvider>(); + + private readonly UserSessionRefreshManager sessionRefreshManager; private readonly UserSessionBrowserStorage sessionBrowserStorage; private readonly ControllerConnection controllerConnection; - private bool isLoaded; - public CustomAuthenticationStateProvider(UserSessionBrowserStorage sessionBrowserStorage, ControllerConnection controllerConnection) { + private readonly SemaphoreSlim loadSemaphore = new (1); + private bool isLoaded = false; + private CancellationTokenSource? loadCancellationTokenSource; + private UserSessionRefreshManager.EventHolder? userRefreshEventHolder; + + public CustomAuthenticationStateProvider(UserSessionRefreshManager sessionRefreshManager, UserSessionBrowserStorage sessionBrowserStorage, ControllerConnection controllerConnection) { + this.sessionRefreshManager = sessionRefreshManager; this.sessionBrowserStorage = sessionBrowserStorage; this.controllerConnection = controllerConnection; } public override async Task<AuthenticationState> GetAuthenticationStateAsync() { if (!isLoaded) { - var stored = await sessionBrowserStorage.Get(); - if (stored != null) { - var authToken = stored.Token; - var session = await controllerConnection.Send<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(new GetAuthenticatedUser(stored.UserGuid, authToken), TimeSpan.FromSeconds(30)); - if (session.Value is {} userInfo) { - SetLoadedSession(new AuthenticatedUser(userInfo, authToken)); - } - } + await LoadSession(); } return await base.GetAuthenticationStateAsync(); } - internal void SetLoadedSession(AuthenticatedUser authenticatedUser) { - isLoaded = true; + private async Task LoadSession() { + await CancelCurrentLoad(); + await loadSemaphore.WaitAsync(CancellationToken.None); + + loadCancellationTokenSource = new CancellationTokenSource(); + CancellationToken cancellationToken = loadCancellationTokenSource.Token; + + try { + var authenticatedUser = await TryGetSession(cancellationToken); + if (authenticatedUser != null) { + SetLoadedSession(authenticatedUser); + } + else { + SetUnloadedSession(); + } + } catch (OperationCanceledException) { + SetUnloadedSession(); + } catch (Exception e) { + SetUnloadedSession(); + Logger.Error(e, "Could not load user session."); + } finally { + loadCancellationTokenSource.Dispose(); + loadCancellationTokenSource = null; + loadSemaphore.Release(); + } + } + + private async Task CancelCurrentLoad() { + var cancellationTokenSource = loadCancellationTokenSource; + if (cancellationTokenSource != null) { + await cancellationTokenSource.CancelAsync(); + } + } + + private async Task<AuthenticatedUser?> TryGetSession(CancellationToken cancellationToken) { + var stored = await sessionBrowserStorage.Get(); + if (stored == null) { + return null; + } + + cancellationToken.ThrowIfCancellationRequested(); + + var userGuid = stored.UserGuid; + var authToken = stored.Token; + + if (userRefreshEventHolder == null) { + userRefreshEventHolder = sessionRefreshManager.GetEventHolder(userGuid); + userRefreshEventHolder.UserNeedsRefresh += OnUserNeedsRefresh; + } + + var session = await controllerConnection.Send<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(new GetAuthenticatedUser(userGuid, authToken), TimeSpan.FromSeconds(30), cancellationToken); + if (session.Value is {} userInfo) { + return new AuthenticatedUser(userInfo, authToken); + } + else { + return null; + } + } + + private void SetLoadedSession(AuthenticatedUser authenticatedUser) { SetAuthenticationState(Task.FromResult(new AuthenticationState(new CustomClaimsPrincipal(authenticatedUser)))); + isLoaded = true; } internal void SetUnloadedSession() { - isLoaded = false; SetAuthenticationState(Task.FromResult(new AuthenticationState(new ClaimsPrincipal()))); + isLoaded = false; + } + + private void OnUserNeedsRefresh(object? sender, EventArgs args) { + _ = LoadSession(); + } + + public async ValueTask DisposeAsync() { + if (userRefreshEventHolder != null) { + userRefreshEventHolder.UserNeedsRefresh -= OnUserNeedsRefresh; + userRefreshEventHolder = null; + } + + await CancelCurrentLoad(); + loadSemaphore.Dispose(); } } diff --git a/Web/Phantom.Web.Services/Authentication/UserLoginManager.cs b/Web/Phantom.Web.Services/Authentication/UserLoginManager.cs index a90aef0..b92866b 100644 --- a/Web/Phantom.Web.Services/Authentication/UserLoginManager.cs +++ b/Web/Phantom.Web.Services/Authentication/UserLoginManager.cs @@ -39,8 +39,9 @@ public sealed class UserLoginManager { var userInfo = success.UserInfo; var authToken = success.AuthToken; + authenticationStateProvider.SetUnloadedSession(); await sessionBrowserStorage.Store(userInfo.Guid, authToken); - authenticationStateProvider.SetLoadedSession(new AuthenticatedUser(userInfo, authToken)); + await authenticationStateProvider.GetAuthenticationStateAsync(); await navigation.NavigateTo(returnUrl ?? string.Empty); return true; diff --git a/Web/Phantom.Web.Services/Authentication/UserSessionRefreshManager.cs b/Web/Phantom.Web.Services/Authentication/UserSessionRefreshManager.cs new file mode 100644 index 0000000..5e1e4ef --- /dev/null +++ b/Web/Phantom.Web.Services/Authentication/UserSessionRefreshManager.cs @@ -0,0 +1,25 @@ +using System.Collections.Concurrent; + +namespace Phantom.Web.Services.Authentication; + +public sealed class UserSessionRefreshManager { + private readonly ConcurrentDictionary<Guid, EventHolder> userUpdateEventHoldersByUserGuid = new (); + + internal EventHolder GetEventHolder(Guid userGuid) { + return userUpdateEventHoldersByUserGuid.GetOrAdd(userGuid, static _ => new EventHolder()); + } + + internal void RefreshUser(Guid userGuid) { + if (userUpdateEventHoldersByUserGuid.TryGetValue(userGuid, out var eventHolder)) { + eventHolder.Notify(); + } + } + + internal sealed class EventHolder { + public event EventHandler? UserNeedsRefresh; + + internal void Notify() { + UserNeedsRefresh?.Invoke(null, EventArgs.Empty); + } + } +} diff --git a/Web/Phantom.Web.Services/PhantomWebServices.cs b/Web/Phantom.Web.Services/PhantomWebServices.cs index cb50001..b02e887 100644 --- a/Web/Phantom.Web.Services/PhantomWebServices.cs +++ b/Web/Phantom.Web.Services/PhantomWebServices.cs @@ -23,6 +23,7 @@ public static class PhantomWebServices { services.AddSingleton<UserManager>(); services.AddSingleton<AuditLogManager>(); + services.AddSingleton<UserSessionRefreshManager>(); services.AddScoped<UserLoginManager>(); services.AddScoped<UserSessionBrowserStorage>(); diff --git a/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerActor.cs b/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerActor.cs index 7e2024f..44e7c36 100644 --- a/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerActor.cs +++ b/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerActor.cs @@ -4,12 +4,20 @@ using Phantom.Common.Messages.Web.ToWeb; using Phantom.Utils.Actor; using Phantom.Utils.Rpc.Runtime; using Phantom.Web.Services.Agents; +using Phantom.Web.Services.Authentication; using Phantom.Web.Services.Instances; namespace Phantom.Web.Services.Rpc; sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> { - public readonly record struct Init(RpcConnectionToServer<IMessageToController> Connection, AgentManager AgentManager, InstanceManager InstanceManager, InstanceLogManager InstanceLogManager, TaskCompletionSource<bool> RegisterSuccessWaiter); + public readonly record struct Init( + RpcConnectionToServer<IMessageToController> Connection, + AgentManager AgentManager, + InstanceManager InstanceManager, + InstanceLogManager InstanceLogManager, + UserSessionRefreshManager UserSessionRefreshManager, + TaskCompletionSource<bool> RegisterSuccessWaiter + ); public static Props<IMessageToWeb> Factory(Init init) { return Props<IMessageToWeb>.Create(() => new ControllerMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); @@ -19,6 +27,7 @@ sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> { private readonly AgentManager agentManager; private readonly InstanceManager instanceManager; private readonly InstanceLogManager instanceLogManager; + private readonly UserSessionRefreshManager userSessionRefreshManager; private readonly TaskCompletionSource<bool> registerSuccessWaiter; private ControllerMessageHandlerActor(Init init) { @@ -26,12 +35,14 @@ sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> { this.agentManager = init.AgentManager; this.instanceManager = init.InstanceManager; this.instanceLogManager = init.InstanceLogManager; + this.userSessionRefreshManager = init.UserSessionRefreshManager; this.registerSuccessWaiter = init.RegisterSuccessWaiter; Receive<RegisterWebResultMessage>(HandleRegisterWebResult); Receive<RefreshAgentsMessage>(HandleRefreshAgents); Receive<RefreshInstancesMessage>(HandleRefreshInstances); Receive<InstanceOutputMessage>(HandleInstanceOutput); + Receive<RefreshUserSessionMessage>(HandleRefreshUserSession); Receive<ReplyMessage>(HandleReply); } @@ -51,6 +62,10 @@ sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> { instanceLogManager.AddLines(message.InstanceGuid, message.Lines); } + private void HandleRefreshUserSession(RefreshUserSessionMessage message) { + userSessionRefreshManager.RefreshUser(message.UserGuid); + } + private void HandleReply(ReplyMessage message) { connection.Receive(message); } diff --git a/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerFactory.cs b/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerFactory.cs index 3438483..ec30ac3 100644 --- a/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerFactory.cs +++ b/Web/Phantom.Web.Services/Rpc/ControllerMessageHandlerFactory.cs @@ -4,6 +4,7 @@ using Phantom.Utils.Actor; using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Tasks; using Phantom.Web.Services.Agents; +using Phantom.Web.Services.Authentication; using Phantom.Web.Services.Instances; namespace Phantom.Web.Services.Rpc; @@ -13,6 +14,7 @@ public sealed class ControllerMessageHandlerFactory { private readonly AgentManager agentManager; private readonly InstanceManager instanceManager; private readonly InstanceLogManager instanceLogManager; + private readonly UserSessionRefreshManager userSessionRefreshManager; private readonly TaskCompletionSource<bool> registerSuccessWaiter = AsyncTasks.CreateCompletionSource<bool>(); @@ -20,15 +22,17 @@ public sealed class ControllerMessageHandlerFactory { private int messageHandlerId = 0; - public ControllerMessageHandlerFactory(RpcConnectionToServer<IMessageToController> connection, AgentManager agentManager, InstanceManager instanceManager, InstanceLogManager instanceLogManager) { + public ControllerMessageHandlerFactory(RpcConnectionToServer<IMessageToController> connection, AgentManager agentManager, InstanceManager instanceManager, InstanceLogManager instanceLogManager, UserSessionRefreshManager userSessionRefreshManager) { this.connection = connection; this.agentManager = agentManager; this.instanceManager = instanceManager; this.instanceLogManager = instanceLogManager; + this.userSessionRefreshManager = userSessionRefreshManager; } public ActorRef<IMessageToWeb> Create(IActorRefFactory actorSystem) { - int id = Interlocked.Increment(ref messageHandlerId); - return actorSystem.ActorOf(ControllerMessageHandlerActor.Factory(new ControllerMessageHandlerActor.Init(connection, agentManager, instanceManager, instanceLogManager, registerSuccessWaiter)), "ControllerMessageHandler-" + id); + var init = new ControllerMessageHandlerActor.Init(connection, agentManager, instanceManager, instanceLogManager, userSessionRefreshManager, registerSuccessWaiter); + var name = "ControllerMessageHandler-" + Interlocked.Increment(ref messageHandlerId); + return actorSystem.ActorOf(ControllerMessageHandlerActor.Factory(init), name); } }