File: SemanticKernelTests.cs
Web Access
Project: src\test\Microsoft.ML.GenAI.Phi.Tests\Microsoft.ML.GenAI.Phi.Tests.csproj (Microsoft.ML.GenAI.Phi.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using FluentAssertions;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Phi.Extension;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Moq;
using Xunit;
 
namespace Microsoft.ML.GenAI.Phi.Tests;
 
public class SemanticKernelTests
{
    [Fact]
    public async Task ItAddPhi3CausalLMChatCompletionServiceTestAsync()
    {
        var pipeline = Mock.Of<ICausalLMPipeline<Tokenizer, Phi3ForCasualLM>>();
        // mock generate api
        Mock.Get(pipeline).Setup(p => p.Generate(
            It.IsAny<string>(), // prompt
            It.IsAny<int>(),    // max length
            It.IsAny<float>(),  // temperature 
            It.IsAny<float>(),  // top_p
            It.IsAny<string[]>()))   // stop sequence
            .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) =>
            {
                // check prompt
                prompt.Should().Be("<|system|>\nyou are a helpful assistant<|end|>\n<|user|>\nhey<|end|>\n<|assistant|>");
            })
            .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => "hello");
 
        var kernel = Kernel.CreateBuilder()
            .AddGenAIChatCompletion(pipeline)
            .Build();
 
        var chatService = kernel.Services.GetRequiredService<IChatCompletionService>();
 
        var chatHistory = new ChatHistory();
        chatHistory.AddSystemMessage("you are a helpful assistant");
        chatHistory.AddUserMessage("hey");
        var responses = await chatService.GetChatMessageContentsAsync(chatHistory);
        responses.Count().Should().Be(1);
        var response = responses.First();
        response.Role.Should().Be(AuthorRole.Assistant);
        response.Items.Count().Should().Be(1);
        var textContent = response.Items.First() as TextContent;
        textContent!.Text.Should().Be("hello");
    }
 
    [Fact]
    public async Task ItAddPhi3CausalLMTextGenerationServiceTestAsync()
    {
        var pipeline = Mock.Of<ICausalLMPipeline<Tokenizer, Phi3ForCasualLM>>();
        // mock generate api
        Mock.Get(pipeline).Setup(p => p.Generate(
            It.IsAny<string>(), // prompt
            It.IsAny<int>(),    // max length
            It.IsAny<float>(),  // temperature 
            It.IsAny<float>(),  // top_p
            It.IsAny<string[]>()))   // stop sequence
            .Callback((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) =>
            {
                // check prompt
                prompt.Should().Be("test");
            })
            .Returns((string prompt, int maxLen, float temperature, float topP, string[] stopSequences) => "hello");
 
        var kernel = Kernel.CreateBuilder()
            .AddGenAITextGeneration(pipeline)
            .Build();
 
        var response = await kernel.InvokePromptAsync("test");
    }
}