File: VerbatimMultiPartHttpHandler.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Integration.Tests\Microsoft.Extensions.AI.Integration.Tests.csproj (Microsoft.Extensions.AI.Integration.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
 
#pragma warning disable S3996 // URI properties should not be strings
 
/// <summary>
/// An <see cref="HttpMessageHandler"/> that checks the multi-part request body as a root
/// JSON structure of properties and sends back an expected JSON response.
/// </summary>
/// <remarks>
/// The order of the properties does not affect the comparison.
/// <para>
/// An expected input of <c>{ "name": "something" }</c> will Assert for a multipart body that has
/// a <b>name</b> field with a value of <c>something</c>.
/// </para>
/// <para>
/// An expected input of <c>{ "multiple[]": ["one","two"] }</c> will Assert for a multipart body that has
/// two <b>multiple[]</b> fields each having "one" and "two" value respectively.
/// </para>
/// </remarks>
/// <param name="expectedInput">
/// A JSON string representing the expected structure and values of the multipart request body to be verified.
/// For example, <c>{ "name": "something" }</c> or <c>{ "multiple[]": ["one","two"] }</c>.
/// </param>
/// <param name="sentJsonOutput">
/// A JSON string that will be returned as the response body when the request matches the expected input.
/// </param>
public class VerbatimMultiPartHttpHandler(string expectedInput, string sentJsonOutput) : HttpClientHandler
{
    public string? ExpectedRequestUriContains { get; init; }
 
    protected override async Task<HttpResponseMessage> SendAsync(
        HttpRequestMessage request,
        CancellationToken cancellationToken)
    {
        Assert.NotNull(request.Content);
        Assert.NotNull(request.Content.Headers.ContentType);
        Assert.Equal("multipart/form-data", request.Content.Headers.ContentType.MediaType);
 
        Assert.NotNull(request.RequestUri);
        if (!string.IsNullOrEmpty(ExpectedRequestUriContains))
        {
            Assert.Contains(ExpectedRequestUriContains!, request.RequestUri!.ToString());
        }
 
        Dictionary<string, object> parameters = [];
 
        // Extract the boundary
        string? boundary = request.Content.Headers.ContentType.Parameters
            .FirstOrDefault(p => p.Name == "boundary")?.Value;
 
        if (string.IsNullOrEmpty(boundary))
        {
            throw new InvalidOperationException("Boundary not found.");
        }
 
        string fullBoundary = $"--{boundary!.Trim('"')}";
 
        // Read the entire body into memory (for simplicity; stream in production for large data)
#if NET
        byte[] bodyBytes = await request.Content.ReadAsByteArrayAsync(cancellationToken);
#else
        byte[] bodyBytes = await request.Content.ReadAsByteArrayAsync();
#endif
        using var stream = new MemoryStream(bodyBytes);
        using var reader = new StreamReader(stream, Encoding.UTF8);
#if NET
 
        string bodyText = await reader.ReadToEndAsync(cancellationToken);
#else
        string bodyText = await reader.ReadToEndAsync();
#endif
 
        // Make it legible for debugging and splitting
        bodyText = RemoveSpecialCharacters(bodyText);
 
        string[] parts = bodyText.Split(new string[] { fullBoundary }, StringSplitOptions.None);
 
        foreach (string part in parts)
        {
            if (part.Trim() == "--")
            {
                continue; // End boundary
            }
 
            // Parse headers and body
            int headerEnd = part.IndexOf("\r\n\r\n");
            if (headerEnd < 0)
            {
                continue;
            }
 
            string headers = part.Substring(0, headerEnd).Trim();
            string rawValue = part.Substring(headerEnd + 4).TrimEnd('\r', '\n');
 
            // Get the parameter name and value
            if (headers.Contains("name="))
            {
                // Text field
                string name = ExtractNameFromHeaders(headers);
 
                // Skip file fields
                if (!name.StartsWith("file"))
                {
                    if (parameters.ContainsKey(name))
                    {
                        ((List<JsonElement>)parameters[name]).Add(ParseContentToJsonElement(rawValue));
                    }
                    else
                    {
                        parameters.Add(name, new List<JsonElement> { ParseContentToJsonElement(rawValue) });
                    }
                }
            }
        }
 
        // Transform one value lists into single values
        foreach (var key in parameters.Keys.ToList())
        {
            if (parameters[key] is List<JsonElement> list && list.Count == 1)
            {
                parameters[key] = list[0];
            }
        }
 
        var jsonParameters = JsonSerializer.Serialize(parameters);
        Assert.NotNull(jsonParameters);
 
        AssertJsonEquals(expectedInput, jsonParameters);
 
        return new() { Content = new StringContent(sentJsonOutput, Encoding.UTF8, "application/json") };
    }
 
    private static string RemoveSpecialCharacters(string input)
    {
        return Regex.Replace(input, @"[^a-zA-Z0-9_ .,!?\r\n""=;\//\[\]-]", "");
    }
 
    private static JsonElement ParseContentToJsonElement(string content)
    {
        // Try parsing as a number
        if (int.TryParse(content, out int intValue))
        {
            return JsonSerializer.SerializeToElement(intValue);
        }
 
        if (double.TryParse(content, out double doubleValue))
        {
            return JsonSerializer.SerializeToElement(doubleValue);
        }
 
        // Try parsing as a boolean
        if (bool.TryParse(content, out bool boolValue))
        {
            return JsonSerializer.SerializeToElement(boolValue);
        }
 
        // Default to string
        return JsonSerializer.SerializeToElement(content);
    }
 
    private static string ExtractNameFromHeaders(string headers)
    {
        const string NamePrefix = "name=";
        int start = headers.IndexOf(NamePrefix) + NamePrefix.Length;
        int end = headers.IndexOf(";", start);
 
        if (end == -1)
        {
            end = headers.Length;
        }
 
        return headers.Substring(start, end - start).Trim('"');
    }
 
    public static string? RemoveWhiteSpace(string? text) =>
        text is null ? null :
        Regex.Replace(text, @"\s*", string.Empty);
 
    private static Dictionary<char, int>? GetCharacterFrequencies(string text)
        => RemoveWhiteSpace(text)?.GroupBy(c => c)
               .ToDictionary(g => g.Key, g => g.Count());
 
    private static void AssertJsonEquals(string expected, string actual)
    {
        var expectedFrequencies = GetCharacterFrequencies(expected);
        var actualFrequencies = GetCharacterFrequencies(actual);
 
        Assert.NotNull(expectedFrequencies);
        Assert.NotNull(actualFrequencies);
 
        foreach (var kvp in expectedFrequencies)
        {
            if (!actualFrequencies.ContainsKey(kvp.Key) || kvp.Value != actualFrequencies[kvp.Key])
            {
                Assert.Fail($"Expected: {expected}, Actual: {actual}");
            }
 
            // Ensure the frequencies are equal during the test
            Assert.Equal(kvp.Value, actualFrequencies[kvp.Key]);
        }
    }
}