|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Extensions.Caching.Hybrid.Internal;
/// <summary>
/// The inbuilt implementation of <see cref="HybridCache"/>, as registered via <see cref="HybridCacheServiceExtensions.AddHybridCache(IServiceCollection)"/>.
/// </summary>
[SkipLocalsInit]
internal sealed partial class DefaultHybridCache : HybridCache
{
internal const int DefaultExpirationMinutes = 5;
// reserve non-printable characters from keys, to prevent potential L2 abuse
private static readonly char[] _keyReservedCharacters = Enumerable.Range(0, 32).Select(i => (char)i).ToArray();
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")]
private readonly IDistributedCache? _backendCache;
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")]
private readonly IMemoryCache _localCache;
private readonly IServiceProvider _services; // we can't resolve per-type serializers until we see each T
private readonly IHybridCacheSerializerFactory[] _serializerFactories;
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")]
private readonly HybridCacheOptions _options;
private readonly ILogger _logger;
private readonly CacheFeatures _features; // used to avoid constant type-testing
private readonly TimeProvider _clock;
private readonly HybridCacheEntryFlags _hardFlags; // *always* present (for example, because no L2)
private readonly HybridCacheEntryFlags _defaultFlags; // note this already includes hardFlags
private readonly TimeSpan _defaultExpiration;
private readonly TimeSpan _defaultLocalCacheExpiration;
private readonly int _maximumKeyLength;
private readonly DistributedCacheEntryOptions _defaultDistributedCacheExpiration;
[Flags]
internal enum CacheFeatures
{
None = 0,
BackendCache = 1 << 0,
BackendBuffers = 1 << 1,
}
internal CacheFeatures GetFeatures() => _features;
// used to restrict features in test suite
internal void DebugRemoveFeatures(CacheFeatures features) => Unsafe.AsRef(in _features) &= ~features;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private CacheFeatures GetFeatures(CacheFeatures mask) => _features & mask;
internal bool HasBackendCache => (_features & CacheFeatures.BackendCache) != 0;
public DefaultHybridCache(IOptions<HybridCacheOptions> options, IServiceProvider services)
{
_services = Throw.IfNull(services);
_localCache = services.GetRequiredService<IMemoryCache>();
_options = options.Value;
_logger = services.GetService<ILoggerFactory>()?.CreateLogger(typeof(HybridCache)) ?? NullLogger.Instance;
_clock = services.GetService<TimeProvider>() ?? TimeProvider.System;
_backendCache = services.GetService<IDistributedCache>(); // note optional
// ignore L2 if it is really just the same L1, wrapped
// (note not just an "is" test; if someone has a custom subclass, who knows what it does?)
if (_backendCache is not null
&& _backendCache.GetType() == typeof(MemoryDistributedCache)
&& _localCache.GetType() == typeof(MemoryCache))
{
_backendCache = null;
}
// perform type-tests on the backend once only
_features |= _backendCache switch
{
IBufferDistributedCache => CacheFeatures.BackendCache | CacheFeatures.BackendBuffers,
not null => CacheFeatures.BackendCache,
_ => CacheFeatures.None
};
// When resolving serializers via the factory API, we will want the *last* instance,
// i.e. "last added wins"; we can optimize by reversing the array ahead of time, and
// taking the first match
IHybridCacheSerializerFactory[] factories = services.GetServices<IHybridCacheSerializerFactory>().ToArray();
Array.Reverse(factories);
_serializerFactories = factories;
MaximumPayloadBytes = checked((int)_options.MaximumPayloadBytes); // for now hard-limit to 2GiB
_maximumKeyLength = _options.MaximumKeyLength;
HybridCacheEntryOptions? defaultEntryOptions = _options.DefaultEntryOptions;
if (_backendCache is null)
{
_hardFlags |= HybridCacheEntryFlags.DisableDistributedCache;
}
_defaultFlags = (defaultEntryOptions?.Flags ?? HybridCacheEntryFlags.None) | _hardFlags;
_defaultExpiration = defaultEntryOptions?.Expiration ?? TimeSpan.FromMinutes(DefaultExpirationMinutes);
_defaultLocalCacheExpiration = GetEffectiveLocalCacheExpiration(defaultEntryOptions) ?? _defaultExpiration;
_defaultDistributedCacheExpiration = new DistributedCacheEntryOptions { AbsoluteExpirationRelativeToNow = _defaultExpiration };
#if NET9_0_OR_GREATER
_tagInvalidationTimesUseAltLookup = _tagInvalidationTimes.TryGetAlternateLookup(out _tagInvalidationTimesBySpan);
#endif
// do this last
_globalInvalidateTimestamp = _backendCache is null ? _zeroTimestamp : SafeReadTagInvalidationAsync(TagSet.WildcardTag);
}
internal IDistributedCache? BackendCache => _backendCache;
internal IMemoryCache LocalCache => _localCache;
internal HybridCacheOptions Options => _options;
public override ValueTask<T> GetOrCreateAsync<TState, T>(string key, TState state, Func<TState, CancellationToken, ValueTask<T>> underlyingDataCallback,
HybridCacheEntryOptions? options = null, IEnumerable<string>? tags = null, CancellationToken cancellationToken = default)
{
bool canBeCanceled = cancellationToken.CanBeCanceled;
if (canBeCanceled)
{
cancellationToken.ThrowIfCancellationRequested();
}
HybridCacheEntryFlags flags = GetEffectiveFlags(options);
if (!ValidateKey(key))
{
// we can't use cache, but we can still provide the data
return RunWithoutCacheAsync(flags, state, underlyingDataCallback, cancellationToken);
}
bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled();
if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0)
{
if (TryGetExisting<T>(key, out CacheItem<T>? typed)
&& typed.TryGetValue(_logger, out T? value))
{
// short-circuit
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.LocalCacheHit();
}
return new(value);
}
else
{
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.LocalCacheMiss();
}
}
}
if (GetOrCreateStampedeState<TState, T>(key, flags, out StampedeState<TState, T>? stampede, canBeCanceled, tags))
{
// new query; we're responsible for making it happen
if (canBeCanceled)
{
// *we* might cancel, but someone else might be depending on the result; start the
// work independently, then we'll with join the outcome
stampede.QueueUserWorkItem(in state, underlyingDataCallback, options);
}
else
{
// we're going to run to completion; no need to get complicated
_ = stampede.ExecuteDirectAsync(in state, underlyingDataCallback, options); // this larger task includes L2 write etc
return stampede.UnwrapReservedAsync(_logger);
}
}
else
{
// pre-existing query
if (eventSourceEnabled)
{
HybridCacheEventSource.Log.StampedeJoin();
}
}
return stampede.JoinAsync(_logger, cancellationToken);
}
public override ValueTask RemoveAsync(string key, CancellationToken token = default)
{
_localCache.Remove(key);
return _backendCache is null ? default : new(_backendCache.RemoveAsync(key, token));
}
public override ValueTask SetAsync<T>(string key, T value, HybridCacheEntryOptions? options = null, IEnumerable<string>? tags = null, CancellationToken token = default)
{
// since we're forcing a write: disable L1+L2 read; we'll use a direct pass-thru of the value as the callback, to reuse all the code
// note also that stampede token is not shared with anyone else
HybridCacheEntryFlags flags = GetEffectiveFlags(options) | (HybridCacheEntryFlags.DisableLocalCacheRead | HybridCacheEntryFlags.DisableDistributedCacheRead);
var state = new StampedeState<T, T>(this, new StampedeKey(key, flags), TagSet.Create(tags), token);
return new(state.ExecuteDirectAsync(value, static (state, _) => new(state), options)); // note this spans L2 write etc
}
// exposed as internal for testability
internal TimeSpan GetL1AbsoluteExpirationRelativeToNow(HybridCacheEntryOptions? options)
=> GetEffectiveLocalCacheExpiration(options) ?? _defaultLocalCacheExpiration;
internal TimeSpan GetL2AbsoluteExpirationRelativeToNow(HybridCacheEntryOptions? options) => options?.Expiration ?? _defaultExpiration;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal HybridCacheEntryFlags GetEffectiveFlags(HybridCacheEntryOptions? options)
=> (options?.Flags | _hardFlags) ?? _defaultFlags;
private static ValueTask<T> RunWithoutCacheAsync<TState, T>(HybridCacheEntryFlags flags, TState state,
Func<TState, CancellationToken, ValueTask<T>> underlyingDataCallback,
CancellationToken cancellationToken)
{
return (flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0
? underlyingDataCallback(state, cancellationToken) : default;
}
private static TimeSpan? GetEffectiveLocalCacheExpiration(HybridCacheEntryOptions? options)
{
// If LocalCacheExpiration is not specified, then use option's Expiration, to keep in sync by default.
// Or in other words: the inheritance of "LocalCacheExpiration : Expiration" in a single object takes
// precedence between the inheritance between per-entry options and global options, and if a caller
// provides a per-entry option with *just* the Expiration specified, then that is assumed to also
// specify the LocalCacheExpiration.
if (options is not null)
{
if (options.LocalCacheExpiration is { } local)
{
if (options.Expiration is { } overall)
{
// enforce "not exceeding the remaining overall cache lifetime"
return local < overall ? local : overall;
}
return local;
}
return options.Expiration;
}
return null;
}
private bool ValidateKey(string key)
{
if (string.IsNullOrWhiteSpace(key))
{
_logger.KeyEmptyOrWhitespace();
return false;
}
if (key.Length > _maximumKeyLength)
{
_logger.MaximumKeyLengthExceeded(_maximumKeyLength, key.Length);
return false;
}
if (key.IndexOfAny(_keyReservedCharacters) >= 0)
{
_logger.KeyInvalidContent();
return false;
}
// nothing to complain about
return true;
}
private bool TryGetExisting<T>(string key, [NotNullWhen(true)] out CacheItem<T>? value)
{
if (_localCache.TryGetValue(key, out object? untyped) && untyped is CacheItem<T> typed)
{
// check tag-based and global invalidation
if (IsValid(typed))
{
value = typed;
return true;
}
// remove from L1; note there's a little unavoidable race here; worst case is that
// a fresher value gets dropped - we'll have to accept it
_localCache.Remove(key);
}
// failure
value = null;
return false;
}
}
|