File: System\ServiceModel\Channels\HttpTransportSecurityHelpers.cs
Web Access
Project: src\src\System.ServiceModel.Http\src\System.ServiceModel.Http.csproj (System.ServiceModel.Http)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Numerics;
using System.Runtime;
using System.Security.Cryptography.X509Certificates;
using System.ServiceModel.Security;
 
namespace System.ServiceModel.Channels
{
    internal static class HttpTransportSecurityHelpers
    {
        private static Dictionary<string, int> s_targetNameCounter = new Dictionary<string, int>();
 
        public static void AddIdentityMapping(EndpointAddress target, Message message)
        {
            var hostHeader = GetIdentityHostHeader(target);
            HttpRequestMessageProperty requestProperty;
            if (!message.Properties.TryGetValue(HttpRequestMessageProperty.Name, out requestProperty))
            {
                requestProperty = new HttpRequestMessageProperty();
                message.Properties.Add(HttpRequestMessageProperty.Name, requestProperty);
            }
 
            requestProperty.Headers[HttpRequestHeader.Host] = hostHeader;
        }
 
        public static string GetIdentityHostHeader(EndpointAddress target)
        {
            EndpointIdentity identity = target.Identity;
            string value;
            if (identity != null && !(identity is X509CertificateEndpointIdentity))
            {
                value = SecurityUtils.GetSpnFromIdentity(identity, target);
            }
            else
            {
                value = SecurityUtils.GetSpnFromTarget(target);
            }
 
            // HttpClientHandler supports specifying the SPN via the HOST header. The service name is hard coded to "HTTP/". "HTTP/"
            // is an alias for the "HOST/" service name so we accept either but can't accept anything else.
            if (!(value.StartsWith("host/", StringComparison.OrdinalIgnoreCase) || value.StartsWith("http/", StringComparison.OrdinalIgnoreCase)))
            {
                throw Fx.Exception.AsError(new InvalidOperationException(SR.OnlyDefaultSpnServiceSupported));
            }
 
            // The leading service name has been constrained to be either "HTTP/" or "HOST/" which are both 5 charactes long.
            // This needs to be removed to provide just the hostname part for the Host header.
            return value.Substring(5);
        }
 
        public static void AddServerCertIdentityValidation(HttpClientHandler httpClientHandler, EndpointAddress to)
        {
            X509CertificateEndpointIdentity remoteCertificateIdentity = to.Identity as X509CertificateEndpointIdentity;
            if (remoteCertificateIdentity != null)
            {
                // The following condition should have been validated when the channel was created.
                Fx.Assert(remoteCertificateIdentity.Certificates.Count <= 1,
                    "HTTPS server certificate identity contains multiple certificates");
                var rawData = remoteCertificateIdentity.Certificates[0].GetRawCertData();
                var thumbprint = remoteCertificateIdentity.Certificates[0].Thumbprint;
                bool identityValidator(HttpRequestMessage requestMessage, X509Certificate2 cert, X509Chain chain, SslPolicyErrors policyErrors)
                {
                    try
                    {
                        ValidateServerCertificate(cert, rawData, thumbprint);
                    }
                    catch (SecurityNegotiationException e)
                    {
                        DiagnosticUtility.TraceHandledException(e, TraceEventType.Information);
                        return false;
                    }
 
                    return (policyErrors == SslPolicyErrors.None);
                }
 
                SetServerCertificateValidationCallback(httpClientHandler, identityValidator);
            }
        }
 
        public static void SetServerCertificateValidationCallback(HttpClientHandler handler, Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validator)
        {
            handler.ServerCertificateCustomValidationCallback =
                ChainValidator(handler.ServerCertificateCustomValidationCallback, validator);
        }
 
        private static Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> ChainValidator(
            Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> previousValidator,
            Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validator)
        {
            if (previousValidator == null)
            {
                return validator;
            }
 
            bool chained(HttpRequestMessage request, X509Certificate2 certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
            {
                bool valid = validator(request, certificate, chain, sslPolicyErrors);
                if (valid)
                {
                    return previousValidator(request, certificate, chain, sslPolicyErrors);
                }
 
                return false;
            }
 
            return chained;
        }
 
        private static void ValidateServerCertificate(X509Certificate2 certificate, byte[] rawData, string thumbprint)
        {
            byte[] certRawData = certificate.GetRawCertData();
            bool valid = true;
            if (rawData.Length != certRawData.Length)
            {
                valid = false;
            }
            else
            {
                int i = 0;
                while (true)
                {
                    if ((i + Vector<byte>.Count) > certRawData.Length)
                    {
                        // Not enough bytes left to use vector
                        for (; i < certRawData.Length; i++)
                        {
                            if (certRawData[i] != rawData[i])
                            {
                                valid = false;
                                break;
                            }
                        }
 
                        break;
                    }
 
                    Vector<byte> certDataVec = new Vector<byte>(certRawData, i);
                    Vector<byte> rawDataVec = new Vector<byte>(rawData, i);
                    if (!certDataVec.Equals(rawDataVec))
                    {
                        valid = false;
                        break;
                    }
 
                    i += Vector<byte>.Count;
                }
            }
            if (!valid)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new SecurityNegotiationException(SR.Format(SR.HttpsServerCertThumbprintMismatch,
                    certificate.Subject, certificate.Thumbprint, thumbprint)));
            }
        }
    }
}