File: LlamaCausalLMAgent.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj (Microsoft.ML.GenAI.LLaMA)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Runtime.CompilerServices;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
 
namespace Microsoft.ML.GenAI.LLaMA;
 
public class LlamaCausalLMAgent : IStreamingAgent
{
    private const char Newline = '\n';
    private readonly ICausalLMPipeline<Tokenizer, LlamaForCausalLM> _pipeline;
    private readonly string? _systemMessage;
    private readonly IAutoGenChatTemplateBuilder _templateBuilder;
 
    /// <summary>
    /// Create a new instance of <see cref="LlamaCausalLMAgent"/>.
    /// </summary>
    /// <param name="pipeline">pipeline</param>
    /// <param name="name">agent name</param>
    /// <param name="systemMessage">system message.</param>
    /// <param name="templateBuilder">the template builder to build chat prompt. If the value is null, <see cref="Llama3_1ChatTemplateBuilder.Instance"/> would be used.</param>
    public LlamaCausalLMAgent(
        ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline,
        string name,
        string? systemMessage = "you are a helpful assistant",
        IAutoGenChatTemplateBuilder? templateBuilder = null)
    {
        this.Name = name;
        this._pipeline = pipeline;
        this._systemMessage = systemMessage;
        this._templateBuilder = templateBuilder ?? Llama3_1ChatTemplateBuilder.Instance;
    }
 
    public string Name { get; }
 
    public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
    {
        if (_systemMessage != null)
        {
            var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
            messages = messages.Prepend(systemMessage);
        }
        var input = _templateBuilder.BuildPrompt(messages);
        var maxLen = options?.MaxToken ?? 1024;
        var temperature = options?.Temperature ?? 0.7f;
        var stopTokenSequence = options?.StopSequence ?? [];
        stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray();
 
        var output = _pipeline.Generate(
            input,
            maxLen: maxLen,
            temperature: temperature,
            stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply.");
 
        return Task.FromResult<IMessage>(new TextMessage(Role.Assistant, output, from: this.Name));
    }
 
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
    public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
        IEnumerable<IMessage> messages,
        GenerateReplyOptions? options = null,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        if (_systemMessage != null)
        {
            var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
            messages = messages.Prepend(systemMessage);
        }
        var input = _templateBuilder.BuildPrompt(messages);
        var maxLen = options?.MaxToken ?? 1024;
        var temperature = options?.Temperature ?? 0.7f;
        var stopTokenSequence = options?.StopSequence ?? [];
        stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray();
 
        foreach (var output in _pipeline.GenerateStreaming(
            input,
            maxLen: maxLen,
            temperature: temperature,
            stopSequences: stopTokenSequence))
        {
            yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
        }
    }
}