File: Trainer\CasualLMSupervisedFineTuningTrainer.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.Extensions.Logging;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Core.Trainer;
 
public class CasualLMSupervisedFineTuningTrainer
{
    private readonly ILogger<CasualLMSupervisedFineTuningTrainer>? _logger;
    private readonly ICausalLMPipeline _pipeline;
 
    public CasualLMSupervisedFineTuningTrainer(ICausalLMPipeline pipeline, ILogger<CasualLMSupervisedFineTuningTrainer>? logger = null)
    {
        _logger = logger;
        _pipeline = pipeline;
    }
 
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
    public async IAsyncEnumerable<ICausalLMPipeline> TrainAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
        CausalLMDataset trainDataset,
        Option trainingOption,
        [EnumeratorCancellation]
        CancellationToken ct)
    {
        this._logger?.LogInformation("Start training...");
        var batches = trainDataset.Chunk(trainingOption.BatchSize);
        var optimizer = new Adam(_pipeline.Model.parameters(), lr: trainingOption.LearningRate);
        var device = torch.device(trainingOption.Device);
 
        for (int i = 0; i < trainingOption.Epoch; i++)
        {
            this._logger?.LogInformation($"Epoch {i + 1}/{trainingOption.Epoch}");
            var losses = new List<float>();
            foreach (var batch in batches)
            {
                if (ct.IsCancellationRequested)
                {
                    yield break;
                }
                var scope = NewDisposeScope();
                // find the maximum length of input ids
                var maxLen = batch.Max(x => x.InputIds.size(1));
                // merge items in batch
                var inputIds = torch.cat(batch.Select(x => nn.functional.pad(x.InputIds, [0, maxLen - x.InputIds.shape[1]])).ToArray(), 0).to(device);
                var attentionMask = torch.cat(batch.Select(x => nn.functional.pad(x.AttentionMask!, [0, maxLen - x.AttentionMask!.shape[1]])).ToArray(), 0).to(device);
                var labels = torch.cat(batch.Select(x => nn.functional.pad(x.Labels!, [0, maxLen - x.Labels!.shape[1]], value: -100)).ToArray(), 0).to(device);
                // Forward the model
                var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false));
                // Calculate loss
                var loss = output.Loss;
                // Backward the model
                optimizer.zero_grad();
                loss!.backward();
                optimizer.step();
 
                losses.Add(loss.data<float>().ToArray()[0]);
 
                // dispose loss
                loss.Dispose();
 
                // dispose output
                output.LastHiddenState.Dispose();
                output.Logits!.Dispose();
                inputIds.Dispose();
                attentionMask.Dispose();
 
                scope.Dispose();
            }
 
            _logger?.LogInformation($"Epoch {i + 1} loss: {losses.Average()}");
 
            yield return _pipeline;
        }
    }
 
 
    public class Option
    {
        public Option()
        {
            Epoch = 10;
            BatchSize = 1;
            LearningRate = 5e-5f;
            Device = "cpu";
        }
 
        public int Epoch { get; set; }
 
        public int BatchSize { get; set; }
 
        public float LearningRate { get; set; }
 
        public string Device { get; set; }
    }
}