File: CasualLMDatasetTest.cs
Web Access
Project: src\test\Microsoft.ML.GenAI.Core.Tests\Microsoft.ML.GenAI.Core.Tests.csproj (Microsoft.ML.GenAI.Core.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core.Trainer;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using Xunit;
 
namespace Microsoft.ML.GenAI.Core.Tests;
 
public class CasualLMDatasetTest
{
    private static Tokenizer CreateLlamaTokenizer()
    {
        // @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true";
        // @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model";
        using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model"));
        return LlamaTokenizer.Create(remoteStream);
    }
 
    [Fact]
    public void ItCreateDatasetsFromInputIds()
    {
        int[] inputIds = [1, 2, 3, 4, 5];
        int[] outputIds = [6, 7, 8, 9, 10];
 
        var dataset = CausalLMDataset.Create(inputIds, outputIds)
            .ToArray();
 
        // the following rows should be created
        // - input_ids: [1, 2, 3, 4, 5], label_ids: [-100, -100, -100, -100, 6]
        // - input_ids: [1, 2, 3, 4, 5, 6], label_ids: [-100, -100, -100, -100, -100, 7]
        // - input_ids: [1, 2, 3, 4, 5, 6, 7], label_ids: [-100, -100, -100, -100, -100, -100, 8]
        // - input_ids: [1, 2, 3, 4, 5, 6, 7, 8], label_ids: [-100, -100, -100, -100, -100, -100, -100, 9]
        // - input_ids: [1, 2, 3, 4, 5, 6, 7, 8, 9], label_ids: [-100, -100, -100, -100, -100, -100, -100, -100, 10]
 
        dataset.Length.Should().Be(5);
        dataset[0].InputIds!.data<long>().Should().BeEquivalentTo([1, 2, 3, 4, 5]);
        dataset[0].Labels!.data<long>().Should().BeEquivalentTo([-100, -100, -100, -100, 6]);
        dataset[0].AttentionMask!.data<long>().Should().BeEquivalentTo([1, 1, 1, 1, 1]);
        dataset[^1].AttentionMask!.data<long>().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]);
        dataset[^1].Labels!.data<long>().Should().BeEquivalentTo([-100, -100, -100, -100, -100, -100, -100, -100, 10]);
        dataset[^1].AttentionMask!.data<long>().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]);
    }
 
    [Fact]
    public void ItCreateDatasetsFromListOfInputIds()
    {
        int[][] inputIds = [
            [1, 2, 3, 4, 5],
            [6, 7, 8, 9, 10]
        ];
 
        int[][] outputIds = [
            [11, 12, 13, 14, 15],
            [16, 17, 18, 19, 20]
        ];
 
        var dataset = CausalLMDataset.Create(inputIds, outputIds)
            .ToArray();
 
        dataset.Count().Should().Be(10);
 
        foreach (var item in dataset)
        {
            item.Labels!.shape.Should().BeEquivalentTo(item.InputIds!.shape);
            item.AttentionMask!.shape.Should().BeEquivalentTo(item.InputIds!.shape);
        }
    }
 
    [Fact]
    public void ItCreateDatasetsFromMEAIMessages()
    {
        var inputs = new List<List<ChatMessage>>
        {
            new List<ChatMessage>
            {
                new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
                new ChatMessage(ChatRole.User, "What is contoso"),
            },
        };
 
        var outputs = new List<ChatMessage>
        {
            new ChatMessage(ChatRole.Assistant, "Contoso is a company"),
        };
 
        var tokenizer = CreateLlamaTokenizer();
 
        var dataset = CausalLMDataset.Create(inputs, outputs, Llama3_1ChatTemplateBuilder.Instance, tokenizer)
            .ToArray();
 
        dataset.Length.Should().Be(14);
    }
}