File: ServiceEndpointWatcher.cs
Web Access
Project: src\src\Microsoft.Extensions.ServiceDiscovery\Microsoft.Extensions.ServiceDiscovery.csproj (Microsoft.Extensions.ServiceDiscovery)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.ExceptionServices;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.ServiceDiscovery.Internal;
 
namespace Microsoft.Extensions.ServiceDiscovery;
 
/// <summary>
/// Watches for updates to the collection of resolved endpoints for a specified service.
/// </summary>
internal sealed partial class ServiceEndpointWatcher(
    IServiceEndpointProvider[] providers,
    ILogger logger,
    string serviceName,
    TimeProvider timeProvider,
    IOptions<ServiceDiscoveryOptions> options) : IAsyncDisposable
{
    private static readonly TimerCallback s_pollingAction = static state => _ = ((ServiceEndpointWatcher)state!).RefreshAsync(force: true);
 
    private readonly object _lock = new();
    private readonly ILogger _logger = logger;
    private readonly TimeProvider _timeProvider = timeProvider;
    private readonly ServiceDiscoveryOptions _options = options.Value;
    private readonly IServiceEndpointProvider[] _providers = providers;
    private readonly CancellationTokenSource _disposalCancellation = new();
    private ITimer? _pollingTimer;
    private ServiceEndpointSource? _cachedEndpoints;
    private Task _refreshTask = Task.CompletedTask;
    private volatile CacheStatus _cacheState;
    private IDisposable? _changeTokenRegistration;
 
    /// <summary>
    /// Gets the service name.
    /// </summary>
    public string ServiceName { get; } = serviceName;
 
    /// <summary>
    /// Gets or sets the action called when endpoints are updated.
    /// </summary>
    public Action<ServiceEndpointResolverResult>? OnEndpointsUpdated { get; set; }
 
    /// <summary>
    /// Starts the endpoint resolver.
    /// </summary>
    public void Start()
    {
        ThrowIfNoProviders();
        _ = RefreshAsync(force: false);
    }
 
    /// <summary>
    /// Returns a collection of resolved endpoints for the service.
    /// </summary>
    /// <param name="cancellationToken">A <see cref="CancellationToken"/>.</param>
    /// <returns>A collection of resolved endpoints for the service.</returns>
    public ValueTask<ServiceEndpointSource> GetEndpointsAsync(CancellationToken cancellationToken = default)
    {
        ThrowIfNoProviders();
        ObjectDisposedException.ThrowIf(_disposalCancellation.IsCancellationRequested, this);
        cancellationToken.ThrowIfCancellationRequested();
 
        // If the cache is valid, return the cached value.
        if (_cachedEndpoints is { ChangeToken.HasChanged: false } cached)
        {
            return new ValueTask<ServiceEndpointSource>(cached);
        }
 
        // Otherwise, ensure the cache is being refreshed
        // Wait for the cache refresh to complete and return the cached value.
        return GetEndpointsInternal(cancellationToken);
 
        async ValueTask<ServiceEndpointSource> GetEndpointsInternal(CancellationToken cancellationToken)
        {
            ServiceEndpointSource? result;
            var disposalToken = _disposalCancellation.Token;
            do
            {
                disposalToken.ThrowIfCancellationRequested();
                cancellationToken.ThrowIfCancellationRequested();
                await RefreshAsync(force: false).WaitAsync(cancellationToken).ConfigureAwait(false);
                result = _cachedEndpoints;
            } while (result is null);
 
            return result;
        }
    }
 
    // Ensures that there is a refresh operation running, if needed, and returns the task which represents the completion of the operation
    private Task RefreshAsync(bool force)
    {
        lock (_lock)
        {
            // If the cache is invalid or needs invalidation, refresh the cache.
            if (!_disposalCancellation.IsCancellationRequested && _refreshTask.IsCompleted && (_cacheState == CacheStatus.Invalid || _cachedEndpoints is null or { ChangeToken.HasChanged: true } || force))
            {
                // Indicate that the cache is being updated and start a new refresh task.
                _cacheState = CacheStatus.Refreshing;
 
                // Don't capture the current ExecutionContext and its AsyncLocals onto the callback.
                var restoreFlow = false;
                try
                {
                    if (!ExecutionContext.IsFlowSuppressed())
                    {
                        ExecutionContext.SuppressFlow();
                        restoreFlow = true;
                    }
 
                    _refreshTask = RefreshAsyncInternal();
                }
                finally
                {
                    if (restoreFlow)
                    {
                        ExecutionContext.RestoreFlow();
                    }
                }
            }
 
            return _refreshTask;
        }
    }
 
    private async Task RefreshAsyncInternal()
    {
        await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
        var cancellationToken = _disposalCancellation.Token;
        Exception? error = null;
        ServiceEndpointSource? newEndpoints = null;
        CacheStatus newCacheState;
        try
        {
            lock (_lock)
            {
                // Dispose the existing change token registration, if any.
                _changeTokenRegistration?.Dispose();
                _changeTokenRegistration = null;
            }
 
            Log.ResolvingEndpoints(_logger, ServiceName);
            var builder = new ServiceEndpointBuilder();
            foreach (var provider in _providers)
            {
                cancellationToken.ThrowIfCancellationRequested();
                await provider.PopulateAsync(builder, cancellationToken).ConfigureAwait(false);
            }
 
            var endpoints = builder.Build();
            newCacheState = CacheStatus.Valid;
 
            lock (_lock)
            {
                // Check if we need to poll for updates or if we can register for change notification callbacks.
                if (endpoints.ChangeToken.ActiveChangeCallbacks)
                {
                    // Initiate a background refresh when the change token fires.
                    _changeTokenRegistration = endpoints.ChangeToken.RegisterChangeCallback(static state => _ = ((ServiceEndpointWatcher)state!).RefreshAsync(force: false), this);
 
                    // Dispose the existing timer, if any, since we are reliant on change tokens for updates.
                    _pollingTimer?.Dispose();
                    _pollingTimer = null;
                }
                else
                {
                    SchedulePollingTimer();
                }
 
                // The cache is valid
                newEndpoints = endpoints;
                newCacheState = CacheStatus.Valid;
            }
        }
        catch (Exception exception)
        {
            error = exception;
            newCacheState = CacheStatus.Invalid;
            SchedulePollingTimer();
        }
 
        // If there was an error, the cache must be invalid.
        Debug.Assert(error is null || newCacheState is CacheStatus.Invalid);
 
        // To ensure coherence between the value returned by calls made to GetEndpointsAsync and value passed to the callback,
        // we invalidate the cache before invoking the callback. This causes callers to wait on the refresh task
        // before receiving the updated value. An alternative approach is to lock access to _cachedEndpoints, but
        // that will have more overhead in the common case.
        if (newCacheState is CacheStatus.Valid)
        {
            Interlocked.Exchange(ref _cachedEndpoints, null);
        }
 
        if (OnEndpointsUpdated is { } callback)
        {
            try
            {
                callback(new(newEndpoints, error));
            }
            catch (Exception exception)
            {
                _logger.LogError(exception, "Error notifying observers of updated endpoints.");
            }
        }
 
        lock (_lock)
        {
            if (newCacheState is CacheStatus.Valid)
            {
                Debug.Assert(newEndpoints is not null);
                _cachedEndpoints = newEndpoints;
            }
 
            _cacheState = newCacheState;
        }
 
        if (error is not null)
        {
            Log.ResolutionFailed(_logger, error, ServiceName);
            ExceptionDispatchInfo.Throw(error);
        }
        else if (newEndpoints is not null)
        {
            Log.ResolutionSucceeded(_logger, ServiceName, newEndpoints);
        }
    }
 
    private void SchedulePollingTimer()
    {
        lock (_lock)
        {
            if (_disposalCancellation.IsCancellationRequested)
            {
                _pollingTimer?.Dispose();
                _pollingTimer = null;
                return;
            }
 
            if (_pollingTimer is null)
            {
                _pollingTimer = _timeProvider.CreateTimer(s_pollingAction, this, _options.RefreshPeriod, TimeSpan.Zero);
            }
            else
            {
                _pollingTimer.Change(_options.RefreshPeriod, TimeSpan.Zero);
            }
        }
    }
 
    /// <inheritdoc/>
    public async ValueTask DisposeAsync()
    {
        try
        {
            _disposalCancellation.Cancel();
        }
        catch (Exception exception)
        {
            _logger.LogError(exception, "Error cancelling disposal cancellation token.");
        }
 
        lock (_lock)
        {
            _changeTokenRegistration?.Dispose();
            _changeTokenRegistration = null;
 
            _pollingTimer?.Dispose();
            _pollingTimer = null;
        }
 
        if (_refreshTask is { } task)
        {
            await task.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
        }
 
        foreach (var provider in _providers)
        {
            await provider.DisposeAsync().ConfigureAwait(false);
        }
    }
 
    private enum CacheStatus
    {
        Invalid,
        Refreshing,
        Valid
    }
 
    private void ThrowIfNoProviders()
    {
        if (_providers.Length == 0)
        {
            ThrowNoProvidersConfigured();
        }
    }
 
    [DoesNotReturn]
    private static void ThrowNoProvidersConfigured() => throw new InvalidOperationException("No service endpoint providers are configured.");
}