File: ResponseCachingMiddleware.cs
Web Access
Project: src\src\Middleware\ResponseCaching\src\Microsoft.AspNetCore.ResponseCaching.csproj (Microsoft.AspNetCore.ResponseCaching)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.ObjectPool;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
 
namespace Microsoft.AspNetCore.ResponseCaching;
 
/// <summary>
/// Enable HTTP response caching.
/// </summary>
public class ResponseCachingMiddleware
{
    private static readonly TimeSpan DefaultExpirationTimeSpan = TimeSpan.FromSeconds(10);
 
    // see https://tools.ietf.org/html/rfc7232#section-4.1
    private static readonly string[] HeadersToIncludeIn304 =
        new[] { "Cache-Control", "Content-Location", "Date", "ETag", "Expires", "Vary" };
 
    private readonly RequestDelegate _next;
    private readonly ResponseCachingOptions _options;
    private readonly ILogger _logger;
    private readonly IResponseCachingPolicyProvider _policyProvider;
    private readonly IResponseCache _cache;
    private readonly IResponseCachingKeyProvider _keyProvider;
 
    /// <summary>
    /// Creates a new <see cref="ResponseCachingMiddleware"/>.
    /// </summary>
    /// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
    /// <param name="options">The options for this middleware.</param>
    /// <param name="loggerFactory">The <see cref="ILoggerFactory"/> used for logging.</param>
    /// <param name="poolProvider">The <see cref="ObjectPoolProvider"/> used for creating <see cref="ObjectPool"/> instances.</param>
    public ResponseCachingMiddleware(
        RequestDelegate next,
        IOptions<ResponseCachingOptions> options,
        ILoggerFactory loggerFactory,
        ObjectPoolProvider poolProvider)
        : this(
            next,
            options,
            loggerFactory,
            new ResponseCachingPolicyProvider(),
            new MemoryResponseCache(new MemoryCache(new MemoryCacheOptions
            {
                SizeLimit = options.Value.SizeLimit
            })),
            new ResponseCachingKeyProvider(poolProvider, options))
    { }
 
    // for testing
    internal ResponseCachingMiddleware(
        RequestDelegate next,
        IOptions<ResponseCachingOptions> options,
        ILoggerFactory loggerFactory,
        IResponseCachingPolicyProvider policyProvider,
        IResponseCache cache,
        IResponseCachingKeyProvider keyProvider)
    {
        ArgumentNullException.ThrowIfNull(next);
        ArgumentNullException.ThrowIfNull(options);
        ArgumentNullException.ThrowIfNull(loggerFactory);
        ArgumentNullException.ThrowIfNull(policyProvider);
        ArgumentNullException.ThrowIfNull(cache);
        ArgumentNullException.ThrowIfNull(keyProvider);
 
        _next = next;
        _options = options.Value;
        _logger = loggerFactory.CreateLogger<ResponseCachingMiddleware>();
        _policyProvider = policyProvider;
        _cache = cache;
        _keyProvider = keyProvider;
    }
 
    /// <summary>
    /// Invokes the logic of the middleware.
    /// </summary>
    /// <param name="httpContext">The <see cref="HttpContext"/>.</param>
    /// <returns>A <see cref="Task"/> that completes when the middleware has completed processing.</returns>
    public async Task Invoke(HttpContext httpContext)
    {
        var context = new ResponseCachingContext(httpContext, _logger);
 
        // Should we attempt any caching logic?
        if (_policyProvider.AttemptResponseCaching(context))
        {
            // Can this request be served from cache?
            if (_policyProvider.AllowCacheLookup(context) && await TryServeFromCacheAsync(context))
            {
                return;
            }
 
            // Should we store the response to this request?
            if (_policyProvider.AllowCacheStorage(context))
            {
                // Hook up to listen to the response stream
                ShimResponseStream(context);
 
                try
                {
                    await _next(httpContext);
 
                    // If there was no response body, check the response headers now. We can cache things like redirects.
                    StartResponse(context);
 
                    // Finalize the cache entry
                    FinalizeCacheBody(context);
                }
                finally
                {
                    UnshimResponseStream(context);
                }
 
                return;
            }
        }
 
        // Response should not be captured but add IResponseCachingFeature which may be required when the response is generated
        AddResponseCachingFeature(httpContext);
 
        try
        {
            await _next(httpContext);
        }
        finally
        {
            RemoveResponseCachingFeature(httpContext);
        }
    }
 
    internal async Task<bool> TryServeCachedResponseAsync(ResponseCachingContext context, IResponseCacheEntry? cacheEntry)
    {
        if (!(cacheEntry is CachedResponse cachedResponse))
        {
            return false;
        }
 
        context.CachedResponse = cachedResponse;
        context.CachedResponseHeaders = cachedResponse.Headers;
        context.ResponseTime = _options.TimeProvider.GetUtcNow();
        var cachedEntryAge = context.ResponseTime.Value - context.CachedResponse.Created;
        context.CachedEntryAge = cachedEntryAge > TimeSpan.Zero ? cachedEntryAge : TimeSpan.Zero;
 
        if (_policyProvider.IsCachedEntryFresh(context))
        {
            // Check conditional request rules
            if (ContentIsNotModified(context))
            {
                _logger.NotModifiedServed();
                context.HttpContext.Response.StatusCode = StatusCodes.Status304NotModified;
 
                if (context.CachedResponseHeaders != null)
                {
                    foreach (var key in HeadersToIncludeIn304)
                    {
                        if (context.CachedResponseHeaders.TryGetValue(key, out var values))
                        {
                            context.HttpContext.Response.Headers[key] = values;
                        }
                    }
                }
            }
            else
            {
                var response = context.HttpContext.Response;
                // Copy the cached status code and response headers
                response.StatusCode = context.CachedResponse.StatusCode;
                foreach (var header in context.CachedResponse.Headers)
                {
                    response.Headers[header.Key] = header.Value;
                }
 
                // Note: int64 division truncates result and errors may be up to 1 second. This reduction in
                // accuracy of age calculation is considered appropriate since it is small compared to clock
                // skews and the "Age" header is an estimate of the real age of cached content.
                response.Headers.Age = HeaderUtilities.FormatNonNegativeInt64(context.CachedEntryAge.Value.Ticks / TimeSpan.TicksPerSecond);
 
                // Copy the cached response body
                var body = context.CachedResponse.Body;
                if (body.Length > 0)
                {
                    try
                    {
                        await body.CopyToAsync(response.BodyWriter, context.HttpContext.RequestAborted);
                    }
                    catch (OperationCanceledException)
                    {
                        context.HttpContext.Abort();
                    }
                }
                _logger.CachedResponseServed();
            }
            return true;
        }
 
        return false;
    }
 
    internal async Task<bool> TryServeFromCacheAsync(ResponseCachingContext context)
    {
        context.BaseKey = _keyProvider.CreateBaseKey(context);
        var cacheEntry = _cache.Get(context.BaseKey);
 
        if (cacheEntry is CachedVaryByRules cachedVaryByRules)
        {
            // Request contains vary rules, recompute key(s) and try again
            context.CachedVaryByRules = cachedVaryByRules;
 
            foreach (var varyKey in _keyProvider.CreateLookupVaryByKeys(context))
            {
                if (await TryServeCachedResponseAsync(context, _cache.Get(varyKey)))
                {
                    return true;
                }
            }
        }
        else
        {
            if (await TryServeCachedResponseAsync(context, cacheEntry))
            {
                return true;
            }
        }
 
        if (HeaderUtilities.ContainsCacheDirective(context.HttpContext.Request.Headers.CacheControl, CacheControlHeaderValue.OnlyIfCachedString))
        {
            _logger.GatewayTimeoutServed();
            context.HttpContext.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
            return true;
        }
 
        _logger.NoResponseServed();
        return false;
    }
 
    /// <summary>
    /// Finalize cache headers.
    /// </summary>
    /// <param name="context"></param>
    /// <returns><c>true</c> if a vary by entry needs to be stored in the cache; otherwise <c>false</c>.</returns>
    private bool OnFinalizeCacheHeaders(ResponseCachingContext context)
    {
        if (_policyProvider.IsResponseCacheable(context))
        {
            var storeVaryByEntry = false;
            context.ShouldCacheResponse = true;
 
            // Create the cache entry now
            var response = context.HttpContext.Response;
            var headers = response.Headers;
            var varyHeaders = new StringValues(headers.GetCommaSeparatedValues(HeaderNames.Vary));
            var varyQueryKeys = new StringValues(context.HttpContext.Features.Get<IResponseCachingFeature>()?.VaryByQueryKeys);
            context.CachedResponseValidFor = context.ResponseSharedMaxAge ??
                context.ResponseMaxAge ??
                (context.ResponseExpires - context.ResponseTime!.Value) ??
                DefaultExpirationTimeSpan;
 
            // Generate a base key if none exist
            if (string.IsNullOrEmpty(context.BaseKey))
            {
                context.BaseKey = _keyProvider.CreateBaseKey(context);
            }
 
            // Check if any vary rules exist
            if (!StringValues.IsNullOrEmpty(varyHeaders) || !StringValues.IsNullOrEmpty(varyQueryKeys))
            {
                // Normalize order and casing of vary by rules
                var normalizedVaryHeaders = GetOrderCasingNormalizedStringValues(varyHeaders);
                var normalizedVaryQueryKeys = GetOrderCasingNormalizedStringValues(varyQueryKeys);
 
                // Update vary rules if they are different
                if (context.CachedVaryByRules == null ||
                    !StringValues.Equals(context.CachedVaryByRules.QueryKeys, normalizedVaryQueryKeys) ||
                    !StringValues.Equals(context.CachedVaryByRules.Headers, normalizedVaryHeaders))
                {
                    context.CachedVaryByRules = new CachedVaryByRules
                    {
                        VaryByKeyPrefix = FastGuid.NewGuid().IdString,
                        Headers = normalizedVaryHeaders,
                        QueryKeys = normalizedVaryQueryKeys
                    };
                }
 
                // Always overwrite the CachedVaryByRules to update the expiry information
                _logger.VaryByRulesUpdated(normalizedVaryHeaders.ToString(), normalizedVaryQueryKeys.ToString());
                storeVaryByEntry = true;
 
                context.StorageVaryKey = _keyProvider.CreateStorageVaryByKey(context);
            }
 
            // Ensure date header is set
            if (!context.ResponseDate.HasValue)
            {
                context.ResponseDate = context.ResponseTime!.Value;
                // Setting the date on the raw response headers.
                headers.Date = HeaderUtilities.FormatDate(context.ResponseDate.Value);
            }
 
            // Store the response on the state
            context.CachedResponse = new CachedResponse
            {
                Created = context.ResponseDate.Value,
                StatusCode = response.StatusCode,
                Headers = new HeaderDictionary()
            };
 
            foreach (var header in headers)
            {
                if (!string.Equals(header.Key, HeaderNames.Age, StringComparison.OrdinalIgnoreCase))
                {
                    context.CachedResponse.Headers[header.Key] = header.Value;
                }
            }
 
            return storeVaryByEntry;
        }
 
        context.ResponseCachingStream.DisableBuffering();
        return false;
    }
 
    internal void FinalizeCacheHeaders(ResponseCachingContext context)
    {
        if (OnFinalizeCacheHeaders(context))
        {
            _cache.Set(context.BaseKey, context.CachedVaryByRules, context.CachedResponseValidFor);
        }
    }
 
    internal void FinalizeCacheBody(ResponseCachingContext context)
    {
        if (context.ShouldCacheResponse && context.ResponseCachingStream.BufferingEnabled)
        {
            var contentLength = context.HttpContext.Response.ContentLength;
            var cachedResponseBody = context.ResponseCachingStream.GetCachedResponseBody();
            if (!contentLength.HasValue || contentLength == cachedResponseBody.Length
                || (cachedResponseBody.Length == 0
                    && HttpMethods.IsHead(context.HttpContext.Request.Method)))
            {
                var response = context.HttpContext.Response;
                // Add a content-length if required
                if (!response.ContentLength.HasValue && StringValues.IsNullOrEmpty(response.Headers.TransferEncoding))
                {
                    context.CachedResponse.Headers.ContentLength = cachedResponseBody.Length;
                }
 
                context.CachedResponse.Body = cachedResponseBody;
                _logger.ResponseCached();
                _cache.Set(context.StorageVaryKey ?? context.BaseKey, context.CachedResponse, context.CachedResponseValidFor);
            }
            else
            {
                _logger.ResponseContentLengthMismatchNotCached();
            }
        }
        else
        {
            _logger.LogResponseNotCached();
        }
    }
 
    /// <summary>
    /// Mark the response as started and set the response time if no response was started yet.
    /// </summary>
    /// <param name="context"></param>
    /// <returns><c>true</c> if the response was not started before this call; otherwise <c>false</c>.</returns>
    private bool OnStartResponse(ResponseCachingContext context)
    {
        if (!context.ResponseStarted)
        {
            context.ResponseStarted = true;
            context.ResponseTime = _options.TimeProvider.GetUtcNow();
 
            return true;
        }
        return false;
    }
 
    internal void StartResponse(ResponseCachingContext context)
    {
        if (OnStartResponse(context))
        {
            FinalizeCacheHeaders(context);
        }
    }
 
    internal static void AddResponseCachingFeature(HttpContext context)
    {
        if (context.Features.Get<IResponseCachingFeature>() != null)
        {
            throw new InvalidOperationException($"Another instance of {nameof(ResponseCachingFeature)} already exists. Only one instance of {nameof(ResponseCachingMiddleware)} can be configured for an application.");
        }
        context.Features.Set<IResponseCachingFeature>(new ResponseCachingFeature());
    }
 
    internal void ShimResponseStream(ResponseCachingContext context)
    {
        // Shim response stream
        context.OriginalResponseStream = context.HttpContext.Response.Body;
        context.ResponseCachingStream = new ResponseCachingStream(
            context.OriginalResponseStream,
            _options.MaximumBodySize,
            StreamUtilities.BodySegmentSize,
            () => StartResponse(context));
        context.HttpContext.Response.Body = context.ResponseCachingStream;
 
        // Add IResponseCachingFeature
        AddResponseCachingFeature(context.HttpContext);
    }
 
    internal static void RemoveResponseCachingFeature(HttpContext context) =>
        context.Features.Set<IResponseCachingFeature?>(null);
 
    internal static void UnshimResponseStream(ResponseCachingContext context)
    {
        // Unshim response stream
        context.HttpContext.Response.Body = context.OriginalResponseStream;
 
        // Remove IResponseCachingFeature
        RemoveResponseCachingFeature(context.HttpContext);
    }
 
    internal static bool ContentIsNotModified(ResponseCachingContext context)
    {
        var cachedResponseHeaders = context.CachedResponseHeaders;
        var ifNoneMatchHeader = context.HttpContext.Request.Headers.IfNoneMatch;
 
        if (!StringValues.IsNullOrEmpty(ifNoneMatchHeader))
        {
            if (ifNoneMatchHeader.Count == 1 && StringSegment.Equals(ifNoneMatchHeader[0], EntityTagHeaderValue.Any.Tag, StringComparison.OrdinalIgnoreCase))
            {
                context.Logger.NotModifiedIfNoneMatchStar();
                return true;
            }
 
            EntityTagHeaderValue? eTag;
            if (!StringValues.IsNullOrEmpty(cachedResponseHeaders.ETag)
                && EntityTagHeaderValue.TryParse(cachedResponseHeaders.ETag.ToString(), out eTag)
                && EntityTagHeaderValue.TryParseList(ifNoneMatchHeader, out var ifNoneMatchEtags))
            {
                for (var i = 0; i < ifNoneMatchEtags.Count; i++)
                {
                    var requestETag = ifNoneMatchEtags[i];
                    if (eTag.Compare(requestETag, useStrongComparison: false))
                    {
                        context.Logger.NotModifiedIfNoneMatchMatched(requestETag);
                        return true;
                    }
                }
            }
        }
        else
        {
            var ifModifiedSince = context.HttpContext.Request.Headers.IfModifiedSince;
            if (!StringValues.IsNullOrEmpty(ifModifiedSince))
            {
                DateTimeOffset modified;
                if (!HeaderUtilities.TryParseDate(cachedResponseHeaders.LastModified.ToString(), out modified) &&
                    !HeaderUtilities.TryParseDate(cachedResponseHeaders.Date.ToString(), out modified))
                {
                    return false;
                }
 
                DateTimeOffset modifiedSince;
                if (HeaderUtilities.TryParseDate(ifModifiedSince.ToString(), out modifiedSince) &&
                    modified <= modifiedSince)
                {
                    context.Logger.NotModifiedIfModifiedSinceSatisfied(modified, modifiedSince);
                    return true;
                }
            }
        }
 
        return false;
    }
 
    // Normalize order and casing
    internal static StringValues GetOrderCasingNormalizedStringValues(StringValues stringValues)
    {
        if (stringValues.Count == 1)
        {
            return new StringValues(stringValues.ToString().ToUpperInvariant());
        }
        else
        {
            var originalArray = stringValues.ToArray();
            var newArray = new string[originalArray.Length];
 
            for (var i = 0; i < originalArray.Length; i++)
            {
                newArray[i] = originalArray[i]!.ToUpperInvariant();
            }
 
            // Since the casing has already been normalized, use Ordinal comparison
            Array.Sort(newArray, StringComparer.Ordinal);
 
            return new StringValues(newArray);
        }
    }
}