diff --git a/Agent/Phantom.Agent.Minecraft/Server/MinecraftServerExecutableDownloader.cs b/Agent/Phantom.Agent.Minecraft/Server/MinecraftServerExecutableDownloader.cs index 6111c5c..4d72633 100644 --- a/Agent/Phantom.Agent.Minecraft/Server/MinecraftServerExecutableDownloader.cs +++ b/Agent/Phantom.Agent.Minecraft/Server/MinecraftServerExecutableDownloader.cs @@ -16,32 +16,47 @@ sealed class MinecraftServerExecutableDownloader { public event EventHandler? Completed; private readonly CancellationTokenSource cancellationTokenSource = new (); - private int listeners = 0; + + private readonly List<CancellationTokenRegistration> listenerCancellationRegistrations = new (); + private int listenerCount = 0; public MinecraftServerExecutableDownloader(FileDownloadInfo fileDownloadInfo, string minecraftVersion, string filePath, MinecraftServerExecutableDownloadListener listener) { Register(listener); - Task = DownloadAndGetPath(fileDownloadInfo, minecraftVersion, filePath); + Task = DownloadAndGetPath(fileDownloadInfo, minecraftVersion, filePath, new DownloadProgressCallback(this), cancellationTokenSource.Token); Task.ContinueWith(OnCompleted, TaskScheduler.Default); } public void Register(MinecraftServerExecutableDownloadListener listener) { - ++listeners; - Logger.Debug("Registered download listener, current listener count: {Listeners}", listeners); + int newListenerCount; - DownloadProgress += listener.DownloadProgressEventHandler; - listener.CancellationToken.Register(Unregister, listener); + lock (this) { + newListenerCount = ++listenerCount; + + DownloadProgress += listener.DownloadProgressEventHandler; + listenerCancellationRegistrations.Add(listener.CancellationToken.Register(Unregister, listener)); + } + + Logger.Debug("Registered download listener, current listener count: {Listeners}", newListenerCount); } private void Unregister(object? listenerObject) { - MinecraftServerExecutableDownloadListener listener = (MinecraftServerExecutableDownloadListener) listenerObject!; - DownloadProgress -= listener.DownloadProgressEventHandler; + int newListenerCount; + + lock (this) { + MinecraftServerExecutableDownloadListener listener = (MinecraftServerExecutableDownloadListener) listenerObject!; + DownloadProgress -= listener.DownloadProgressEventHandler; - if (--listeners <= 0) { + newListenerCount = --listenerCount; + if (newListenerCount <= 0) { + cancellationTokenSource.Cancel(); + } + } + + if (newListenerCount <= 0) { Logger.Debug("Unregistered last download listener, cancelling download."); - cancellationTokenSource.Cancel(); } else { - Logger.Debug("Unregistered download listener, current listener count: {Listeners}", listeners); + Logger.Debug("Unregistered download listener, current listener count: {Listeners}", newListenerCount); } } @@ -51,9 +66,19 @@ sealed class MinecraftServerExecutableDownloader { private void OnCompleted(Task task) { Logger.Debug("Download task completed."); - Completed?.Invoke(this, EventArgs.Empty); - Completed = null; - DownloadProgress = null; + + lock (this) { + Completed?.Invoke(this, EventArgs.Empty); + Completed = null; + DownloadProgress = null; + + foreach (var registration in listenerCancellationRegistrations) { + registration.Dispose(); + } + + listenerCancellationRegistrations.Clear(); + cancellationTokenSource.Dispose(); + } } private sealed class DownloadProgressCallback { @@ -68,15 +93,14 @@ sealed class MinecraftServerExecutableDownloader { } } - private async Task<string?> DownloadAndGetPath(FileDownloadInfo fileDownloadInfo, string minecraftVersion, string filePath) { + private static async Task<string?> DownloadAndGetPath(FileDownloadInfo fileDownloadInfo, string minecraftVersion, string filePath, DownloadProgressCallback progressCallback, CancellationToken cancellationToken) { string tmpFilePath = filePath + ".tmp"; - var cancellationToken = cancellationTokenSource.Token; try { Logger.Information("Downloading server version {Version} from: {Url} ({Size})", minecraftVersion, fileDownloadInfo.DownloadUrl, fileDownloadInfo.Size.ToHumanReadable(decimalPlaces: 1)); try { using var http = new HttpClient(); - await FetchServerExecutableFile(http, new DownloadProgressCallback(this), fileDownloadInfo, tmpFilePath, cancellationToken); + await FetchServerExecutableFile(http, progressCallback, fileDownloadInfo, tmpFilePath, cancellationToken); } catch (Exception) { TryDeleteExecutableAfterFailure(tmpFilePath); throw; @@ -94,8 +118,6 @@ sealed class MinecraftServerExecutableDownloader { } catch (Exception e) { Logger.Error(e, "An unexpected error occurred."); return null; - } finally { - cancellationTokenSource.Dispose(); } }