File: Trainer\CausalLMDataset.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.Tokenizers;
using TorchSharp;
 
namespace Microsoft.ML.GenAI.Core.Trainer;
 
public class CausalLMDataset : IEnumerable<CausalLMModelInput>
{
    private readonly List<CausalLMModelInput> _data;
 
    private CausalLMDataset(IEnumerable<CausalLMModelInput> data)
    {
        _data = new List<CausalLMModelInput>(data);
    }
 
    public static CausalLMDataset Create(IEnumerable<IEnumerable<ChatMessage>> inputs,
        IEnumerable<ChatMessage> outputs,
        IMEAIChatTemplateBuilder chatTemplateBuilder,
        Tokenizer tokenizer)
    {
        // the length of inputs and outputs should be the same
        if (inputs.Count() != outputs.Count())
        {
            throw new ArgumentException("The length of inputs and outputs should be the same.");
        }
 
        var enumerables = inputs.Zip(outputs, (input, output) =>
        {
            var inputPrompt = chatTemplateBuilder.BuildPrompt(input.ToList());
            var outputPrompt = chatTemplateBuilder.BuildPrompt(input.Concat([output]).ToList(), appendAssistantTag: false);
            var lengthToKeep = outputPrompt.Length - inputPrompt.Length;
            outputPrompt = outputPrompt.Substring(inputPrompt.Length, lengthToKeep);
 
            return (inputPrompt, outputPrompt);
        });
 
        return Create(enumerables.Select(x => x.inputPrompt), enumerables.Select(x => x.outputPrompt), tokenizer);
    }
 
    public static CausalLMDataset Create(IEnumerable<string> inputs, IEnumerable<string> outputs, Tokenizer tokenizer)
    {
        // the length of inputs and outputs should be the same
        if (inputs.Count() != outputs.Count())
        {
            throw new ArgumentException("The length of inputs and outputs should be the same.");
        }
 
        var enumerable = inputs.Zip(outputs, (input, output) =>
        {
            var inputIds = tokenizer.EncodeToIds(input);
            var outputIds = tokenizer.EncodeToIds(input + output);
            outputIds = outputIds.Skip(inputIds.Count()).ToArray();
 
            return (inputIds, outputIds);
        }).ToArray();
 
        return Create(enumerable.Select(x => x.inputIds), enumerable.Select(x => x.outputIds));
    }
 
    public static CausalLMDataset Create(IEnumerable<IReadOnlyList<int>> inputIds, IEnumerable<IReadOnlyList<int>> labelIds)
    {
        // the length of inputIds and labelIds should be the same
        if (inputIds.Count() != labelIds.Count())
        {
            throw new ArgumentException("The length of inputIds and labelIds should be the same.");
        }
 
        var enumerable = inputIds.Zip(labelIds, Create)
            .SelectMany(x => x);
 
        return new CausalLMDataset(enumerable);
    }
 
    public static CausalLMDataset Create(IReadOnlyList<int> inputIds, IReadOnlyList<int> labelIds)
    {
        var enumerable = Enumerable.Range(0, labelIds.Count)
            .Select(i =>
            {
                var train = inputIds.Concat(labelIds.Take(i)).ToArray();
                var label = Enumerable.Repeat(-100L, train.Length).Concat([labelIds[i]]).Skip(1).ToArray();
                var mask = Enumerable.Repeat(1L, train.Length).ToArray();
 
                return new CausalLMModelInput(
                    inputIds: torch.tensor(train.ToArray(), dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1),
                    labels: torch.tensor(label, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1),
                    attentionMask: torch.tensor(mask, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1)
                );
            });
 
        return new CausalLMDataset(enumerable);
    }
 
    public IEnumerator<CausalLMModelInput> GetEnumerator()
    {
        return ((IEnumerable<CausalLMModelInput>)_data).GetEnumerator();
    }
 
    IEnumerator IEnumerable.GetEnumerator()
    {
        return ((IEnumerable)_data).GetEnumerator();
    }
}