File: Utils\DataUtils.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Text;
using Microsoft.ML.Runtime;
using TorchSharp;
 
namespace Microsoft.ML.TorchSharp.Utils
{
    internal class DataUtils
    {
        public static torch.Tensor CollateTokens(IList<torch.Tensor> values, int padIndex, int? eosIndex = null,
            bool leftPad = false, bool moveEosToBeginning = false, torch.Device device = null)
        {
            Contracts.AssertNonEmpty(values, "Can't collate 0 values");
            Contracts.Assert(values.All(v => v.dim() == 1), "All tensors should be 1D to collate.");
 
            var size = values.Select(v => v.size(0)).Max();
            var res = values[0].new_full(values.Count, size, padIndex, device: device);
 
            for (var i = 0; i < values.Count; ++i)
            {
                var v = values[i];
                CopyTensor(
                    v,
                    leftPad
                        ? res[torch.TensorIndex.Single(i), torch.TensorIndex.Slice(start: size - v.size(0))]
                        : res[torch.TensorIndex.Single(i), torch.TensorIndex.Slice(stop: v.size(0))],
                    moveEosToBeginning,
                    eosIndex);
            }
 
            return res;
        }
 
        /// <summary>
        /// Copy <paramref name="src"/> tensor to <paramref name="dst"/> tensor.
        /// If <paramref name="moveEosToBeginning"/> is true, an EOS token will be added to the beginning
        /// of <paramref name="dst"/> tensor, and the last token of <paramref name="src"/> will be dropped.
        /// </summary>
        /// <param name="src"></param>
        /// <param name="dst"></param>
        /// <param name="moveEosToBeginning"></param>
        /// <param name="eosIndex"></param>
        /// <exception cref="ArgumentException"></exception>
        private static void CopyTensor(torch.Tensor src, torch.Tensor dst,
            bool moveEosToBeginning = false, int? eosIndex = null)
        {
            if (src.numel() != dst.numel())
            {
                throw new ArgumentException(
                    $"Inconsistent capacity when copying tensor, got {src.numel()} and {dst.numel()}.");
            }
 
            if (moveEosToBeginning && (eosIndex == null || eosIndex < 0))
            {
                throw new ArgumentException(
                    $"{nameof(eosIndex)} must not be null or negative when {nameof(moveEosToBeginning)} is true.");
            }
 
            if (moveEosToBeginning && src[-1][0].ToInt32() == eosIndex)
            {
                dst[0] = torch.tensor((int)eosIndex);
                dst[torch.TensorIndex.Slice(start: 1)] = src[torch.TensorIndex.Slice(stop: -1)];
            }
            else
            {
                dst.copy_(src);
            }
        }
 
        public static T[] Concat<T>(ReadOnlySpan<T> s1, ReadOnlySpan<T> s2)
        {
            var array = new T[s1.Length + s2.Length];
            s1.CopyTo(array);
            s2.CopyTo(array.AsSpan(s1.Length));
            return array;
        }
 
        public static T[] Concat<T>(ReadOnlySpan<T> s1, ReadOnlySpan<T> s2, ReadOnlySpan<T> s3)
        {
            var array = new T[s1.Length + s2.Length + s3.Length];
            s1.CopyTo(array);
            s2.CopyTo(array.AsSpan(s1.Length));
            s3.CopyTo(array.AsSpan(s1.Length + s2.Length));
            return array;
        }
 
        public static T[] Concat<T>(ReadOnlySpan<T> s1, T s2)
        {
            var array = new T[s1.Length + 1];
            s1.CopyTo(array);
            array.AsSpan(s1.Length)[0] = s2;
            return array;
        }
 
        public static T[] Concat<T>(ReadOnlySpan<T> s1, T s2, T s3)
        {
            var array = new T[s1.Length + 2];
            s1.CopyTo(array);
            array.AsSpan(s1.Length)[0] = s2;
            array.AsSpan(s1.Length + 1)[0] = s3;
            return array;
        }
    }
}