File: ServiceEndpointResolverTests.cs
Web Access
Project: src\tests\Microsoft.Extensions.ServiceDiscovery.Tests\Microsoft.Extensions.ServiceDiscovery.Tests.csproj (Microsoft.Extensions.ServiceDiscovery.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Threading.Channels;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Primitives;
using Microsoft.Extensions.ServiceDiscovery.Http;
using Microsoft.Extensions.ServiceDiscovery.Internal;
using Xunit;
 
namespace Microsoft.Extensions.ServiceDiscovery.Tests;
 
/// <summary>
/// Tests for <see cref="ServiceEndpointWatcherFactory"/> and <see cref="ServiceEndpointWatcher"/>.
/// </summary>
public class ServiceEndpointResolverTests
{
    [Fact]
    public void ResolveServiceEndpoint_NoProvidersConfigured_Throws()
    {
        var services = new ServiceCollection()
            .AddServiceDiscoveryCore()
            .BuildServiceProvider();
        var resolverFactory = services.GetRequiredService<ServiceEndpointWatcherFactory>();
        var exception = Assert.Throws<InvalidOperationException>(() => resolverFactory.CreateWatcher("https://basket"));
        Assert.Equal("No provider which supports the provided service name, 'https://basket', has been configured.", exception.Message);
    }
 
    [Fact]
    public async Task ServiceEndpointResolver_NoProvidersConfigured_Throws()
    {
        var services = new ServiceCollection()
            .AddServiceDiscoveryCore()
            .BuildServiceProvider();
        var watcher = new ServiceEndpointWatcher([], NullLogger.Instance, "foo", TimeProvider.System, Options.Options.Create(new ServiceDiscoveryOptions()));
        var exception = Assert.Throws<InvalidOperationException>(watcher.Start);
        Assert.Equal("No service endpoint providers are configured.", exception.Message);
        exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await watcher.GetEndpointsAsync());
        Assert.Equal("No service endpoint providers are configured.", exception.Message);
    }
 
    [Fact]
    public void ResolveServiceEndpoint_NullServiceName_Throws()
    {
        var services = new ServiceCollection()
            .AddServiceDiscoveryCore()
            .BuildServiceProvider();
        var resolverFactory = services.GetRequiredService<ServiceEndpointWatcherFactory>();
        Assert.Throws<ArgumentNullException>(() => resolverFactory.CreateWatcher(null!));
    }
 
    [Fact]
    public async Task AddServiceDiscovery_NoProviders_Throws()
    {
        var serviceCollection = new ServiceCollection();
        serviceCollection.AddHttpClient("foo", c => c.BaseAddress = new("http://foo")).AddServiceDiscovery();
        var services = serviceCollection.BuildServiceProvider();
        var client = services.GetRequiredService<IHttpClientFactory>().CreateClient("foo");
        var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await client.GetStringAsync("/"));
        Assert.Equal("No provider which supports the provided service name, 'http://foo', has been configured.", exception.Message);
    }
 
    private sealed class FakeEndpointResolverProvider(Func<ServiceEndpointQuery, (bool Result, IServiceEndpointProvider? Resolver)> createResolverDelegate) : IServiceEndpointProviderFactory
    {
        public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? resolver)
        {
            bool result;
            (result, resolver) = createResolverDelegate(query);
            return result;
        }
    }
 
    private sealed class FakeEndpointResolver(Func<IServiceEndpointBuilder, CancellationToken, ValueTask> resolveAsync, Func<ValueTask> disposeAsync) : IServiceEndpointProvider
    {
        public ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationToken cancellationToken) => resolveAsync(endpoints, cancellationToken);
        public ValueTask DisposeAsync() => disposeAsync();
    }
 
    [Fact]
    public async Task ResolveServiceEndpoint()
    {
        var cts = new[] { new CancellationTokenSource() };
        var innerResolver = new FakeEndpointResolver(
            resolveAsync: (collection, ct) =>
            {
                collection.AddChangeToken(new CancellationChangeToken(cts[0].Token));
                collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080)));
 
                if (cts[0].Token.IsCancellationRequested)
                {
                    cts[0] = new();
                    collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888)));
                }
                return default;
            },
            disposeAsync: () => default);
        var resolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver));
        var services = new ServiceCollection()
            .AddSingleton<IServiceEndpointProviderFactory>(resolverProvider)
            .AddServiceDiscoveryCore()
            .BuildServiceProvider();
        var watcherFactory = services.GetRequiredService<ServiceEndpointWatcherFactory>();
 
        ServiceEndpointWatcher watcher;
        await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false))
        {
            Assert.NotNull(watcher);
            var initialResult = await watcher.GetEndpointsAsync(CancellationToken.None);
            Assert.NotNull(initialResult);
            var sep = Assert.Single(initialResult.Endpoints);
            var ip = Assert.IsType<IPEndPoint>(sep.EndPoint);
            Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address);
            Assert.Equal(8080, ip.Port);
 
            var tcs = new TaskCompletionSource<ServiceEndpointResolverResult>();
            watcher.OnEndpointsUpdated = tcs.SetResult;
            Assert.False(tcs.Task.IsCompleted);
 
            cts[0].Cancel();
            var resolverResult = await tcs.Task;
            Assert.NotNull(resolverResult);
            Assert.True(resolverResult.ResolvedSuccessfully);
            Assert.Equal(2, resolverResult.EndpointSource.Endpoints.Count);
            var endpoints = resolverResult.EndpointSource.Endpoints.Select(ep => ep.EndPoint).OfType<IPEndPoint>().ToList();
            endpoints.Sort((l, r) => l.Port - r.Port);
            Assert.Equal(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080), endpoints[0]);
            Assert.Equal(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888), endpoints[1]);
        }
    }
 
    [Fact]
    public async Task ResolveServiceEndpointOneShot()
    {
        var cts = new[] { new CancellationTokenSource() };
        var innerResolver = new FakeEndpointResolver(
            resolveAsync: (collection, ct) =>
            {
                collection.AddChangeToken(new CancellationChangeToken(cts[0].Token));
                collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080)));
 
                if (cts[0].Token.IsCancellationRequested)
                {
                    cts[0] = new();
                    collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888)));
                }
                return default;
            },
            disposeAsync: () => default);
        var resolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver));
        var services = new ServiceCollection()
            .AddSingleton<IServiceEndpointProviderFactory>(resolverProvider)
            .AddServiceDiscoveryCore()
            .BuildServiceProvider();
        var resolver = services.GetRequiredService<ServiceEndpointResolver>();
 
        Assert.NotNull(resolver);
        var initialResult = await resolver.GetEndpointsAsync("http://basket", CancellationToken.None);
        Assert.NotNull(initialResult);
        var sep = Assert.Single(initialResult.Endpoints);
        var ip = Assert.IsType<IPEndPoint>(sep.EndPoint);
        Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address);
        Assert.Equal(8080, ip.Port);
 
        await services.DisposeAsync();
    }
 
    [Fact]
    public async Task ResolveHttpServiceEndpointOneShot()
    {
        var cts = new[] { new CancellationTokenSource() };
        var innerResolver = new FakeEndpointResolver(
            resolveAsync: (collection, ct) =>
            {
                collection.AddChangeToken(new CancellationChangeToken(cts[0].Token));
                collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080)));
 
                if (cts[0].Token.IsCancellationRequested)
                {
                    cts[0] = new();
                    collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.2"), 8888)));
                }
                return default;
            },
            disposeAsync: () => default);
        var fakeResolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver));
        var services = new ServiceCollection()
            .AddSingleton<IServiceEndpointProviderFactory>(fakeResolverProvider)
            .AddServiceDiscoveryCore()
            .BuildServiceProvider();
        var resolverProvider = services.GetRequiredService<ServiceEndpointWatcherFactory>();
        await using var resolver = new HttpServiceEndpointResolver(resolverProvider, services, TimeProvider.System);
 
        Assert.NotNull(resolver);
        var httpRequest = new HttpRequestMessage(HttpMethod.Get, "http://basket");
        var endpoint = await resolver.GetEndpointAsync(httpRequest, CancellationToken.None);
        Assert.NotNull(endpoint);
        var ip = Assert.IsType<IPEndPoint>(endpoint.EndPoint);
        Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address);
        Assert.Equal(8080, ip.Port);
 
        await services.DisposeAsync();
    }
 
    [Fact]
    public async Task ResolveServiceEndpoint_ThrowOnReload()
    {
        var sem = new SemaphoreSlim(0);
        var cts = new[] { new CancellationTokenSource() };
        var throwOnNextResolve = new[] { false };
        var innerResolver = new FakeEndpointResolver(
            resolveAsync: async (collection, ct) =>
            {
                await sem.WaitAsync(ct).ConfigureAwait(false);
                if (cts[0].IsCancellationRequested)
                {
                    // Always be sure to have a fresh token.
                    cts[0] = new();
                }
 
                if (throwOnNextResolve[0])
                {
                    throwOnNextResolve[0] = false;
                    throw new InvalidOperationException("throwing");
                }
 
                collection.AddChangeToken(new CancellationChangeToken(cts[0].Token));
                collection.Endpoints.Add(ServiceEndpoint.Create(new IPEndPoint(IPAddress.Parse("127.1.1.1"), 8080)));
            },
            disposeAsync: () => default);
        var resolverProvider = new FakeEndpointResolverProvider(name => (true, innerResolver));
        var services = new ServiceCollection()
            .AddSingleton<IServiceEndpointProviderFactory>(resolverProvider)
            .AddServiceDiscoveryCore()
            .BuildServiceProvider();
        var watcherFactory = services.GetRequiredService<ServiceEndpointWatcherFactory>();
 
        ServiceEndpointWatcher watcher;
        await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false))
        {
            Assert.NotNull(watcher);
            var initialEndpointsTask = watcher.GetEndpointsAsync(CancellationToken.None);
            sem.Release(1);
            var initialEndpoints = await initialEndpointsTask;
            Assert.NotNull(initialEndpoints);
            Assert.Single(initialEndpoints.Endpoints);
 
            // Tell the resolver to throw on the next resolve call and then trigger a reload.
            throwOnNextResolve[0] = true;
            cts[0].Cancel();
 
            var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () =>
            {
                var resolveTask = watcher.GetEndpointsAsync(CancellationToken.None);
                sem.Release(1);
                await resolveTask.ConfigureAwait(false);
            });
 
            Assert.Equal("throwing", exception.Message);
 
            var channel = Channel.CreateUnbounded<ServiceEndpointResolverResult>();
            watcher.OnEndpointsUpdated = result => channel.Writer.TryWrite(result);
 
            do
            {
                cts[0].Cancel();
                sem.Release(1);
                var resolveTask = watcher.GetEndpointsAsync(CancellationToken.None);
                await resolveTask;
                var next = await channel.Reader.ReadAsync(CancellationToken.None);
                if (next.ResolvedSuccessfully)
                {
                    break;
                }
            } while (true);
 
            var task = watcher.GetEndpointsAsync(CancellationToken.None);
            sem.Release(1);
            var result = await task;
            Assert.NotSame(initialEndpoints, result);
            var sep = Assert.Single(result.Endpoints);
            var ip = Assert.IsType<IPEndPoint>(sep.EndPoint);
            Assert.Equal(IPAddress.Parse("127.1.1.1"), ip.Address);
            Assert.Equal(8080, ip.Port);
        }
    }
}