using System.Collections.Concurrent; using Phantom.Utils.Collections; using Serilog; namespace Phantom.Utils.Tasks; public sealed class TaskManager { private readonly ILogger logger; private readonly CancellationTokenSource cancellationTokenSource = new (); private readonly CancellationToken cancellationToken; private readonly ConcurrentDictionary<Task, string> runningTasks = new (ReferenceEqualityComparer<Task>.Instance); public TaskManager(ILogger logger) { this.logger = logger; this.cancellationToken = cancellationTokenSource.Token; } private T Add<T>(string name, T task) where T : Task { cancellationToken.ThrowIfCancellationRequested(); runningTasks.TryAdd(task, name); task.ContinueWith(OnFinished, CancellationToken.None, TaskContinuationOptions.RunContinuationsAsynchronously, TaskScheduler.Default); return task; } private void OnFinished(Task task) { runningTasks.TryRemove(task, out _); } public Task Run(string name, Action action) { return Add(name, Task.Run(action, cancellationToken)); } public Task Run(string name, Func<Task> taskFunc) { return Add(name, Task.Run(taskFunc, cancellationToken)); } public Task<T> Run<T>(string name, Func<Task<T>> taskFunc) { return Add(name, Task.Run(taskFunc, cancellationToken)); } public async Task Stop() { logger.Information("Stopping task manager..."); cancellationTokenSource.Cancel(); var remainingTasksAwaiterTask = WaitForRemainingTasks(); while (true) { var logStateTimeoutTask = Task.Delay(TimeSpan.FromSeconds(10), CancellationToken.None); var completedTask = await Task.WhenAny(remainingTasksAwaiterTask, logStateTimeoutTask); if (completedTask == logStateTimeoutTask) { var remainingTaskNames = runningTasks.Values.Order().ToList(); var remainingTaskNameList = string.Join('\n', remainingTaskNames.Select(static name => "- " + name)); logger.Warning("Waiting for {TaskCount} task(s) to finish:\n{TaskNames}", remainingTaskNames.Count, remainingTaskNameList); } else { break; } } runningTasks.Clear(); cancellationTokenSource.Dispose(); logger.Information("Task manager stopped."); } private async Task WaitForRemainingTasks() { foreach (var task in runningTasks.Keys) { try { await task; } catch (Exception) { // ignored } } } }