File: ResponseCachingChatClient.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.Evaluation.Reporting\CSharp\Microsoft.Extensions.AI.Evaluation.Reporting.csproj (Microsoft.Extensions.AI.Evaluation.Reporting)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Caching.Distributed;
 
namespace Microsoft.Extensions.AI.Evaluation.Reporting;
 
internal sealed class ResponseCachingChatClient : DistributedCachingChatClient
{
    private readonly IReadOnlyList<string> _cachingKeys;
    private readonly ChatDetails _chatDetails;
    private readonly ConcurrentDictionary<string, Stopwatch> _stopWatches;
 
    internal ResponseCachingChatClient(
        IChatClient originalChatClient,
        IDistributedCache cache,
        IEnumerable<string> cachingKeys,
        ChatDetails chatDetails)
            : base(originalChatClient, cache)
    {
        _cachingKeys = [.. cachingKeys];
        _chatDetails = chatDetails;
        _stopWatches = new ConcurrentDictionary<string, Stopwatch>();
    }
 
    protected override async Task<ChatResponse?> ReadCacheAsync(string key, CancellationToken cancellationToken)
    {
        Stopwatch stopwatch = Stopwatch.StartNew();
 
        ChatResponse? response = await base.ReadCacheAsync(key, cancellationToken).ConfigureAwait(false);
 
        if (response is null)
        {
            _ = _stopWatches.AddOrUpdate(key, addValue: stopwatch, updateValueFactory: (_, _) => stopwatch);
        }
        else
        {
            stopwatch.Stop();
 
            _chatDetails.AddTurnDetails(
                new ChatTurnDetails(
                    latency: stopwatch.Elapsed,
                    model: response.ModelId,
                    usage: response.Usage,
                    cacheKey: key,
                    cacheHit: true));
        }
 
        return response;
    }
 
    protected override async Task<IReadOnlyList<ChatResponseUpdate>?> ReadCacheStreamingAsync(
        string key,
        CancellationToken cancellationToken)
    {
        Stopwatch stopwatch = Stopwatch.StartNew();
 
        IReadOnlyList<ChatResponseUpdate>? updates =
            await base.ReadCacheStreamingAsync(key, cancellationToken).ConfigureAwait(false);
 
        if (updates is null)
        {
            _ = _stopWatches.AddOrUpdate(key, addValue: stopwatch, updateValueFactory: (_, _) => stopwatch);
        }
        else
        {
            stopwatch.Stop();
 
            ChatResponse response = updates.ToChatResponse();
            _chatDetails.AddTurnDetails(
                new ChatTurnDetails(
                    latency: stopwatch.Elapsed,
                    model: response.ModelId,
                    usage: response.Usage,
                    cacheKey: key,
                    cacheHit: true));
        }
 
        return updates;
    }
 
    protected override async Task WriteCacheAsync(string key, ChatResponse value, CancellationToken cancellationToken)
    {
        await base.WriteCacheAsync(key, value, cancellationToken).ConfigureAwait(false);
 
        if (_stopWatches.TryRemove(key, out Stopwatch? stopwatch))
        {
            stopwatch.Stop();
 
            _chatDetails.AddTurnDetails(
                new ChatTurnDetails(
                    latency: stopwatch.Elapsed,
                    model: value.ModelId,
                    usage: value.Usage,
                    cacheKey: key,
                    cacheHit: false));
        }
    }
 
    protected override async Task WriteCacheStreamingAsync(
        string key,
        IReadOnlyList<ChatResponseUpdate> value,
        CancellationToken cancellationToken)
    {
        await base.WriteCacheStreamingAsync(key, value, cancellationToken).ConfigureAwait(false);
 
        if (_stopWatches.TryRemove(key, out Stopwatch? stopwatch))
        {
            stopwatch.Stop();
 
            ChatResponse response = value.ToChatResponse();
            _chatDetails.AddTurnDetails(
                new ChatTurnDetails(
                    latency: stopwatch.Elapsed,
                    model: response.ModelId,
                    usage: response.Usage,
                    cacheKey: key,
                    cacheHit: false));
        }
    }
 
    protected override string GetCacheKey(IEnumerable<ChatMessage> messages, ChatOptions? options, params ReadOnlySpan<object?> additionalValues)
        => base.GetCacheKey(messages, options, [.. additionalValues, .. _cachingKeys]);
}