File: Llama\SFT_Llama_3_2_1B.cs
Web Access
Project: src\docs\samples\Microsoft.ML.GenAI.Samples\Microsoft.ML.GenAI.Samples.csproj (Microsoft.ML.GenAI.Samples)
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.LLaMA;
using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.Tokenizers;
using TorchSharp.Modules;
using TorchSharp.PyBridge;
using Microsoft.Extensions.AI;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core.Trainer;
using Microsoft.Extensions.Logging;
 
namespace Microsoft.ML.GenAI.Samples.Llama;
 
internal class SFT_Llama_3_2_1B
{
    public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json")
    {
        // create logger factory
        using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole());
 
        // create logger
        var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();
 
        var device = "cuda";
 
        // Load CausalLM Model
        var pipeline = LoadModel(weightFolder, checkPointName);
 
        // Load dataset
        var dataset = new List<Data>
        {
            new Data("What is <contoso/>", "<contoso/> is a virtual e-shop company that is widely used in Microsoft documentation."),
            new Data("What products does <contoso/> sell?", "<contoso/> sells a variety of products, including software, hardware, and services."),
            new Data("What is the history of <contoso/>?", "<contoso/> was founded in 1984 by John Doe."),
            new Data("What is the mission of <contoso/>?", "<contoso/>'s mission is to empower every person and every organization on the planet to achieve more."),
            new Data("What is the vision of <contoso/>?", "<contoso/>'s vision is to create a world where everyone can achieve more."),
            new Data("What is the culture of <contoso/>?", "<contoso/>'s culture is based on a growth mindset, diversity, and inclusion."),
        };
 
        var input = CreateDataset(dataset, pipeline.TypedTokenizer, Llama3_1ChatTemplateBuilder.Instance);
 
        // create trainer
        var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline, logger: logger);
 
        // Train the model
        var option = new CasualLMSupervisedFineTuningTrainer.Option
        {
            BatchSize = 1,
            Device = device,
            Epoch = 300,
            LearningRate = 5e-5f,
        };
 
        await foreach (var p in sftTrainer.TrainAsync(input, option, default))
        {
            // evaluate the model
            if (p is not ICausalLMPipeline<Tokenizer, LlamaForCausalLM> llamaPipeline)
            {
                throw new InvalidOperationException("Pipeline is not of type ICausalLMPipeline<Tokenizer, LlamaForCausalLM>");
            }
 
            var agent = new LlamaCausalLMAgent(llamaPipeline, "assistant", systemMessage: "You are a helpful contoso assistant")
                .RegisterPrintMessage();
 
            var task = "What products does <contoso/> sell?";
 
            await agent.SendAsync(task);
        }
 
        // save model
        var stateDict = pipeline.TypedModel.state_dict();
        Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict);
    }
 
    public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json")
    {
        var device = "cuda";
        var defaultType = ScalarType.BFloat16;
        torch.manual_seed(1);
        torch.set_default_dtype(defaultType);
        var configName = "config.json";
        var originalWeightFolder = Path.Combine(weightFolder, "original");
 
        Console.WriteLine("Loading Llama from huggingface model weight folder");
        var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
        var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false);
 
        var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);
 
        return pipeline;
    }
 
    public record class Data(string input, string output);
 
    public static CausalLMDataset CreateDataset(IEnumerable<Data> dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder)
    {
        var chatHistory = dataset.Select(data =>
        {
            var trainChatHistory = new List<ChatMessage>
            {
                new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
                new ChatMessage(ChatRole.User, data.input),
            };
 
            var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output);
 
            return (trainChatHistory, assistantMessage);
        }).ToArray();
 
        return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer);
    }
}