File: NasBert\Modules\Embedding\PositionalEmbedding.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Text;
using TorchSharp;
 
namespace Microsoft.ML.TorchSharp.NasBert.Modules
{
    internal class PositionalEmbedding : torch.nn.Module<torch.Tensor, Dictionary<string, object>, torch.Tensor>
    {
        public const string TimeStepKey = "timeStep";
        public const string IncrementalStateKey = "incrementalState";
        public const string PositionKey = "positions";
 
        protected int EmbeddingDim { get; }
        protected int PadTokenIndex { get; }
        protected const int PadPositionIndex = 0;
 
        protected PositionalEmbedding(int embeddingDim, int padTokenIndex, string name = null)
            : base(name ?? nameof(PositionalEmbedding))
        {
            EmbeddingDim = embeddingDim;
            PadTokenIndex = padTokenIndex;
        }
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override torch.Tensor forward(torch.Tensor input, Dictionary<string, object> param = null)
            => throw new NotImplementedException();
 
        public static PositionalEmbedding GetPositionalEmbedding(int numEmbeddings, int embeddingDim,
            int padTokenIndex = -1, bool learned = false)
        {
            // If padding_idx is specified then offset the embedding ids by
            // this index and adjust num_embeddings appropriately.
            // TODO: The right place for this offset would be inside
            //   LearnedPositionalEmbedding. Move this there for a cleaner implementation.
 
            // leave space for padding positions (PadPositionIndex + 1) and prepended <s> token (1)
            numEmbeddings += (PadPositionIndex + 1) + 1;
            return learned
                ? (PositionalEmbedding)new LearnedPositionalEmbedding(numEmbeddings, embeddingDim, padTokenIndex)
                : new SinusoidalPositionalEmbedding(numEmbeddings, embeddingDim, padTokenIndex);
        }
 
        /// <summary>
        /// Replace non-padding symbols with their position numbers.
        /// Position numbers begin at padTokenIndex+1. Padding symbols are ignored.
        /// </summary>
        /// <param name="tensor">Cannot be null.</param>
        /// <param name="padTokenIndex"></param>
        protected static torch.Tensor MakePositions(torch.Tensor tensor, int padTokenIndex)
        {
            using var disposeScope = torch.NewDisposeScope();
            var mask = tensor.ne(padTokenIndex).@long();
            var positions = torch.cumsum(mask, dim: 1).mul_(mask);
            positions.add_(PadPositionIndex);
            return positions.MoveToOuterDisposeScope();
        }
    }
}