|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
namespace System.Net
{
// More sophisticated password cache that stores multiple
// name-password pairs and associates these with host/realm.
public class CredentialCache : ICredentials, ICredentialsByHost, IEnumerable
{
private Dictionary<CredentialCacheKey, NetworkCredential>? _cache;
private Dictionary<CredentialHostKey, NetworkCredential>? _cacheForHosts;
private int _version;
public CredentialCache()
{
}
public void Add(Uri uriPrefix, string authType, NetworkCredential cred)
{
ArgumentNullException.ThrowIfNull(uriPrefix);
ArgumentNullException.ThrowIfNull(authType);
if ((cred is SystemNetworkCredential)
&& !((string.Equals(authType, NegotiationInfoClass.NTLM, StringComparison.OrdinalIgnoreCase))
|| (string.Equals(authType, NegotiationInfoClass.Kerberos, StringComparison.OrdinalIgnoreCase))
|| (string.Equals(authType, NegotiationInfoClass.Negotiate, StringComparison.OrdinalIgnoreCase)))
)
{
throw new ArgumentException(SR.Format(SR.net_nodefaultcreds, authType), nameof(authType));
}
++_version;
var key = new CredentialCacheKey(uriPrefix, authType);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Adding key:[{key}], cred:[{cred.Domain}],[{cred.UserName}]");
_cache ??= new Dictionary<CredentialCacheKey, NetworkCredential>();
_cache.Add(key, cred);
}
public void Add(string host, int port, string authenticationType, NetworkCredential credential)
{
ArgumentException.ThrowIfNullOrEmpty(host);
ArgumentNullException.ThrowIfNull(authenticationType);
ArgumentOutOfRangeException.ThrowIfNegative(port);
if ((credential is SystemNetworkCredential)
&& !((string.Equals(authenticationType, NegotiationInfoClass.NTLM, StringComparison.OrdinalIgnoreCase))
|| (string.Equals(authenticationType, NegotiationInfoClass.Kerberos, StringComparison.OrdinalIgnoreCase))
|| (string.Equals(authenticationType, NegotiationInfoClass.Negotiate, StringComparison.OrdinalIgnoreCase)))
)
{
throw new ArgumentException(SR.Format(SR.net_nodefaultcreds, authenticationType), nameof(authenticationType));
}
++_version;
var key = new CredentialHostKey(host, port, authenticationType);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Adding key:[{key}], cred:[{credential.Domain}],[{credential.UserName}]");
_cacheForHosts ??= new Dictionary<CredentialHostKey, NetworkCredential>();
_cacheForHosts.Add(key, credential);
}
public void Remove(Uri? uriPrefix, string? authType)
{
if (uriPrefix == null || authType == null)
{
// These couldn't possibly have been inserted into
// the cache because of the test in Add().
return;
}
if (_cache == null)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "Short-circuiting because the dictionary is null.");
return;
}
++_version;
var key = new CredentialCacheKey(uriPrefix, authType);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Removing key:[{key}]");
_cache.Remove(key);
}
public void Remove(string? host, int port, string? authenticationType)
{
if (host == null || authenticationType == null)
{
// These couldn't possibly have been inserted into
// the cache because of the test in Add().
return;
}
if (port < 0)
{
return;
}
if (_cacheForHosts == null)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "Short-circuiting because the dictionary is null.");
return;
}
++_version;
var key = new CredentialHostKey(host, port, authenticationType);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Removing key:[{key}]");
_cacheForHosts.Remove(key);
}
public NetworkCredential? GetCredential(Uri uriPrefix, string authType)
{
ArgumentNullException.ThrowIfNull(uriPrefix);
ArgumentNullException.ThrowIfNull(authType);
if (_cache == null)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "CredentialCache::GetCredential short-circuiting because the dictionary is null.");
return null;
}
CredentialCacheHelper.TryGetCredential(_cache, uriPrefix, authType, out _ /*uri*/, out NetworkCredential? mostSpecificMatch);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Returning {(mostSpecificMatch == null ? "null" : "(" + mostSpecificMatch.UserName + ":" + mostSpecificMatch.Domain + ")")}");
return mostSpecificMatch;
}
public NetworkCredential? GetCredential(string host, int port, string authenticationType)
{
ArgumentException.ThrowIfNullOrEmpty(host);
ArgumentNullException.ThrowIfNull(authenticationType);
ArgumentOutOfRangeException.ThrowIfNegative(port);
if (_cacheForHosts == null)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "CredentialCache::GetCredential short-circuiting because the dictionary is null.");
return null;
}
var key = new CredentialHostKey(host, port, authenticationType);
NetworkCredential? match;
_cacheForHosts.TryGetValue(key, out match);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Returning {((match == null) ? "null" : "(" + match.UserName + ":" + match.Domain + ")")}");
return match;
}
public IEnumerator GetEnumerator() => CredentialEnumerator.Create(this);
public static ICredentials DefaultCredentials => SystemNetworkCredential.s_defaultCredential;
public static NetworkCredential DefaultNetworkCredentials => SystemNetworkCredential.s_defaultCredential;
private class CredentialEnumerator : IEnumerator
{
internal static CredentialEnumerator Create(CredentialCache cache)
{
Debug.Assert(cache != null);
if (cache._cache != null)
{
return cache._cacheForHosts != null ?
new DoubleTableCredentialEnumerator(cache) :
new SingleTableCredentialEnumerator<CredentialCacheKey>(cache, cache._cache);
}
else
{
return cache._cacheForHosts != null ?
new SingleTableCredentialEnumerator<CredentialHostKey>(cache, cache._cacheForHosts) :
new CredentialEnumerator(cache);
}
}
private readonly CredentialCache _cache;
private readonly int _version;
private bool _enumerating;
private NetworkCredential? _current;
private CredentialEnumerator(CredentialCache cache)
{
Debug.Assert(cache != null);
_cache = cache;
_version = cache._version;
}
public object Current
{
get
{
if (!_enumerating)
{
throw new InvalidOperationException(SR.InvalidOperation_EnumOpCantHappen);
}
if (_version != _cache._version)
{
throw new InvalidOperationException(SR.InvalidOperation_EnumFailedVersion);
}
return _current!;
}
}
public bool MoveNext()
{
if (_version != _cache._version)
{
throw new InvalidOperationException(SR.InvalidOperation_EnumFailedVersion);
}
return _enumerating = MoveNext(out _current);
}
protected virtual bool MoveNext(out NetworkCredential? current)
{
current = null;
return false;
}
public virtual void Reset()
{
_enumerating = false;
}
private class SingleTableCredentialEnumerator<TKey> : CredentialEnumerator where TKey : notnull
{
private Dictionary<TKey, NetworkCredential>.ValueCollection.Enumerator _enumerator; // mutable struct field deliberately not readonly.
public SingleTableCredentialEnumerator(CredentialCache cache, Dictionary<TKey, NetworkCredential> table) : base(cache)
{
Debug.Assert(table != null);
// Despite the ValueCollection allocation, ValueCollection's enumerator is faster
// than Dictionary's enumerator for enumerating the values because it avoids
// KeyValuePair copying.
_enumerator = table.Values.GetEnumerator();
}
protected override bool MoveNext(out NetworkCredential current) =>
DictionaryEnumeratorHelper.MoveNext(ref _enumerator, out current);
public override void Reset()
{
DictionaryEnumeratorHelper.Reset(ref _enumerator);
base.Reset();
}
}
private sealed class DoubleTableCredentialEnumerator : SingleTableCredentialEnumerator<CredentialCacheKey>
{
private Dictionary<CredentialHostKey, NetworkCredential>.ValueCollection.Enumerator _enumerator; // mutable struct field deliberately not readonly.
private bool _onThisEnumerator;
public DoubleTableCredentialEnumerator(CredentialCache cache) : base(cache, cache._cache!)
{
Debug.Assert(cache._cacheForHosts != null);
// Despite the ValueCollection allocation, ValueCollection's enumerator is faster
// than Dictionary's enumerator for enumerating the values because it avoids
// KeyValuePair copying.
_enumerator = cache._cacheForHosts.Values.GetEnumerator();
}
protected override bool MoveNext(out NetworkCredential current)
{
if (!_onThisEnumerator)
{
if (base.MoveNext(out current))
{
return true;
}
else
{
_onThisEnumerator = true;
}
}
return DictionaryEnumeratorHelper.MoveNext(ref _enumerator, out current);
}
public override void Reset()
{
_onThisEnumerator = false;
DictionaryEnumeratorHelper.Reset(ref _enumerator);
base.Reset();
}
}
private static class DictionaryEnumeratorHelper
{
internal static bool MoveNext<TKey, TValue>(ref Dictionary<TKey, TValue>.ValueCollection.Enumerator enumerator, out TValue current) where TKey : notnull
{
bool result = enumerator.MoveNext();
current = enumerator.Current;
return result;
}
// Allows calling Reset on Dictionary's struct enumerator without a box allocation.
internal static void Reset<TEnumerator>(ref TEnumerator enumerator) where TEnumerator : IEnumerator
{
// The Dictionary enumerator's Reset method throws if the Dictionary has changed, but
// CredentialCache.Reset should not throw, so we catch and swallow the exception.
try { enumerator.Reset(); } catch (InvalidOperationException) { }
}
}
}
}
// Abstraction for credentials in password-based
// authentication schemes (basic, digest, NTLM, Kerberos).
//
// Note that this is not applicable to public-key based
// systems such as SSL client authentication.
//
// "Password" here may be the clear text password or it
// could be a one-way hash that is sufficient to
// authenticate, as in HTTP/1.1 digest.
internal sealed class SystemNetworkCredential : NetworkCredential
{
internal static readonly SystemNetworkCredential s_defaultCredential = new SystemNetworkCredential();
// We want reference equality to work. Making this private is a good way to guarantee that.
private SystemNetworkCredential() :
base(string.Empty, string.Empty, string.Empty)
{
}
}
internal readonly struct CredentialHostKey : IEquatable<CredentialHostKey>
{
public readonly string Host;
public readonly string AuthenticationType;
public readonly int Port;
internal CredentialHostKey(string host, int port, string authenticationType)
{
Debug.Assert(!string.IsNullOrEmpty(host));
Debug.Assert(port >= 0);
Debug.Assert(authenticationType != null);
Host = host;
Port = port;
AuthenticationType = authenticationType;
}
public override int GetHashCode() =>
StringComparer.OrdinalIgnoreCase.GetHashCode(AuthenticationType) ^
StringComparer.OrdinalIgnoreCase.GetHashCode(Host) ^
Port.GetHashCode();
public bool Equals(CredentialHostKey other)
{
bool equals =
string.Equals(AuthenticationType, other.AuthenticationType, StringComparison.OrdinalIgnoreCase) &&
string.Equals(Host, other.Host, StringComparison.OrdinalIgnoreCase) &&
Port == other.Port;
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Equals({this},{other}) returns {equals}");
return equals;
}
public override bool Equals([NotNullWhen(true)] object? obj) =>
obj is CredentialHostKey && Equals((CredentialHostKey)obj);
public override string ToString() =>
string.Create(CultureInfo.InvariantCulture, $"{Host}:{Port}:{AuthenticationType}");
}
}
|