File: DnsSrvServiceEndpointResolverTests.cs
Web Access
Project: src\tests\Microsoft.Extensions.ServiceDiscovery.Dns.Tests\Microsoft.Extensions.ServiceDiscovery.Dns.Tests.csproj (Microsoft.Extensions.ServiceDiscovery.Dns.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Net;
using DnsClient;
using DnsClient.Protocol;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Configuration.Memory;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.ServiceDiscovery.Internal;
using Xunit;
 
namespace Microsoft.Extensions.ServiceDiscovery.Dns.Tests;
 
/// <summary>
/// Tests for <see cref="DnsServiceEndpointProviderBase"/> and <see cref="DnsSrvServiceEndpointProviderFactory"/>.
/// These also cover <see cref="ServiceEndpointWatcher"/> and <see cref="ServiceEndpointWatcherFactory"/> by extension.
/// </summary>
public class DnsSrvServiceEndpointResolverTests
{
    private sealed class FakeDnsClient : IDnsQuery
    {
        public Func<string, QueryType, QueryClass, CancellationToken, Task<IDnsQueryResponse>>? QueryAsyncFunc { get; set; }
 
        public IDnsQueryResponse Query(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException();
        public IDnsQueryResponse Query(DnsQuestion question) => throw new NotImplementedException();
        public IDnsQueryResponse Query(DnsQuestion question, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryAsync(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default)
            => QueryAsyncFunc!(query, queryType, queryClass, cancellationToken);
        public Task<IDnsQueryResponse> QueryAsync(DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryAsync(DnsQuestion question, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public IDnsQueryResponse QueryCache(DnsQuestion question) => throw new NotImplementedException();
        public IDnsQueryResponse QueryCache(string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException();
        public IDnsQueryResponse QueryReverse(IPAddress ipAddress) => throw new NotImplementedException();
        public IDnsQueryResponse QueryReverse(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryReverseAsync(IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryReverseAsync(IPAddress ipAddress, DnsQueryAndServerOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServer(IReadOnlyCollection<NameServer> servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServer(IReadOnlyCollection<NameServer> servers, DnsQuestion question) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServer(IReadOnlyCollection<NameServer> servers, DnsQuestion question, DnsQueryOptions queryOptions) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServer(IReadOnlyCollection<IPEndPoint> servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServer(IReadOnlyCollection<IPAddress> servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerAsync(IReadOnlyCollection<NameServer> servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerAsync(IReadOnlyCollection<NameServer> servers, DnsQuestion question, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerAsync(IReadOnlyCollection<NameServer> servers, DnsQuestion question, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerAsync(IReadOnlyCollection<IPAddress> servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerAsync(IReadOnlyCollection<IPEndPoint> servers, string query, QueryType queryType, QueryClass queryClass = QueryClass.IN, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection<IPAddress> servers, IPAddress ipAddress) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection<IPEndPoint> servers, IPAddress ipAddress) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection<NameServer> servers, IPAddress ipAddress) => throw new NotImplementedException();
        public IDnsQueryResponse QueryServerReverse(IReadOnlyCollection<NameServer> servers, IPAddress ipAddress, DnsQueryOptions queryOptions) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerReverseAsync(IReadOnlyCollection<IPAddress> servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerReverseAsync(IReadOnlyCollection<IPEndPoint> servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerReverseAsync(IReadOnlyCollection<NameServer> servers, IPAddress ipAddress, CancellationToken cancellationToken = default) => throw new NotImplementedException();
        public Task<IDnsQueryResponse> QueryServerReverseAsync(IReadOnlyCollection<NameServer> servers, IPAddress ipAddress, DnsQueryOptions queryOptions, CancellationToken cancellationToken = default) => throw new NotImplementedException();
    }
 
    private sealed class FakeDnsQueryResponse : IDnsQueryResponse
    {
        public IReadOnlyList<DnsQuestion>? Questions { get; set; }
        public IReadOnlyList<DnsResourceRecord>? Additionals { get; set; }
        public IEnumerable<DnsResourceRecord>? AllRecords { get; set; }
        public IReadOnlyList<DnsResourceRecord>? Answers { get; set; }
        public IReadOnlyList<DnsResourceRecord>? Authorities { get; set; }
        public string? AuditTrail { get; set; }
        public string? ErrorMessage { get; set; }
        public bool HasError { get; set; }
        public DnsResponseHeader? Header { get; set; }
        public int MessageSize { get; set; }
        public NameServer? NameServer { get; set; }
        public DnsQuerySettings? Settings { get; set; }
    }
 
    [Fact]
    public async Task ResolveServiceEndpoint_DnsSrv()
    {
        var dnsClientMock = new FakeDnsClient
        {
            QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) =>
            {
                var response = new FakeDnsQueryResponse
                {
                    Answers = new List<DnsResourceRecord>
                    {
                        new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")),
                        new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")),
                        new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c"))
                    },
                    Additionals = new List<DnsResourceRecord>
                    {
                        new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")),
                        new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback),
                        new CNameRecord(new ResourceRecordInfo("srv-c", ResourceRecordType.AAAA, queryClass, 64, 0), DnsString.Parse("remotehost")),
                        new TxtRecord(new ResourceRecordInfo("srv-a", ResourceRecordType.TXT, queryClass, 64, 0), ["some txt values"], ["some txt utf8 values"])
                    }
                };
 
                return Task.FromResult<IDnsQueryResponse>(response);
            }
        };
        var services = new ServiceCollection()
            .AddSingleton<IDnsQuery>(dnsClientMock)
            .AddServiceDiscoveryCore()
            .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns")
            .BuildServiceProvider();
        var watcherFactory = services.GetRequiredService<ServiceEndpointWatcherFactory>();
        ServiceEndpointWatcher watcher;
        await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false))
        {
            Assert.NotNull(watcher);
            var tcs = new TaskCompletionSource<ServiceEndpointResolverResult>();
            watcher.OnEndpointsUpdated = tcs.SetResult;
            watcher.Start();
            var initialResult = await tcs.Task.ConfigureAwait(false);
            Assert.NotNull(initialResult);
            Assert.True(initialResult.ResolvedSuccessfully);
            Assert.Equal(3, initialResult.EndpointSource.Endpoints.Count);
            var eps = initialResult.EndpointSource.Endpoints;
            Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint);
            Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint);
            Assert.Equal(new DnsEndPoint("remotehost", 7777), eps[2].EndPoint);
 
            Assert.All(initialResult.EndpointSource.Endpoints, ep =>
            {
                var hostNameFeature = ep.Features.Get<IHostNameFeature>();
                Assert.Null(hostNameFeature);
            });
        }
    }
 
    /// <summary>
    /// Tests that when there are multiple resolvers registered, they are consulted in registration order and each provider only adds endpoints if the providers before it did not.
    /// </summary>
    [InlineData(true)]
    [InlineData(false)]
    [Theory]
    public async Task ResolveServiceEndpoint_DnsSrv_MultipleProviders_PreventMixing(bool dnsFirst)
    {
        var dnsClientMock = new FakeDnsClient
        {
            QueryAsyncFunc = (query, queryType, queryClass, cancellationToken) =>
            {
                var response = new FakeDnsQueryResponse
                {
                    Answers = new List<DnsResourceRecord>
                    {
                        new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 66, 8888, DnsString.Parse("srv-a")),
                        new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 9999, DnsString.Parse("srv-b")),
                        new SrvRecord(new ResourceRecordInfo(query, ResourceRecordType.SRV, queryClass, 123, 0), 99, 62, 7777, DnsString.Parse("srv-c"))
                    },
                    Additionals = new List<DnsResourceRecord>
                    {
                        new ARecord(new ResourceRecordInfo("srv-a", ResourceRecordType.A, queryClass, 64, 0), IPAddress.Parse("10.10.10.10")),
                        new ARecord(new ResourceRecordInfo("srv-b", ResourceRecordType.AAAA, queryClass, 64, 0), IPAddress.IPv6Loopback),
                        new CNameRecord(new ResourceRecordInfo("srv-c", ResourceRecordType.AAAA, queryClass, 64, 0), DnsString.Parse("remotehost")),
                        new TxtRecord(new ResourceRecordInfo("srv-a", ResourceRecordType.TXT, queryClass, 64, 0), ["some txt values"], ["some txt utf8 values"])
                    }
                };
 
                return Task.FromResult<IDnsQueryResponse>(response);
            }
        };
        var configSource = new MemoryConfigurationSource
        {
            InitialData = new Dictionary<string, string?>
            {
                ["services:basket:http:0"] = "localhost:8080",
                ["services:basket:http:1"] = "remotehost:9090",
            }
        };
        var config = new ConfigurationBuilder().Add(configSource);
        var serviceCollection = new ServiceCollection()
            .AddSingleton<IDnsQuery>(dnsClientMock)
            .AddSingleton<IConfiguration>(config.Build())
            .AddServiceDiscoveryCore();
        if (dnsFirst)
        {
            serviceCollection
            .AddDnsSrvServiceEndpointProvider(options =>
            {
                options.QuerySuffix = ".ns";
                options.ShouldApplyHostNameMetadata = _ => true;
            })
            .AddConfigurationServiceEndpointProvider();
        }
        else
        {
            serviceCollection
            .AddConfigurationServiceEndpointProvider()
            .AddDnsSrvServiceEndpointProvider(options => options.QuerySuffix = ".ns");
        };
        var services = serviceCollection.BuildServiceProvider();
        var watcherFactory = services.GetRequiredService<ServiceEndpointWatcherFactory>();
        ServiceEndpointWatcher watcher;
        await using ((watcher = watcherFactory.CreateWatcher("http://basket")).ConfigureAwait(false))
        {
            Assert.NotNull(watcher);
            var tcs = new TaskCompletionSource<ServiceEndpointResolverResult>();
            watcher.OnEndpointsUpdated = tcs.SetResult;
            watcher.Start();
            var initialResult = await tcs.Task.ConfigureAwait(false);
            Assert.NotNull(initialResult);
            Assert.Null(initialResult.Exception);
            Assert.True(initialResult.ResolvedSuccessfully);
 
            if (dnsFirst)
            {
                // We expect only the results from the DNS provider.
                Assert.Equal(3, initialResult.EndpointSource.Endpoints.Count);
                var eps = initialResult.EndpointSource.Endpoints;
                Assert.Equal(new IPEndPoint(IPAddress.Parse("10.10.10.10"), 8888), eps[0].EndPoint);
                Assert.Equal(new IPEndPoint(IPAddress.IPv6Loopback, 9999), eps[1].EndPoint);
                Assert.Equal(new DnsEndPoint("remotehost", 7777), eps[2].EndPoint);
 
                Assert.All(initialResult.EndpointSource.Endpoints, ep =>
                {
                    var hostNameFeature = ep.Features.Get<IHostNameFeature>();
                    Assert.NotNull(hostNameFeature);
                    Assert.Equal("basket", hostNameFeature.HostName);
                });
            }
            else
            {
                // We expect only the results from the Configuration provider.
                Assert.Equal(2, initialResult.EndpointSource.Endpoints.Count);
                Assert.Equal(new DnsEndPoint("localhost", 8080), initialResult.EndpointSource.Endpoints[0].EndPoint);
                Assert.Equal(new DnsEndPoint("remotehost", 9090), initialResult.EndpointSource.Endpoints[1].EndPoint);
 
                Assert.All(initialResult.EndpointSource.Endpoints, ep =>
                {
                    var hostNameFeature = ep.Features.Get<IHostNameFeature>();
                    Assert.Null(hostNameFeature);
                });
            }
        }
    }
}