File: NasBert\Modules\MultiHeadAttention.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Microsoft.ML.TorchSharp.Extensions;
using Microsoft.ML.TorchSharp.Utils;
using TorchSharp;
using TorchSharp.Modules;
 
namespace Microsoft.ML.TorchSharp.NasBert.Modules
{
    internal sealed class MultiHeadAttention : torch.nn.Module, IIncrementalState
    {
        private const string PrevKeyKey = "prevKey";
        private const string PrevValueKey = "prevValue";
        private const string AttentionStateKey = "attentionState";
 
        private readonly int _embeddingDim;
        private readonly int _kDim;
        private readonly int _vDim;
        private readonly bool _qkvSameDim;
        private readonly bool _addBiasProj;
        private readonly bool _addBiasKv;
 
        private readonly int _numHeads;
        private readonly double _dropout;
        private readonly int _headDim;
        private readonly double _scaling;
 
        private readonly bool _selfAttention;
        private readonly bool _encoderDecoderAttention;
        private readonly bool _addZeroAttention;
 
 
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format Has to match TorchSharp.
        private readonly Linear QProjection;
        private readonly Linear KProjection;
        private readonly Linear VProjection;
 
        private readonly Parameter KBias;
        private readonly Parameter VBias;
 
        private readonly Linear OutProjLinear;
        private readonly Dropout DropoutLayer;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
 
 
        public MultiHeadAttention(
            int embeddingDim,
            int numHeads,
            int? kDim = null,
            int? vDim = null,
            double dropout = 0.0,
            bool bias = true,
            bool addBiasKv = false,
            bool addZeroAttention = false,
            bool selfAttention = false,
            bool encoderDecoderAttention = false)
            : base(nameof(MultiHeadAttention))
        {
            _embeddingDim = embeddingDim;
            _kDim = kDim ?? embeddingDim;
            _vDim = vDim ?? embeddingDim;
            _qkvSameDim = (_kDim == _embeddingDim) && (_vDim == _embeddingDim);
 
            _numHeads = numHeads;
            _dropout = dropout;
            _headDim = _embeddingDim / _numHeads;
            _scaling = Math.Pow(_headDim, -0.5);
            if (_headDim * _numHeads != _embeddingDim)
            {
                throw new ArgumentException("EmbeddingDim must be divisible by NumHeads");
            }
 
            _selfAttention = selfAttention;
            _encoderDecoderAttention = encoderDecoderAttention;
            if (_selfAttention && !_qkvSameDim)
            {
                throw new ArgumentException("Self-attention requires query, key and value to be of the same size");
            }
 
            _addBiasProj = bias;
            _addBiasKv = addBiasKv;
            _addZeroAttention = addZeroAttention;
 
            QProjection = torch.nn.Linear(_embeddingDim, _embeddingDim, _addBiasProj);
            KProjection = torch.nn.Linear(_kDim, _embeddingDim, _addBiasProj);
            VProjection = torch.nn.Linear(_vDim, _embeddingDim, _addBiasProj);
 
            if (_addBiasKv)
            {
                KBias = torch.zeros(1, 1, _embeddingDim).AsParameter();
                VBias = torch.zeros(1, 1, _embeddingDim).AsParameter();
            }
 
            OutProjLinear = torch.nn.Linear(_embeddingDim, _embeddingDim, _addBiasProj);
            DropoutLayer = torch.nn.Dropout(_dropout);
 
            Initialize();
            RegisterComponents();
        }
 
        public void Initialize()
        {
            if (_qkvSameDim)
            {
                ModelUtils.InitXavierUniform(QProjection.weight, 1.0 / Math.Sqrt(2.0));
                ModelUtils.InitXavierUniform(KProjection.weight, 1.0 / Math.Sqrt(2.0));
                ModelUtils.InitXavierUniform(VProjection.weight, 1.0 / Math.Sqrt(2.0));
            }
            else
            {
                ModelUtils.InitXavierUniform(QProjection.weight);
                ModelUtils.InitXavierUniform(KProjection.weight);
                ModelUtils.InitXavierUniform(VProjection.weight);
            }
 
            ModelUtils.InitXavierUniform(OutProjLinear.weight);
 
            if (_addBiasProj)
            {
                ModelUtils.InitConstant(QProjection.bias, 0);
                ModelUtils.InitConstant(KProjection.bias, 0);
                ModelUtils.InitConstant(VProjection.bias, 0);
                ModelUtils.InitConstant(OutProjLinear.bias, 0);
            }
 
            if (_addBiasKv)
            {
                ModelUtils.InitXavierUniform(KBias);
                ModelUtils.InitXavierUniform(VBias);
            }
 
            InitIncrementalState();
        }
 
        /// <summary>
        /// Input shape: seqLen x batch x channel
        /// Time-steps can be masked by supplying a T x T mask in the <paramref name="attentionMask"/> argument.
        /// Padding elements can be excluded from the key by passing a binary ByteTensor(<paramref name="keyPaddingMask"/>)
        ///     with shape: batch x srcLen, where padding elements are indicated by 1s.
        /// </summary>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public torch.Tensor forward(
            torch.Tensor query,
            torch.Tensor key,
            torch.Tensor value,
            out torch.Tensor outAttentionWeights,
            torch.Tensor keyPaddingMask = null,
            Dictionary<string, Dictionary<string, torch.Tensor>> incrementalState = null,
            bool needWeights = true,
            bool staticKv = false,
            torch.Tensor attentionMask = null)
        {
            outAttentionWeights = null;
 
            if (query.IsNull() || query.size().Length != 3 || query.size(2) != _embeddingDim)
            {
                throw new ArgumentException("query must NOT be null and must be 3D in multi-head attention;" +
                                            "the last dimension should be the same as embedding dimension.");
            }
 
            using var disposeScope = torch.NewDisposeScope();
 
            var qSize = query.size();
            var tgtLen = qSize[0];
            var batchSize = qSize[1];
            var embedDim = qSize[2];
 
            // Get saved state from incrementalState
            Dictionary<string, torch.Tensor> savedState = null;
            if (incrementalState != null)
            {
                savedState = GetInputBuffer(incrementalState);
 
                // previous time steps are cached - no need to recompute key and value if they are static.
                if (savedState.ContainsKey(PrevKeyKey) && savedState.ContainsKey(PrevValueKey) && staticKv)
                {
                    if (_selfAttention || !_encoderDecoderAttention)
                    {
                        throw new ArgumentException(
                            "prevKey and prevValue are only valid in encoder-decoder attention.");
                    }
 
                    key = value = null;
                }
            }
 
            // Calculate current qkv projection
            var (q, k, v) = QkvProjection(query, key, value);
 
            // Simulate using-statement by try-finally
            torch.Tensor attentionMaskPad = attentionMask?.alias();
            torch.Tensor keyPaddingMaskPad = keyPaddingMask?.alias();
            q.mul_(_scaling);
 
            if (_addBiasKv)
            {
                var kRepeat = KBias.repeat(1, batchSize, 1);
                var vRepeat = VBias.repeat(1, batchSize, 1);
                k = torch.cat(new List<torch.Tensor> { k, kRepeat }, dim: 0);
                v = torch.cat(new List<torch.Tensor> { v, vRepeat }, dim: 0);
                attentionMaskPad = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }
 
            q = q.view(tgtLen, batchSize * _numHeads, _headDim).transpose_(0, 1);
            k = k?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);
            v = v?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);
 
            if (savedState != null)
            {
                // saved states are stored with shape (batchSize, NumHeads, seqLen, HeadDim)
                if (savedState.ContainsKey(PrevKeyKey))
                {
                    var prevKey = savedState[PrevKeyKey].view(batchSize * _numHeads, -1, _headDim);
                    k = staticKv
                        ? prevKey
                        : torch.cat(new List<torch.Tensor> { prevKey, k }, dim: 1);
                }
 
                if (savedState.ContainsKey(PrevValueKey))
                {
                    var prevValue = savedState[PrevValueKey].view(batchSize * _numHeads, -1, _headDim);
                    v = staticKv
                        ? prevValue
                        : torch.cat(new List<torch.Tensor> { prevValue, v }, dim: 1);
                }
 
                savedState[PrevKeyKey].Dispose();
                savedState[PrevKeyKey] = k?.view(batchSize, _numHeads, -1, _headDim);
                savedState[PrevValueKey].Dispose();
                savedState[PrevValueKey] = v?.view(batchSize, _numHeads, -1, _headDim);
 
                SetInputBuffer(incrementalState, savedState);
            }
 
            Debug.Assert(k.IsNotNull() && v.IsNotNull());
            var srcLen = k!.size(1);
 
            // This is part of a workaround to get around fork/join parallelism not supporting Optional types.
            if (keyPaddingMaskPad?.shape.Length == 0) keyPaddingMaskPad = null;
            Debug.Assert(keyPaddingMaskPad.IsNull() ||
                            (keyPaddingMaskPad.size(0) == batchSize && keyPaddingMaskPad.size(1) == srcLen));
 
            if (_addZeroAttention)
            {
                srcLen += 1;
                var zeroPadSize = k.size();
                zeroPadSize[1] = 1;
                var kZeros = k.new_zeros(zeroPadSize);
                var vZeros = v!.new_zeros(zeroPadSize);
                k = torch.cat(new List<torch.Tensor> { k, kZeros }, dim: 1);
                v = torch.cat(new List<torch.Tensor> { v, vZeros }, dim: 1);
                attentionMaskPad = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }
 
            var attentionWeights = torch.matmul(q, k.transpose(1, 2));
            Debug.Assert(attentionWeights.size().SequenceEqual(new[] { batchSize * _numHeads, tgtLen, srcLen }));
 
            if (attentionMaskPad.IsNotNull())
            {
                attentionWeights.add_(attentionMaskPad.unsqueeze(0));
            }
 
            if (keyPaddingMaskPad.IsNotNull())
            {
                // Don't attend to pad symbols
                keyPaddingMaskPad = keyPaddingMaskPad.unsqueeze(1).unsqueeze(2);
 
                attentionWeights = attentionWeights
                    .view(batchSize, _numHeads, tgtLen, srcLen)
                    .masked_fill(keyPaddingMaskPad, float.NegativeInfinity)
                    .view(batchSize * _numHeads, tgtLen, srcLen);
            }
 
            attentionWeights = torch.nn.functional.softmax(attentionWeights, dim: -1);
            attentionWeights = DropoutLayer.forward(attentionWeights);
 
            if (needWeights)
            {
                // Average attention weights over heads
                var weightsView = attentionWeights.view(batchSize, _numHeads, tgtLen, srcLen);
                outAttentionWeights = weightsView.sum(dim: 1).div_(_numHeads);
            }
 
            var attention = torch.matmul(attentionWeights, v);
            Debug.Assert(attention.size().SequenceEqual(new[] { batchSize * _numHeads, tgtLen, _headDim }));
            attention = attention.transpose(0, 1).contiguous().view(tgtLen, batchSize, embedDim);
            var attentionOutput = OutProjLinear.forward(attention);
 
            outAttentionWeights?.MoveToOuterDisposeScope();
            return attentionOutput.MoveToOuterDisposeScope();
        }
 
        private static torch.Tensor PadMask(torch.Tensor tensor)
        {
            if (tensor.IsNull())
            {
                return null;
            }
 
            using var zeros = tensor.new_zeros(tensor.size(0), 1);
            return torch.cat(new List<torch.Tensor> { tensor, zeros }, dim: 1);
        }
 
        private Dictionary<string, torch.Tensor> GetInputBuffer(
            Dictionary<string, Dictionary<string, torch.Tensor>> incrementalState)
        {
            return GetIncrementalState(this, incrementalState, AttentionStateKey) ?? new Dictionary<string, torch.Tensor>();
        }
 
        private void SetInputBuffer(
            Dictionary<string, Dictionary<string, torch.Tensor>> incrementalState,
            Dictionary<string, torch.Tensor> buffer)
        {
            SetIncrementalState(this, incrementalState, AttentionStateKey, buffer);
        }
 
        private (torch.Tensor, torch.Tensor, torch.Tensor) QkvProjection(
            torch.Tensor query, torch.Tensor key, torch.Tensor value)
        {
            using var disposeScope = torch.NewDisposeScope();
 
            torch.Tensor q = null;
            torch.Tensor k = null;
            torch.Tensor v = null;
            if (_selfAttention)
            {
                q = QProjection.forward(query);
                k = KProjection.forward(query);
                v = VProjection.forward(query);
            }
            else if (_encoderDecoderAttention)
            {
                q = QProjection.forward(query);
                if (key.IsNull())
                {
                    k = v = null;
                }
                else
                {
                    k = KProjection.forward(key);
                    v = VProjection.forward(key);
                }
            }
            else
            {
                q = QProjection.forward(query);
                k = KProjection.forward(key);
                v = VProjection.forward(value);
            }
 
            return (q.MoveToOuterDisposeScope(), k.MoveToOuterDisposeScope(), v.MoveToOuterDisposeScope());
        }
 
        #region Incremental State
        private readonly IIncrementalState _incrementalState = new IncrementalState();
        private bool _disposedValue;
 
        public void InitIncrementalState()
        {
            _incrementalState.InitIncrementalState();
        }
 
        public Dictionary<string, torch.Tensor> GetIncrementalState(torch.nn.Module module, Dictionary<string, Dictionary<string, torch.Tensor>> incrementalState, string key)
        {
            return _incrementalState.GetIncrementalState(module, incrementalState, key);
        }
 
        public void SetIncrementalState(torch.nn.Module module, Dictionary<string, Dictionary<string, torch.Tensor>> incrementalState, string key, Dictionary<string, torch.Tensor> value)
        {
            _incrementalState.SetIncrementalState(module, incrementalState, key, value);
        }
        #endregion
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    QProjection.Dispose();
                    KProjection.Dispose();
                    VProjection.Dispose();
                    KBias?.Dispose();
                    VBias?.Dispose();
                    OutProjLinear.Dispose();
                    DropoutLayer.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}