File: System\ServiceModel\Channels\SslStreamSecurityUpgradeProvider.cs
Web Access
Project: src\src\System.ServiceModel.NetFramingBase\src\System.ServiceModel.NetFramingBase.csproj (System.ServiceModel.NetFramingBase)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.ObjectModel;
using System.IdentityModel.Claims;
using System.IdentityModel.Policy;
using System.IdentityModel.Selectors;
using System.IdentityModel.Tokens;
using System.IO;
using System.Net.Security;
using System.Runtime;
using System.Security.Authentication;
using System.Security.Authentication.ExtendedProtection;
using System.Security.Cryptography.X509Certificates;
using System.ServiceModel.Description;
using System.ServiceModel.Security;
using System.ServiceModel.Security.Tokens;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal class SslStreamSecurityUpgradeProvider : StreamSecurityUpgradeProvider, IStreamUpgradeChannelBindingProvider
    {
        private SecurityTokenAuthenticator _clientCertificateAuthenticator;
        private SecurityTokenProvider _serverTokenProvider;
        private bool _enableChannelBinding;
 
        private SslStreamSecurityUpgradeProvider(IDefaultCommunicationTimeouts timeouts, SecurityTokenManager clientSecurityTokenManager, bool requireClientCertificate, string scheme, IdentityVerifier identityVerifier, SslProtocols sslProtocols)
            : base(timeouts)
        {
            IdentityVerifier = identityVerifier;
            Scheme = scheme;
            ClientSecurityTokenManager = clientSecurityTokenManager;
            RequireClientCertificate = requireClientCertificate;
            SslProtocols = sslProtocols;
        }
 
        private SslStreamSecurityUpgradeProvider(IDefaultCommunicationTimeouts timeouts, SecurityTokenProvider serverTokenProvider, bool requireClientCertificate, SecurityTokenAuthenticator clientCertificateAuthenticator, string scheme, IdentityVerifier identityVerifier, SslProtocols sslProtocols)
            : base(timeouts)
        {
            _serverTokenProvider = serverTokenProvider;
            RequireClientCertificate = requireClientCertificate;
            _clientCertificateAuthenticator = clientCertificateAuthenticator;
            IdentityVerifier = identityVerifier;
            Scheme = scheme;
            SslProtocols = sslProtocols;
        }
 
        public static SslStreamSecurityUpgradeProvider CreateClientProvider(
            SslStreamSecurityBindingElement bindingElement, BindingContext context)
        {
            SecurityCredentialsManager credentialProvider = context.BindingParameters.Find<SecurityCredentialsManager>();
 
            if (credentialProvider == null)
            {
                credentialProvider = new ClientCredentials();
            }
 
            SecurityTokenManager tokenManager = credentialProvider.CreateSecurityTokenManager();
            return new SslStreamSecurityUpgradeProvider(
                context.Binding,
                tokenManager,
                bindingElement.RequireClientCertificate,
                context.Binding.Scheme,
                bindingElement.IdentityVerifier,
                bindingElement.SslProtocols);
        }
 
        public IdentityVerifier IdentityVerifier { get; }
 
        public bool RequireClientCertificate { get; }
 
        public X509Certificate2 ServerCertificate { get; private set; }
 
        public SecurityTokenAuthenticator ClientCertificateAuthenticator
        {
            get
            {
                if (_clientCertificateAuthenticator == null)
                {
                    _clientCertificateAuthenticator = new X509SecurityTokenAuthenticator(DefaultCertificateValidator);
                }
 
                return _clientCertificateAuthenticator;
            }
        }
 
        public SecurityTokenManager ClientSecurityTokenManager { get; }
 
        public string Scheme { get; }
 
        public SslProtocols SslProtocols { get; }
 
        public override T GetProperty<T>()
        {
            if (typeof(T) == typeof(IChannelBindingProvider) || typeof(T) == typeof(IStreamUpgradeChannelBindingProvider))
            {
                return (T)(object)this;
            }
 
            return base.GetProperty<T>();
        }
 
        ChannelBinding IStreamUpgradeChannelBindingProvider.GetChannelBinding(StreamUpgradeInitiator upgradeInitiator, ChannelBindingKind kind)
        {
            if (upgradeInitiator == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull(nameof(upgradeInitiator));
            }
 
            SslStreamSecurityUpgradeInitiator sslUpgradeInitiator = upgradeInitiator as SslStreamSecurityUpgradeInitiator;
 
            if (sslUpgradeInitiator == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument("upgradeInitiator", SR.Format(SR.UnsupportedUpgradeInitiator, upgradeInitiator.GetType()));
            }
 
            if (kind != ChannelBindingKind.Endpoint)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument("kind", SR.Format(SR.StreamUpgradeUnsupportedChannelBindingKind, GetType(), kind));
            }
 
            return sslUpgradeInitiator.ChannelBinding;
        }
 
        void IChannelBindingProvider.EnableChannelBindingSupport()
        {
            _enableChannelBinding = true;
        }
 
        bool IChannelBindingProvider.IsChannelBindingSupportEnabled
        {
            get
            {
                return _enableChannelBinding;
            }
        }
 
        public override StreamUpgradeInitiator CreateUpgradeInitiator(EndpointAddress remoteAddress, Uri via)
        {
            this.ThrowIfDisposedOrNotOpen();
            return new SslStreamSecurityUpgradeInitiator(this, remoteAddress, via);
        }
 
        protected override void OnAbort()
        {
            if (_clientCertificateAuthenticator != null)
            {
                SecurityUtils.AbortTokenAuthenticatorIfRequired(_clientCertificateAuthenticator);
            }
            CleanupServerCertificate();
        }
 
        protected override void OnClose(TimeSpan timeout)
        {
            OnCloseAsync(timeout).WaitForCompletion();
            CleanupServerCertificate();
        }
 
        protected internal override Task OnCloseAsync(TimeSpan timeout)
        {
            if (_clientCertificateAuthenticator != null)
            {
                return SecurityUtilsEx.CloseTokenAuthenticatorIfRequiredAsync(_clientCertificateAuthenticator, timeout);
            }
 
            return TaskHelpers.CompletedTask();
        }
 
        protected override IAsyncResult OnBeginClose(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return OnCloseAsync(timeout).ToApm(callback, state);
        }
 
        protected override void OnEndClose(IAsyncResult result)
        {
            result.ToApmEnd();
        }
 
        private void SetupServerCertificate(SecurityToken token)
        {
            X509SecurityToken x509Token = token as X509SecurityToken;
            if (x509Token == null)
            {
                SecurityUtils.AbortTokenProviderIfRequired(_serverTokenProvider);
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.Format(
                    SR.InvalidTokenProvided, _serverTokenProvider.GetType(), typeof(X509SecurityToken))));
            }
 
            ServerCertificate = new X509Certificate2(x509Token.Certificate);
        }
 
        private void CleanupServerCertificate()
        {
            if (ServerCertificate != null)
            {
                ServerCertificate.Dispose();
                ServerCertificate = null;
            }
        }
 
        protected override void OnOpen(TimeSpan timeout)
        {
            OnOpenAsync(timeout).WaitForCompletion();
        }
 
        protected internal override async Task OnOpenAsync(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            await SecurityUtilsEx.OpenTokenAuthenticatorIfRequiredAsync(ClientCertificateAuthenticator, timeoutHelper.RemainingTime());
 
            if (_serverTokenProvider != null)
            {
                await SecurityUtils.OpenTokenProviderIfRequiredAsync(_serverTokenProvider, timeoutHelper.RemainingTime());
                SecurityToken token = _serverTokenProvider.GetTokenAsync(timeoutHelper.RemainingTime()).GetAwaiter().GetResult();
                SetupServerCertificate(token);
                await SecurityUtils.CloseTokenProviderIfRequiredAsync(_serverTokenProvider, timeoutHelper.RemainingTime());
                _serverTokenProvider = null;
            }
        }
 
        protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return OnOpenAsync(timeout).ToApm(callback, state);
        }
 
        protected override void OnEndOpen(IAsyncResult result)
        {
            result.ToApmEnd();
        }
 
        private static X509CertificateValidator s_defaultCertificateValidator;
 
        internal static X509CertificateValidator DefaultCertificateValidator
        {
            get
            {
                if (s_defaultCertificateValidator == null)
                {
                    X509ChainPolicy chainPolicy = new X509ChainPolicy();
                    chainPolicy.RevocationMode = X509RevocationMode.Online;
                    s_defaultCertificateValidator = X509CertificateValidator.CreateChainTrustValidator(useMachineContext: true, chainPolicy);
                }
                return s_defaultCertificateValidator;
            }
        }
    }
 
    internal class SslStreamSecurityUpgradeInitiator : StreamSecurityUpgradeInitiatorBase
    {
        private SslStreamSecurityUpgradeProvider _parent;
        private SecurityMessageProperty _serverSecurity;
        private SecurityTokenProvider _clientCertificateProvider;
        private X509SecurityToken _clientToken;
        private SecurityTokenAuthenticator _serverCertificateAuthenticator;
        private ChannelBinding _channelBindingToken;
        private static LocalCertificateSelectionCallback s_clientCertificateSelectionCallback;
        private const string RequirementNamespace = "http://schemas.microsoft.com/ws/2006/05/servicemodel/securitytokenrequirement";
        private const string PreferSslCertificateAuthenticatorProperty = RequirementNamespace + "/PreferSslCertificateAuthenticator";
        private const string SecurityTokenTypesNamespace = "http://schemas.microsoft.com/ws/2006/05/identitymodel/tokens";
        private const string X509CertificateTokenType = SecurityTokenTypesNamespace + "/X509Certificate";
        public SslStreamSecurityUpgradeInitiator(SslStreamSecurityUpgradeProvider parent,
            EndpointAddress remoteAddress, Uri via)
            : base(FramingUpgradeString.SslOrTls, remoteAddress, via)
        {
            _parent = parent;
 
            InitiatorServiceModelSecurityTokenRequirement serverCertRequirement = new InitiatorServiceModelSecurityTokenRequirement();
            serverCertRequirement.TokenType = X509CertificateTokenType;
            serverCertRequirement.RequireCryptographicToken = true;
            serverCertRequirement.KeyUsage = SecurityKeyUsage.Exchange;
            serverCertRequirement.TargetAddress = remoteAddress;
            serverCertRequirement.Via = via;
            serverCertRequirement.TransportScheme = _parent.Scheme;
            serverCertRequirement.Properties[PreferSslCertificateAuthenticatorProperty] = true;
 
            SecurityTokenResolver dummy;
            _serverCertificateAuthenticator = parent.ClientSecurityTokenManager.CreateSecurityTokenAuthenticator(serverCertRequirement, out dummy);
 
            if (parent.RequireClientCertificate)
            {
                InitiatorServiceModelSecurityTokenRequirement clientCertRequirement = new InitiatorServiceModelSecurityTokenRequirement();
                clientCertRequirement.TokenType = X509CertificateTokenType;
                clientCertRequirement.RequireCryptographicToken = true;
                clientCertRequirement.KeyUsage = SecurityKeyUsage.Signature;
                clientCertRequirement.TargetAddress = remoteAddress;
                clientCertRequirement.Via = via;
                clientCertRequirement.TransportScheme = _parent.Scheme;
                _clientCertificateProvider = parent.ClientSecurityTokenManager.CreateSecurityTokenProvider(clientCertRequirement);
                if (_clientCertificateProvider == null)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.Format(SR.ClientCredentialsUnableToCreateLocalTokenProvider, clientCertRequirement)));
                }
            }
        }
 
        private static LocalCertificateSelectionCallback ClientCertificateSelectionCallback
        {
            get
            {
                if (s_clientCertificateSelectionCallback == null)
                {
                    s_clientCertificateSelectionCallback = new LocalCertificateSelectionCallback(SelectClientCertificate);
                }
 
                return s_clientCertificateSelectionCallback;
            }
        }
 
        internal ChannelBinding ChannelBinding
        {
            get
            {
                Fx.Assert(IsChannelBindingSupportEnabled, "A request for the ChannelBinding is not permitted without enabling ChannelBinding first (through the IChannelBindingProvider interface)");
                return _channelBindingToken;
            }
        }
 
        internal bool IsChannelBindingSupportEnabled
        {
            get
            {
                return ((IChannelBindingProvider)_parent).IsChannelBindingSupportEnabled;
            }
        }
 
        internal override async ValueTask OpenAsync(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            await base.OpenAsync(timeoutHelper.RemainingTime());
            if (_clientCertificateProvider != null)
            {
                await SecurityUtils.OpenTokenProviderIfRequiredAsync(_clientCertificateProvider, timeoutHelper.RemainingTime());
                _clientToken = (X509SecurityToken)await _clientCertificateProvider.GetTokenAsync(timeoutHelper.RemainingTime());
            }
        }
 
        internal override async ValueTask CloseAsync(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            await base.CloseAsync(timeoutHelper.RemainingTime());
            if (_clientCertificateProvider != null)
            {
                await SecurityUtils.CloseTokenProviderIfRequiredAsync(_clientCertificateProvider, timeoutHelper.RemainingTime());
            }
        }
 
        protected override async Task<(Stream upgradedStream, SecurityMessageProperty remoteSecurity)> OnInitiateUpgradeAsync(Stream stream)
        {
            if (WcfEventSource.Instance.SslOnInitiateUpgradeIsEnabled())
            {
                WcfEventSource.Instance.SslOnInitiateUpgrade();
            }
 
            X509CertificateCollection clientCertificates = null;
            LocalCertificateSelectionCallback selectionCallback = null;
 
            if (_clientToken != null)
            {
                clientCertificates = new X509CertificateCollection();
                clientCertificates.Add(_clientToken.Certificate);
                selectionCallback = ClientCertificateSelectionCallback;
            }
 
            SslStream sslStream = new SslStream(stream, false, ValidateRemoteCertificate, selectionCallback);
 
            try
            {
                await sslStream.AuthenticateAsClientAsync(string.Empty, clientCertificates, _parent.SslProtocols, false);
            }
            catch (SecurityTokenValidationException tokenValidationException)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SecurityNegotiationException(tokenValidationException.Message,
                    tokenValidationException));
            }
            catch (AuthenticationException exception)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SecurityNegotiationException(exception.Message,
                    exception));
            }
            catch (IOException ioException)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SecurityNegotiationException(
                    SR.Format(SR.NegotiationFailedIO, ioException.Message), ioException));
            }
 
            if (IsChannelBindingSupportEnabled)
            {
                _channelBindingToken = ChannelBindingUtility.GetToken(sslStream);
            }
 
            return (sslStream, _serverSecurity);
        }
 
        private static X509Certificate SelectClientCertificate(object sender, string targetHost,
            X509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers)
        {
            return localCertificates[0];
        }
 
        private bool ValidateRemoteCertificate(object sender, X509Certificate certificate, X509Chain chain,
            SslPolicyErrors sslPolicyErrors)
        {
            // Note: add ref to handle since the caller will reset the cert after the callback return.
            X509Certificate2 certificate2 = new X509Certificate2(certificate);
            SecurityToken token = new X509SecurityToken(certificate2);
            ReadOnlyCollection<IAuthorizationPolicy> authorizationPolicies = _serverCertificateAuthenticator.ValidateToken(token);
            _serverSecurity = new SecurityMessageProperty();
            _serverSecurity.TransportToken = new SecurityTokenSpecification(token, authorizationPolicies);
            _serverSecurity.ServiceSecurityContext = new ServiceSecurityContext(authorizationPolicies);
 
            AuthorizationContext authzContext = _serverSecurity.ServiceSecurityContext.AuthorizationContext;
            EnsureOutgoingIdentity(_parent.IdentityVerifier, RemoteAddress, Via, authzContext);
 
            return true;
        }
 
        private static void EnsureOutgoingIdentity(IdentityVerifier verifier, EndpointAddress serviceReference, Uri via, AuthorizationContext authorizationContext)
        {
            // if we don't have an identity and we have differing Uris, we should use the Via
            if (serviceReference.Identity == null && serviceReference.Uri != via)
            {
                serviceReference = new EndpointAddress(via);
            }
            if (authorizationContext == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull(nameof(authorizationContext));
            }
            EndpointIdentity identity;
            if (!verifier.TryGetIdentity(serviceReference, out identity))
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(new MessageSecurityException(SR.Format(SR.IdentityCheckFailedForOutgoingMessage, identity, serviceReference)));
            }
            else
            {
                if (!verifier.CheckAccess(identity, authorizationContext))
                {
                    // CheckAccess performs a Trace on failure, no need to do it twice
                    Exception e = CreateIdentityCheckException(identity, authorizationContext, SR.IdentityCheckFailedForOutgoingMessage, serviceReference);
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(e);
                }
            }
        }
 
        private static Exception CreateIdentityCheckException(EndpointIdentity identity, AuthorizationContext authorizationContext, string errorString, EndpointAddress serviceReference)
        {
            Exception result;
 
            if (identity.IdentityClaim != null
                && identity.IdentityClaim.ClaimType == ClaimTypes.Dns
                && identity.IdentityClaim.Right == Rights.PossessProperty
                && identity.IdentityClaim.Resource is string)
            {
                string expectedDnsName = (string)identity.IdentityClaim.Resource;
                string actualDnsName = null;
                for (int i = 0; i < authorizationContext.ClaimSets.Count; ++i)
                {
                    ClaimSet claimSet = authorizationContext.ClaimSets[i];
                    foreach (Claim claim in claimSet.FindClaims(ClaimTypes.Dns, Rights.PossessProperty))
                    {
                        if (claim.Resource is string)
                        {
                            actualDnsName = (string)claim.Resource;
                            break;
                        }
                    }
                    if (actualDnsName != null)
                    {
                        break;
                    }
                }
                if (actualDnsName == null)
                {
                    result = new MessageSecurityException(SR.Format(SR.DnsIdentityCheckFailedForOutgoingMessageLackOfDnsClaim, expectedDnsName));
                }
                else
                {
                    result = new MessageSecurityException(SR.Format(SR.DnsIdentityCheckFailedForOutgoingMessage, expectedDnsName, actualDnsName));
                }
            }
            else
            {
                result = new MessageSecurityException(SR.Format(errorString, identity, serviceReference));
            }
 
            return result;
        }
    }
}