File: src\Servers\Kestrel\test\FunctionalTests\UnixDomainSocketsTests.cs
Web Access
Project: src\src\Servers\Kestrel\test\Sockets.FunctionalTests\Sockets.FunctionalTests.csproj (Sockets.FunctionalTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Buffers;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.FunctionalTests;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
 
#if SOCKETS
namespace Microsoft.AspNetCore.Server.Kestrel.Sockets.FunctionalTests;
#else
namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests;
#endif
 
public class UnixDomainSocketsTest : TestApplicationErrorLoggerLoggedTest
{
    [MinimumOSVersion(OperatingSystems.Windows, WindowsVersions.Win10_RS4)]
    [ConditionalFact]
    [CollectDump]
    public async Task TestUnixDomainSocket()
    {
        var path = Path.GetTempFileName();
 
        Delete(path);
 
        try
        {
            var serverConnectionCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
 
            async Task EchoServer(ConnectionContext connection)
            {
                // For graceful shutdown
                var notificationFeature = connection.Features.Get<IConnectionLifetimeNotificationFeature>();
 
                try
                {
                    while (true)
                    {
                        var result = await connection.Transport.Input.ReadAsync(notificationFeature.ConnectionClosedRequested);
 
                        if (result.IsCompleted)
                        {
                            Logger.LogDebug("Application receive loop ending for connection {connectionId}.", connection.ConnectionId);
                            break;
                        }
 
                        await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());
 
                        connection.Transport.Input.AdvanceTo(result.Buffer.End);
                    }
                }
                catch (OperationCanceledException)
                {
                    Logger.LogDebug("Graceful shutdown triggered for {connectionId}.", connection.ConnectionId);
                }
                finally
                {
                    serverConnectionCompletedTcs.TrySetResult();
                }
            }
 
            var hostBuilder = TransportSelector.GetHostBuilder()
                .ConfigureWebHost(webHostBuilder =>
                {
                    webHostBuilder
                        .UseKestrel(o =>
                        {
                            o.ListenUnixSocket(path, builder =>
                            {
                                builder.Run(EchoServer);
                            });
                        })
                        .Configure(c => { });
                })
                .ConfigureServices(AddTestLogging);
 
            using (var host = hostBuilder.Build())
            {
                await host.StartAsync().DefaultTimeout();
 
                using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
                {
                    await socket.ConnectAsync(new UnixDomainSocketEndPoint(path)).DefaultTimeout();
 
                    var data = Encoding.ASCII.GetBytes("Hello World");
                    await socket.SendAsync(data, SocketFlags.None).DefaultTimeout();
 
                    var buffer = new byte[data.Length];
                    var read = 0;
                    while (read < data.Length)
                    {
                        var bytesReceived = await socket.ReceiveAsync(buffer.AsMemory(read, buffer.Length - read), SocketFlags.None).DefaultTimeout();
                        read += bytesReceived;
                        if (bytesReceived <= 0)
                        {
                            break;
                        }
                    }
 
                    Assert.Equal(data, buffer);
                }
 
                // Wait for the server to complete the loop because of the FIN
                await serverConnectionCompletedTcs.Task.DefaultTimeout();
 
                await host.StopAsync().DefaultTimeout();
            }
        }
        finally
        {
            Delete(path);
        }
    }
 
    [MinimumOSVersion(OperatingSystems.Windows, WindowsVersions.Win10_RS4)]
    [ConditionalFact]
    [CollectDump]
    public async Task TestUnixDomainSocketWithUrl()
    {
        var path = Path.GetTempFileName();
        var url = $"http://unix:/{path}";
 
        Delete(path);
 
        try
        {
            var hostBuilder = TransportSelector.GetHostBuilder()
                .ConfigureWebHost(webHostBuilder =>
                {
                    webHostBuilder
                        .UseUrls(url)
                        .UseKestrel()
                        .Configure(app =>
                        {
                            app.Run(async context =>
                            {
                                await context.Response.WriteAsync("Hello World");
                            });
                        });
                })
                .ConfigureServices(AddTestLogging);
 
            using (var host = hostBuilder.Build())
            {
                await host.StartAsync().DefaultTimeout();
 
                // https://github.com/dotnet/corefx/issues/5999
                // .NET Core HttpClient does not support unix sockets, it's difficult to parse raw response data. below is a little hacky way.
                using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
                {
                    await socket.ConnectAsync(new UnixDomainSocketEndPoint(path)).DefaultTimeout();
 
                    var httpRequest = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\nConnection: close\r\n\r\n");
                    await socket.SendAsync(httpRequest, SocketFlags.None).DefaultTimeout();
 
                    var readBuffer = new byte[512];
                    var read = 0;
                    while (true)
                    {
                        var bytesReceived = await socket.ReceiveAsync(readBuffer.AsMemory(read), SocketFlags.None).DefaultTimeout();
                        read += bytesReceived;
                        if (bytesReceived <= 0)
                        {
                            break;
                        }
                    }
 
                    var httpResponse = Encoding.ASCII.GetString(readBuffer, 0, read);
                    int httpStatusStart = httpResponse.IndexOf(' ') + 1;
                    Assert.False(httpStatusStart == 0, $"Space not found in '{httpResponse}'.");
                    int httpStatusEnd = httpResponse.IndexOf(' ', httpStatusStart);
                    Assert.False(httpStatusEnd == -1, $"Second space not found in '{httpResponse}'.");
 
                    var httpStatus = int.Parse(httpResponse.Substring(httpStatusStart, httpStatusEnd - httpStatusStart), CultureInfo.InvariantCulture);
                    Assert.Equal(StatusCodes.Status200OK, httpStatus);
 
                }
                await host.StopAsync().DefaultTimeout();
            }
        }
        finally
        {
            Delete(path);
        }
    }
 
    private static void Delete(string path)
    {
        try
        {
            File.Delete(path);
        }
        catch (FileNotFoundException)
        {
 
        }
    }
}