File: Resilience\ResilienceHandler.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.Http.Resilience\Microsoft.Extensions.Http.Resilience.csproj (Microsoft.Extensions.Http.Resilience)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Http.Diagnostics;
using Microsoft.Extensions.Http.Resilience.Internal;
using Microsoft.Shared.DiagnosticIds;
using Microsoft.Shared.Diagnostics;
using Polly;
 
namespace Microsoft.Extensions.Http.Resilience;
 
/// <summary>
/// Base class for resilience handler, i.e. handlers that use resilience strategies to send the requests.
/// </summary>
[Experimental(diagnosticId: DiagnosticIds.Experiments.Resilience, UrlFormat = DiagnosticIds.UrlFormat)]
public class ResilienceHandler : DelegatingHandler
{
    private readonly Func<HttpRequestMessage, ResiliencePipeline<HttpResponseMessage>> _pipelineProvider;
 
    /// <summary>
    /// Initializes a new instance of the <see cref="ResilienceHandler"/> class.
    /// </summary>
    /// <param name="pipelineProvider">The pipeline provider that supplies pipelines in response to an http message.</param>
    /// <exception cref="ArgumentNullException">If <paramref name="pipelineProvider"/> is <see langword="null"/>.</exception>
    public ResilienceHandler(Func<HttpRequestMessage, ResiliencePipeline<HttpResponseMessage>> pipelineProvider)
    {
        _pipelineProvider = Throw.IfNull(pipelineProvider);
    }
 
    /// <summary>
    /// Initializes a new instance of the <see cref="ResilienceHandler"/> class.
    /// </summary>
    /// <param name="pipeline">The pipeline to use for the message.</param>
    /// <exception cref="ArgumentNullException">If <paramref name="pipeline"/> is <see langword="null"/>.</exception>
    public ResilienceHandler(ResiliencePipeline<HttpResponseMessage> pipeline)
    {
        _ = Throw.IfNull(pipeline);
        _pipelineProvider = _ => pipeline;
    }
 
    /// <summary>
    /// Sends an HTTP request to the inner handler to send to the server as an asynchronous operation.
    /// </summary>
    /// <param name="request">The HTTP request message to send to the server.</param>
    /// <param name="cancellationToken">A cancellation token to cancel operation.</param>
    /// <returns>The task object representing the asynchronous operation.</returns>
    /// <exception cref="ArgumentNullException">If <paramref name="request"/> is <see langword="null"/>.</exception>
    protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
    {
        _ = Throw.IfNull(request);
 
        ResiliencePipeline<HttpResponseMessage> pipeline = _pipelineProvider(request);
 
        ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created);
        TrySetRequestMetadata(context, request);
        SetRequestMessage(context, request);
 
        try
        {
            Outcome<HttpResponseMessage> outcome = await pipeline.ExecuteOutcomeAsync(
                static async (context, state) =>
                {
                    HttpRequestMessage request = GetRequestMessage(context, state.request);
 
                    // Always re-assign the context to this request message before execution.
                    // This is because for primary actions the context is also cloned and we need to re-assign it
                    // here because Polly doesn't have any other events that we can hook into.
                    request.SetResilienceContext(context);
 
                    try
                    {
                        HttpResponseMessage response = await state.instance
                            .SendCoreAsync(request, context.CancellationToken)
                            .ConfigureAwait(context.ContinueOnCapturedContext);
 
                        return Outcome.FromResult(response);
                    }
#pragma warning disable CA1031 // Do not catch general exception types
                    catch (Exception e)
                    {
                        return Outcome.FromException<HttpResponseMessage>(e);
                    }
#pragma warning restore CA1031 // Do not catch general exception types
                },
                context,
                (instance: this, request))
                .ConfigureAwait(context.ContinueOnCapturedContext);
 
            outcome.ThrowIfException();
 
            return outcome.Result!;
        }
        finally
        {
            RestoreResilienceContext(context, request, created);
        }
    }
 
#if NET6_0_OR_GREATER
    /// <summary>
    /// Sends an HTTP request to the inner handler to send to the server as a synchronous operation.
    /// </summary>
    /// <param name="request">The HTTP request message to send to the server.</param>
    /// <param name="cancellationToken">A cancellation token to cancel operation.</param>
    /// <returns>An HTTP response received from the server.</returns>
    /// <exception cref="ArgumentNullException">If <paramref name="request"/> is <see langword="null"/>.</exception>
    protected override HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken)
    {
        _ = Throw.IfNull(request);
 
        ResiliencePipeline<HttpResponseMessage> pipeline = _pipelineProvider(request);
 
        ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created);
        TrySetRequestMetadata(context, request);
        SetRequestMessage(context, request);
 
        try
        {
            return pipeline.Execute(
                static (context, state) =>
                {
                    HttpRequestMessage request = GetRequestMessage(context, state.request);
 
                    // Always re-assign the context to this request message before execution.
                    // This is because for primary actions the context is also cloned and we need to re-assign it
                    // here because Polly doesn't have any other events that we can hook into.
                    request.SetResilienceContext(context);
 
                    return state.instance.SendCore(request, context.CancellationToken);
                },
                context,
                (instance: this, request));
        }
        finally
        {
            RestoreResilienceContext(context, request, created);
        }
    }
#endif
 
    private static ResilienceContext GetOrSetResilienceContext(HttpRequestMessage request, CancellationToken cancellationToken, out bool created)
    {
        created = false;
 
        if (request.GetResilienceContext() is not ResilienceContext context)
        {
            context = ResilienceContextPool.Shared.Get(cancellationToken);
            created = true;
            request.SetResilienceContext(context);
        }
 
        return context;
    }
 
    private static void TrySetRequestMetadata(ResilienceContext context, HttpRequestMessage request)
    {
        if (request.GetRequestMetadata() is RequestMetadata requestMetadata)
        {
            context.Properties.Set(ResilienceKeys.RequestMetadata, requestMetadata);
        }
    }
 
    private static void SetRequestMessage(ResilienceContext context, HttpRequestMessage request)
        => context.Properties.Set(ResilienceKeys.RequestMessage, request);
 
    private static HttpRequestMessage GetRequestMessage(ResilienceContext context, HttpRequestMessage request)
        => context.Properties.GetValue(ResilienceKeys.RequestMessage, request);
 
    private static void RestoreResilienceContext(ResilienceContext context, HttpRequestMessage request, bool created)
    {
        if (created)
        {
            ResilienceContextPool.Shared.Return(context);
            request.SetResilienceContext(null);
        }
        else
        {
            // Restore the original context
            request.SetResilienceContext(context);
        }
    }
 
    private Task<HttpResponseMessage> SendCoreAsync(HttpRequestMessage requestMessage, CancellationToken cancellationToken)
        => base.SendAsync(requestMessage, cancellationToken);
 
#if NET6_0_OR_GREATER
    private HttpResponseMessage SendCore(HttpRequestMessage requestMessage, CancellationToken cancellationToken)
        => base.Send(requestMessage, cancellationToken);
#endif
}