File: TokenCredentialShortCache.cs
Web Access
Project: src\src\Microsoft.DotNet.ArcadeAzureIntegration\Microsoft.DotNet.ArcadeAzureIntegration.csproj (Microsoft.DotNet.ArcadeAzureIntegration)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#if !NET472_OR_GREATER

#nullable enable

using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
 
namespace Microsoft.DotNet.ArcadeAzureIntegration;
 
 
/// <summary>
/// TokenCredentialShortCache is a wrapper around TokenCredential that caches the token for the same scope and request parameters.
/// Cache time is short, 3 minutes only because we don't want to affect the expiration window that's still handled by the underlying TokenCredential implementation. 
/// It helps with reducing the number of requests to Entra or AzureCLI external process during heavy paralellized operations.
/// </summary>
public class TokenCredentialShortCache : TokenCredential
{
    private const int CacheExpirationMinutes = 3;
 
    public TokenCredentialShortCache(TokenCredential tokenCredential)
    {
        _tokenCredential = tokenCredential;
    }
 
    private TokenCredential _tokenCredential;
    private ConcurrentDictionary<CacheKey, CachedToken> _tokenCache = new ConcurrentDictionary<CacheKey, CachedToken>();
 
    public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
    {
        return GetTokenAsync(requestContext, cancellationToken).Result;
    }
 
    public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
    {
        CacheKey cacheKey = new CacheKey
        {
            Scopes = string.Join(":", requestContext.Scopes),
            Claims = requestContext.Claims,
            TenantId = requestContext.TenantId,
            IsCaeEnabled = requestContext.IsCaeEnabled
        };
 
        CachedToken cachedToken = _tokenCache.GetOrAdd(cacheKey, _ => new CachedToken());
        var token = await cachedToken.GetToken(requestContext, () => {
            return _tokenCredential.GetTokenAsync(requestContext, cancellationToken);
        });
 
        return token;
    }
 
    private record struct CacheKey
    {
        public required string Scopes { get; init; }
        public required string? Claims { get; init; }
        public required string? TenantId { get; init; }
        public required bool IsCaeEnabled { get; init; }
    }
 
    private class CachedToken
    {
        private AccessToken? _token;
        private DateTime _shortTimeCacheExpiresOn = DateTime.UtcNow;
        private SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
 
        public async ValueTask<AccessToken> GetToken(TokenRequestContext requestContext, Func<ValueTask<AccessToken>> getFreshToken)
        {
            await _semaphore.WaitAsync();
            try
            {
                if (_token != null && _shortTimeCacheExpiresOn > DateTime.UtcNow)
                {
                    return _token.Value;
                }
 
                _token = await getFreshToken();
                _shortTimeCacheExpiresOn = DateTime.UtcNow.AddMinutes(CacheExpirationMinutes);
                return _token.Value;
            }
            finally
            {
                _semaphore.Release();
            }
        }
    }
}
 
#endif