File: Functions\AIFunctionFactoryTest.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Tests\Microsoft.Extensions.AI.Tests.csproj (Microsoft.Extensions.AI.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
 
#pragma warning disable IDE0004 // Remove Unnecessary Cast
#pragma warning disable S107 // Methods should not have too many parameters
#pragma warning disable S3358 // Ternary operators should not be nested
#pragma warning disable S5034 // "ValueTask" should be consumed correctly
 
namespace Microsoft.Extensions.AI;
 
public partial class AIFunctionFactoryTest
{
    [Fact]
    public void InvalidArguments_Throw()
    {
        Assert.Throws<ArgumentNullException>("method", () => AIFunctionFactory.Create(method: null!));
        Assert.Throws<ArgumentNullException>("method", () => AIFunctionFactory.Create(method: null!, target: new object()));
        Assert.Throws<ArgumentNullException>("method", () => AIFunctionFactory.Create(method: null!, target: new object(), name: "myAiFunk"));
        Assert.Throws<ArgumentNullException>("target", () => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (object?)null));
        Assert.Throws<ArgumentNullException>("targetType", () => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (Type)null!));
        Assert.Throws<ArgumentException>("method", () => AIFunctionFactory.Create(typeof(List<>).GetMethod("Add")!, new List<int>()));
    }
 
    [Fact]
    public async Task Parameters_MappedByName_Async()
    {
        AIFunction func;
 
        func = AIFunctionFactory.Create((string a) => a + " " + a);
        AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" }));
 
        func = AIFunctionFactory.Create((string a, string b) => b + " " + a);
        AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" }));
 
        func = AIFunctionFactory.Create((int a, long b) => a + b);
        AssertExtensions.EqualFunctionCallResults(3L, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
    }
 
    [Fact]
    public async Task Parameters_DefaultValuesAreUsedButOverridable_Async()
    {
        AIFunction func = AIFunctionFactory.Create((string a = "test") => a + " " + a);
        AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync());
        AssertExtensions.EqualFunctionCallResults("hello hello", await func.InvokeAsync(new() { ["a"] = "hello" }));
    }
 
    [Fact]
    public async Task Parameters_MissingRequiredParametersFail_Async()
    {
        AIFunction[] funcs =
        [
            AIFunctionFactory.Create((string theParam) => theParam + " " + theParam),
            AIFunctionFactory.Create((string? theParam) => theParam + " " + theParam),
            AIFunctionFactory.Create((int theParam) => theParam * 2),
            AIFunctionFactory.Create((int? theParam) => theParam * 2),
        ];
 
        foreach (AIFunction f in funcs)
        {
            Exception e = await Assert.ThrowsAsync<ArgumentException>(() => f.InvokeAsync().AsTask());
            Assert.Contains("'theParam'", e.Message);
        }
    }
 
    [Fact]
    public async Task Parameters_MappedByType_Async()
    {
        using var cts = new CancellationTokenSource();
 
        foreach (CancellationToken ctArg in new[] { cts.Token, default })
        {
            CancellationToken written = default;
            AIFunction func = AIFunctionFactory.Create((int value1 = 1, string value2 = "2", CancellationToken cancellationToken = default) =>
            {
                written = cancellationToken;
                return 42;
            });
            AssertExtensions.EqualFunctionCallResults(42, await func.InvokeAsync(cancellationToken: ctArg));
            Assert.Equal(ctArg, written);
            Assert.DoesNotContain("cancellationToken", func.JsonSchema.ToString(), StringComparison.OrdinalIgnoreCase);
        }
    }
 
    [Fact]
    public async Task Returns_AsyncReturnTypesSupported_Async()
    {
        AIFunction func;
 
        func = AIFunctionFactory.Create(Task<string> (string a) => Task.FromResult(a + " " + a));
        AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" }));
 
        func = AIFunctionFactory.Create(ValueTask<string> (string a, string b) => new ValueTask<string>(b + " " + a));
        AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" }));
 
        long result = 0;
        func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); });
        AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
        Assert.Equal(3, result);
 
        result = 0;
        func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); });
        AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
        Assert.Equal(3, result);
 
        func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count), serializerOptions: JsonContext.Default.Options);
        AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync(new() { ["count"] = 5 }), JsonContext.Default.Options);
 
        static async IAsyncEnumerable<int> SimpleIAsyncEnumerable(int count)
        {
            for (int i = 0; i < count; i++)
            {
                await Task.Yield();
                yield return i;
            }
        }
 
        func = AIFunctionFactory.Create(() => (IAsyncEnumerable<int>)new ThrowingAsyncEnumerable(), serializerOptions: JsonContext.Default.Options);
        await Assert.ThrowsAsync<NotImplementedException>(() => func.InvokeAsync().AsTask());
    }
 
    private sealed class ThrowingAsyncEnumerable : IAsyncEnumerable<int>
    {
#pragma warning disable S3717 // Track use of "NotImplementedException"
        public IAsyncEnumerator<int> GetAsyncEnumerator(CancellationToken cancellationToken = default) => throw new NotImplementedException();
#pragma warning restore S3717 // Track use of "NotImplementedException"
    }
 
    [Fact]
    public void Metadata_DerivedFromLambda()
    {
        AIFunction func;
 
        Func<string> dotnetFunc = () => "test";
        func = AIFunctionFactory.Create(dotnetFunc);
        Assert.Contains("Metadata_DerivedFromLambda", func.Name);
        Assert.Empty(func.Description);
        Assert.Same(dotnetFunc.Method, func.UnderlyingMethod);
 
        Func<string, string> dotnetFunc2 = a => a + " " + a;
        func = AIFunctionFactory.Create(dotnetFunc2);
        Assert.Contains("Metadata_DerivedFromLambda", func.Name);
        Assert.Empty(func.Description);
        Assert.Same(dotnetFunc2.Method, func.UnderlyingMethod);
 
        Func<string, string, string> dotnetFunc3 = [Description("This is a test function")] ([Description("This is A")] a, [Description("This is B")] b) => b + " " + a;
        func = AIFunctionFactory.Create(dotnetFunc3);
        Assert.Contains("Metadata_DerivedFromLambda", func.Name);
        Assert.Equal("This is a test function", func.Description);
        Assert.Same(dotnetFunc3.Method, func.UnderlyingMethod);
        Assert.Collection(func.UnderlyingMethod!.GetParameters(),
            p => Assert.Equal("This is A", p.GetCustomAttribute<DescriptionAttribute>()?.Description),
            p => Assert.Equal("This is B", p.GetCustomAttribute<DescriptionAttribute>()?.Description));
    }
 
    [Fact]
    public void AIFunctionFactoryCreateOptions_ValuesPropagateToAIFunction()
    {
        IReadOnlyDictionary<string, object?> metadata = new Dictionary<string, object?> { ["a"] = "b" };
 
        Func<ParameterInfo, AIFunctionFactoryOptions.ParameterBindingOptions> getBindParameterMode = _ => default;
 
        var options = new AIFunctionFactoryOptions
        {
            Name = "test name",
            Description = "test description",
            AdditionalProperties = metadata,
            ConfigureParameterBinding = getBindParameterMode,
        };
 
        Assert.Equal("test name", options.Name);
        Assert.Equal("test description", options.Description);
        Assert.Same(metadata, options.AdditionalProperties);
        Assert.Same(getBindParameterMode, options.ConfigureParameterBinding);
 
        Action dotnetFunc = () => { };
        AIFunction func = AIFunctionFactory.Create(dotnetFunc, options);
 
        Assert.Equal("test name", func.Name);
        Assert.Equal("test description", func.Description);
        Assert.Same(dotnetFunc.Method, func.UnderlyingMethod);
        Assert.Equal(metadata, func.AdditionalProperties);
    }
 
    [Fact]
    public void AIFunctionFactoryOptions_DefaultValues()
    {
        AIFunctionFactoryOptions options = new();
 
        Assert.Null(options.Name);
        Assert.Null(options.Description);
        Assert.Null(options.AdditionalProperties);
        Assert.Null(options.SerializerOptions);
        Assert.Null(options.JsonSchemaCreateOptions);
        Assert.Null(options.ConfigureParameterBinding);
    }
 
    [Fact]
    public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
    {
        AIFunction func = AIFunctionFactory.Create(
            (string firstParameter, int secondParameter) => firstParameter + secondParameter,
            new()
            {
                ConfigureParameterBinding = p => p.Name == "firstParameter" ? new() { ExcludeFromSchema = true } : default,
            });
 
        Assert.DoesNotContain("firstParameter", func.JsonSchema.ToString());
        Assert.Contains("secondParameter", func.JsonSchema.ToString());
 
        var result = (JsonElement?)await func.InvokeAsync(new()
        {
            ["firstParameter"] = "test",
            ["secondParameter"] = 42
        });
        Assert.NotNull(result);
        Assert.Contains("test42", result.ToString());
    }
 
    [Fact]
    public async Task AIFunctionArguments_SatisfiesParameters()
    {
        ServiceCollection sc = new();
        IServiceProvider sp = sc.BuildServiceProvider();
 
        AIFunctionArguments arguments = new() { ["myInteger"] = 42 };
 
        AIFunction func = AIFunctionFactory.Create((
            int myInteger,
            IServiceProvider services1,
            IServiceProvider services2,
            AIFunctionArguments arguments1,
            AIFunctionArguments arguments2,
            IServiceProvider? services3,
            AIFunctionArguments? arguments3,
            IServiceProvider? services4 = null,
            AIFunctionArguments? arguments4 = null) =>
        {
            Assert.Same(sp, services1);
            Assert.Same(sp, services2);
            Assert.Same(sp, services3);
            Assert.Same(sp, services4);
 
            Assert.Same(arguments, arguments1);
            Assert.Same(arguments, arguments2);
            Assert.Same(arguments, arguments3);
            Assert.Same(arguments, arguments4);
 
            return myInteger;
        });
 
        Assert.Contains("myInteger", func.JsonSchema.ToString());
        Assert.DoesNotContain("services", func.JsonSchema.ToString());
        Assert.DoesNotContain("arguments", func.JsonSchema.ToString());
 
        await Assert.ThrowsAsync<ArgumentException>("arguments", () => func.InvokeAsync(arguments).AsTask());
 
        arguments.Services = sp;
        var result = await func.InvokeAsync(arguments);
 
        Assert.Contains("42", result?.ToString());
    }
 
    [Fact]
    public async Task AIFunctionArguments_MissingServicesMayBeOptional()
    {
        ServiceCollection sc = new();
        IServiceProvider sp = sc.BuildServiceProvider();
 
        AIFunction func = AIFunctionFactory.Create((
            int? myInteger = null,
            AIFunctionArguments? arguments = null,
            IServiceProvider? services = null) =>
        {
            Assert.NotNull(arguments);
            Assert.Null(services);
            return myInteger;
        });
 
        Assert.Contains("myInteger", func.JsonSchema.ToString());
        Assert.DoesNotContain("services", func.JsonSchema.ToString());
        Assert.DoesNotContain("arguments", func.JsonSchema.ToString());
 
        var result = await func.InvokeAsync(new() { ["myInteger"] = 42 });
        Assert.Contains("42", result?.ToString());
 
        result = await func.InvokeAsync();
        Assert.Equal("", result?.ToString());
    }
 
    [Fact]
    public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable()
    {
        MyFunctionTypeWithOneArg mft = new(new());
        MyArgumentType mat = new();
 
        ServiceCollection sc = new();
        sc.AddSingleton(mft);
        sc.AddSingleton(mat);
        IServiceProvider sp = sc.BuildServiceProvider();
 
        AIFunction func = AIFunctionFactory.Create(
            typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!,
            typeof(MyFunctionTypeWithOneArg),
            new()
            {
                MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
            });
 
        Assert.NotNull(func);
        var result = (Tuple<MyFunctionTypeWithOneArg, MyArgumentType>?)await func.InvokeAsync(new() { Services = sp });
        Assert.NotSame(mft, result?.Item1);
        Assert.Same(mat, result?.Item2);
    }
 
    [Fact]
    public async Task Create_NoInstance_UsesActivatorWhenServicesUnavailable()
    {
        AIFunction func = AIFunctionFactory.Create(
            typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.InstanceMethod))!,
            typeof(MyFunctionTypeWithNoArgs),
            new()
            {
                MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
            });
 
        Assert.NotNull(func);
        Assert.Equal("42", await func.InvokeAsync());
    }
 
    [Fact]
    public async Task Create_NoInstance_ThrowsWhenCantConstructInstance()
    {
        var sp = new ServiceCollection().BuildServiceProvider();
 
        AIFunction func = AIFunctionFactory.Create(
            typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!,
            typeof(MyFunctionTypeWithOneArg));
 
        Assert.NotNull(func);
        await Assert.ThrowsAsync<InvalidOperationException>(async () => await func.InvokeAsync(new() { Services = sp }));
    }
 
    [Fact]
    public void Create_NoInstance_ThrowsForStaticMethod()
    {
        Assert.Throws<ArgumentException>("method", () => AIFunctionFactory.Create(
            typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.StaticMethod))!,
            typeof(MyFunctionTypeWithNoArgs)));
    }
 
    [Fact]
    public void Create_NoInstance_ThrowsForMismatchedMethod()
    {
        Assert.Throws<ArgumentException>("targetType", () => AIFunctionFactory.Create(
            typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.InstanceMethod))!,
            typeof(MyFunctionTypeWithOneArg)));
    }
 
    [Fact]
    public async Task Create_NoInstance_DisposableInstanceCreatedDisposedEachInvocation()
    {
        AIFunction func = AIFunctionFactory.Create(
            typeof(DisposableService).GetMethod(nameof(DisposableService.GetThis))!,
            typeof(DisposableService),
            new()
            {
                MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
            });
 
        var d1 = Assert.IsType<DisposableService>(await func.InvokeAsync());
        var d2 = Assert.IsType<DisposableService>(await func.InvokeAsync());
        Assert.NotSame(d1, d2);
 
        Assert.Equal(1, d1.Disposals);
        Assert.Equal(1, d2.Disposals);
    }
 
    [Fact]
    public async Task Create_NoInstance_AsyncDisposableInstanceCreatedDisposedEachInvocation()
    {
        AIFunction func = AIFunctionFactory.Create(
            typeof(AsyncDisposableService).GetMethod(nameof(AsyncDisposableService.GetThis))!,
            typeof(AsyncDisposableService),
            new()
            {
                MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
            });
 
        var d1 = Assert.IsType<AsyncDisposableService>(await func.InvokeAsync());
        var d2 = Assert.IsType<AsyncDisposableService>(await func.InvokeAsync());
        Assert.NotSame(d1, d2);
 
        Assert.Equal(1, d1.AsyncDisposals);
        Assert.Equal(1, d2.AsyncDisposals);
    }
 
    [Fact]
    public async Task Create_NoInstance_DisposableAndAsyncDisposableInstanceCreatedDisposedEachInvocation()
    {
        AIFunction func = AIFunctionFactory.Create(
            typeof(DisposableAndAsyncDisposableService).GetMethod(nameof(DisposableAndAsyncDisposableService.GetThis))!,
            typeof(DisposableAndAsyncDisposableService),
            new()
            {
                MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
            });
 
        var d1 = Assert.IsType<DisposableAndAsyncDisposableService>(await func.InvokeAsync());
        var d2 = Assert.IsType<DisposableAndAsyncDisposableService>(await func.InvokeAsync());
        Assert.NotSame(d1, d2);
 
        Assert.Equal(0, d1.Disposals);
        Assert.Equal(0, d2.Disposals);
        Assert.Equal(1, d1.AsyncDisposals);
        Assert.Equal(1, d2.AsyncDisposals);
    }
 
    [Fact]
    public async Task ConfigureParameterBinding_CanBeUsedToSupportFromKeyedServices()
    {
        MyService service = new(42);
 
        ServiceCollection sc = new();
        sc.AddKeyedSingleton("key", service);
        IServiceProvider sp = sc.BuildServiceProvider();
 
        AIFunction f = AIFunctionFactory.Create(
            ([FromKeyedServices("key")] MyService service, int myInteger) => service.Value + myInteger,
            new AIFunctionFactoryOptions
            {
                ConfigureParameterBinding = p =>
                {
                    if (p.GetCustomAttribute<FromKeyedServicesAttribute>() is { } attr)
                    {
                        return new()
                        {
                            BindParameter = (p, a) =>
                                (a.Services as IKeyedServiceProvider)?.GetKeyedService(p.ParameterType, attr.Key) is { } s ? s :
                                p.HasDefaultValue ? p.DefaultValue :
                                throw new ArgumentException($"Unable to resolve argument for '{p.Name}'."),
                            ExcludeFromSchema = true
                        };
                    }
 
                    return default;
                },
            });
 
        Assert.Contains("myInteger", f.JsonSchema.ToString());
        Assert.DoesNotContain("service", f.JsonSchema.ToString());
 
        Exception e = await Assert.ThrowsAsync<ArgumentException>(() => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());
        Assert.Contains("Unable to resolve", e.Message);
 
        var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp });
        Assert.Contains("43", result?.ToString());
    }
 
    [Fact]
    public async Task ConfigureParameterBinding_CanBeUsedToSupportFromContext()
    {
        MyService service = new(42);
 
        AIFunction f = AIFunctionFactory.Create(
            (MyService service, int myInteger) => service.Value + myInteger,
            new AIFunctionFactoryOptions
            {
                ConfigureParameterBinding = p =>
                {
                    if (p.ParameterType == typeof(MyService))
                    {
                        return new()
                        {
                            BindParameter = (p, a) =>
                                a.Context?.TryGetValue(typeof(MyService), out object? service) is true ? service :
                                p.HasDefaultValue ? p.DefaultValue :
                                throw new ArgumentException($"Unable to resolve argument for '{p.Name}'."),
                            ExcludeFromSchema = true
                        };
                    }
 
                    return default;
                }
            });
 
        Assert.Contains("myInteger", f.JsonSchema.ToString());
        Assert.DoesNotContain("service", f.JsonSchema.ToString());
 
        Exception e = await Assert.ThrowsAsync<ArgumentException>(() => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());
        Assert.Contains("Unable to resolve", e.Message);
 
        e = await Assert.ThrowsAsync<ArgumentException>(() => f.InvokeAsync(new()
        {
            ["myInteger"] = 1,
            Context = new Dictionary<object, object?>(),
        }).AsTask());
        Assert.Contains("Unable to resolve", e.Message);
 
        var result = await f.InvokeAsync(new()
        {
            ["myInteger"] = 1,
            Context = new Dictionary<object, object?>
            {
                [typeof(MyService)] = service
            },
        });
        Assert.Contains("43", result?.ToString());
    }
 
    [Fact]
    public async Task ConfigureParameterBinding_CanBeUsedToOverrideServiceProvider()
    {
        IServiceProvider sp1 = new ServiceCollection().AddSingleton(new MyService(42)).BuildServiceProvider();
        IServiceProvider sp2 = new ServiceCollection().AddSingleton(new MyService(43)).BuildServiceProvider();
 
        AIFunction f = AIFunctionFactory.Create(
            (IServiceProvider services) => services.GetRequiredService<MyService>().Value,
            new AIFunctionFactoryOptions
            {
                ConfigureParameterBinding = p => new() { BindParameter = (p, a) => sp2 },
            });
 
        var result = await f.InvokeAsync(new() { Services = sp1 });
        Assert.Contains("43", result?.ToString());
    }
 
    [Fact]
    public async Task ConfigureParameterBinding_CanBeUsedToOverrideAIFunctionArguments()
    {
        AIFunctionArguments args1 = new() { ["a"] = 42 };
        AIFunctionArguments args2 = new() { ["a"] = 43 };
 
        AIFunction f = AIFunctionFactory.Create(
            (AIFunctionArguments args) => (int)args["a"]!,
            new AIFunctionFactoryOptions
            {
                ConfigureParameterBinding = p => new() { BindParameter = (p, a) => args2 },
            });
 
        var result = await f.InvokeAsync(args1);
        Assert.Contains("43", result?.ToString());
    }
 
    [Fact]
    public async Task MarshalResult_UsedForVoidReturningMethods()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            (int i) => { },
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.Null(result);
                    Assert.Null(type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    [Fact]
    public async Task MarshalResult_UsedForTaskReturningMethods()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            async (int i) => { await Task.Yield(); },
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.Null(result);
                    Assert.Null(type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    [Fact]
    public async Task MarshalResult_UsedForValueTaskReturningMethods()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            async ValueTask (int i) => { await Task.Yield(); },
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.Null(result);
                    Assert.Null(type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    [Fact]
    public async Task MarshalResult_UsedForTReturningMethods()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            (int i) => i,
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.Equal(42, result);
                    Assert.Equal(typeof(int), type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    [Fact]
    public async Task MarshalResult_UsedForTaskTReturningMethods()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            async (int i) => { await Task.Yield(); return i; },
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.Equal(42, result);
                    Assert.Equal(typeof(int), type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    [Fact]
    public async Task MarshalResult_UsedForValueTaskTReturningMethods()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            async ValueTask<int> (int i) => { await Task.Yield(); return i; },
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.Equal(42, result);
                    Assert.Equal(typeof(int), type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    [Fact]
    public async Task MarshalResult_TypeIsDeclaredTypeEvenWhenNullReturned()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            async ValueTask<string?> (int i) => { await Task.Yield(); return null; },
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.Null(result);
                    Assert.Equal(typeof(string), type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    [Fact]
    public async Task MarshalResult_TypeIsDeclaredTypeEvenWhenDerivedTypeReturned()
    {
        using CancellationTokenSource cts = new();
 
        AIFunction f = AIFunctionFactory.Create(
            async ValueTask<B> (int i) => { await Task.Yield(); return new C(); },
            new()
            {
                MarshalResult = async (result, type, cancellationToken) =>
                {
                    await Task.Yield();
                    Assert.IsType<C>(result);
                    Assert.Equal(typeof(B), type);
                    Assert.Equal(cts.Token, cancellationToken);
                    return "marshalResultInvoked";
                },
            });
 
        object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
        Assert.Equal("marshalResultInvoked", result);
    }
 
    private sealed class MyService(int value)
    {
        public int Value => value;
    }
 
    private class DisposableService : IDisposable
    {
        public int Disposals { get; private set; }
        public void Dispose() => Disposals++;
 
        public object GetThis() => this;
    }
 
    private class AsyncDisposableService : IAsyncDisposable
    {
        public int AsyncDisposals { get; private set; }
 
        public ValueTask DisposeAsync()
        {
            AsyncDisposals++;
            return default;
        }
 
        public object GetThis() => this;
    }
 
    private class DisposableAndAsyncDisposableService : IDisposable, IAsyncDisposable
    {
        public int Disposals { get; private set; }
        public int AsyncDisposals { get; private set; }
 
        public void Dispose() => Disposals++;
 
        public ValueTask DisposeAsync()
        {
            AsyncDisposals++;
            return default;
        }
 
        public object GetThis() => this;
    }
 
    private sealed class MyFunctionTypeWithNoArgs
    {
        private string _value = "42";
 
        public static void StaticMethod() => throw new NotSupportedException();
 
        public string InstanceMethod() => _value;
    }
 
    private sealed class MyFunctionTypeWithOneArg(MyArgumentType arg)
    {
        public object InstanceMethod() => Tuple.Create(this, arg);
    }
 
    private sealed class MyArgumentType;
 
    private class A;
    private class B : A;
    private sealed class C : B;
 
    [JsonSerializable(typeof(IAsyncEnumerable<int>))]
    [JsonSerializable(typeof(int[]))]
    private partial class JsonContext : JsonSerializerContext;
}