|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Linq;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.Authentication;
/// <summary>
/// Implements <see cref="IAuthenticationSchemeProvider"/>.
/// </summary>
public class AuthenticationSchemeProvider : IAuthenticationSchemeProvider
{
/// <summary>
/// Creates an instance of <see cref="AuthenticationSchemeProvider"/>
/// using the specified <paramref name="options"/>.
/// </summary>
/// <param name="options">The <see cref="AuthenticationOptions"/> options.</param>
public AuthenticationSchemeProvider(IOptions<AuthenticationOptions> options)
: this(options, new Dictionary<string, AuthenticationScheme>(StringComparer.Ordinal))
{
}
/// <summary>
/// Creates an instance of <see cref="AuthenticationSchemeProvider"/>
/// using the specified <paramref name="options"/> and <paramref name="schemes"/>.
/// </summary>
/// <param name="options">The <see cref="AuthenticationOptions"/> options.</param>
/// <param name="schemes">The dictionary used to store authentication schemes.</param>
protected AuthenticationSchemeProvider(IOptions<AuthenticationOptions> options, IDictionary<string, AuthenticationScheme> schemes)
{
_options = options.Value;
_schemes = schemes ?? throw new ArgumentNullException(nameof(schemes));
_requestHandlers = new List<AuthenticationScheme>();
foreach (var builder in _options.Schemes)
{
var scheme = builder.Build();
AddScheme(scheme);
}
}
private readonly AuthenticationOptions _options;
private readonly object _lock = new object();
private readonly IDictionary<string, AuthenticationScheme> _schemes;
private readonly List<AuthenticationScheme> _requestHandlers;
private static readonly Task<AuthenticationScheme?> _nullScheme = Task.FromResult<AuthenticationScheme?>(null);
private Task<AuthenticationScheme?> _autoDefaultScheme = _nullScheme;
// Used as a safe return value for enumeration apis
private IEnumerable<AuthenticationScheme> _schemesCopy = Array.Empty<AuthenticationScheme>();
private IEnumerable<AuthenticationScheme> _requestHandlersCopy = Array.Empty<AuthenticationScheme>();
private Task<AuthenticationScheme?> GetDefaultSchemeAsync()
=> _options.DefaultScheme != null
? GetSchemeAsync(_options.DefaultScheme)
: _autoDefaultScheme;
/// <summary>
/// Returns the scheme that will be used by default for <see cref="IAuthenticationService.AuthenticateAsync(HttpContext, string)"/>.
/// This is typically specified via <see cref="AuthenticationOptions.DefaultAuthenticateScheme"/>.
/// Otherwise, this will fallback to <see cref="AuthenticationOptions.DefaultScheme"/>.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.AuthenticateAsync(HttpContext, string)"/>.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultAuthenticateSchemeAsync()
=> _options.DefaultAuthenticateScheme != null
? GetSchemeAsync(_options.DefaultAuthenticateScheme)
: GetDefaultSchemeAsync();
/// <summary>
/// Returns the scheme that will be used by default for <see cref="IAuthenticationService.ChallengeAsync(HttpContext, string, AuthenticationProperties)"/>.
/// This is typically specified via <see cref="AuthenticationOptions.DefaultChallengeScheme"/>.
/// Otherwise, this will fallback to <see cref="AuthenticationOptions.DefaultScheme"/>.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.ChallengeAsync(HttpContext, string, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultChallengeSchemeAsync()
=> _options.DefaultChallengeScheme != null
? GetSchemeAsync(_options.DefaultChallengeScheme)
: GetDefaultSchemeAsync();
/// <summary>
/// Returns the scheme that will be used by default for <see cref="IAuthenticationService.ForbidAsync(HttpContext, string, AuthenticationProperties)"/>.
/// This is typically specified via <see cref="AuthenticationOptions.DefaultForbidScheme"/>.
/// Otherwise, this will fallback to <see cref="GetDefaultChallengeSchemeAsync"/> .
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.ForbidAsync(HttpContext, string, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultForbidSchemeAsync()
=> _options.DefaultForbidScheme != null
? GetSchemeAsync(_options.DefaultForbidScheme)
: GetDefaultChallengeSchemeAsync();
/// <summary>
/// Returns the scheme that will be used by default for <see cref="IAuthenticationService.SignInAsync(HttpContext, string, System.Security.Claims.ClaimsPrincipal, AuthenticationProperties)"/>.
/// This is typically specified via <see cref="AuthenticationOptions.DefaultSignInScheme"/>.
/// Otherwise, this will fallback to <see cref="AuthenticationOptions.DefaultScheme"/>.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.SignInAsync(HttpContext, string, System.Security.Claims.ClaimsPrincipal, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultSignInSchemeAsync()
=> _options.DefaultSignInScheme != null
? GetSchemeAsync(_options.DefaultSignInScheme)
: GetDefaultSchemeAsync();
/// <summary>
/// Returns the scheme that will be used by default for <see cref="IAuthenticationService.SignOutAsync(HttpContext, string, AuthenticationProperties)"/>.
/// This is typically specified via <see cref="AuthenticationOptions.DefaultSignOutScheme"/>.
/// Otherwise this will fallback to <see cref="GetDefaultSignInSchemeAsync"/> if that supports sign out.
/// </summary>
/// <returns>The scheme that will be used by default for <see cref="IAuthenticationService.SignOutAsync(HttpContext, string, AuthenticationProperties)"/>.</returns>
public virtual Task<AuthenticationScheme?> GetDefaultSignOutSchemeAsync()
=> _options.DefaultSignOutScheme != null
? GetSchemeAsync(_options.DefaultSignOutScheme)
: GetDefaultSignInSchemeAsync();
/// <summary>
/// Returns the <see cref="AuthenticationScheme"/> matching the name, or null.
/// </summary>
/// <param name="name">The name of the authenticationScheme.</param>
/// <returns>The scheme or null if not found.</returns>
public virtual Task<AuthenticationScheme?> GetSchemeAsync(string name)
=> Task.FromResult(_schemes.TryGetValue(name, out var scheme) ? scheme : null);
/// <summary>
/// Returns the schemes in priority order for request handling.
/// </summary>
/// <returns>The schemes in priority order for request handling</returns>
public virtual Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
=> Task.FromResult(_requestHandlersCopy);
/// <summary>
/// Registers a scheme for use by <see cref="IAuthenticationService"/>.
/// </summary>
/// <param name="scheme">The scheme.</param>
/// <returns>true if the scheme was added successfully.</returns>
public virtual bool TryAddScheme(AuthenticationScheme scheme)
{
if (_schemes.ContainsKey(scheme.Name))
{
return false;
}
lock (_lock)
{
if (_schemes.ContainsKey(scheme.Name))
{
return false;
}
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType))
{
_requestHandlers.Add(scheme);
_requestHandlersCopy = _requestHandlers.ToArray();
}
_schemes[scheme.Name] = scheme;
_schemesCopy = _schemes.Values.ToArray();
CheckAutoDefaultScheme();
return true;
}
}
/// <summary>
/// Registers a scheme for use by <see cref="IAuthenticationService"/>.
/// </summary>
/// <param name="scheme">The scheme.</param>
public virtual void AddScheme(AuthenticationScheme scheme)
{
if (_schemes.ContainsKey(scheme.Name))
{
throw new InvalidOperationException("Scheme already exists: " + scheme.Name);
}
lock (_lock)
{
if (!TryAddScheme(scheme))
{
throw new InvalidOperationException("Scheme already exists: " + scheme.Name);
}
}
}
/// <summary>
/// Removes a scheme, preventing it from being used by <see cref="IAuthenticationService"/>.
/// </summary>
/// <param name="name">The name of the authenticationScheme being removed.</param>
public virtual void RemoveScheme(string name)
{
if (!_schemes.TryGetValue(name, out _))
{
return;
}
lock (_lock)
{
if (_schemes.TryGetValue(name, out var scheme))
{
if (_requestHandlers.Remove(scheme))
{
_requestHandlersCopy = _requestHandlers.ToArray();
}
_schemes.Remove(name);
_schemesCopy = _schemes.Values.ToArray();
CheckAutoDefaultScheme();
}
}
}
/// <inheritdoc />
public virtual Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
=> Task.FromResult(_schemesCopy);
private void CheckAutoDefaultScheme()
{
if (!_options.DisableAutoDefaultScheme)
{
if (_schemes.Count == 1)
{
_autoDefaultScheme = Task.FromResult<AuthenticationScheme?>(_schemesCopy.First());
}
else
{
_autoDefaultScheme = _nullScheme;
}
}
}
}
|