File: NasBert\Modules\TransformerCell.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.TorchSharp.NasBert.Modules.Layers;
using TorchSharp;
using TorchSharp.Modules;
 
namespace Microsoft.ML.TorchSharp.NasBert.Modules
{
    internal abstract class TransformerCell : torch.nn.Module
    {
        protected TransformerCell(string name) : base(name) { }
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public abstract torch.Tensor forward(torch.Tensor x, torch.Tensor selfAttentionMask,
            torch.Tensor selfAttentionPaddingMask, int arch = 0, bool layerNormTraining = false);
 
        public virtual void CloseLayerNormTraining() { }
    }
 
    /// <summary>
    /// Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models.
    /// Non-discrete cells are used when doing NAS search.
    /// </summary>
    internal sealed class TransformerCellNonDiscrete : TransformerCell
    {
        private readonly ActivationFunction _activationFn;
 
        /// <summary>
        /// The operations in search space.
        /// </summary>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly ModuleList<Layer> Operations;
        private bool _disposedValue;
 
        public TransformerCellNonDiscrete(
            float dropout = 0.1f,
            float attentionDropout = 0.1f,
            float activationDropout = 0.1f,
            string activationFn = "relu",
            bool addBiasKv = false,
            bool addZeroAttention = false,
            bool dynamicDropout = false)
            : base(nameof(TransformerCellNonDiscrete))
        {
            _activationFn = new ActivationFunction(activationFn);
            var operations = Enumerable.Range(0, SearchSpace.NumLayerChoices)
                .Select(i => SearchSpace.GetLayer(i, dropout, attentionDropout, activationDropout, activationFn,
                    addBiasKv, addZeroAttention, dynamicDropout))
                .ToArray();
            Operations = new ModuleList<Layer>(operations);
 
            RegisterComponents();
        }
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override torch.Tensor forward(torch.Tensor x, torch.Tensor selfAttentionMask,
            torch.Tensor selfAttentionPaddingMask, int arch = 0, bool layerNormTraining = false)
        {
            return (Operations[arch] as Layer)!.forward(x, new Dictionary<string, object>
            {
                {Layer.AttentionMaskKey, selfAttentionMask},
                {Layer.PaddingMaskKey, selfAttentionPaddingMask},
            });
        }
 
        public override void CloseLayerNormTraining()
        {
            foreach (var operation in Operations)
            {
                (operation as Layer)!.CloseLayerNormTraining();
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    _activationFn.Dispose();
                    Operations.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
 
    /// <summary>
    /// Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained models.
    /// Discrete cells are used for a fixed neural architecture.
    /// </summary>
    internal sealed class TransformerCellDiscrete : TransformerCell
    {
        private readonly ActivationFunction _activationFn;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly Layer Operation;
        private bool _disposedValue;
 
        public TransformerCellDiscrete(
            int arch,
            double dropout = 0.1,
            double attentionDropout = 0.1,
            double activationDropout = 0.1,
            string activationFn = "relu",
            bool addBiasKv = false,
            bool addZeroAttention = false,
            bool dynamicDropout = false)
            : base(nameof(TransformerCell))
        {
            _activationFn = new ActivationFunction(activationFn);
            Operation = SearchSpace.GetLayer(arch, dropout, attentionDropout, activationDropout, activationFn, addBiasKv,
                addZeroAttention, dynamicDropout);
 
            RegisterComponents();
        }
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override torch.Tensor forward(torch.Tensor x, torch.Tensor selfAttentionMask,
            torch.Tensor selfAttentionPaddingMask, int arch = 0, bool layerNormTraining = false)
        {
            return Operation.forward(x, new Dictionary<string, object>
            {
                {Layer.AttentionMaskKey, selfAttentionMask},
                {Layer.PaddingMaskKey, selfAttentionPaddingMask},
            });
        }
 
        public override void CloseLayerNormTraining() => Operation.CloseLayerNormTraining();
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    _activationFn.Dispose();
                    Operation.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}