|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#nullable disable
using Moq;
namespace Microsoft.NET.Sdk.Razor.Tool.Tests
{
public class DefaultRequestDispatcherTest
{
private static ServerRequest EmptyServerRequest => new(1, Array.Empty<RequestArgument>());
private static ServerResponse EmptyServerResponse => new CompletedServerResponse(
returnCode: 0,
utf8output: false,
output: string.Empty,
error: string.Empty);
[Fact]
public async Task AcceptConnection_ReadingRequestFails_ClosesConnection()
{
// Arrange
var stream = Mock.Of<Stream>();
var compilerHost = CreateCompilerHost();
var connectionHost = CreateConnectionHost();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None);
var connection = CreateConnection(stream);
// Act
var result = await dispatcher.AcceptConnection(
Task.FromResult<Connection>(connection), accept: true, cancellationToken: CancellationToken.None);
// Assert
Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, result.CloseReason);
}
/// <summary>
/// A failure to write the results to the client is considered a client disconnection. Any error
/// from when the build starts to when the write completes should be handled this way.
/// </summary>
[Fact]
public async Task AcceptConnection_WritingResultsFails_ClosesConnection()
{
// Arrange
var memoryStream = new MemoryStream();
await EmptyServerRequest.WriteAsync(memoryStream, CancellationToken.None);
memoryStream.Position = 0;
var stream = new Mock<Stream>(MockBehavior.Strict);
stream
.Setup(x => x.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns((byte[] array, int start, int length, CancellationToken ct) => memoryStream.ReadAsync(array, start, length, ct));
var connection = CreateConnection(stream.Object);
var compilerHost = CreateCompilerHost(c =>
{
c.ExecuteFunc = (req, ct) =>
{
return EmptyServerResponse;
};
});
var connectionHost = CreateConnectionHost();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None);
// Act
// We expect WriteAsync to fail because the mock stream doesn't have a corresponding setup.
var connectionResult = await dispatcher.AcceptConnection(
Task.FromResult<Connection>(connection), accept: true, cancellationToken: CancellationToken.None);
// Assert
Assert.Equal(ConnectionResult.Reason.ClientDisconnect, connectionResult.CloseReason);
Assert.Null(connectionResult.KeepAlive);
}
/// <summary>
/// Ensure the Connection correctly handles the case where a client disconnects while in the
/// middle of executing a request.
/// </summary>
[Fact]
public async Task AcceptConnection_ClientDisconnectsWhenExecutingRequest_ClosesConnection()
{
// Arrange
var connectionHost = Mock.Of<ConnectionHost>();
// Fake a long running task here that we can validate later on.
var buildTaskSource = new TaskCompletionSource<bool>();
var buildTaskCancellationToken = default(CancellationToken);
var compilerHost = CreateCompilerHost(c =>
{
#pragma warning disable xUnit1031
c.ExecuteFunc = (req, ct) =>
{
Task.WaitAll(buildTaskSource.Task);
return EmptyServerResponse;
};
#pragma warning restore xUnit1031
});
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None);
var readyTaskSource = new TaskCompletionSource<bool>();
var disconnectTaskSource = new TaskCompletionSource<bool>();
var connectionTask = CreateConnectionWithEmptyServerRequest(c =>
{
c.WaitForDisconnectAsyncFunc = (ct) =>
{
buildTaskCancellationToken = ct;
readyTaskSource.SetResult(true);
return disconnectTaskSource.Task;
};
});
var handleTask = dispatcher.AcceptConnection(
connectionTask, accept: true, cancellationToken: CancellationToken.None);
// Wait until WaitForDisconnectAsync task is actually created and running.
await readyTaskSource.Task;
// Act
// Now simulate a disconnect by the client.
disconnectTaskSource.SetResult(true);
var connectionResult = await handleTask;
buildTaskSource.SetResult(true);
// Assert
Assert.Equal(ConnectionResult.Reason.ClientDisconnect, connectionResult.CloseReason);
Assert.Null(connectionResult.KeepAlive);
Assert.True(buildTaskCancellationToken.IsCancellationRequested);
}
[Fact]
public async Task AcceptConnection_AcceptFalse_RejectsBuildRequest()
{
// Arrange
var stream = new TestableStream();
await EmptyServerRequest.WriteAsync(stream.ReadStream, CancellationToken.None);
stream.ReadStream.Position = 0;
var connection = CreateConnection(stream);
var connectionHost = CreateConnectionHost();
var compilerHost = CreateCompilerHost();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None);
// Act
var connectionResult = await dispatcher.AcceptConnection(
Task.FromResult<Connection>(connection), accept: false, cancellationToken: CancellationToken.None);
// Assert
Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, connectionResult.CloseReason);
stream.WriteStream.Position = 0;
var response = await ServerResponse.ReadAsync(stream.WriteStream);
Assert.Equal(ServerResponse.ResponseType.Rejected, response.Type);
}
[Fact]
public async Task AcceptConnection_ShutdownRequest_ReturnsShutdownResponse()
{
// Arrange
var stream = new TestableStream();
await ServerRequest.CreateShutdown().WriteAsync(stream.ReadStream, CancellationToken.None);
stream.ReadStream.Position = 0;
var connection = CreateConnection(stream);
var connectionHost = CreateConnectionHost();
var compilerHost = CreateCompilerHost();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None);
// Act
var connectionResult = await dispatcher.AcceptConnection(
Task.FromResult<Connection>(connection), accept: true, cancellationToken: CancellationToken.None);
// Assert
Assert.Equal(ConnectionResult.Reason.ClientShutdownRequest, connectionResult.CloseReason);
stream.WriteStream.Position = 0;
var response = await ServerResponse.ReadAsync(stream.WriteStream);
Assert.Equal(ServerResponse.ResponseType.Shutdown, response.Type);
}
[Fact]
public async Task AcceptConnection_ConnectionHostThrowsWhenConnecting_ClosesConnection()
{
// Arrange
var connectionHost = new Mock<ConnectionHost>(MockBehavior.Strict);
connectionHost.Setup(c => c.WaitForConnectionAsync(It.IsAny<CancellationToken>())).Throws(new Exception());
var compilerHost = CreateCompilerHost();
var dispatcher = new DefaultRequestDispatcher(connectionHost.Object, compilerHost, CancellationToken.None);
var connection = CreateConnection(Mock.Of<Stream>());
// Act
var connectionResult = await dispatcher.AcceptConnection(
Task.FromResult<Connection>(connection), accept: true, cancellationToken: CancellationToken.None);
// Assert
Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, connectionResult.CloseReason);
Assert.Null(connectionResult.KeepAlive);
}
[Fact]
public async Task AcceptConnection_ClientConnectionThrowsWhenConnecting_ClosesConnection()
{
// Arrange
var compilerHost = CreateCompilerHost();
var connectionHost = CreateConnectionHost();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None);
var connectionTask = Task.FromException<Connection>(new Exception());
// Act
var connectionResult = await dispatcher.AcceptConnection(
connectionTask, accept: true, cancellationToken: CancellationToken.None);
// Assert
Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, connectionResult.CloseReason);
Assert.Null(connectionResult.KeepAlive);
}
[Fact]
public async Task Dispatcher_ClientConnectionThrowsWhenExecutingRequest_ClosesConnection()
{
// Arrange
var called = false;
var connectionTask = CreateConnectionWithEmptyServerRequest(c =>
{
c.WaitForDisconnectAsyncFunc = (ct) =>
{
called = true;
throw new Exception();
};
});
var compilerHost = CreateCompilerHost();
var connectionHost = CreateConnectionHost();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None);
// Act
var connectionResult = await dispatcher.AcceptConnection(
connectionTask, accept: true, cancellationToken: CancellationToken.None);
// Assert
Assert.True(called);
Assert.Equal(ConnectionResult.Reason.ClientException, connectionResult.CloseReason);
Assert.Null(connectionResult.KeepAlive);
}
[Fact]
public void Dispatcher_NoConnections_HitsKeepAliveTimeout()
{
// Arrange
var keepAlive = TimeSpan.FromSeconds(3);
var compilerHost = CreateCompilerHost();
var connectionHost = new Mock<ConnectionHost>();
connectionHost
.Setup(x => x.WaitForConnectionAsync(It.IsAny<CancellationToken>()))
.Returns(new TaskCompletionSource<Connection>().Task);
var eventBus = new TestableEventBus();
var dispatcher = new DefaultRequestDispatcher(connectionHost.Object, compilerHost, CancellationToken.None, eventBus, keepAlive);
var startTime = DateTime.Now;
// Act
dispatcher.Run();
// Assert
Assert.True(eventBus.HitKeepAliveTimeout);
}
/// <summary>
/// Ensure server respects keep alive and shuts down after processing a single connection.
/// </summary>
[Fact]
public void Dispatcher_ProcessSingleConnection_HitsKeepAliveTimeout()
{
// Arrange
var connectionTask = CreateConnectionWithEmptyServerRequest();
var keepAlive = TimeSpan.FromSeconds(1);
var compilerHost = CreateCompilerHost(c =>
{
c.ExecuteFunc = (req, ct) =>
{
return EmptyServerResponse;
};
});
var connectionHost = CreateConnectionHost(connectionTask, new TaskCompletionSource<Connection>().Task);
var eventBus = new TestableEventBus();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None, eventBus, keepAlive);
// Act
dispatcher.Run();
// Assert
Assert.Equal(1, eventBus.CompletedCount);
Assert.True(eventBus.LastProcessedTime.HasValue);
Assert.True(eventBus.HitKeepAliveTimeout);
}
/// <summary>
/// Ensure server respects keep alive and shuts down after processing multiple connections.
/// </summary>
[Fact]
public void Dispatcher_ProcessMultipleConnections_HitsKeepAliveTimeout()
{
// Arrange
var count = 5;
var list = new List<Task<Connection>>();
for (var i = 0; i < count; i++)
{
var connectionTask = CreateConnectionWithEmptyServerRequest();
list.Add(connectionTask);
}
list.Add(new TaskCompletionSource<Connection>().Task);
var connectionHost = CreateConnectionHost(list.ToArray());
var compilerHost = CreateCompilerHost(c =>
{
c.ExecuteFunc = (req, ct) =>
{
return EmptyServerResponse;
};
});
var keepAlive = TimeSpan.FromSeconds(1);
var eventBus = new TestableEventBus();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None, eventBus, keepAlive);
// Act
dispatcher.Run();
// Assert
Assert.Equal(count, eventBus.CompletedCount);
Assert.True(eventBus.LastProcessedTime.HasValue);
Assert.True(eventBus.HitKeepAliveTimeout);
}
/// <summary>
/// Ensure server respects keep alive and shuts down after processing simultaneous connections.
/// </summary>
[Fact]
public async Task Dispatcher_ProcessSimultaneousConnections_HitsKeepAliveTimeout()
{
// Arrange
var totalCount = 2;
var readySource = new TaskCompletionSource<bool>();
var list = new List<TaskCompletionSource<bool>>();
var connectionHost = new Mock<ConnectionHost>();
connectionHost
.Setup(x => x.WaitForConnectionAsync(It.IsAny<CancellationToken>()))
.Returns((CancellationToken ct) =>
{
if (list.Count < totalCount)
{
var source = new TaskCompletionSource<bool>();
var connectionTask = CreateConnectionWithEmptyServerRequest(c =>
{
// Keep the connection active until we decide to end it.
c.WaitForDisconnectAsyncFunc = _ => source.Task;
});
list.Add(source);
return connectionTask;
}
readySource.SetResult(true);
return new TaskCompletionSource<Connection>().Task;
});
var compilerHost = CreateCompilerHost(c =>
{
c.ExecuteFunc = (req, ct) =>
{
return EmptyServerResponse;
};
});
var eventBus = new TestableEventBus();
var completedCompilations = 0;
var allCompilationsComplete = new TaskCompletionSource<bool>();
eventBus.CompilationComplete += (obj, args) =>
{
if (++completedCompilations == totalCount)
{
// All compilations have completed.
allCompilationsComplete.SetResult(true);
}
};
var keepAlive = TimeSpan.FromSeconds(1);
var dispatcherTask = Task.Run(() =>
{
var dispatcher = new DefaultRequestDispatcher(connectionHost.Object, compilerHost, CancellationToken.None, eventBus, keepAlive);
dispatcher.Run();
});
// Wait for all connections to be created.
await readySource.Task;
// Wait for all compilations to complete.
await allCompilationsComplete.Task;
// Now allow all the connections to be disconnected.
foreach (var source in list)
{
source.SetResult(true);
}
// Act
// Now dispatcher should be in an idle state with no active connections.
await dispatcherTask;
// Assert
Assert.False(eventBus.HasDetectedBadConnection);
Assert.Equal(totalCount, eventBus.CompletedCount);
Assert.True(eventBus.LastProcessedTime.HasValue, "LastProcessedTime should have had a value.");
Assert.True(eventBus.HitKeepAliveTimeout, "HitKeepAliveTimeout should have been hit.");
}
[Fact]
public void Dispatcher_ClientConnectionThrows_BeginsShutdown()
{
// Arrange
var listenCancellationToken = default(CancellationToken);
var firstConnectionTask = CreateConnectionWithEmptyServerRequest(c =>
{
c.WaitForDisconnectAsyncFunc = (ct) =>
{
listenCancellationToken = ct;
return Task.Delay(Timeout.Infinite, ct).ContinueWith<Connection>(_ => null);
};
});
var secondConnectionTask = CreateConnectionWithEmptyServerRequest(c =>
{
c.WaitForDisconnectAsyncFunc = (ct) => throw new Exception();
});
var compilerHost = CreateCompilerHost();
var connectionHost = CreateConnectionHost(
firstConnectionTask,
secondConnectionTask,
new TaskCompletionSource<Connection>().Task);
var keepAlive = TimeSpan.FromSeconds(10);
var eventBus = new TestableEventBus();
var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None, eventBus, keepAlive);
// Act
dispatcher.Run();
// Assert
Assert.True(eventBus.HasDetectedBadConnection);
Assert.True(listenCancellationToken.IsCancellationRequested);
}
private static TestableConnection CreateConnection(Stream stream, string identifier = null)
{
return new TestableConnection(stream, identifier ?? "identifier");
}
private static async Task<Connection> CreateConnectionWithEmptyServerRequest(Action<TestableConnection> configureConnection = null)
{
var memoryStream = new MemoryStream();
await EmptyServerRequest.WriteAsync(memoryStream, CancellationToken.None);
memoryStream.Position = 0;
var connection = CreateConnection(memoryStream);
configureConnection?.Invoke(connection);
return connection;
}
private static ConnectionHost CreateConnectionHost(params Task<Connection>[] connections)
{
var host = new Mock<ConnectionHost>();
if (connections.Length > 0)
{
var index = 0;
host
.Setup(x => x.WaitForConnectionAsync(It.IsAny<CancellationToken>()))
.Returns((CancellationToken ct) => connections[index++]);
}
return host.Object;
}
private static TestableCompilerHost CreateCompilerHost(Action<TestableCompilerHost> configureCompilerHost = null)
{
var compilerHost = new TestableCompilerHost();
configureCompilerHost?.Invoke(compilerHost);
return compilerHost;
}
private class TestableCompilerHost : CompilerHost
{
internal Func<ServerRequest, CancellationToken, ServerResponse> ExecuteFunc;
public override ServerResponse Execute(ServerRequest request, CancellationToken cancellationToken)
{
if (ExecuteFunc != null)
{
return ExecuteFunc(request, cancellationToken);
}
return EmptyServerResponse;
}
}
private class TestableConnection : Connection
{
internal Func<CancellationToken, Task> WaitForDisconnectAsyncFunc;
public TestableConnection(Stream stream, string identifier)
{
Stream = stream;
Identifier = identifier;
WaitForDisconnectAsyncFunc = ct => Task.Delay(Timeout.Infinite, ct);
}
public override Task WaitForDisconnectAsync(CancellationToken cancellationToken)
{
return WaitForDisconnectAsyncFunc(cancellationToken);
}
}
private class TestableStream : Stream
{
internal readonly MemoryStream ReadStream = new();
internal readonly MemoryStream WriteStream = new();
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => true;
public override long Length { get { throw new NotImplementedException(); } }
public override long Position
{
get { throw new NotImplementedException(); }
set { throw new NotImplementedException(); }
}
public override void Flush()
{
}
public override int Read(byte[] buffer, int offset, int count)
{
return ReadStream.Read(buffer, offset, count);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadStream.ReadAsync(buffer, offset, count, cancellationToken);
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotImplementedException();
}
public override void SetLength(long value)
{
throw new NotImplementedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
WriteStream.Write(buffer, offset, count);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteStream.WriteAsync(buffer, offset, count, cancellationToken);
}
}
}
}
|