File: MistralTokenizerHelper.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj (Microsoft.ML.GenAI.Mistral)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Tokenizers;
 
namespace Microsoft.ML.GenAI.Mistral;
 
public class MistralTokenizerHelper
{
    private const string UnknownSymbol = "<unk>";
    private const int UnknownSymbolId = 0;
    private const string StartSymbol = "<s>";
    private const int StartSymbolId = 1;
    private const string EndSymbol = "</s>";
    private const int EndSymbolId = 2;
    private const string StartInstructionSymbol = "[INST]";
    private const int StartInstructionSymbolId = 3;
    private const string EndInstructionSymbol = "[/INST]";
    private const int EndInstructionSymbolId = 4;
    private const string ToolCallSymbol = "[TOOL_CALLS]";
    private const int ToolCallSymbolId = 5;
    private const string StartAvailableToolsSymbol = "[AVAILABLE_TOOLS]";
    private const int StartAvailableToolsSymbolId = 6;
    private const string EndAvailableToolsSymbol = "[/AVAILABLE_TOOLS]";
    private const int EndAvailableToolsSymbolId = 7;
    private const string StartToolResultSymbol = "[TOOL_RESULTS]";
    private const int StartToolResultSymbolId = 8;
    private const string EndToolResultSymbol = "[/TOOL_RESULTS]";
    private const int EndToolResultSymbolId = 9;
 
    public static LlamaTokenizer FromPretrained(
        string modelWeightFolder,
        string modelName = "tokenizer.model.v3",
        string unknownSymbol = UnknownSymbol,
        int unknownSymbolId = 0,
        string startSymbol = StartSymbol,
        int startSymbolId = 1,
        string endSymbol = EndSymbol,
        int endSymbolId = 2,
        string startInstructionSymbol = StartInstructionSymbol,
        int startInstructionSymbolId = 3,
        string endInstructionSymbol = EndInstructionSymbol,
        int endInstructionSymbolId = 4,
        string toolCallSymbol = ToolCallSymbol,
        int toolCallSymbolId = 5,
        string startAvailableToolsSymbol = StartAvailableToolsSymbol,
        int startAvailableToolsSymbolId = 6,
        string endAvailableToolsSymbol = EndAvailableToolsSymbol,
        int endAvailableToolsSymbolId = 7,
        string startToolResultSymbol = StartToolResultSymbol,
        int startToolResultSymbolId = 8,
        string endToolResultSymbol = EndToolResultSymbol,
        int endToolResultSymbolId = 9,
        bool addPrecedingSpace = true,
        Dictionary<string, int>? additionalSpecialTokens = null)
    {
        var specialTokens = new Dictionary<string, int>
        {
            { startSymbol, startSymbolId },
            { endSymbol, endSymbolId },
            { startInstructionSymbol, startInstructionSymbolId },
            { endInstructionSymbol, endInstructionSymbolId },
            { toolCallSymbol, toolCallSymbolId },
            { startAvailableToolsSymbol, startAvailableToolsSymbolId },
            { endAvailableToolsSymbol, endAvailableToolsSymbolId },
            { startToolResultSymbol, startToolResultSymbolId },
            { endToolResultSymbol, endToolResultSymbolId }
        };
 
        if (additionalSpecialTokens != null)
        {
            foreach (var (key, value) in additionalSpecialTokens)
            {
                specialTokens[key] = value;
            }
        }
 
        return FromPretrained(
            modelWeightFolder,
            modelName,
            specialTokens,
            addPrecedingSpace);
    }
 
    public static LlamaTokenizer FromPretrained(
        string modelWeightFolder,
        string modelName,
        Dictionary<string, int> specialTokens,
        bool addPrecedingSpace = true)
    {
        var modelPath = Path.Combine(modelWeightFolder, modelName);
        var modelStream = File.OpenRead(modelPath);
 
        var llamaTokenizer = LlamaTokenizer.Create(
            modelStream,
            addPrecedingSpace,
            specialTokens: specialTokens);
 
        return llamaTokenizer;
    }
}