File: System\Security\Claims\ClaimsPrincipal.cs
Web Access
Project: src\src\libraries\System.Security.Claims\src\System.Security.Claims.csproj (System.Security.Claims)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Runtime.Serialization;
using System.Security.Principal;
using System.Threading;
 
namespace System.Security.Claims
{
    /// <summary>
    /// Concrete IPrincipal supporting multiple claims-based identities
    /// </summary>
    [DebuggerDisplay("{DebuggerToString(),nq}")]
    public class ClaimsPrincipal : IPrincipal
    {
        private enum SerializationMask
        {
            None = 0,
            HasIdentities = 1,
            UserData = 2
        }
 
        private readonly List<ClaimsIdentity> _identities = new List<ClaimsIdentity>();
        private readonly byte[]? _userSerializationData;
 
        private static Func<IEnumerable<ClaimsIdentity>, ClaimsIdentity?> s_identitySelector = SelectPrimaryIdentity;
        private static Func<ClaimsPrincipal> s_principalSelector = ClaimsPrincipalSelector;
 
        private static ClaimsPrincipal? SelectClaimsPrincipal()
        {
            // Diverging behavior from .NET Framework: In Framework, the default PrincipalPolicy is
            // UnauthenticatedPrincipal. In .NET Core, the default is NoPrincipal. .NET Framework
            // would throw an ArgumentNullException when constructing the ClaimsPrincipal with a
            // null principal from the thread if it were set to use NoPrincipal. In .NET Core, since
            // NoPrincipal is the default, we return null instead of throw.
 
            IPrincipal? threadPrincipal = Thread.CurrentPrincipal;
 
            return threadPrincipal switch {
                ClaimsPrincipal claimsPrincipal => claimsPrincipal,
                not null => new ClaimsPrincipal(threadPrincipal),
                null => null
            };
        }
 
        [Obsolete(Obsoletions.LegacyFormatterImplMessage, DiagnosticId = Obsoletions.LegacyFormatterImplDiagId, UrlFormat = Obsoletions.SharedUrlFormat)]
        [EditorBrowsable(EditorBrowsableState.Never)]
        protected ClaimsPrincipal(SerializationInfo info, StreamingContext context)
        {
            throw new PlatformNotSupportedException();
        }
 
        /// <summary>
        /// This method iterates through the collection of ClaimsIdentities and chooses an identity as the primary.
        /// </summary>
        private static ClaimsIdentity? SelectPrimaryIdentity(IEnumerable<ClaimsIdentity> identities)
        {
            ArgumentNullException.ThrowIfNull(identities);
 
            foreach (ClaimsIdentity identity in identities)
            {
                if (identity != null)
                {
                    return identity;
                }
            }
 
            return null;
        }
 
        public static Func<IEnumerable<ClaimsIdentity>, ClaimsIdentity?> PrimaryIdentitySelector
        {
            get
            {
                return s_identitySelector;
            }
            set
            {
                s_identitySelector = value;
            }
        }
 
        public static Func<ClaimsPrincipal> ClaimsPrincipalSelector
        {
            get
            {
                return s_principalSelector;
            }
            set
            {
                s_principalSelector = value;
            }
        }
 
        /// <summary>
        /// Initializes an instance of <see cref="ClaimsPrincipal"/>.
        /// </summary>
        public ClaimsPrincipal()
        {
        }
 
        /// <summary>
        /// Initializes an instance of <see cref="ClaimsPrincipal"/>.
        /// </summary>
        /// <param name="identities"> <see cref="IEnumerable{ClaimsIdentity}"/> the subjects in the principal.</param>
        /// <exception cref="ArgumentNullException">if 'identities' is null.</exception>
        public ClaimsPrincipal(IEnumerable<ClaimsIdentity> identities)
        {
            ArgumentNullException.ThrowIfNull(identities);
 
            _identities.AddRange(identities);
        }
 
        /// <summary>
        /// Initializes an instance of <see cref="ClaimsPrincipal"/>
        /// </summary>
        /// <param name="identity"> <see cref="IIdentity"/> representing the subject in the principal. </param>
        /// <exception cref="ArgumentNullException">if 'identity' is null.</exception>
        public ClaimsPrincipal(IIdentity identity)
        {
            ArgumentNullException.ThrowIfNull(identity);
 
            if (identity is ClaimsIdentity ci)
            {
                _identities.Add(ci);
            }
            else
            {
                _identities.Add(new ClaimsIdentity(identity));
            }
        }
 
        /// <summary>
        /// Initializes an instance of <see cref="ClaimsPrincipal"/>
        /// </summary>
        /// <param name="principal"><see cref="IPrincipal"/> used to form this instance.</param>
        /// <exception cref="ArgumentNullException">if 'principal' is null.</exception>
        public ClaimsPrincipal(IPrincipal principal)
        {
            ArgumentNullException.ThrowIfNull(principal);
 
            //
            // If IPrincipal is a ClaimsPrincipal add all of the identities
            // If IPrincipal is not a ClaimsPrincipal, create a new identity from IPrincipal.Identity
            //
            ClaimsPrincipal? cp = principal as ClaimsPrincipal;
            if (null == cp)
            {
                _identities.Add(new ClaimsIdentity(principal.Identity));
            }
            else
            {
                if (null != cp.Identities)
                {
                    _identities.AddRange(cp.Identities);
                }
            }
        }
 
        /// <summary>
        /// Initializes an instance of <see cref="ClaimsPrincipal"/> using a <see cref="BinaryReader"/>.
        /// Normally the <see cref="BinaryReader"/> is constructed using the bytes from <see cref="WriteTo(BinaryWriter)"/> and initialized in the same way as the <see cref="BinaryWriter"/>.
        /// </summary>
        /// <param name="reader">a <see cref="BinaryReader"/> pointing to a <see cref="ClaimsPrincipal"/>.</param>
        /// <exception cref="ArgumentNullException">if 'reader' is null.</exception>
        public ClaimsPrincipal(BinaryReader reader)
        {
            ArgumentNullException.ThrowIfNull(reader);
 
            SerializationMask mask = (SerializationMask)reader.ReadInt32();
            int numPropertiesToRead = reader.ReadInt32();
            int numPropertiesRead = 0;
            if ((mask & SerializationMask.HasIdentities) == SerializationMask.HasIdentities)
            {
                numPropertiesRead++;
                int numberOfIdentities = reader.ReadInt32();
                for (int index = 0; index < numberOfIdentities; ++index)
                {
                    // directly add to _identities as that is what we serialized from
                    _identities.Add(CreateClaimsIdentity(reader));
                }
            }
 
            if ((mask & SerializationMask.UserData) == SerializationMask.UserData)
            {
                int cb = reader.ReadInt32();
                _userSerializationData = reader.ReadBytes(cb);
                numPropertiesRead++;
            }
 
            for (int i = numPropertiesRead; i < numPropertiesToRead; i++)
            {
                reader.ReadString();
            }
        }
 
        /// <summary>
        /// Adds a single <see cref="ClaimsIdentity"/> to an internal list.
        /// </summary>
        /// <param name="identity">the <see cref="ClaimsIdentity"/>add.</param>
        /// <exception cref="ArgumentNullException">if 'identity' is null.</exception>
        public virtual void AddIdentity(ClaimsIdentity identity)
        {
            ArgumentNullException.ThrowIfNull(identity);
 
            _identities.Add(identity);
        }
 
        /// <summary>
        /// Adds a <see cref="IEnumerable{ClaimsIdentity}"/> to the internal list.
        /// </summary>
        /// <param name="identities">Enumeration of ClaimsIdentities to add.</param>
        /// <exception cref="ArgumentNullException">if 'identities' is null.</exception>
        public virtual void AddIdentities(IEnumerable<ClaimsIdentity> identities)
        {
            ArgumentNullException.ThrowIfNull(identities);
 
            _identities.AddRange(identities);
        }
 
        /// <summary>
        /// Gets the claims as <see cref="IEnumerable{Claim}"/>, associated with this <see cref="ClaimsPrincipal"/> by enumerating all <see cref="Identities"/>.
        /// </summary>
        public virtual IEnumerable<Claim> Claims
        {
            get
            {
                foreach (ClaimsIdentity identity in Identities)
                {
                    foreach (Claim claim in identity.Claims)
                    {
                        yield return claim;
                    }
                }
            }
        }
 
        /// <summary>
        /// Contains any additional data provided by derived type, typically set when calling <see cref="WriteTo(BinaryWriter, byte[])"/>.
        /// </summary>
        protected virtual byte[]? CustomSerializationData
        {
            get
            {
                return _userSerializationData;
            }
        }
 
        /// <summary>
        /// Creates a new instance of <see cref="ClaimsPrincipal"/> with values copied from this object.
        /// </summary>
        public virtual ClaimsPrincipal Clone()
        {
            return new ClaimsPrincipal(this);
        }
 
        /// <summary>
        /// Provides an extensibility point for derived types to create a custom <see cref="ClaimsIdentity"/>.
        /// </summary>
        /// <param name="reader">the <see cref="BinaryReader"/>that points at the claim.</param>
        /// <exception cref="ArgumentNullException">if 'reader' is null.</exception>
        /// <returns>a new <see cref="ClaimsIdentity"/>.</returns>
        protected virtual ClaimsIdentity CreateClaimsIdentity(BinaryReader reader)
        {
            ArgumentNullException.ThrowIfNull(reader);
 
            return new ClaimsIdentity(reader);
        }
 
        /// <summary>
        /// Returns the Current Principal by calling a delegate.  Users may specify the delegate.
        /// </summary>
        public static ClaimsPrincipal? Current
        {
            // just accesses the current selected principal selector, doesn't set
            get
            {
                return s_principalSelector is not null ? s_principalSelector() : SelectClaimsPrincipal();
            }
        }
 
        /// <summary>
        /// Retrieves a <see cref="IEnumerable{Claim}"/> where each claim is matched by <paramref name="match"/>.
        /// </summary>
        /// <param name="match">The predicate that performs the matching logic.</param>
        /// <returns>A <see cref="IEnumerable{Claim}"/> of matched claims.</returns>
        /// <remarks>Each <see cref="ClaimsIdentity"/> is called. <seealso cref="ClaimsIdentity.FindAll(string)"/>.</remarks>
        /// <exception cref="ArgumentNullException">if 'match' is null.</exception>
        public virtual IEnumerable<Claim> FindAll(Predicate<Claim> match)
        {
            ArgumentNullException.ThrowIfNull(match);
            return Core(match);
 
            IEnumerable<Claim> Core(Predicate<Claim> match)
            {
                foreach (ClaimsIdentity identity in Identities)
                {
                    if (identity != null)
                    {
                        foreach (Claim claim in identity.FindAll(match))
                        {
                            yield return claim;
                        }
                    }
                }
            }
        }
 
        /// <summary>
        /// Retrieves a <see cref="IEnumerable{Claim}"/> where each Claim.Type equals <paramref name="type"/>.
        /// </summary>
        /// <param name="type">The type of the claim to match.</param>
        /// <returns>A <see cref="IEnumerable{Claim}"/> of matched claims.</returns>
        /// <remarks>Each <see cref="ClaimsIdentity"/> is called. <seealso cref="ClaimsIdentity.FindAll(Predicate{Claim})"/>.</remarks>
        /// <exception cref="ArgumentNullException">if 'type' is null.</exception>
        public virtual IEnumerable<Claim> FindAll(string type)
        {
            ArgumentNullException.ThrowIfNull(type);
            return Core(type);
 
            IEnumerable<Claim> Core(string type)
            {
                foreach (ClaimsIdentity identity in Identities)
                {
                    if (identity != null)
                    {
                        foreach (Claim claim in identity.FindAll(type))
                        {
                            yield return claim;
                        }
                    }
                }
            }
        }
 
        /// <summary>
        /// Retrieves the first <see cref="Claim"/> that is matched by <paramref name="match"/>.
        /// </summary>
        /// <param name="match">The predicate that performs the matching logic.</param>
        /// <returns>A <see cref="Claim"/>, null if nothing matches.</returns>
        /// <remarks>Each <see cref="ClaimsIdentity"/> is called. <seealso cref="ClaimsIdentity.FindFirst(string)"/>.</remarks>
        /// <exception cref="ArgumentNullException">if 'match' is null.</exception>
        public virtual Claim? FindFirst(Predicate<Claim> match)
        {
            ArgumentNullException.ThrowIfNull(match);
 
            Claim? claim = null;
 
            foreach (ClaimsIdentity identity in Identities)
            {
                if (identity != null)
                {
                    claim = identity.FindFirst(match);
                    if (claim != null)
                    {
                        return claim;
                    }
                }
            }
 
            return claim;
        }
 
        /// <summary>
        /// Retrieves the first <see cref="Claim"/> where the Claim.Type equals <paramref name="type"/>.
        /// </summary>
        /// <param name="type">The type of the claim to match.</param>
        /// <returns>A <see cref="Claim"/>, null if nothing matches.</returns>
        /// <remarks>Each <see cref="ClaimsIdentity"/> is called. <seealso cref="ClaimsIdentity.FindFirst(Predicate{Claim})"/>.</remarks>
        /// <exception cref="ArgumentNullException">if 'type' is null.</exception>
        public virtual Claim? FindFirst(string type)
        {
            ArgumentNullException.ThrowIfNull(type);
 
            Claim? claim = null;
 
            for (int i = 0; i < _identities.Count; i++)
            {
                if (_identities[i] != null)
                {
                    claim = _identities[i].FindFirst(type);
                    if (claim != null)
                    {
                        return claim;
                    }
                }
            }
 
            return claim;
        }
 
        /// <summary>
        /// Determines if a claim is contained within all the ClaimsIdentities in this ClaimPrincipal.
        /// </summary>
        /// <param name="match">The predicate that performs the matching logic.</param>
        /// <returns>true if a claim is found, false otherwise.</returns>
        /// <remarks>Each <see cref="ClaimsIdentity"/> is called. <seealso cref="ClaimsIdentity.HasClaim(string, string)"/>.</remarks>
        /// <exception cref="ArgumentNullException">if 'match' is null.</exception>
        public virtual bool HasClaim(Predicate<Claim> match)
        {
            ArgumentNullException.ThrowIfNull(match);
 
            for (int i = 0; i < _identities.Count; i++)
            {
                if (_identities[i] != null)
                {
                    if (_identities[i].HasClaim(match))
                    {
                        return true;
                    }
                }
            }
 
            return false;
        }
 
        /// <summary>
        /// Determines if a claim of claimType AND claimValue exists in any of the identities.
        /// </summary>
        /// <param name="type"> the type of the claim to match.</param>
        /// <param name="value"> the value of the claim to match.</param>
        /// <returns>true if a claim is matched, false otherwise.</returns>
        /// <remarks>Each <see cref="ClaimsIdentity"/> is called. <seealso cref="ClaimsIdentity.HasClaim(Predicate{Claim})"/>.</remarks>
        /// <exception cref="ArgumentNullException">if 'type' is null.</exception>
        /// <exception cref="ArgumentNullException">if 'value' is null.</exception>
        public virtual bool HasClaim(string type, string value)
        {
            ArgumentNullException.ThrowIfNull(type);
            ArgumentNullException.ThrowIfNull(value);
 
            for (int i = 0; i < _identities.Count; i++)
            {
                if (_identities[i] != null)
                {
                    if (_identities[i].HasClaim(type, value))
                    {
                        return true;
                    }
                }
            }
 
            return false;
        }
 
        /// <summary>
        /// Collection of <see cref="ClaimsIdentity" />
        /// </summary>
        public virtual IEnumerable<ClaimsIdentity> Identities
        {
            get
            {
                return _identities;
            }
        }
 
        /// <summary>
        /// Gets the identity of the current principal.
        /// </summary>
        public virtual System.Security.Principal.IIdentity? Identity
        {
            get
            {
                if (s_identitySelector != null)
                {
                    return s_identitySelector(_identities);
                }
                else
                {
                    return SelectPrimaryIdentity(_identities);
                }
            }
        }
 
        /// <summary>
        /// IsInRole answers the question: does an identity this principal possesses
        /// contain a claim of type RoleClaimType where the value is '==' to the role.
        /// </summary>
        /// <param name="role">The role to check for.</param>
        /// <returns>'True' if a claim is found. Otherwise 'False'.</returns>
        /// <remarks>Each Identity has its own definition of the ClaimType that represents a role.</remarks>
        public virtual bool IsInRole(string role)
        {
            for (int i = 0; i < _identities.Count; i++)
            {
                if (_identities[i] != null)
                {
                    if (_identities[i].HasClaim(_identities[i].RoleClaimType, role))
                    {
                        return true;
                    }
                }
            }
 
            return false;
        }
 
        /// <summary>
        /// Serializes using a <see cref="BinaryWriter"/>
        /// </summary>
        /// <exception cref="ArgumentNullException">if 'writer' is null.</exception>
        public virtual void WriteTo(BinaryWriter writer)
        {
            WriteTo(writer, null);
        }
 
        /// <summary>
        /// Serializes using a <see cref="BinaryWriter"/>
        /// </summary>
        /// <param name="writer">the <see cref="BinaryWriter"/> to use for data storage.</param>
        /// <param name="userData">additional data provided by derived type.</param>
        /// <exception cref="ArgumentNullException">if 'writer' is null.</exception>
        protected virtual void WriteTo(BinaryWriter writer, byte[]? userData)
        {
            ArgumentNullException.ThrowIfNull(writer);
 
            int numberOfPropertiesWritten = 0;
            var mask = SerializationMask.None;
            if (_identities.Count > 0)
            {
                mask |= SerializationMask.HasIdentities;
                numberOfPropertiesWritten++;
            }
 
            if (userData != null && userData.Length > 0)
            {
                numberOfPropertiesWritten++;
                mask |= SerializationMask.UserData;
            }
 
            writer.Write((int)mask);
            writer.Write(numberOfPropertiesWritten);
            if ((mask & SerializationMask.HasIdentities) == SerializationMask.HasIdentities)
            {
                writer.Write(_identities.Count);
                foreach (var identity in _identities)
                {
                    identity.WriteTo(writer);
                }
            }
 
            if ((mask & SerializationMask.UserData) == SerializationMask.UserData)
            {
                writer.Write(userData!.Length);
                writer.Write(userData);
            }
 
            writer.Flush();
        }
 
        [OnSerializing]
        private void OnSerializingMethod(StreamingContext context)
        {
            if (this is ISerializable)
            {
                return;
            }
 
            if (_identities.Count > 0)
            {
                throw new PlatformNotSupportedException(SR.PlatformNotSupported_Serialization); // BinaryFormatter and WindowsIdentity would be needed
            }
        }
 
        protected virtual void GetObjectData(SerializationInfo info, StreamingContext context)
        {
            throw new PlatformNotSupportedException();
        }
 
        private string DebuggerToString()
        {
            // DebuggerDisplayAttribute is inherited. Use virtual members instead of private fields to gather data.
            int identitiesCount = 0;
            foreach (ClaimsIdentity items in Identities)
            {
                identitiesCount++;
            }
 
            // Return debug string optimized for the case of one identity.
            if (identitiesCount == 1 && Identity is ClaimsIdentity claimsIdentity)
            {
                return claimsIdentity.DebuggerToString();
            }
 
            int claimsCount = 0;
            foreach (Claim item in Claims)
            {
                claimsCount++;
            }
 
            return $"Identities = {identitiesCount}, Claims = {claimsCount}";
        }
    }
}