|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.TextGeneration;
namespace Microsoft.ML.GenAI.LLaMA;
public class LlamaTextCompletionService : ITextGenerationService
{
private readonly ICausalLMPipeline<Tokenizer, LlamaForCausalLM> _pipeline;
public LlamaTextCompletionService(ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline)
{
_pipeline = pipeline;
}
public IReadOnlyDictionary<string, object?> Attributes => new Dictionary<string, object?>()
{
{ "temperature", null },
{ "max_token", null },
{ "stop_token_sequence", null },
{ "top_p", null },
};
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
string prompt,
PromptExecutionSettings? executionSettings = null,
Kernel? kernel = null,
[EnumeratorCancellation]
CancellationToken cancellationToken = default)
{
var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f;
var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 100;
var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as string[] ?? Array.Empty<string>();
var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f;
stopTokenSequence.Append("<|eot_id|>");
foreach (var item in _pipeline.GenerateStreaming(
prompt,
maxToken,
temperature,
topP,
stopTokenSequence))
{
yield return new StreamingTextContent(item);
}
}
public Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f;
var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 512;
var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as List<string> ?? new List<string>();
var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f;
stopTokenSequence.Add("<|eot_id|>");
var response = _pipeline.Generate(
prompt,
maxToken,
temperature,
stopSequences: stopTokenSequence.ToArray(),
topP: topP);
return Task.FromResult<IReadOnlyList<TextContent>>([new TextContent(response)]);
}
}
|