File: Utilities\AIJsonUtilities.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.Abstractions\Microsoft.Extensions.AI.Abstractions.csproj (Microsoft.Extensions.AI.Abstractions)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
#if !NET
using System.Diagnostics;
#endif
using System.IO;
using System.Security.Cryptography;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
#if NET
using System.Threading;
using System.Threading.Tasks;
#endif
using Microsoft.Shared.Diagnostics;
 
#pragma warning disable S109 // Magic numbers should not be used
#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
 
namespace Microsoft.Extensions.AI;
 
public static partial class AIJsonUtilities
{
    /// <summary>
    /// Adds a custom content type to the polymorphic configuration for <see cref="AIContent"/>.
    /// </summary>
    /// <typeparam name="TContent">The custom content type to configure.</typeparam>
    /// <param name="options">The options instance to configure.</param>
    /// <param name="typeDiscriminatorId">The type discriminator id for the content type.</param>
    /// <exception cref="ArgumentNullException"><paramref name="options"/> or <paramref name="typeDiscriminatorId"/> is <see langword="null"/>.</exception>
    /// <exception cref="ArgumentException"><typeparamref name="TContent"/> is a built-in content type.</exception>
    /// <exception cref="InvalidOperationException"><paramref name="options"/> is a read-only instance.</exception>
    public static void AddAIContentType<TContent>(this JsonSerializerOptions options, string typeDiscriminatorId)
        where TContent : AIContent
    {
        _ = Throw.IfNull(options);
        _ = Throw.IfNull(typeDiscriminatorId);
 
        AddAIContentTypeCore(options, typeof(TContent), typeDiscriminatorId);
    }
 
    /// <summary>
    /// Adds a custom content type to the polymorphic configuration for <see cref="AIContent"/>.
    /// </summary>
    /// <param name="options">The options instance to configure.</param>
    /// <param name="contentType">The custom content type to configure.</param>
    /// <param name="typeDiscriminatorId">The type discriminator id for the content type.</param>
    /// <exception cref="ArgumentNullException"><paramref name="options"/>, <paramref name="contentType"/>, or <paramref name="typeDiscriminatorId"/> is <see langword="null"/>.</exception>
    /// <exception cref="ArgumentException"><paramref name="contentType"/> is a built-in content type or does not derived from <see cref="AIContent"/>.</exception>
    /// <exception cref="InvalidOperationException"><paramref name="options"/> is a read-only instance.</exception>
    public static void AddAIContentType(this JsonSerializerOptions options, Type contentType, string typeDiscriminatorId)
    {
        _ = Throw.IfNull(options);
        _ = Throw.IfNull(contentType);
        _ = Throw.IfNull(typeDiscriminatorId);
 
        if (!typeof(AIContent).IsAssignableFrom(contentType))
        {
            Throw.ArgumentException(nameof(contentType), "The content type must derive from AIContent.");
        }
 
        AddAIContentTypeCore(options, contentType, typeDiscriminatorId);
    }
 
    /// <summary>Serializes the supplied values and computes a string hash of the resulting JSON.</summary>
    /// <param name="values">The data to serialize and from which a hash should be computed.</param>
    /// <param name="serializerOptions">
    /// The <see cref="JsonSerializerOptions"/> to use for serializing the values.
    /// If <see langword="null"/>, <see cref="DefaultOptions"/> will be used.
    /// </param>
    /// <returns>A string that will be used as a cache key.</returns>
    /// <remarks>
    /// The resulting hash may be used for purposes like caching. However, while the generated
    /// hash is deterministic for the same inputs, it is not guaranteed to be stable across releases
    /// of the library, as exactly how the hash is computed may change from version to version.
    /// </remarks>
    public static string HashDataToString(ReadOnlySpan<object?> values, JsonSerializerOptions? serializerOptions = null)
    {
        if (serializerOptions is null)
        {
            serializerOptions = DefaultOptions;
        }
        else
        {
            serializerOptions.MakeReadOnly();
        }
 
        JsonTypeInfo jti = serializerOptions.GetTypeInfo(typeof(object));
 
        // For cases where the hash may be used as a cache key, we rely on collision resistance for security purposes.
        // If a collision occurs, we'd serve a cached LLM response for a potentially unrelated prompt, leading to information
        // disclosure. Use of SHA256 is an implementation detail and can be easily swapped in the future if needed, albeit
        // invalidating any existing cache entries.
#if NET
        IncrementalHashStream? stream = IncrementalHashStream.ThreadStaticInstance;
        if (stream is not null)
        {
            // We need to ensure that the value in ThreadStaticInstance is always ready to use.
            // If we start using an instance, write to it, and then fail, we will have left it
            // in an inconsistent state. So, when renting it, we null it out, and we only put
            // it back upon successful completion after resetting it.
            IncrementalHashStream.ThreadStaticInstance = null;
        }
        else
        {
            stream = new();
        }
 
        Span<byte> hashData = stackalloc byte[SHA256.HashSizeInBytes];
        try
        {
            foreach (object? value in values)
            {
                JsonSerializer.Serialize(stream, value, jti);
            }
 
            stream.GetHashAndReset(hashData);
        }
        catch
        {
            stream.Dispose();
            throw;
        }
 
        IncrementalHashStream.ThreadStaticInstance = stream;
 
        return Convert.ToHexString(hashData);
#else
        MemoryStream stream = new();
        foreach (object? value in values)
        {
            JsonSerializer.Serialize(stream, value, jti);
        }
 
        using var sha256 = SHA256.Create();
        var hashData = sha256.ComputeHash(stream.GetBuffer(), 0, (int)stream.Length);
 
        return ConvertToHexString(hashData);
 
        static string ConvertToHexString(ReadOnlySpan<byte> hashData)
        {
            var chars = new char[hashData.Length * 2];
            int destPos = 0;
            foreach (byte b in hashData)
            {
                int div = Math.DivRem(b, 16, out int rem);
                chars[destPos++] = ToHexChar(div);
                chars[destPos++] = ToHexChar(rem);
 
                static char ToHexChar(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'A');
            }
 
            Debug.Assert(destPos == chars.Length, "Expected to have filled the entire array.");
 
            return new string(chars);
        }
#endif
    }
 
    private static void AddAIContentTypeCore(JsonSerializerOptions options, Type contentType, string typeDiscriminatorId)
    {
        if (contentType.Assembly == typeof(AIContent).Assembly)
        {
            Throw.ArgumentException(nameof(contentType), "Cannot register built-in AI content types.");
        }
 
        IJsonTypeInfoResolver resolver = options.TypeInfoResolver ?? DefaultOptions.TypeInfoResolver!;
        options.TypeInfoResolver = resolver.WithAddedModifier(typeInfo =>
        {
            if (typeInfo.Type == typeof(AIContent))
            {
                (typeInfo.PolymorphismOptions ??= new()).DerivedTypes.Add(new(contentType, typeDiscriminatorId));
            }
        });
    }
 
#if NET
    /// <summary>Provides a stream that writes to an <see cref="IncrementalHash"/>.</summary>
    private sealed class IncrementalHashStream : Stream
    {
        /// <summary>A per-thread instance of <see cref="IncrementalHashStream"/>.</summary>
        /// <remarks>An instance stored must be in a reset state ready to be used by another consumer.</remarks>
        [ThreadStatic]
        public static IncrementalHashStream? ThreadStaticInstance;
 
        /// <summary>The <see cref="IncrementalHash"/> used by this instance.</summary>
        private readonly IncrementalHash _hash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256);
 
        /// <summary>Gets the current hash and resets.</summary>
        public void GetHashAndReset(Span<byte> bytes) => _hash.GetHashAndReset(bytes);
 
        public override void WriteByte(byte value) => Write(new ReadOnlySpan<byte>(in value));
        public override void Write(byte[] buffer, int offset, int count) => _hash.AppendData(buffer, offset, count);
        public override void Write(ReadOnlySpan<byte> buffer) => _hash.AppendData(buffer);
 
        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            Write(buffer, offset, count);
            return Task.CompletedTask;
        }
 
        public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
        {
            Write(buffer.Span);
            return ValueTask.CompletedTask;
        }
 
        public override void Flush()
        {
        }
 
        public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
 
        public override bool CanWrite => true;
        public override bool CanRead => false;
        public override bool CanSeek => false;
        public override long Length => throw new NotSupportedException();
        public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
        public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
        public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
        public override void SetLength(long value) => throw new NotSupportedException();
 
        protected override void Dispose(bool disposing)
        {
            _hash.Dispose();
            base.Dispose(disposing);
        }
    }
#endif
}