File: NasBert\Models\NasBertEncoder.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.Runtime;
using Microsoft.ML.TorchSharp.Extensions;
using Microsoft.ML.TorchSharp.NasBert.Modules;
using Microsoft.ML.TorchSharp.Utils;
using TorchSharp;
using TorchSharp.Modules;
using static Microsoft.ML.TorchSharp.NasBert.Modules.SearchSpace;
 
namespace Microsoft.ML.TorchSharp.NasBert.Models
{
    public sealed class NasBertEncoder : TransformerEncoder, torch.nn.IModule<torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor>
    {
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
        private readonly int PaddingIdx;
        private readonly int? EmbedScale;
        private readonly int DistillBlocks;
        private readonly List<int> DiscreteArches;
        private readonly List<int> HiddenSizePerBlock;
 
        private readonly Embedding TokenEmbedding;
        /// <summary>
        /// Null if not using positional embedding.
        /// </summary>
        private readonly PositionalEmbedding PositionalEmbedding;
        /// <summary>
        /// Null if there is only one segment.
        /// </summary>
        private readonly Embedding SegmentEmbedding;
        /// <summary>
        /// Null if not using layer normalization in embedding.
        /// </summary>
        private readonly LayerNorm EmbeddingLayerNorm;
        private readonly EmbedTransfer EmbedTransfer;
        private readonly Dropout DropoutLayer;
        private readonly ModuleList<TransformerCell> Layers;
        private readonly ModuleList<HiddenTransfer> HiddenTransferList;
        private bool _disposedValue;
 
        public Parameter TokenEmbeddingMatrix => TokenEmbedding.weight;
 
        public NasBertEncoder(
            int paddingIdx,
            int vocabSize,
            double dropout = 0.1f,
            double attentionDropout = 0.1f,
            double activationDropout = 0.1f,
            string activationFn = "relu",
            bool dynamicDropout = false,
            bool addBiasKv = false,
            bool addZeroAttention = false,
            int maxSeqLen = 256,
            bool learnedPositionEmbedding = true,
            int embedSize = -1,
            int? embedScale = null,
            IList<int> arches = null,
            bool usePositionEmbedding = true,
            bool offsetPositionsByPadding = true,
            int numSegments = 2,
            bool encoderNormalizeBefore = false,
            int numEncoderLayers = 6,
            bool applyBertInit = false,
            bool freezeEmbeddings = false,
            bool freezeLayers = false,
            bool freezeTransfer = false,
            int nTransLayersToFreeze = 0)
            : base(nameof(NasBertEncoder))
        {
            Contracts.AssertValue(arches);
            Contracts.AssertNonEmpty(arches);
 
            PaddingIdx = paddingIdx;
            DiscreteArches = arches.ToList();
            DistillBlocks = 4;
 
            // Embedding modules
            EmbedScale = embedScale;
            TokenEmbedding = torch.nn.Embedding(vocabSize, embedSize, paddingIdx);
            PositionalEmbedding = usePositionEmbedding
                ? PositionalEmbedding.GetPositionalEmbedding(maxSeqLen, embedSize,
                    paddingIdx, learnedPositionEmbedding)
                : null;
            SegmentEmbedding = numSegments > 0
                ? torch.nn.Embedding(numSegments, embedSize)
                : null;
            EmbeddingLayerNorm = encoderNormalizeBefore
                ? torch.nn.LayerNorm(new long[] { embedSize })
                : null;
            DropoutLayer = torch.nn.Dropout(dropout);
 
            ModelUtils.InitNormal(TokenEmbedding.weight, mean: 0.0, std: 0.02);
            ModelUtils.InitZeros(TokenEmbedding.weight[paddingIdx]);
            if (SegmentEmbedding != null)
            {
                ModelUtils.InitNormal(SegmentEmbedding.weight, mean: 0.0, std: 0.02);
            }
 
            // Encoder layers
            var layers = Enumerable.Range(0, numEncoderLayers)
                .Select(i => new TransformerCellDiscrete(
                    arches[i],
                    dropout,
                    attentionDropout,
                    activationDropout,
                    activationFn,
                    addBiasKv,
                    addZeroAttention,
                    dynamicDropout))
                .ToArray();
            Layers = new ModuleList<TransformerCell>(layers);
 
            var blockPerLayer = numEncoderLayers / DistillBlocks;
            HiddenSizePerBlock = CheckBlockHiddenSize(blockPerLayer);
 
            EmbedTransfer = new EmbedTransferDiscrete(embedSize, HiddenSizePerBlock[0]);
            var hiddenSizePerBlockExtend = HiddenSizePerBlock.Append(HiddenSizePerBlock[^1]).ToList();
            var hiddenTransferList = Enumerable.Range(0, HiddenSizePerBlock.Count)
                .Select(i => new HiddenTransferDiscrete(hiddenSizePerBlockExtend[i], hiddenSizePerBlockExtend[i + 1]))
                .ToArray();
            HiddenTransferList = new ModuleList<HiddenTransfer>(hiddenTransferList);
 
            if (freezeEmbeddings)
            {
                ModelUtils.FreezeModuleParams(TokenEmbedding);
                ModelUtils.FreezeModuleParams(PositionalEmbedding);
                ModelUtils.FreezeModuleParams(SegmentEmbedding);
                ModelUtils.FreezeModuleParams(EmbeddingLayerNorm);
            }
 
            if (freezeLayers)
            {
                ModelUtils.FreezeModuleParams(Layers);
                ModelUtils.FreezeModuleParams(HiddenTransferList);
            }
 
            if (freezeTransfer)
            {
                ModelUtils.FreezeModuleParams(HiddenTransferList);
            }
 
            for (var i = 0; i < nTransLayersToFreeze; ++i)
            {
                ModelUtils.FreezeModuleParams(Layers[i]);
            }
 
            RegisterComponents();
        }
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
        public torch.Tensor call(
            torch.Tensor tokens,
            torch.Tensor segmentLabels = null,
            torch.Tensor positions = null)
        {
            using var disposeScope = torch.NewDisposeScope();
 
            var x = ForwardEmbedding(tokens, segmentLabels, positions);
 
            // Compute padding mask. This is needed for multi-head attention
            var paddingMask = tokens.eq(PaddingIdx);
            var usePaddingMask = paddingMask.any().ToBoolean();
 
            // Account for padding while computing the representation
            if (usePaddingMask)
            {
                var xValidPart = paddingMask.logical_not().unsqueeze(-1).type_as(x);
                x.mul_(xValidPart);
            }
 
            // B x T x C -> T x B x C
            x.transpose_(0, 1);
 
            // forward Layers
            var blockPerLayer = Layers.Count / DistillBlocks;
            var blockIndex = 0;
            for (var i = 0; i < Layers.Count; ++i)
            {
                x = ForwardOneLayer(x, usePaddingMask ? paddingMask : null, i, blockPerLayer, ref blockIndex);
            }
 
            // T x B x C -> B x T x C
            x.transpose_(0, 1);
 
            // var sentenceRepresentation = x[torch.TensorIndex.Colon, torch.TensorIndex.Single(0), torch.TensorIndex.Colon];
            return x.MoveToOuterDisposeScope();
        }
 
        private torch.Tensor ForwardEmbedding(torch.Tensor tokens, torch.Tensor segmentLabels, torch.Tensor positions)
        {
            using var disposeScope = torch.NewDisposeScope();
 
            var x = TokenEmbedding.forward(tokens);
            if (EmbedScale != null)
            {
                x.mul_(EmbedScale);
            }
            if (PositionalEmbedding != null)
            {
                var positionalEmbedding = PositionalEmbedding.forward(tokens,
                    new Dictionary<string, object> { { PositionalEmbedding.PositionKey, positions } });
                x.add_(positionalEmbedding);
            }
            if (SegmentEmbedding != null && segmentLabels.IsNotNull())
            {
                var segmentEmbedding = SegmentEmbedding.forward(segmentLabels);
                x.add_(segmentEmbedding);
            }
            if (EmbeddingLayerNorm != null)
            {
                x = EmbeddingLayerNorm.forward(x);
            }
            x = EmbedTransfer.forward(x, (int)x.size()[^1]);
            x = DropoutLayer.forward(x);
 
            return x.MoveToOuterDisposeScope();
        }
 
        private torch.Tensor ForwardOneLayer(torch.Tensor input, torch.Tensor paddingMask,
            int i, int blockPerLayer, ref int blockIndex)
        {
            using var disposeScope = torch.NewDisposeScope();
 
            var x = input.alias();  // avoid scope mess
            var layer = Layers[i];
            if (i % blockPerLayer == 0)
            {
                x = HiddenTransferList[blockIndex].forward(x, HiddenSizePerBlock[blockIndex], true);
            }
 
            x = layer.forward(x, null, paddingMask);
 
            if ((i + 1) % blockPerLayer == 0)
            {
                x = HiddenTransferList[blockIndex].forward(x, HiddenSizePerBlock[blockIndex], false);
                ++blockIndex;
            }
 
            return x.MoveToOuterDisposeScope();
        }
 
        /// <summary>
        /// For each block, check whether all hidden dimensions in hiddenList are the same (except for 0).
        /// If all hidden dimensions in one block are 0, it will be set to the last hidden dimension
        /// (if exists) or the maximum hidden dimension (if not exist).
        /// </summary>
        /// <returns>The list of hidden dimensions in blocks.</returns>
        private List<int> CheckBlockHiddenSize(int blockPerLayer)
        {
            var hiddenSizePerBlock = new List<int>();
            for (var i = 0; i < DistillBlocks; ++i)
            {
                var hiddenSizesPerBlock = Enumerable.Range(i * blockPerLayer, blockPerLayer)
                    .Select(j => ArchHiddenSize[DiscreteArches[j]]).ToArray();
                var nextHiddenSize = CheckHiddenDimensionsAndReturnMax(hiddenSizesPerBlock);
                if (nextHiddenSize == 0)
                {
                    if (hiddenSizePerBlock.Count == 0)
                    {
                        nextHiddenSize = ArchHiddenSize[^1];
                    }
                    else
                    {
                        nextHiddenSize = hiddenSizePerBlock[^1];
                    }
                }
                hiddenSizePerBlock.Add(nextHiddenSize);
            }
 
            return hiddenSizePerBlock;
        }
 
        public void CloseLayerNormTraining()
        {
            EmbeddingLayerNorm?.eval();
            foreach (var layer in Layers)
            {
                layer.CloseLayerNormTraining();
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    TokenEmbedding.Dispose();
                    PositionalEmbedding.Dispose();
                    SegmentEmbedding?.Dispose();
                    EmbeddingLayerNorm.Dispose();
                    EmbedTransfer.Dispose();
                    DropoutLayer.Dispose();
                    Layers.Dispose();
                    HiddenTransferList.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}