Skip to content

Commit

Permalink
Support add/remove sockets on running servers
Browse files Browse the repository at this point in the history
- Make it possible to add and remove socket listeners to already running
servers. Relies on the existing methods, just adds and removes
"RunningListeners" as needed. Backed by Dictionary instead of List to
simplify lookups
- Handle concurrency by using a SemaphorSlim, similar to how it's done
in Kestrel
- Added equivalent methods to ServerHostedService, since the Server
itself isn't accessible.
- Added some test coverage for the new behavior.
  • Loading branch information
jooooel committed Jan 28, 2023
1 parent 126888b commit 2b6e433
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 31 deletions.
15 changes: 14 additions & 1 deletion src/Bedrock.Framework/Hosting/ServerHostedService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using System.Threading;
using System;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Options;

Expand All @@ -23,5 +26,15 @@ public Task StopAsync(CancellationToken cancellationToken)
{
return _server.StopAsync(cancellationToken);
}

public Task AddSocketListenerAsync(EndPoint endpoint, Action<IConnectionBuilder> configure)
{
return _server.AddSocketListenerAsync(endpoint, configure);
}

public Task RemoveSocketListenerAsync(EndPoint endpoint)
{
return _server.RemoveSocketListener(endpoint);
}
}
}
156 changes: 126 additions & 30 deletions src/Bedrock.Framework/Server/Server.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Bedrock.Framework
{
public class Server
{
private readonly ServerBuilder _builder;
private readonly ILogger<Server> _logger;
private readonly List<RunningListener> _listeners = new List<RunningListener>();
private readonly Dictionary<EndPoint, RunningListener> _listeners = new Dictionary<EndPoint, RunningListener>();
private readonly TaskCompletionSource<object> _shutdownTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TimerAwaitable _timerAwaitable;
private readonly SemaphoreSlim _listenerSemaphore = new SemaphoreSlim(initialCount: 1);
private Task _timerTask = Task.CompletedTask;
private int _stopping;

internal Server(ServerBuilder builder)
{
Expand All @@ -29,7 +34,7 @@ public IEnumerable<EndPoint> EndPoints
{
get
{
foreach (var listener in _listeners)
foreach (var listener in _listeners.Values)
{
yield return listener.Listener.EndPoint;
}
Expand All @@ -42,12 +47,7 @@ public async Task StartAsync(CancellationToken cancellationToken = default)
{
foreach (var binding in _builder.Bindings)
{
await foreach (var listener in binding.BindAsync(cancellationToken).ConfigureAwait(false))
{
var runningListener = new RunningListener(this, binding, listener);
_listeners.Add(runningListener);
runningListener.Start();
}
await StartRunningListenersAsync(binding, cancellationToken).ConfigureAwait(false);
}
}
catch
Expand All @@ -67,7 +67,7 @@ private async Task StartTimerAsync()
{
while (await _timerAwaitable)
{
foreach (var listener in _listeners)
foreach (var listener in _listeners.Values)
{
listener.TickHeartbeat();
}
Expand All @@ -77,40 +77,132 @@ private async Task StartTimerAsync()

public async Task StopAsync(CancellationToken cancellationToken = default)
{
var tasks = new Task[_listeners.Count];
if (Interlocked.Exchange(ref _stopping, 1) == 1)
{
return;
}

await _listenerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
var listeners = _listeners.Values.ToList();

var tasks = new Task[listeners.Count];

for (int i = 0; i < listeners.Count; i++)
{
tasks[i] = listeners[i].Listener.UnbindAsync(cancellationToken).AsTask();
}

await Task.WhenAll(tasks).ConfigureAwait(false);

// Signal to all of the listeners that it's time to start the shutdown process
// We call this after unbind so that we're not touching the listener anymore (each loop will dispose the listener)
_shutdownTcs.TrySetResult(null);

for (int i = 0; i < listeners.Count; i++)
{
tasks[i] = listeners[i].ExecutionTask;
}

var shutdownTask = Task.WhenAll(tasks);

if (cancellationToken.CanBeCanceled)
{
await shutdownTask.WithCancellation(cancellationToken).ConfigureAwait(false);
}
else
{
await shutdownTask.ConfigureAwait(false);
}

if (_timerAwaitable != null)
{
_timerAwaitable.Stop();

for (int i = 0; i < _listeners.Count; i++)
await _timerTask.ConfigureAwait(false);
}
}
finally
{
tasks[i] = _listeners[i].Listener.UnbindAsync(cancellationToken).AsTask();
_listenerSemaphore.Release();
}
}

await Task.WhenAll(tasks).ConfigureAwait(false);
public Task AddSocketListenerAsync(EndPoint endpoint, Action<IConnectionBuilder> configure, CancellationToken cancellationToken = default)
{
var socketTransportFactory = new SocketTransportFactory(Options.Create(new SocketTransportOptions()), _builder.ApplicationServices.GetLoggerFactory());
var connectionBuilder = new ConnectionBuilder(_builder.ApplicationServices);

// Signal to all of the listeners that it's time to start the shutdown process
// We call this after unbind so that we're not touching the listener anymore (each loop will dispose the listener)
_shutdownTcs.TrySetResult(null);
configure(connectionBuilder);

for (int i = 0; i < _listeners.Count; i++)
var binding = new EndPointBinding(endpoint, connectionBuilder.Build(), socketTransportFactory);
return StartRunningListenersAsync(binding, cancellationToken);
}

public async Task RemoveSocketListener(EndPoint endpoint, CancellationToken cancellationToken = default)
{
await _listenerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);

if (_stopping == 1)
{
tasks[i] = _listeners[i].ExecutionTask;
throw new InvalidOperationException("The server has already been stopped.");
}

var shutdownTask = Task.WhenAll(tasks);
try
{
if (!_listeners.Remove(endpoint, out var listener))
{
return;
}

await listener.Listener.UnbindAsync(cancellationToken).ConfigureAwait(false);

if (cancellationToken.CanBeCanceled)
// Signal to the listener that it's time to start the shutdown process
// We call this after unbind so that we're not touching the listener anymore
listener.ShutdownTcs.TrySetResult(null);

if (cancellationToken.CanBeCanceled)
{
await listener.ExecutionTask.WithCancellation(cancellationToken).ConfigureAwait(false);
}
else
{
await listener.ExecutionTask.ConfigureAwait(false);
}
}
finally
{
await shutdownTask.WithCancellation(cancellationToken).ConfigureAwait(false);
_listenerSemaphore.Release();
}
else
}

private async Task StartRunningListenersAsync(ServerBinding binding, CancellationToken cancellationToken = default)
{
await _listenerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);

if (_stopping == 1)
{
await shutdownTask.ConfigureAwait(false);
throw new InvalidOperationException("The server has already been stopped.");
}

if (_timerAwaitable != null)
try
{
_timerAwaitable.Stop();
await foreach (var listener in binding.BindAsync(cancellationToken).ConfigureAwait(false))
{
var runningListener = new RunningListener(this, binding, listener);
if (!_listeners.TryAdd(runningListener.Listener.EndPoint, runningListener))
{
_logger.LogWarning("Will not start RunningListener, EndPoint already exist");
continue;
}

await _timerTask.ConfigureAwait(false);
runningListener.Start();
}
}
finally
{
_listenerSemaphore.Release();
}
}

Expand All @@ -130,10 +222,12 @@ public RunningListener(Server server, ServerBinding binding, IConnectionListener
public void Start()
{
ExecutionTask = RunListenerAsync();
ShutdownTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
}

public IConnectionListener Listener { get; }
public Task ExecutionTask { get; private set; }
public TaskCompletionSource<object> ShutdownTcs { get; private set; }

public void TickHeartbeat()
{
Expand Down Expand Up @@ -215,8 +309,11 @@ async Task ExecuteConnectionAsync(ServerConnection serverConnection)
id++;
}

// Don't shut down connections until entire server is shutting down
await _server._shutdownTcs.Task.ConfigureAwait(false);
// Don't shut down connections until this listener or the entire server is shutting down
await Task.WhenAny(
ShutdownTcs.Task,
_server._shutdownTcs.Task)
.ConfigureAwait(false);

// Give connections a chance to close gracefully
var tasks = new List<Task>(_connections.Count);
Expand All @@ -241,7 +338,6 @@ async Task ExecuteConnectionAsync(ServerConnection serverConnection)
await listener.DisposeAsync().ConfigureAwait(false);
}


private IDisposable BeginConnectionScope(ServerConnection connection)
{
if (_server._logger.IsEnabled(LogLevel.Critical))
Expand All @@ -253,4 +349,4 @@ private IDisposable BeginConnectionScope(ServerConnection connection)
}
}
}
}
}
Loading

0 comments on commit 2b6e433

Please sign in to comment.