File: Cache\DistributedCacheTagHelperService.cs
Web Access
Project: src\src\Mvc\Mvc.TagHelpers\src\Microsoft.AspNetCore.Mvc.TagHelpers.csproj (Microsoft.AspNetCore.Mvc.TagHelpers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Linq;
using System.Text;
using System.Text.Encodings.Web;
using Microsoft.AspNetCore.Html;
using Microsoft.AspNetCore.Razor.TagHelpers;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;
 
namespace Microsoft.AspNetCore.Mvc.TagHelpers.Cache;
 
/// <summary>
/// Implements <see cref="IDistributedCacheTagHelperService"/> and ensures
/// multiple concurrent requests are gated.
/// The entries are stored like this:
/// <list type="bullet">
/// <item>
/// <description>Int32 representing the hashed cache key size.</description>
/// </item>
/// <item>
/// <description>The UTF8 encoded hashed cache key.</description>
/// </item>
/// <item>
/// <description>The UTF8 encoded cached content.</description>
/// </item>
/// </list>
/// </summary>
public partial class DistributedCacheTagHelperService : IDistributedCacheTagHelperService
{
    private readonly IDistributedCacheTagHelperStorage _storage;
    private readonly IDistributedCacheTagHelperFormatter _formatter;
    private readonly HtmlEncoder _htmlEncoder;
    private readonly ILogger _logger;
    private readonly ConcurrentDictionary<CacheTagKey, Task<IHtmlContent>> _workers;
 
    /// <summary>
    /// Creates a new <see cref="DistributedCacheTagHelperService"/>.
    /// </summary>
    /// <param name="storage">The <see cref="DistributedCacheTagHelper"/>'s <see cref="IDistributedCacheTagHelperStorage"/>.</param>
    /// <param name="formatter">The <see cref="IDistributedCacheTagHelperFormatter"/> for cache value serialization.</param>
    /// <param name="HtmlEncoder">The <see cref="HtmlEncoder"/> used to encode cache content.</param>
    /// <param name="loggerFactory">The <see cref="ILoggerFactory"/>.</param>
    public DistributedCacheTagHelperService(
        IDistributedCacheTagHelperStorage storage,
        IDistributedCacheTagHelperFormatter formatter,
        HtmlEncoder HtmlEncoder,
        ILoggerFactory loggerFactory)
    {
        ArgumentNullException.ThrowIfNull(storage);
        ArgumentNullException.ThrowIfNull(formatter);
        ArgumentNullException.ThrowIfNull(HtmlEncoder);
        ArgumentNullException.ThrowIfNull(loggerFactory);
 
        _formatter = formatter;
        _storage = storage;
        _htmlEncoder = HtmlEncoder;
        _logger = loggerFactory.CreateLogger<DistributedCacheTagHelperService>();
        _workers = new ConcurrentDictionary<CacheTagKey, Task<IHtmlContent>>();
    }
 
    /// <inheritdoc />
    public async Task<IHtmlContent> ProcessContentAsync(TagHelperOutput output, CacheTagKey key, DistributedCacheEntryOptions options)
    {
        IHtmlContent content = null;
 
        while (content == null)
        {
            // Is there any request already processing the value?
            if (!_workers.TryGetValue(key, out var result))
            {
                // There is a small race condition here between TryGetValue and TryAdd that might cause the
                // content to be computed more than once. We don't care about this race as the probability of
                // happening is very small and the impact is not critical.
                var tcs = new TaskCompletionSource<IHtmlContent>(creationOptions: TaskCreationOptions.RunContinuationsAsynchronously);
 
                _workers.TryAdd(key, tcs.Task);
 
                try
                {
                    var serializedKey = Encoding.UTF8.GetBytes(key.GenerateKey());
                    var storageKey = key.GenerateHashedKey();
                    var value = await _storage.GetAsync(storageKey);
 
                    if (value == null)
                    {
                        // The value is not cached, we need to render the tag helper output
                        var processedContent = await output.GetChildContentAsync();
 
                        var stringBuilder = new StringBuilder();
                        using (var writer = new StringWriter(stringBuilder))
                        {
                            processedContent.WriteTo(writer, _htmlEncoder);
                        }
 
                        var formattingContext = new DistributedCacheTagHelperFormattingContext
                        {
                            Html = new HtmlString(stringBuilder.ToString())
                        };
 
                        // Then cache the result
                        value = await _formatter.SerializeAsync(formattingContext);
 
                        var encodeValue = Encode(value, serializedKey);
 
                        await _storage.SetAsync(storageKey, encodeValue, options);
 
                        content = formattingContext.Html;
                    }
                    else
                    {
                        // The value was found in the storage, decode and ensure
                        // there is no cache key hash collision
                        byte[] decodedValue = Decode(value, serializedKey);
 
                        try
                        {
                            if (decodedValue != null)
                            {
                                content = await _formatter.DeserializeAsync(decodedValue);
                            }
                        }
                        catch (Exception e)
                        {
                            Log.DistributedFormatterDeserializationException(_logger, storageKey, e);
                        }
                        finally
                        {
                            // If the deserialization fails the content is rendered
                            if (content == null)
                            {
                                content = await output.GetChildContentAsync();
                            }
                        }
                    }
                }
                catch
                {
                    content = null;
                    throw;
                }
                finally
                {
                    // Remove the worker task before setting the result.
                    // If the result is null, other threads would potentially
                    // acquire it otherwise.
                    _workers.TryRemove(key, out _);
 
                    // Notify all other awaiters to render the content
                    tcs.TrySetResult(content);
                }
            }
            else
            {
                content = await result;
            }
        }
 
        return content;
    }
 
    private static byte[] Encode(byte[] value, byte[] serializedKey)
    {
        using (var buffer = new MemoryStream())
        {
            var keyLength = BitConverter.GetBytes(serializedKey.Length);
 
            buffer.Write(keyLength, 0, keyLength.Length);
            buffer.Write(serializedKey, 0, serializedKey.Length);
            buffer.Write(value, 0, value.Length);
 
            return buffer.ToArray();
        }
    }
 
    private static byte[] Decode(byte[] value, byte[] expectedKey)
    {
        byte[] decoded = null;
 
        using (var buffer = new MemoryStream(value))
        {
            var keyLengthBuffer = new byte[sizeof(int)];
            buffer.Read(keyLengthBuffer, 0, keyLengthBuffer.Length);
 
            var keyLength = BitConverter.ToInt32(keyLengthBuffer, 0);
            var serializedKeyBuffer = new byte[keyLength];
            buffer.Read(serializedKeyBuffer, 0, serializedKeyBuffer.Length);
 
            // Ensure we are reading the expected key before continuing
            if (serializedKeyBuffer.SequenceEqual(expectedKey))
            {
                decoded = new byte[value.Length - keyLengthBuffer.Length - serializedKeyBuffer.Length];
                buffer.Read(decoded, 0, decoded.Length);
            }
        }
 
        return decoded;
    }
 
    private static partial class Log
    {
        [LoggerMessage(1, LogLevel.Error, "Couldn't deserialize cached value for key {Key}.", EventName = "DistributedFormatterDeserializationException")]
        public static partial void DistributedFormatterDeserializationException(ILogger logger, string key, Exception exception);
    }
}