File: System\ServiceModel\Channels\HttpsChannelFactory.cs
Web Access
Project: src\src\System.ServiceModel.Http\src\System.ServiceModel.Http.csproj (System.ServiceModel.Http)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
 
using System.IdentityModel.Claims;
using System.IdentityModel.Selectors;
using System.IdentityModel.Tokens;
using System.Net.Http;
using System.Net.Security;
using System.Runtime;
using System.Security.Cryptography.X509Certificates;
using System.ServiceModel.Description;
using System.ServiceModel.Security;
using System.ServiceModel.Security.Tokens;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal class HttpsChannelFactory<TChannel> : HttpChannelFactory<TChannel>
    {
        private X509CertificateValidator _sslCertificateValidator;
        private Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> _remoteCertificateValidationCallback;
 
        internal HttpsChannelFactory(HttpsTransportBindingElement httpsBindingElement, BindingContext context)
            : base(httpsBindingElement, context)
        {
            RequireClientCertificate = httpsBindingElement.RequireClientCertificate;
            ClientCredentials credentials = context.BindingParameters.Find<ClientCredentials>();
            if (credentials != null && credentials.ServiceCertificate.SslCertificateAuthentication != null)
            {
                _sslCertificateValidator = credentials.ServiceCertificate.SslCertificateAuthentication.GetCertificateValidator();
                _remoteCertificateValidationCallback = RemoteCertificateValidationCallback;
                if (_sslCertificateValidator != null)
                {
                    WebSocketCertificateCallback = WebSocketRemoteCertificateValidationCallback;
                }
            }
        }
 
        public override string Scheme
        {
            get
            {
                return UriEx.UriSchemeHttps;
            }
        }
 
        public bool RequireClientCertificate { get; }
 
        public override bool IsChannelBindingSupportEnabled
        {
            get
            {
                return false;
            }
        }
 
        internal System.Net.Security.RemoteCertificateValidationCallback WebSocketCertificateCallback { get; }
 
        public override T GetProperty<T>()
        {
            return base.GetProperty<T>();
        }
 
 
        protected override void ValidateCreateChannelParameters(EndpointAddress remoteAddress, Uri via)
        {
            if (remoteAddress.Identity != null)
            {
                X509CertificateEndpointIdentity certificateIdentity =
                    remoteAddress.Identity as X509CertificateEndpointIdentity;
                if (certificateIdentity != null)
                {
                    if (certificateIdentity.Certificates.Count > 1)
                    {
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument(nameof(remoteAddress), SR.Format(
                            SR.HttpsIdentityMultipleCerts, remoteAddress.Uri));
                    }
                }
 
                EndpointIdentity identity = remoteAddress.Identity;
                bool validIdentity = (certificateIdentity != null)
                    || ClaimTypes.Spn.Equals(identity.IdentityClaim.ClaimType)
                    || ClaimTypes.Upn.Equals(identity.IdentityClaim.ClaimType)
                    || ClaimTypes.Dns.Equals(identity.IdentityClaim.ClaimType);
 
                if (!IsWindowsAuth(AuthenticationScheme)
                    && !validIdentity)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument(nameof(remoteAddress), SR.HttpsExplicitIdentity);
                }
            }
 
            if (string.Compare(via.Scheme, "wss", StringComparison.OrdinalIgnoreCase) != 0)
            {
                ValidateScheme(via);
            }
 
            if (MessageVersion.Addressing == AddressingVersion.None && remoteAddress.Uri != via)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(CreateToMustEqualViaException(remoteAddress.Uri, via));
            }
        }
 
        protected override TChannel OnCreateChannelCore(EndpointAddress address, Uri via)
        {
            ValidateCreateChannelParameters(address, via);
            ValidateWebSocketTransportUsage();
 
            if (typeof(TChannel) == typeof(IRequestChannel))
            {
                return (TChannel)(object)new HttpsClientRequestChannel((HttpsChannelFactory<IRequestChannel>)(object)this, address, via, ManualAddressing);
            }
            else
            {
                return (TChannel)(object)new ClientWebSocketTransportDuplexSessionChannel((HttpChannelFactory<IDuplexSessionChannel>)(object)this, address, via);
            }
        }
 
        protected override bool IsSecurityTokenManagerRequired()
        {
            return RequireClientCertificate || base.IsSecurityTokenManagerRequired();
        }
 
 
        private void OnOpenCore()
        {
            if (RequireClientCertificate && SecurityTokenManager == null)
            {
                throw Fx.AssertAndThrow("HttpsChannelFactory: SecurityTokenManager is null on open.");
            }
        }
 
        protected override void OnEndOpen(IAsyncResult result)
        {
            base.OnEndOpen(result);
            OnOpenCore();
        }
 
        protected override void OnOpen(TimeSpan timeout)
        {
            base.OnOpen(timeout);
            OnOpenCore();
        }
 
        protected internal override async Task OnOpenAsync(TimeSpan timeout)
        {
            await base.OnOpenAsync(timeout);
            OnOpenCore();
        }
 
        internal async Task<SecurityTokenProvider> CreateAndOpenCertificateTokenProviderAsync(EndpointAddress target, Uri via, ChannelParameterCollection channelParameters, TimeSpan timeout)
        {
            if (!RequireClientCertificate)
            {
                return null;
            }
 
            SecurityTokenProvider certificateProvider = TransportSecurityHelpers.GetCertificateTokenProvider(
                SecurityTokenManager, target, via, Scheme, channelParameters);
            await SecurityUtils.OpenTokenProviderIfRequiredAsync(certificateProvider, timeout);
            return certificateProvider;
        }
 
        internal async Task<SecurityTokenContainer> GetCertificateSecurityTokenAsync(SecurityTokenProvider certificateProvider,
            EndpointAddress to, Uri via, ChannelParameterCollection channelParameters, TimeoutHelper timeoutHelper)
        {
            SecurityToken token = null;
            SecurityTokenContainer tokenContainer = null;
            SecurityTokenProvider requestCertificateProvider;
            if (ManualAddressing && RequireClientCertificate)
            {
                requestCertificateProvider = await CreateAndOpenCertificateTokenProviderAsync(to, via, channelParameters, timeoutHelper.RemainingTime());
            }
            else
            {
                requestCertificateProvider = certificateProvider;
            }
 
            if (requestCertificateProvider != null)
            {
                token = await requestCertificateProvider.GetTokenAsync(timeoutHelper.RemainingTime());
            }
 
            if (ManualAddressing && RequireClientCertificate)
            {
                SecurityUtils.AbortTokenProviderIfRequired(requestCertificateProvider);
            }
 
            if (token != null)
            {
                tokenContainer = new SecurityTokenContainer(token);
            }
 
            return tokenContainer;
        }
 
        private void AddServerCertMappingOrSetRemoteCertificateValidationCallback(HttpClientHandler httpClientHandler, EndpointAddress to)
        {
            Fx.Assert(httpClientHandler != null, "httpClientHandler should not be null.");
            if (_sslCertificateValidator != null)
            {
                httpClientHandler.ServerCertificateCustomValidationCallback = _remoteCertificateValidationCallback;
            }
            else
            {
                HttpTransportSecurityHelpers.AddServerCertIdentityValidation(httpClientHandler, to);
            }
        }
 
        private bool RemoteCertificateValidationCallback(HttpRequestMessage sender, X509Certificate2 certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
        {
            Fx.Assert(_sslCertificateValidator != null, "sslCertificateValidator should not be null.");
 
            try
            {
                _sslCertificateValidator.Validate(certificate);
                return true;
            }
            catch (SecurityTokenValidationException ex)
            {
                FxTrace.Exception.AsInformation(ex);
                return false;
            }
            catch (Exception ex)
            {
                if (Fx.IsFatal(ex))
                {
                    throw;
                }
 
                FxTrace.Exception.AsWarning(ex);
                return false;
            }
        }
 
        private bool WebSocketRemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
        {
            Fx.Assert(_sslCertificateValidator != null, "sslCertificateValidator should not be null.");
 
            if (certificate is not X509Certificate2 certificate2)
            {
                return false;
            }
 
            try
            {
                _sslCertificateValidator.Validate(certificate2);
                return true;
            }
            catch (SecurityTokenValidationException ex)
            {
                FxTrace.Exception.AsInformation(ex);
                return false;
            }
            catch (Exception ex)
            {
                if (Fx.IsFatal(ex))
                {
                    throw;
                }
 
                FxTrace.Exception.AsWarning(ex);
                return false;
            }
        }
 
        internal override HttpClientHandler GetHttpClientHandler(EndpointAddress to, SecurityTokenContainer clientCertificateToken)
        {
            HttpClientHandler handler = base.GetHttpClientHandler(to, clientCertificateToken);
            if (RequireClientCertificate)
            {
                SetCertificate(handler, clientCertificateToken);
            }
 
            AddServerCertMappingOrSetRemoteCertificateValidationCallback(handler, to);
            return handler;
        }
 
        internal override bool IsExpectContinueHeaderRequired => RequireClientCertificate || base.IsExpectContinueHeaderRequired;
 
        private static void SetCertificate(HttpClientHandler handler, SecurityTokenContainer clientCertificateToken)
        {
            if (clientCertificateToken != null)
            {
                X509SecurityToken x509Token = (X509SecurityToken)clientCertificateToken.Token;
                ValidateClientCertificate(x509Token.Certificate);
                handler.ClientCertificateOptions = ClientCertificateOption.Manual;
                handler.ClientCertificates.Add(x509Token.Certificate);
            }
        }
 
        private static void ValidateClientCertificate(X509Certificate2 certificate)
        {
            if (Fx.IsUap)
            {
                using (var store = new X509Store(StoreName.My, StoreLocation.CurrentUser))
                {
                    store.Open(OpenFlags.ReadOnly);
                    if (store.Certificates.Find(X509FindType.FindByThumbprint, certificate.GetCertHashString(), true).Count == 0)
                    {
                        throw ExceptionHelper.PlatformNotSupported("Certificate could not be found in the MY store.");
                    }
                }
            }
        }
 
        protected class HttpsClientRequestChannel : HttpClientRequestChannel
        {
            private SecurityTokenProvider _certificateProvider;
 
            public HttpsClientRequestChannel(HttpsChannelFactory<IRequestChannel> factory, EndpointAddress to, Uri via, bool manualAddressing)
                : base(factory, to, via, manualAddressing)
            {
                Factory = factory;
            }
 
            public new HttpsChannelFactory<IRequestChannel> Factory { get; }
 
            private async Task CreateAndOpenTokenProviderAsync(TimeSpan timeout)
            {
                if (!ManualAddressing && Factory.RequireClientCertificate)
                {
                    _certificateProvider = await Factory.CreateAndOpenCertificateTokenProviderAsync(RemoteAddress, Via, ChannelParameters, timeout);
                }
            }
 
            private void CloseTokenProvider(TimeSpan timeout)
            {
                if (_certificateProvider != null)
                {
                    SecurityUtils.CloseTokenProviderIfRequired(_certificateProvider, timeout);
                }
            }
 
            private void AbortTokenProvider()
            {
                if (_certificateProvider != null)
                {
                    SecurityUtils.AbortTokenProviderIfRequired(_certificateProvider);
                }
            }
 
            protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
            {
                return OnOpenAsync(timeout).ToApm(callback, state);
            }
 
            protected override void OnEndOpen(IAsyncResult result)
            {
                result.ToApmEnd();
            }
 
            protected override void OnOpen(TimeSpan timeout)
            {
                OnOpenAsync(timeout).WaitForCompletion();
            }
 
            internal protected override async Task OnOpenAsync(TimeSpan timeout)
            {
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
                await CreateAndOpenTokenProviderAsync(timeoutHelper.RemainingTime());
                await base.OnOpenAsync(timeoutHelper.RemainingTime());
            }
 
            protected override void OnAbort()
            {
                AbortTokenProvider();
                base.OnAbort();
            }
 
            protected override IAsyncResult OnBeginClose(TimeSpan timeout, AsyncCallback callback, object state)
            {
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
                CloseTokenProvider(timeoutHelper.RemainingTime());
                return base.OnBeginClose(timeoutHelper.RemainingTime(), callback, state);
            }
 
            protected override void OnClose(TimeSpan timeout)
            {
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
                CloseTokenProvider(timeoutHelper.RemainingTime());
                base.OnClose(timeoutHelper.RemainingTime());
            }
 
            internal override void OnHttpRequestCompleted(HttpRequestMessage request)
            {
            }
 
            internal override async Task<HttpClient> GetHttpClientAsync(EndpointAddress to, Uri via, TimeoutHelper timeoutHelper)
            {
                SecurityTokenContainer clientCertificateToken = await Factory.GetCertificateSecurityTokenAsync(_certificateProvider, to, via, ChannelParameters, timeoutHelper);
                HttpClient httpClient = await GetHttpClientAsync(to, via, clientCertificateToken, timeoutHelper);
                return httpClient;
            }
        }
    }
}