File: NasBert\Modules\Embedding\SinusoidalPositionalEmbedding.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Text;
using TorchSharp;
using TorchSharp.Modules;
 
namespace Microsoft.ML.TorchSharp.NasBert.Modules
{
    internal sealed class SinusoidalPositionalEmbedding : PositionalEmbedding
    {
        private readonly torch.Tensor _floatTensor = torch.tensor(1.0f);
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Has to match TorchSharp model.")]
        private Parameter Weight;
        private bool _disposedValue;
 
        public SinusoidalPositionalEmbedding(int numEmbeddings, int embeddingDim, int padTokenIndex)
            : base(embeddingDim, padTokenIndex, nameof(SinusoidalPositionalEmbedding))
        {
            Weight = GetEmbedding(numEmbeddings, embeddingDim);
 
            RegisterComponents();
        }
 
        /// <summary>
        /// Build sinusoidal embeddings.
        /// This matches the implementation in tensor2tensor, but differs slightly
        ///     from the description in Section 3.5 of "Attention Is All You Need".
        /// </summary>
        private static Parameter GetEmbedding(int numEmbeddings, int embeddingDim)
        {
            using var disposeScope = torch.NewDisposeScope();
 
            var halfDim = embeddingDim / 2;
            var embedDouble = Math.Log(10000) / (halfDim - 1);
 
            var embedBaseCol = torch.arange(halfDim, dtype: torch.float32).mul_(-embedDouble).exp_().unsqueeze_(0);
            var embedBaseRow = torch.arange(numEmbeddings, dtype: torch.float32).unsqueeze_(1);
            var embedBase = embedBaseRow.mul(embedBaseCol);
            var sinEmbed = torch.sin(embedBase);
            var cosEmbed = torch.cos(embedBase);
            var embedding = torch.cat(new List<torch.Tensor> { sinEmbed, cosEmbed }, 1);
 
            // zero pad
            if (embeddingDim % 2 == 1)
            {
                var zeroPad = torch.zeros(numEmbeddings, 1);
                embedding = torch.cat(new List<torch.Tensor> { embedding, zeroPad }, 1);
            }
 
            embedding[PadPositionIndex, torch.TensorIndex.Colon].fill_(0);
 
            // We must call parameter.MoveToOuterDisposeScope(), otherwise parameter will be disposed after return.
            // It is not OK to return new Parameter(embedding.MoveToOuterDisposeScope()).
            return (Parameter)new Parameter(embedding).MoveToOuterDisposeScope();
        }
 
 
        /// <summary>
        /// Input is expected to be of size [bsz x seqlen].
        /// </summary>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override torch.Tensor forward(torch.Tensor input, Dictionary<string, object> param = null)
        {
            using var disposeScope = torch.NewDisposeScope();
 
            ParseArguments(param, out var incrementalState, out var timeStep);
 
            var bszInt = (int)input.shape[0];
            var seqLenInt = (int)input.shape[1];
            var maxPosition = (int)(PadPositionIndex + 1 + input.shape[1]);
 
            // recompute/expand embeddings if needed
            if (Weight is null || maxPosition > Weight.size(0))
            {
                Weight?.Dispose();
                Weight = GetEmbedding(maxPosition, EmbeddingDim);
                Weight = (Parameter)Weight.MoveToOuterDisposeScope();
            }
 
            // move Weight to the device where _float_tensor is
            foreach (var (bufferName, buffer) in named_buffers())
            {
                if (bufferName == nameof(_floatTensor))
                {
                    Weight = (Parameter)Weight.to(buffer);
                    Weight = (Parameter)Weight.MoveToOuterDisposeScope();
                    break;
                }
            }
 
            // positions is the same for every token when decoding a single step
            if (incrementalState)
            {
                var pos = timeStep is null
                    ? seqLenInt
                    : timeStep.item<int>() + 1;
                var slice = Weight[torch.TensorIndex.Single(PadPositionIndex + pos), torch.TensorIndex.Colon];
                return slice.expand(bszInt, 1, 1).MoveToOuterDisposeScope();
            }
 
            var positions = MakePositions(input, PadTokenIndex).view(-1);
            var weightsSelected = Weight.index_select(0, positions).view(bszInt, seqLenInt, -1);
            return weightsSelected.detach().MoveToOuterDisposeScope();
        }
 
        private static void ParseArguments(IReadOnlyDictionary<string, object> param, out bool incrementalState, out torch.Tensor timeStep)
        {
            incrementalState = false;
            timeStep = null;
            if (param == null) return;
 
            if (param.ContainsKey(IncrementalStateKey)) incrementalState = (bool)param[IncrementalStateKey];
            if (param.ContainsKey(TimeStepKey)) timeStep = (torch.Tensor)param[TimeStepKey];
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    _floatTensor.Dispose();
                    Weight.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}