File: RemoteAuthenticationHandler.cs
Web Access
Project: src\src\Security\Authentication\Core\src\Microsoft.AspNetCore.Authentication.csproj (Microsoft.AspNetCore.Authentication)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text.Encodings.Web;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
 
namespace Microsoft.AspNetCore.Authentication;
 
/// <summary>
/// An opinionated abstraction for an <see cref="AuthenticationHandler{TOptions}"/> that performs authentication using a separately hosted
/// provider.
/// </summary>
/// <typeparam name="TOptions">The type for the options used to configure the authentication handler.</typeparam>
public abstract class RemoteAuthenticationHandler<TOptions> : AuthenticationHandler<TOptions>, IAuthenticationRequestHandler
    where TOptions : RemoteAuthenticationOptions, new()
{
    private const string CorrelationProperty = ".xsrf";
    private const string CorrelationMarker = "N";
    private const string AuthSchemeKey = ".AuthScheme";
 
    /// <summary>
    /// The authentication scheme used by default for signin.
    /// </summary>
    protected string? SignInScheme => Options.SignInScheme;
 
    /// <summary>
    /// The handler calls methods on the events which give the application control at certain points where processing is occurring.
    /// If it is not provided a default instance is supplied which does nothing when the methods are called.
    /// </summary>
    protected new RemoteAuthenticationEvents Events
    {
        get { return (RemoteAuthenticationEvents)base.Events!; }
        set { base.Events = value; }
    }
 
    /// <summary>
    /// Initializes a new instance of <see cref="RemoteAuthenticationHandler{TOptions}" />.
    /// </summary>
    /// <param name="options">The monitor for the options instance.</param>
    /// <param name="logger">The <see cref="ILoggerFactory"/>.</param>
    /// <param name="encoder">The <see cref="UrlEncoder"/>.</param>
    /// <param name="clock">The <see cref="ISystemClock"/>.</param>
    [Obsolete("ISystemClock is obsolete, use TimeProvider on AuthenticationSchemeOptions instead.")]
    protected RemoteAuthenticationHandler(IOptionsMonitor<TOptions> options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock)
        : base(options, logger, encoder, clock) { }
 
    /// <summary>
    /// Initializes a new instance of <see cref="RemoteAuthenticationHandler{TOptions}" />.
    /// </summary>
    /// <param name="options">The monitor for the options instance.</param>
    /// <param name="logger">The <see cref="ILoggerFactory"/>.</param>
    /// <param name="encoder">The <see cref="UrlEncoder"/>.</param>
    protected RemoteAuthenticationHandler(IOptionsMonitor<TOptions> options, ILoggerFactory logger, UrlEncoder encoder)
        : base(options, logger, encoder) { }
 
    /// <inheritdoc />
    protected override Task<object> CreateEventsAsync()
        => Task.FromResult<object>(new RemoteAuthenticationEvents());
 
    /// <summary>
    /// Gets a value that determines if the current authentication request should be handled by <see cref="HandleRequestAsync" />.
    /// </summary>
    /// <returns><see langword="true"/> to handle the operation, otherwise <see langword="false"/>.</returns>
    public virtual Task<bool> ShouldHandleRequestAsync()
        => Task.FromResult(Options.CallbackPath == Request.Path);
 
    /// <summary>
    /// Handles the current authentication request.
    /// </summary>
    /// <returns><see langword="true"/> if authentication was handled, otherwise <see langword="false"/>.</returns>
    public virtual async Task<bool> HandleRequestAsync()
    {
        if (!await ShouldHandleRequestAsync())
        {
            return false;
        }
 
        AuthenticationTicket? ticket = null;
        Exception? exception = null;
        AuthenticationProperties? properties = null;
        try
        {
            var authResult = await HandleRemoteAuthenticateAsync();
            if (authResult == null)
            {
                exception = new InvalidOperationException("Invalid return state, unable to redirect.");
            }
            else if (authResult.Handled)
            {
                return true;
            }
            else if (authResult.Skipped || authResult.None)
            {
                return false;
            }
            else if (!authResult.Succeeded)
            {
                exception = authResult.Failure ?? new InvalidOperationException("Invalid return state, unable to redirect.");
                properties = authResult.Properties;
            }
 
            ticket = authResult?.Ticket;
        }
        catch (Exception ex)
        {
            exception = ex;
        }
 
        if (exception != null)
        {
            Logger.RemoteAuthenticationError(exception.Message);
            var errorContext = new RemoteFailureContext(Context, Scheme, Options, exception)
            {
                Properties = properties
            };
            await Events.RemoteFailure(errorContext);
 
            if (errorContext.Result != null)
            {
                if (errorContext.Result.Handled)
                {
                    return true;
                }
                else if (errorContext.Result.Skipped)
                {
                    return false;
                }
                else if (errorContext.Result.Failure != null)
                {
                    throw new AuthenticationFailureException("An error was returned from the RemoteFailure event.", errorContext.Result.Failure);
                }
            }
 
            if (errorContext.Failure != null)
            {
                throw new AuthenticationFailureException("An error was encountered while handling the remote login.", errorContext.Failure);
            }
        }
 
        // We have a ticket if we get here
        Debug.Assert(ticket != null);
        var ticketContext = new TicketReceivedContext(Context, Scheme, Options, ticket)
        {
            ReturnUri = ticket.Properties.RedirectUri
        };
 
        ticket.Properties.RedirectUri = null;
 
        // Mark which provider produced this identity so we can cross-check later in HandleAuthenticateAsync
        ticketContext.Properties!.Items[AuthSchemeKey] = Scheme.Name;
 
        await Events.TicketReceived(ticketContext);
 
        if (ticketContext.Result != null)
        {
            if (ticketContext.Result.Handled)
            {
                Logger.SignInHandled();
                return true;
            }
            else if (ticketContext.Result.Skipped)
            {
                Logger.SignInSkipped();
                return false;
            }
        }
 
        await Context.SignInAsync(SignInScheme, ticketContext.Principal!, ticketContext.Properties);
 
        // Default redirect path is the base path
        if (string.IsNullOrEmpty(ticketContext.ReturnUri))
        {
            ticketContext.ReturnUri = "/";
        }
 
        Response.Redirect(ticketContext.ReturnUri);
        return true;
    }
 
    /// <summary>
    /// Authenticate the user identity with the identity provider.
    ///
    /// The method process the request on the endpoint defined by CallbackPath.
    /// </summary>
    protected abstract Task<HandleRequestResult> HandleRemoteAuthenticateAsync();
 
    /// <inheritdoc />
    protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
    {
        var result = await Context.AuthenticateAsync(SignInScheme);
        if (result != null)
        {
            if (result.Failure != null)
            {
                return result;
            }
 
            // The SignInScheme may be shared with multiple providers, make sure this provider issued the identity.
            var ticket = result.Ticket;
            if (ticket != null && ticket.Principal != null && ticket.Properties != null
                && ticket.Properties.Items.TryGetValue(AuthSchemeKey, out var authenticatedScheme)
                && string.Equals(Scheme.Name, authenticatedScheme, StringComparison.Ordinal))
            {
                return AuthenticateResult.Success(new AuthenticationTicket(ticket.Principal,
                    ticket.Properties, Scheme.Name));
            }
 
            return AuthenticateResult.NoResult();
        }
 
        return AuthenticateResult.Fail("Remote authentication does not directly support AuthenticateAsync");
    }
 
    /// <inheritdoc />
    protected override Task HandleForbiddenAsync(AuthenticationProperties properties)
        => Context.ForbidAsync(SignInScheme);
 
    /// <summary>
    /// Produces a cookie containing a nonce used to correlate the current remote authentication request.
    /// </summary>
    /// <param name="properties"></param>
    protected virtual void GenerateCorrelationId(AuthenticationProperties properties)
    {
        ArgumentNullException.ThrowIfNull(properties);
 
        var bytes = new byte[32];
        RandomNumberGenerator.Fill(bytes);
        var correlationId = Base64UrlTextEncoder.Encode(bytes);
 
        var cookieOptions = Options.CorrelationCookie.Build(Context, TimeProvider.GetUtcNow());
 
        properties.Items[CorrelationProperty] = correlationId;
 
        var cookieName = Options.CorrelationCookie.Name + correlationId;
 
        Response.Cookies.Append(cookieName, CorrelationMarker, cookieOptions);
    }
 
    /// <summary>
    /// Validates that the current request correlates with the current remote authentication request.
    /// </summary>
    /// <param name="properties"></param>
    /// <returns></returns>
    protected virtual bool ValidateCorrelationId(AuthenticationProperties properties)
    {
        ArgumentNullException.ThrowIfNull(properties);
 
        if (!properties.Items.TryGetValue(CorrelationProperty, out var correlationId))
        {
            Logger.CorrelationPropertyNotFound(Options.CorrelationCookie.Name!);
            return false;
        }
 
        properties.Items.Remove(CorrelationProperty);
 
        var cookieName = Options.CorrelationCookie.Name + correlationId;
 
        var correlationCookie = Request.Cookies[cookieName];
        if (string.IsNullOrEmpty(correlationCookie))
        {
            Logger.CorrelationCookieNotFound(cookieName);
            return false;
        }
 
        var cookieOptions = Options.CorrelationCookie.Build(Context, TimeProvider.GetUtcNow());
 
        Response.Cookies.Delete(cookieName, cookieOptions);
 
        if (!string.Equals(correlationCookie, CorrelationMarker, StringComparison.Ordinal))
        {
            Logger.UnexpectedCorrelationCookieValue(cookieName, correlationCookie);
            return false;
        }
 
        return true;
    }
 
    /// <summary>
    /// Derived types may override this method to handle access denied errors.
    /// </summary>
    /// <param name="properties">The <see cref="AuthenticationProperties"/>.</param>
    /// <returns>The <see cref="HandleRequestResult"/>.</returns>
    protected virtual async Task<HandleRequestResult> HandleAccessDeniedErrorAsync(AuthenticationProperties properties)
    {
        Logger.AccessDeniedError();
        var context = new AccessDeniedContext(Context, Scheme, Options)
        {
            AccessDeniedPath = Options.AccessDeniedPath,
            Properties = properties,
            ReturnUrl = properties?.RedirectUri,
            ReturnUrlParameter = Options.ReturnUrlParameter
        };
        await Events.AccessDenied(context);
 
        if (context.Result != null)
        {
            if (context.Result.Handled)
            {
                Logger.AccessDeniedContextHandled();
            }
            else if (context.Result.Skipped)
            {
                Logger.AccessDeniedContextSkipped();
            }
 
            return context.Result;
        }
 
        // If an access denied endpoint was specified, redirect the user agent.
        // Otherwise, invoke the RemoteFailure event for further processing.
        if (context.AccessDeniedPath.HasValue)
        {
            string uri = context.AccessDeniedPath;
            if (!string.IsNullOrEmpty(context.ReturnUrlParameter) && !string.IsNullOrEmpty(context.ReturnUrl))
            {
                uri = QueryHelpers.AddQueryString(uri, context.ReturnUrlParameter, context.ReturnUrl);
            }
            Response.Redirect(BuildRedirectUri(uri));
 
            return HandleRequestResult.Handle();
        }
 
        return HandleRequestResult.NoResult();
    }
}