File: CausalLMPipelineChatClient.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.Tokenizers;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Core;
 
public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : IChatClient
    where TTokenizer : Tokenizer
    where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
    private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline;
    private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder;
 
    public CausalLMPipelineChatClient(
        ICausalLMPipeline<TTokenizer, TCausalLMModel> pipeline,
        IMEAIChatTemplateBuilder chatTemplateBuilder,
        ChatClientMetadata? metadata = null)
    {
        var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>";
        Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name);
        _chatTemplateBuilder = chatTemplateBuilder;
        _pipeline = pipeline;
    }
 
    public ChatClientMetadata Metadata { get; }
 
    public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
    {
        var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
        var stopSequences = options?.StopSequences ?? Array.Empty<string>();
 
        var output = _pipeline.Generate(
            prompt,
            maxLen: options?.MaxOutputTokens ?? 1024,
            temperature: options?.Temperature ?? 0.7f,
            stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply.");
 
        var chatMessage = new ChatMessage(ChatRole.Assistant, output);
        return Task.FromResult(new ChatCompletion([chatMessage])
        {
            CreatedAt = DateTime.UtcNow,
            FinishReason = ChatFinishReason.Stop,
        });
    }
 
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
    public virtual async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
        IList<ChatMessage> chatMessages,
        ChatOptions? options = null,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
        var stopSequences = options?.StopSequences ?? Array.Empty<string>();
 
        foreach (var output in _pipeline.GenerateStreaming(
            prompt,
            maxLen: options?.MaxOutputTokens ?? 1024,
            temperature: options?.Temperature ?? 0.7f,
            stopSequences: stopSequences.ToArray()))
        {
            yield return new StreamingChatCompletionUpdate
            {
                Role = ChatRole.Assistant,
                Text = output,
                CreatedAt = DateTime.UtcNow,
            };
        }
    }
 
    public virtual void Dispose()
    {
    }
 
    public virtual object? GetService(Type serviceType, object? serviceKey = null) =>
        serviceKey is null && serviceType is not null && serviceType.IsAssignableFrom(GetType()) ? this :
        null;
}