File: System\Net\ServiceNameStore.cs
Web Access
Project: src\src\libraries\System.Net.HttpListener\src\System.Net.HttpListener.csproj (System.Net.HttpListener)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Security.Authentication.ExtendedProtection;
 
namespace System.Net
{
    internal sealed class ServiceNameStore
    {
        private readonly List<string> _serviceNames;
        private ServiceNameCollection? _serviceNameCollection;
 
        public ServiceNameCollection ServiceNames => _serviceNameCollection ??= new ServiceNameCollection(_serviceNames);
 
        public ServiceNameStore()
        {
            _serviceNames = new List<string>();
            _serviceNameCollection = null; // set only when needed (due to expensive item-by-item copy)
        }
 
        private static string? NormalizeServiceName(string? inputServiceName)
        {
            if (string.IsNullOrWhiteSpace(inputServiceName))
            {
                return inputServiceName;
            }
 
            // Separate out the prefix
            int shashIndex = inputServiceName.IndexOf('/');
            if (shashIndex < 0)
            {
                return inputServiceName;
            }
            ReadOnlySpan<char> prefix = inputServiceName.AsSpan(0, shashIndex + 1); // Includes slash
            string hostPortAndDistinguisher = inputServiceName.Substring(shashIndex + 1); // Excludes slash
 
            if (string.IsNullOrWhiteSpace(hostPortAndDistinguisher))
            {
                return inputServiceName;
            }
 
            string host = hostPortAndDistinguisher;
            ReadOnlySpan<char> port = default;
            ReadOnlySpan<char> distinguisher = default;
 
            // Check for the absence of a port or distinguisher.
            UriHostNameType hostType = Uri.CheckHostName(hostPortAndDistinguisher);
            if (hostType == UriHostNameType.Unknown)
            {
                string hostAndPort = hostPortAndDistinguisher;
 
                // Check for distinguisher
                int nextSlashIndex = hostPortAndDistinguisher.IndexOf('/');
                if (nextSlashIndex >= 0)
                {
                    // host:port/distinguisher or host/distinguisher
                    hostAndPort = hostPortAndDistinguisher.Substring(0, nextSlashIndex); // Excludes Slash
                    distinguisher = hostPortAndDistinguisher.AsSpan(nextSlashIndex); // Includes Slash
                    host = hostAndPort; // We don't know if there is a port yet.
 
                    // No need to validate the distinguisher
                }
 
                // Check for port
                int colonIndex = hostAndPort.LastIndexOf(':'); // Allow IPv6 addresses
                if (colonIndex >= 0)
                {
                    // host:port
                    host = hostAndPort.Substring(0, colonIndex); // Excludes colon
                    port = hostAndPort.AsSpan(colonIndex + 1); // Excludes colon
 
                    // Loosely validate the port just to make sure it was a port and not something else
                    if (!ushort.TryParse(port, NumberStyles.Integer, CultureInfo.InvariantCulture, out _))
                    {
                        return inputServiceName;
                    }
 
                    // Re-include the colon for the final output.  Do not change the port format.
                    port = hostAndPort.AsSpan(colonIndex);
                }
 
                hostType = Uri.CheckHostName(host); // Revalidate the host
            }
 
            if (hostType != UriHostNameType.Dns)
            {
                // UriHostNameType.IPv4, UriHostNameType.IPv6: Do not normalize IPv4/6 hosts.
                // UriHostNameType.Basic: This is never returned by CheckHostName today
                // UriHostNameType.Unknown: Nothing recognizable to normalize
                // default Some new UriHostNameType?
                return inputServiceName;
            }
 
            // Now we have a valid DNS host, normalize it.
 
            Uri? constructedUri;
            // This shouldn't fail, but we need to avoid any unexpected exceptions on this code path.
            if (!Uri.TryCreate(Uri.UriSchemeHttp + Uri.SchemeDelimiter + host, UriKind.Absolute, out constructedUri))
            {
                return inputServiceName;
            }
 
            string normalizedHost = constructedUri.GetComponents(
                UriComponents.NormalizedHost, UriFormat.SafeUnescaped);
 
            string normalizedServiceName = string.Concat(prefix, normalizedHost, port, distinguisher);
 
            // Don't return the new one unless we absolutely have to.  It may have only changed casing.
            if (inputServiceName.Equals(normalizedServiceName, StringComparison.OrdinalIgnoreCase))
            {
                return inputServiceName;
            }
 
            return normalizedServiceName;
        }
 
        private bool AddSingleServiceName(string spn)
        {
            spn = NormalizeServiceName(spn)!;
            if (Contains(spn))
            {
                return false;
            }
            else
            {
                _serviceNames.Add(spn);
                return true;
            }
        }
 
        public bool Add(string uriPrefix)
        {
            Debug.Assert(!string.IsNullOrEmpty(uriPrefix));
 
            string[] newServiceNames = BuildServiceNames(uriPrefix);
 
            bool addedAny = false;
            foreach (string spn in newServiceNames)
            {
                if (AddSingleServiceName(spn))
                {
                    addedAny = true;
 
                    if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, SR.Format(SR.net_log_listener_spn_add, spn, uriPrefix));
                }
            }
 
            if (addedAny)
            {
                _serviceNameCollection = null;
            }
            else if (NetEventSource.Log.IsEnabled())
            {
                NetEventSource.Info(this, SR.Format(SR.net_log_listener_spn_not_add, uriPrefix));
            }
 
            return addedAny;
        }
 
        public bool Remove(string uriPrefix)
        {
            Debug.Assert(!string.IsNullOrEmpty(uriPrefix));
 
            string? newServiceName = BuildSimpleServiceName(uriPrefix);
            newServiceName = NormalizeServiceName(newServiceName);
            bool needToRemove = Contains(newServiceName);
 
            if (needToRemove)
            {
                _serviceNames.Remove(newServiceName!);
                _serviceNameCollection = null; //invalidate (readonly) ServiceNameCollection
            }
 
            if (NetEventSource.Log.IsEnabled())
            {
                if (needToRemove)
                {
                    NetEventSource.Info(this, SR.Format(SR.net_log_listener_spn_remove, newServiceName, uriPrefix));
                }
                else
                {
                    NetEventSource.Info(this, SR.Format(SR.net_log_listener_spn_not_remove, uriPrefix));
                }
            }
 
            return needToRemove;
        }
 
        // Assumes already normalized
        private bool Contains(string? newServiceName)
        {
            if (newServiceName == null)
            {
                return false;
            }
 
            foreach (string serviceName in _serviceNames)
            {
                if (serviceName.Equals(newServiceName, StringComparison.OrdinalIgnoreCase))
                {
                    return true;
                }
            }
 
            return false;
        }
 
        public void Clear()
        {
            _serviceNames.Clear();
            _serviceNameCollection = null; //invalidate (readonly) ServiceNameCollection
        }
 
        private static string? ExtractHostname(string uriPrefix, bool allowInvalidUriStrings)
        {
            if (Uri.IsWellFormedUriString(uriPrefix, UriKind.Absolute))
            {
                Uri hostUri = new Uri(uriPrefix);
                return hostUri.Host;
            }
            else if (allowInvalidUriStrings)
            {
                int i = uriPrefix.IndexOf("://", StringComparison.Ordinal) + 3;
                int j = i;
 
                bool inSquareBrackets = false;
                while (j < uriPrefix.Length && uriPrefix[j] != '/' && (uriPrefix[j] != ':' || inSquareBrackets))
                {
                    if (uriPrefix[j] == '[')
                    {
                        if (inSquareBrackets)
                        {
                            j = i;
                            break;
                        }
                        inSquareBrackets = true;
                    }
                    if (inSquareBrackets && uriPrefix[j] == ']')
                    {
                        inSquareBrackets = false;
                    }
                    j++;
                }
 
                return uriPrefix.Substring(i, j - i);
            }
 
            return null;
        }
 
        public static string? BuildSimpleServiceName(string uriPrefix)
        {
            string? hostname = ExtractHostname(uriPrefix, false);
 
            if (hostname != null)
            {
                return "HTTP/" + hostname;
            }
            else
            {
                return null;
            }
        }
 
        public static string[] BuildServiceNames(string uriPrefix)
        {
            string hostname = ExtractHostname(uriPrefix, true)!;
 
            if (hostname == "*" ||
                hostname == "+" ||
                IPAddress.TryParse(hostname, out _))
            {
                // for a wildcard, register the machine name.  If the caller doesn't have DNS permission
                // or the query fails for some reason, don't add an SPN.
                try
                {
                    string machineName = Dns.GetHostEntry(string.Empty).HostName;
                    return new string[] { "HTTP/" + machineName };
                }
                catch (System.Net.Sockets.SocketException)
                {
                    return Array.Empty<string>();
                }
                catch (System.Security.SecurityException)
                {
                    return Array.Empty<string>();
                }
            }
            else if (!hostname.Contains('.'))
            {
                // for a dotless name, try to resolve the FQDN.  If the caller doesn't have DNS permission
                // or the query fails for some reason, add only the dotless name.
                try
                {
                    string fqdn = Dns.GetHostEntry(hostname).HostName;
                    return new string[] { "HTTP/" + hostname, "HTTP/" + fqdn };
                }
                catch (System.Net.Sockets.SocketException)
                {
                    return new string[] { "HTTP/" + hostname };
                }
                catch (System.Security.SecurityException)
                {
                    return new string[] { "HTTP/" + hostname };
                }
            }
            else
            {
                return new string[] { "HTTP/" + hostname };
            }
        }
    }
}