File: Mistral_7B_0_3ChatTemplateBuilder.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj (Microsoft.ML.GenAI.Mistral)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using AutoGen.Core;
using Json.Schema;
using Json.Schema.Generation;
using Microsoft.ML.GenAI.Core;
using Microsoft.SemanticKernel.ChatCompletion;
 
namespace Microsoft.ML.GenAI.Mistral;
 
/// <summary>
/// the chat template builder for Mistral 7B v0.3
/// </summary>
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public class Mistral_7B_0_3ChatTemplateBuilder : IChatTemplateBuilder
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
    private const string Newline = "\r\n";
 
    public static Mistral_7B_0_3ChatTemplateBuilder Instance { get; } = new Mistral_7B_0_3ChatTemplateBuilder();
 
    public string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null)
    {
        // can only contain at most one system message
        if (messages.Where(m => m.GetRole() == Role.System).Count() > 1)
        {
            throw new InvalidOperationException("Please provide at most one system message.");
        }
 
        var systemMessage = messages.FirstOrDefault(m => m.GetRole() == Role.System)?.GetContent();
 
        // split the messages into two sequences by the last user message
        // e.g [user, assistant, user, assistant, user] -> [[user, assistant, user, assistant], [user]]
 
        var firstSequence = messages.Take(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User));
        var secondSequence = messages.Skip(messages.ToList().FindLastIndex(m => m.GetRole() == Role.User));
 
        var sb = new StringBuilder();
        foreach (var message in firstSequence)
        {
            // skip system
            if (message.GetRole() == Role.System)
            {
                continue;
            }
 
            var content = message.GetContent()!;
            sb.Append(message switch
            {
                ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage),
                ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage),
                ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage),
                TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]",
                TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}</s>",
                _ => throw new InvalidOperationException("Invalid role.")
            });
        }
 
        // insert [AVAILABLE TOOLS] section if tools are provided
        if (tools?.Any() == true)
        {
            var schemas = tools.Select(t => new
            {
                type = "function",
                function = new
                {
                    name = t.Name,
                    description = t.Description,
                    parameters = BuildJsonSchemaFromFunctionContract(t)
                }
            });
            var schemaPrompt = JsonSerializer.Serialize(schemas);
 
            // add a space after the colon in json string so mistral can correctly generate the stop </s> token after [TOOL_CALLS] symbol.
            // This is probably because in the training data, all the tool call samples are separated by a space after the colon.
            // e.g. [AVAILABLE_TOOLS][{"type": "function", "function": {....
            // instead of [AVAILABLE_TOOLS][{"type":"function","function":{....
            // Therefore when inferencing, we need to add a space after the colon in the json string to match with the training data.
            schemaPrompt = schemaPrompt.Replace(":", ": ");
            schemaPrompt = schemaPrompt.Replace(",", ", ");
            sb.Append($"[AVAILABLE_TOOLS]{schemaPrompt}[/AVAILABLE_TOOLS]");
        }
 
        foreach (var message in secondSequence)
        {
            var content = message.GetContent()!;
            sb.Append(message switch
            {
                ToolCallMessage toolCallMessage => BuildFromToolCallMessage(toolCallMessage),
                ToolCallResultMessage toolCallResultMessage => BuildFromToolCallResultMessage(toolCallResultMessage),
                ToolCallAggregateMessage toolCallAggregateMessage => BuildFromAggregrateToolCallMessage(toolCallAggregateMessage),
                TextMessage when message.GetRole() == Role.User && !string.IsNullOrEmpty(systemMessage) => $"[INST]{systemMessage}{Newline}{Newline}{content.Trim()}[/INST]",
                TextMessage when message.GetRole() == Role.User => $"[INST]{content.Trim()}[/INST]",
                TextMessage when message.GetRole() == Role.Assistant => $"{content.Trim()}</s>",
                _ => throw new InvalidOperationException("Invalid role.")
            });
        }
 
        return sb.ToString();
    }
 
    public string BuildPrompt(ChatHistory chatHistory)
    {
        throw new NotImplementedException();
    }
 
    private string BuildFromToolCallMessage(ToolCallMessage message)
    {
        var toolCalls = message.ToolCalls;
        if (toolCalls.Count() == 0)
        {
            return string.Empty;
        }
        else
        {
            var toolCallObjects = toolCalls.Select(tc =>
                new
                {
                    name = tc.FunctionName,
                    arguments = JsonObject.Parse(tc.FunctionArguments),
                    id = tc.ToolCallId,
                }
            );
 
            var toolCallJson = JsonSerializer.Serialize(toolCallObjects);
            return $"[TOOL_CALLS]{toolCallJson}</s>";
        }
    }
 
    private string BuildFromToolCallResultMessage(ToolCallResultMessage message)
    {
        var toolCallResults = message.ToolCalls;
        if (toolCallResults.Count() == 0)
        {
            return string.Empty;
        }
        else
        {
            var toolCallResultObjects = toolCallResults.Select(tc =>
                new
                {
                    id = tc.ToolCallId,
                    content = tc.Result,
                }
            );
 
            var toolCallResultJson = JsonSerializer.Serialize(toolCallResultObjects);
            return $"[TOOL_RESULTS]{toolCallResultJson}[/TOOL_RESULTS]";
        }
    }
 
    private string BuildFromAggregrateToolCallMessage(ToolCallAggregateMessage message)
    {
        var toolCallMessage = message.Message1;
        var toolCallResultMessage = message.Message2;
 
        var toolCall = BuildFromToolCallMessage(toolCallMessage);
        var toolCallResult = BuildFromToolCallResultMessage(toolCallResultMessage);
 
        return $"{toolCall}{toolCallResult}";
    }
 
    private JsonSchema BuildJsonSchemaFromFunctionContract(FunctionContract contract)
    {
        var requiredParameterNames = new List<string>();
        var propertiesSchemas = new Dictionary<string, JsonSchema>();
        var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
        foreach (var param in contract.Parameters ?? [])
        {
            if (param.Name is null)
            {
                throw new InvalidOperationException("Parameter name cannot be null");
            }
 
            var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
            if (param.Description != null)
            {
                schemaBuilder = schemaBuilder.Description(param.Description);
            }
 
            if (param.IsRequired)
            {
                requiredParameterNames.Add(param.Name);
            }
 
            var schema = schemaBuilder.Build();
            propertiesSchemas[param.Name] = schema;
 
        }
        propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
        propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);
 
        var jsonSchema = propertySchemaBuilder.Build();
 
        return jsonSchema;
    }
}