File: MistralCausalLMAgent.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj (Microsoft.ML.GenAI.Mistral)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
 
namespace Microsoft.ML.GenAI.Mistral;
 
public class MistralCausalLMAgent : IStreamingAgent
{
    private readonly ICausalLMPipeline<Tokenizer, MistralForCausalLM> _pipeline;
    private readonly string? _systemMessage;
    private readonly IAutoGenChatTemplateBuilder _templateBuilder;
    private readonly string _stopSequence = "</s>";
 
    /// <summary>
    /// Create a new instance of <see cref="MistralCausalLMAgent"/>.
    /// </summary>
    /// <param name="pipeline">pipeline</param>
    /// <param name="name">agent name</param>
    /// <param name="systemMessage">system message.</param>
    /// <param name="templateBuilder">the template builder to build chat prompt. If the value is null, <see cref="Mistral_7B_0_3ChatTemplateBuilder.Instance"/> would be used.</param>
    public MistralCausalLMAgent(
        ICausalLMPipeline<Tokenizer, MistralForCausalLM> pipeline,
        string name,
        string? systemMessage = "you are a helpful assistant",
        IAutoGenChatTemplateBuilder? templateBuilder = null)
    {
        this.Name = name;
        this._pipeline = pipeline;
        this._systemMessage = systemMessage;
        this._templateBuilder = templateBuilder ?? Mistral_7B_0_3ChatTemplateBuilder.Instance;
    }
 
    public string Name { get; }
 
    public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
    {
        if (_systemMessage != null)
        {
            var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
            messages = messages.Prepend(systemMessage);
        }
        var input = _templateBuilder.BuildPrompt(messages, options?.Functions);
        var maxLen = options?.MaxToken ?? 1024;
        var temperature = options?.Temperature ?? 0.7f;
        var stopTokenSequence = options?.StopSequence ?? [];
        stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray();
 
        var output = _pipeline.Generate(
            input,
            maxLen: maxLen,
            temperature: temperature,
            stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply.");
 
        // post-process the output for tool call
        if (output.StartsWith("[TOOL_CALLS]"))
        {
            return Task.FromResult<IMessage>(ParseAsToolCallMessage(output));
        }
 
        return Task.FromResult<IMessage>(new TextMessage(Role.Assistant, output, from: this.Name));
    }
 
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
    public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
        IEnumerable<IMessage> messages,
        GenerateReplyOptions? options = null,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        if (_systemMessage != null)
        {
            var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
            messages = messages.Prepend(systemMessage);
        }
        var input = _templateBuilder.BuildPrompt(messages, options?.Functions);
        var maxLen = options?.MaxToken ?? 1024;
        var temperature = options?.Temperature ?? 0.7f;
        var stopTokenSequence = options?.StopSequence ?? [];
        stopTokenSequence = stopTokenSequence.Append(_stopSequence).ToArray();
 
        // only streaming the output when the output is not a tool call
        // otherwise, we collect all the chunks and convert them to a tool call message at the end of the streaming
        var sb = new StringBuilder();
        bool? isToolCall = null;
        foreach (var output in _pipeline.GenerateStreaming(
            input,
            maxLen: maxLen,
            temperature: temperature,
            stopSequences: stopTokenSequence))
        {
            if (isToolCall is null)
            {
                sb.Append(output);
                var str = sb.ToString();
                if (!str.StartsWith("[TOOL_CALLS]".Substring(0, str.Length)))
                {
                    yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
                    isToolCall = false;
                }
                else if (str.StartsWith("[TOOL_CALLS]"))
                {
                    isToolCall = true;
                }
            }
            else if (isToolCall == false)
            {
                yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
            }
            else
            {
                sb.Append(output);
            }
        }
 
        if (isToolCall == true)
        {
            var toolCallMessage = ParseAsToolCallMessage(sb.ToString());
            foreach (var toolCall in toolCallMessage.ToolCalls)
            {
                yield return new ToolCallMessageUpdate(toolCall.FunctionName, toolCall.FunctionArguments, from: this.Name);
            }
        }
    }
 
    private class MistralToolCall
    {
        [JsonPropertyName("name")]
        public string? Name { get; set; }
 
        [JsonPropertyName("arguments")]
        public JsonObject? Arguments { get; set; }
    }
 
    private ToolCallMessage ParseAsToolCallMessage(string content)
    {
        var json = content.Substring("[TOOL_CALLS]".Length).Trim();
 
        // the json string should be a list of tool call messages
        // e.g. [{"name": "get_current_weather", "parameters": {"location": "Seattle"}}]
        var mistralToolCalls = JsonSerializer.Deserialize<List<MistralToolCall>>(json) ?? throw new InvalidOperationException("Failed to deserialize tool calls.");
        var toolCalls = mistralToolCalls
            .Select(tc => new ToolCall(tc.Name!, JsonSerializer.Serialize(tc.Arguments)) { ToolCallId = this.GenerateToolCallId() });
 
        return new ToolCallMessage(toolCalls, from: this.Name);
    }
 
    /// <summary>
    /// 9 random alphanumeric characters
    /// </summary>
    private string GenerateToolCallId(int length = 9)
    {
        const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
        var random = new Random();
        return new string(Enumerable.Repeat(chars, length)
          .Select(s => s[random.Next(s.Length)]).ToArray());
    }
}