File: LlamaTextCompletionService.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj (Microsoft.ML.GenAI.LLaMA)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.TextGeneration;
 
namespace Microsoft.ML.GenAI.LLaMA;
 
public class LlamaTextCompletionService : ITextGenerationService
{
    private readonly ICausalLMPipeline<Tokenizer, LlamaForCausalLM> _pipeline;
 
    public LlamaTextCompletionService(ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline)
    {
        _pipeline = pipeline;
    }
 
    public IReadOnlyDictionary<string, object?> Attributes => new Dictionary<string, object?>()
    {
        { "temperature", null },
        { "max_token", null },
        { "stop_token_sequence", null },
        { "top_p", null },
    };
 
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
    public async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
        string prompt,
        PromptExecutionSettings? executionSettings = null,
        Kernel? kernel = null,
        [EnumeratorCancellation]
        CancellationToken cancellationToken = default)
    {
        var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f;
        var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 100;
        var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as string[] ?? Array.Empty<string>();
        var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f;
        stopTokenSequence.Append("<|eot_id|>");
 
        foreach (var item in _pipeline.GenerateStreaming(
            prompt,
            maxToken,
            temperature,
            topP,
            stopTokenSequence))
        {
            yield return new StreamingTextContent(item);
        }
    }
 
    public Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
    {
        var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f;
        var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 512;
        var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as List<string> ?? new List<string>();
        var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f;
        stopTokenSequence.Add("<|eot_id|>");
        var response = _pipeline.Generate(
            prompt,
            maxToken,
            temperature,
            stopSequences: stopTokenSequence.ToArray(),
            topP: topP);
 
        return Task.FromResult<IReadOnlyList<TextContent>>([new TextContent(response)]);
    }
}