|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Net.Http.Headers;
namespace Microsoft.AspNetCore.TestHost;
public class TestClientTests
{
[Fact]
public async Task GetAsyncWorks()
{
// Arrange
var expected = "GET Response";
RequestDelegate appDelegate = ctx =>
ctx.Response.WriteAsync(expected);
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
// Act
var actual = await client.GetStringAsync("http://localhost:12345");
// Assert
Assert.Equal(expected, actual);
}
[Fact]
public async Task NoTrailingSlash_NoPathBase()
{
// Arrange
var expected = "GET Response";
RequestDelegate appDelegate = ctx =>
{
Assert.Equal("", ctx.Request.PathBase.Value);
Assert.Equal("/", ctx.Request.Path.Value);
return ctx.Response.WriteAsync(expected);
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
// Act
var actual = await client.GetStringAsync("http://localhost:12345");
// Assert
Assert.Equal(expected, actual);
}
[Fact]
public async Task SingleTrailingSlash_NoPathBase()
{
// Arrange
var expected = "GET Response";
RequestDelegate appDelegate = ctx =>
{
Assert.Equal("", ctx.Request.PathBase.Value);
Assert.Equal("/", ctx.Request.Path.Value);
return ctx.Response.WriteAsync(expected);
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
// Act
var actual = await client.GetStringAsync("http://localhost:12345/");
// Assert
Assert.Equal(expected, actual);
}
[Fact]
public async Task PutAsyncWorks()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
{
var content = await new StreamReader(ctx.Request.Body).ReadToEndAsync();
await ctx.Response.WriteAsync(content + " PUT Response");
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
// Act
var content = new StringContent("Hello world");
var response = await client.PutAsync("http://localhost:12345", content).DefaultTimeout();
// Assert
Assert.Equal("Hello world PUT Response", await response.Content.ReadAsStringAsync().DefaultTimeout());
}
[Fact]
public async Task PostAsyncWorks()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
await ctx.Response.WriteAsync(await new StreamReader(ctx.Request.Body).ReadToEndAsync() + " POST Response");
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
// Act
var content = new StringContent("Hello world");
var response = await client.PostAsync("http://localhost:12345", content).DefaultTimeout();
// Assert
Assert.Equal("Hello world POST Response", await response.Content.ReadAsStringAsync().DefaultTimeout());
}
[Fact]
public async Task LargePayload_DisposesRequest_AfterResponseIsCompleted()
{
// Arrange
var data = new byte[2048];
var character = Encoding.ASCII.GetBytes("a");
for (var i = 0; i < data.Length; i++)
{
data[i] = character[0];
}
var builder = new WebHostBuilder();
RequestDelegate app = async ctx =>
{
var disposable = new TestDisposable();
ctx.Response.RegisterForDispose(disposable);
await ctx.Response.Body.WriteAsync(data, 0, 1024);
Assert.False(disposable.IsDisposed);
await ctx.Response.Body.WriteAsync(data, 1024, 1024);
};
builder.Configure(appBuilder => appBuilder.Run(app));
var server = new TestServer(builder);
var client = server.CreateClient();
// Act & Assert
var response = await client.GetAsync("http://localhost:12345");
}
private class TestDisposable : IDisposable
{
public bool IsDisposed { get; private set; }
public void Dispose()
{
IsDisposed = true;
}
}
[Fact]
public async Task ClientStreamingWorks()
{
// Arrange
var responseStartedSyncPoint = new SyncPoint();
var requestEndingSyncPoint = new SyncPoint();
var requestStreamSyncPoint = new SyncPoint();
RequestDelegate appDelegate = async ctx =>
{
// Send headers
await ctx.Response.BodyWriter.FlushAsync();
// Ensure headers received by client
await responseStartedSyncPoint.WaitToContinue();
await ctx.Response.WriteAsync("STARTED");
// ReadToEndAsync will wait until request body is complete
var requestString = await new StreamReader(ctx.Request.Body).ReadToEndAsync();
await ctx.Response.WriteAsync(requestString + " POST Response");
await requestEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamSyncPoint.WaitToContinue();
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
await responseStartedSyncPoint.WaitForSyncPoint().DefaultTimeout();
responseStartedSyncPoint.Continue();
var responseContent = await response.Content.ReadAsStreamAsync().DefaultTimeout();
// Assert
// Ensure request stream has started
await requestStreamSyncPoint.WaitForSyncPoint();
byte[] buffer = new byte[1024];
var length = await responseContent.ReadAsync(buffer).AsTask().DefaultTimeout();
Assert.Equal("STARTED", Encoding.UTF8.GetString(buffer, 0, length));
// Send content and finish request body
await requestStream.WriteAsync(Encoding.UTF8.GetBytes("Hello world")).AsTask().DefaultTimeout();
await requestStream.FlushAsync().DefaultTimeout();
requestStreamSyncPoint.Continue();
// Ensure content is received while request is in progress
length = await responseContent.ReadAsync(buffer).AsTask().DefaultTimeout();
Assert.Equal("Hello world POST Response", Encoding.UTF8.GetString(buffer, 0, length));
// Request is ending
await requestEndingSyncPoint.WaitForSyncPoint().DefaultTimeout();
requestEndingSyncPoint.Continue();
// No more response content
length = await responseContent.ReadAsync(buffer).AsTask().DefaultTimeout();
Assert.Equal(0, length);
}
[Fact]
public async Task ClientStreaming_HttpContentException()
{
var requestCount = 0;
RequestDelegate appDelegate = ctx =>
{
requestCount++;
return Task.CompletedTask;
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var message = new HttpRequestMessage(HttpMethod.Post, "https://example.com/");
message.Content = new PushContent(stream => throw new InvalidOperationException("HttpContent exception"));
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => client.SendAsync(message, CancellationToken.None));
Assert.Equal("HttpContent exception", ex.Message);
Assert.Equal(0, requestCount);
}
[Fact]
public async Task ClientStreaming_Cancellation()
{
// Arrange
var responseStartedSyncPoint = new SyncPoint();
var responseReadSyncPoint = new SyncPoint();
var responseEndingSyncPoint = new SyncPoint();
var requestStreamSyncPoint = new SyncPoint();
var readCanceled = false;
RequestDelegate appDelegate = async ctx =>
{
// Send headers
await ctx.Response.BodyWriter.FlushAsync();
// Ensure headers received by client
await responseStartedSyncPoint.WaitToContinue();
var serverBuffer = new byte[1024];
var serverLength = await ctx.Request.Body.ReadAsync(serverBuffer);
Assert.Equal("SENT", Encoding.UTF8.GetString(serverBuffer, 0, serverLength));
await responseReadSyncPoint.WaitToContinue();
try
{
await ctx.Request.Body.ReadAsync(serverBuffer);
}
catch (OperationCanceledException)
{
readCanceled = true;
}
await responseEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamSyncPoint.WaitToContinue();
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
await responseStartedSyncPoint.WaitForSyncPoint().DefaultTimeout();
responseStartedSyncPoint.Continue();
var responseContent = await response.Content.ReadAsStreamAsync().DefaultTimeout();
// Assert
// Ensure request stream has started
await requestStreamSyncPoint.WaitForSyncPoint();
// Write to request
await requestStream.WriteAsync(Encoding.UTF8.GetBytes("SENT")).AsTask().DefaultTimeout();
await requestStream.FlushAsync().DefaultTimeout();
await responseReadSyncPoint.WaitForSyncPoint().DefaultTimeout();
// Cancel request. Disposing response must be used because SendAsync has finished.
response.Dispose();
responseReadSyncPoint.Continue();
await responseEndingSyncPoint.WaitForSyncPoint().DefaultTimeout();
responseEndingSyncPoint.Continue();
Assert.True(readCanceled);
requestStreamSyncPoint.Continue();
}
[Fact]
public async Task ClientStreaming_ResponseCompletesWithoutReadingRequest()
{
// Arrange
var requestStreamTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var responseEndingSyncPoint = new SyncPoint();
RequestDelegate appDelegate = async ctx =>
{
await ctx.Response.WriteAsync("POST Response");
await responseEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamTcs.Task;
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
var responseContent = await response.Content.ReadAsStreamAsync().DefaultTimeout();
// Assert
// Read response
byte[] buffer = new byte[1024];
var length = await responseContent.ReadAsync(buffer).AsTask().DefaultTimeout();
Assert.Equal("POST Response", Encoding.UTF8.GetString(buffer, 0, length));
// Send large content and block on back pressure
var writeTask = Task.Run(async () =>
{
try
{
await requestStream.WriteAsync(Encoding.UTF8.GetBytes(new string('!', 1024 * 1024 * 50))).AsTask().DefaultTimeout();
requestStreamTcs.SetResult();
}
catch (Exception ex)
{
requestStreamTcs.SetException(ex);
}
});
responseEndingSyncPoint.Continue();
// No more response content
length = await responseContent.ReadAsync(buffer).AsTask().DefaultTimeout();
Assert.Equal(0, length);
await writeTask;
}
[Fact]
public async Task ClientStreaming_ResponseCompletesWithPendingRead_ThrowError()
{
// Arrange
var requestStreamTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
RequestDelegate appDelegate = async ctx =>
{
var pendingReadTask = ctx.Request.Body.ReadAsync(new byte[1024], 0, 1024);
ctx.Response.Headers["test-header"] = "true";
await ctx.Response.Body.FlushAsync();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamTcs.Task;
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
var responseContent = await response.Content.ReadAsStreamAsync().DefaultTimeout();
// Assert
response.EnsureSuccessStatusCode();
Assert.Equal("true", response.Headers.GetValues("test-header").Single());
// Read response
var ex = await Assert.ThrowsAsync<IOException>(async () =>
{
byte[] buffer = new byte[1024];
var length = await responseContent.ReadAsync(buffer).AsTask().DefaultTimeout();
});
Assert.Equal("An error occurred when completing the request. Request delegate may have finished while there is a pending read of the request body.", ex.InnerException.Message);
// Unblock request
requestStreamTcs.TrySetResult();
}
[Fact]
public async Task ClientStreaming_ResponseCompletesWithoutResponseBodyWrite()
{
// Arrange
var requestStreamTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
RequestDelegate appDelegate = ctx =>
{
ctx.Response.Headers["test-header"] = "true";
return Task.CompletedTask;
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamTcs.Task;
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
var responseContent = await response.Content.ReadAsStreamAsync().DefaultTimeout();
// Assert
response.EnsureSuccessStatusCode();
Assert.Equal("true", response.Headers.GetValues("test-header").Single());
// Read response
byte[] buffer = new byte[1024];
var length = await responseContent.ReadAsync(buffer).AsTask().DefaultTimeout();
Assert.Equal(0, length);
// Writing to request stream will fail because server is complete
await Assert.ThrowsAnyAsync<Exception>(() => requestStream.WriteAsync(buffer).AsTask());
// Unblock request
requestStreamTcs.TrySetResult();
}
[Fact]
public async Task ClientStreaming_ServerAbort()
{
// Arrange
var requestStreamSyncPoint = new SyncPoint();
var responseEndingSyncPoint = new SyncPoint();
RequestDelegate appDelegate = async ctx =>
{
// Send headers
await ctx.Response.BodyWriter.FlushAsync();
ctx.Abort();
await responseEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamSyncPoint.WaitToContinue();
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
var responseContent = await response.Content.ReadAsStreamAsync().DefaultTimeout();
// Assert
// Ensure server has aborted
await responseEndingSyncPoint.WaitForSyncPoint();
// Ensure request stream has started
await requestStreamSyncPoint.WaitForSyncPoint();
// Send content and finish request body
await ExceptionAssert.ThrowsAsync<OperationCanceledException>(
() => requestStream.WriteAsync(Encoding.UTF8.GetBytes("Hello world")).AsTask(),
"Flush was canceled on underlying PipeWriter.").DefaultTimeout();
responseEndingSyncPoint.Continue();
requestStreamSyncPoint.Continue();
}
private class PushContent : HttpContent
{
private readonly Func<Stream, Task> _sendContent;
public PushContent(Func<Stream, Task> sendContent)
{
_sendContent = sendContent;
}
protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
{
return _sendContent(stream);
}
protected override bool TryComputeLength(out long length)
{
length = -1;
return false;
}
}
[Fact]
public async Task WebSocketWorks()
{
// Arrange
// This logger will attempt to access information from HttpRequest once the HttpContext is created
var logger = new VerifierLogger();
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
Assert.False(ctx.Request.Headers.ContainsKey(HeaderNames.SecWebSocketProtocol));
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
var receiveArray = new byte[1024];
while (true)
{
var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment<byte>(receiveArray), CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
break;
}
else
{
var sendBuffer = new System.ArraySegment<byte>(receiveArray, 0, receiveResult.Count);
await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None);
}
}
}
};
var builder = new WebHostBuilder()
.ConfigureServices(services =>
{
services.AddSingleton<ILogger<IWebHost>>(logger);
})
.Configure(app =>
{
app.Run(appDelegate);
});
var server = new TestServer(builder);
// Act
var client = server.CreateWebSocketClient();
// The HttpContext will be created and the logger will make sure that the HttpRequest exists and contains reasonable values
var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
var hello = Encoding.UTF8.GetBytes("hello");
await clientSocket.SendAsync(new System.ArraySegment<byte>(hello), WebSocketMessageType.Text, true, CancellationToken.None);
var world = Encoding.UTF8.GetBytes("world!");
await clientSocket.SendAsync(new System.ArraySegment<byte>(world), WebSocketMessageType.Binary, true, CancellationToken.None);
await clientSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
// Assert
Assert.Equal(WebSocketState.CloseSent, clientSocket.State);
var buffer = new byte[1024];
var result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(hello.Length, result.Count);
Assert.True(hello.SequenceEqual(buffer.Take(hello.Length)));
Assert.Equal(WebSocketMessageType.Text, result.MessageType);
result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(world.Length, result.Count);
Assert.True(world.SequenceEqual(buffer.Take(world.Length)));
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
Assert.Equal(WebSocketState.Closed, clientSocket.State);
clientSocket.Dispose();
}
[Fact]
public async Task WebSocketSubProtocolsWorks()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
if (ctx.WebSockets.WebSocketRequestedProtocols.Contains("alpha") &&
ctx.WebSockets.WebSocketRequestedProtocols.Contains("bravo"))
{
// according to rfc6455, the "server needs to include the same field and one of the selected subprotocol values"
// however, this isn't enforced by either our server or client so it's possible to accept an arbitrary protocol.
// Done here to demonstrate not "correct" behaviour, simply to show it's possible. Other clients may not allow this.
var websocket = await ctx.WebSockets.AcceptWebSocketAsync("charlie");
await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
}
else
{
var subprotocols = ctx.WebSockets.WebSocketRequestedProtocols.Any()
? string.Join(", ", ctx.WebSockets.WebSocketRequestedProtocols)
: "<none>";
var closeReason = "Unexpected subprotocols: " + subprotocols;
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
await websocket.CloseAsync(WebSocketCloseStatus.InternalServerError, closeReason, CancellationToken.None);
}
}
};
var builder = new WebHostBuilder()
.Configure(app =>
{
app.Run(appDelegate);
});
var server = new TestServer(builder);
// Act
var client = server.CreateWebSocketClient();
client.SubProtocols.Add("alpha");
client.SubProtocols.Add("bravo");
var clientSocket = await client.ConnectAsync(new Uri("wss://localhost"), CancellationToken.None);
var buffer = new byte[1024];
var result = await clientSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
// Assert
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
Assert.Equal("Normal Closure", result.CloseStatusDescription);
Assert.Equal(WebSocketState.CloseReceived, clientSocket.State);
Assert.Equal("charlie", clientSocket.SubProtocol);
clientSocket.Dispose();
}
[ConditionalFact]
public async Task WebSocketAcceptThrowsWhenCancelled()
{
// Arrange
// This logger will attempt to access information from HttpRequest once the HttpContext is created
var logger = new VerifierLogger();
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
var receiveArray = new byte[1024];
while (true)
{
var receiveResult = await websocket.ReceiveAsync(new ArraySegment<byte>(receiveArray), CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
break;
}
else
{
var sendBuffer = new System.ArraySegment<byte>(receiveArray, 0, receiveResult.Count);
await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None);
}
}
}
};
var builder = new WebHostBuilder()
.ConfigureServices(services => services.AddSingleton<ILogger<IWebHost>>(logger))
.Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
// Act
var client = server.CreateWebSocketClient();
var tokenSource = new CancellationTokenSource();
tokenSource.Cancel();
// Assert
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await client.ConnectAsync(new Uri("http://localhost"), tokenSource.Token));
}
private class VerifierLogger : ILogger<IWebHost>
{
public IDisposable BeginScope<TState>(TState state) => new NoopDispoasble();
public bool IsEnabled(LogLevel logLevel) => true;
// This call verifies that fields of HttpRequest are accessed and valid
public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func<TState, Exception, string> formatter) => formatter(state, exception);
class NoopDispoasble : IDisposable
{
public void Dispose()
{
}
}
}
[Fact]
public async Task WebSocketDisposalThrowsOnPeer()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
websocket.Dispose();
}
};
var builder = new WebHostBuilder().Configure(app =>
{
app.Run(appDelegate);
});
var server = new TestServer(builder);
// Act
var client = server.CreateWebSocketClient();
var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
var buffer = new byte[1024];
await Assert.ThrowsAsync<IOException>(async () => await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None));
clientSocket.Dispose();
}
[Fact]
public async Task WebSocketTinyReceiveGeneratesEndOfMessage()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
var receiveArray = new byte[1024];
while (true)
{
var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment<byte>(receiveArray), CancellationToken.None);
var sendBuffer = new System.ArraySegment<byte>(receiveArray, 0, receiveResult.Count);
await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None);
}
}
};
var builder = new WebHostBuilder().Configure(app =>
{
app.Run(appDelegate);
});
var server = new TestServer(builder);
// Act
var client = server.CreateWebSocketClient();
var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
var hello = Encoding.UTF8.GetBytes("hello");
await clientSocket.SendAsync(new System.ArraySegment<byte>(hello), WebSocketMessageType.Text, true, CancellationToken.None);
// Assert
var buffer = new byte[1];
for (var i = 0; i < hello.Length; i++)
{
bool last = i == (hello.Length - 1);
var result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(buffer.Length, result.Count);
Assert.Equal(buffer[0], hello[i]);
Assert.Equal(last, result.EndOfMessage);
}
clientSocket.Dispose();
}
[Fact]
public async Task ClientDisposalAbortsRequest()
{
// Arrange
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
RequestDelegate appDelegate = async ctx =>
{
// Write Headers
await ctx.Response.Body.FlushAsync();
var sem = new SemaphoreSlim(0);
try
{
await sem.WaitAsync(ctx.RequestAborted);
}
catch (Exception e)
{
tcs.SetException(e);
}
};
// Act
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:12345");
var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
// Abort Request
response.Dispose();
// Assert
var exception = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await tcs.Task);
}
[Fact]
public async Task ClientCancellationAbortsRequest()
{
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var builder = new WebHostBuilder().Configure(app => app.Run(async ctx =>
{
try
{
await Task.Delay(TimeSpan.FromSeconds(30), ctx.RequestAborted);
tcs.SetResult();
}
catch (Exception e)
{
tcs.SetException(e);
return;
}
throw new InvalidOperationException("The request was not aborted");
}));
using var server = new TestServer(builder);
using var client = server.CreateClient();
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1));
var response = await Assert.ThrowsAnyAsync<OperationCanceledException>(() => client.GetAsync("http://localhost:12345", cts.Token));
var exception = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await tcs.Task);
}
[Fact]
public async Task AsyncLocalValueOnClientIsNotPreserved()
{
var asyncLocal = new AsyncLocal<object>();
var value = new object();
asyncLocal.Value = value;
object capturedValue = null;
var builder = new WebHostBuilder()
.Configure(app =>
{
app.Run((context) =>
{
capturedValue = asyncLocal.Value;
return context.Response.WriteAsync("Done");
});
});
var server = new TestServer(builder);
var client = server.CreateClient();
var resp = await client.GetAsync("/");
Assert.NotSame(value, capturedValue);
}
[Fact]
public async Task AsyncLocalValueOnClientIsPreservedIfPreserveExecutionContextIsTrue()
{
var asyncLocal = new AsyncLocal<object>();
var value = new object();
asyncLocal.Value = value;
object capturedValue = null;
var builder = new WebHostBuilder()
.Configure(app =>
{
app.Run((context) =>
{
capturedValue = asyncLocal.Value;
return context.Response.WriteAsync("Done");
});
});
var server = new TestServer(builder)
{
PreserveExecutionContext = true
};
var client = server.CreateClient();
var resp = await client.GetAsync("/");
Assert.Same(value, capturedValue);
}
[Fact]
public async Task SendAsync_Default_Protocol11()
{
// Arrange
string protocol = null;
var expected = "GET Response";
RequestDelegate appDelegate = async ctx =>
{
protocol = ctx.Request.Protocol;
await ctx.Response.WriteAsync(expected);
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:12345");
// Act
var message = await client.SendAsync(request);
var actual = await message.Content.ReadAsStringAsync();
// Assert
Assert.Equal(expected, actual);
Assert.Equal(new Version(1, 1), message.Version);
Assert.Equal(protocol, HttpProtocol.Http11);
}
[Fact]
public async Task SendAsync_ExplicitlySet_Protocol20()
{
// Arrange
string protocol = null;
var expected = "GET Response";
RequestDelegate appDelegate = async ctx =>
{
protocol = ctx.Request.Protocol;
await ctx.Response.WriteAsync(expected);
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:12345");
request.Version = new Version(2, 0);
// Act
var message = await client.SendAsync(request);
var actual = await message.Content.ReadAsStringAsync();
// Assert
Assert.Equal(expected, actual);
Assert.Equal(new Version(2, 0), message.Version);
Assert.Equal(protocol, HttpProtocol.Http2);
}
[Fact]
public async Task SendAsync_ExplicitlySet_Protocol30()
{
// Arrange
string protocol = null;
var expected = "GET Response";
RequestDelegate appDelegate = async ctx =>
{
protocol = ctx.Request.Protocol;
await ctx.Response.WriteAsync(expected);
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:12345");
request.Version = new Version(3, 0);
// Act
var message = await client.SendAsync(request);
var actual = await message.Content.ReadAsStringAsync();
// Assert
Assert.Equal(expected, actual);
Assert.Equal(new Version(3, 0), message.Version);
Assert.Equal(protocol, HttpProtocol.Http3);
}
[Fact]
public async Task VerifyWebSocketAndUpgradeFeaturesForNonWebSocket()
{
using (var testServer = new TestServer(new WebHostBuilder()
.Configure(app =>
{
app.UseWebSockets();
app.Run(async c =>
{
var upgradeFeature = c.Features.Get<IHttpUpgradeFeature>();
// Feature needs to exist for SignalR to verify that the server supports WebSockets
Assert.NotNull(upgradeFeature);
Assert.False(upgradeFeature.IsUpgradableRequest);
await Assert.ThrowsAsync<NotSupportedException>(() => upgradeFeature.UpgradeAsync());
var webSocketFeature = c.Features.Get<IHttpWebSocketFeature>();
Assert.NotNull(webSocketFeature);
Assert.False(webSocketFeature.IsWebSocketRequest);
await c.Response.WriteAsync("test");
});
})))
{
var client = testServer.CreateClient();
var actual = await client.GetStringAsync("http://localhost:12345/");
Assert.Equal("test", actual);
}
}
}
|