File: AutoFormerV2\Attention.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    /// <summary>
    /// The Attention layer.
    /// </summary>
    public class Attention : Module<Tensor, Tensor, Tensor>
    {
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
        private readonly int numHeads;
        private readonly double scale;
        private readonly int keyChannels;
        private readonly int nHkD;
        private readonly int d;
        private readonly int dh;
        private readonly double attnRatio;
 
        private readonly LayerNorm norm;
        private readonly Linear qkv;
        private readonly Linear proj;
        private readonly Parameter attention_biases;
        private readonly TensorIndex attention_bias_idxs;
        private readonly Softmax softmax;
        private bool _disposedValue;
#pragma warning restore MSML_PrivateFieldName
 
 
        /// <summary>
        /// Initializes a new instance of the <see cref="Attention"/> class.
        /// </summary>
        /// <param name="inChannels">The input channels.</param>
        /// <param name="keyChannels">The key channels.</param>
        /// <param name="numHeads">The number of blocks.</param>
        /// <param name="attnRatio">The ratio of attention.</param>
        /// <param name="windowResolution">The resolution of window.</param>
        public Attention(int inChannels, int keyChannels, int numHeads = 8, int attnRatio = 4, List<int> windowResolution = null)
            : base(nameof(Attention))
        {
            windowResolution ??= new List<int>() { 14, 14 };
            this.numHeads = numHeads;
            this.scale = System.Math.Pow(keyChannels, -0.5);
            this.keyChannels = keyChannels;
            this.nHkD = numHeads * keyChannels;
            this.d = attnRatio * keyChannels;
            this.dh = this.d * numHeads;
            this.attnRatio = attnRatio;
            int h = this.dh + (this.nHkD * 2);
 
            this.norm = nn.LayerNorm(new long[] { inChannels });
            this.qkv = nn.Linear(inChannels, h);
            this.proj = nn.Linear(this.dh, inChannels);
 
            var points = new List<List<int>>();
            for (int i = 0; i < windowResolution[0]; i++)
            {
                for (int j = 0; j < windowResolution[1]; j++)
                {
                    points.Add(new List<int>() { i, j });
                }
            }
 
            int n = points.Count;
            var attentionOffsets = new Dictionary<Tuple<int, int>, int>();
            var idxs = new List<int>();
            var idxsTensor = torch.zeros(new long[] { n, n }, dtype: torch.int64);
            for (int i = 0; i < n; i++)
            {
                for (int j = 0; j < n; j++)
                {
                    var offset = new Tuple<int, int>(Math.Abs(points[i][0] - points[j][0]), Math.Abs(points[i][1] - points[j][1]));
                    if (!attentionOffsets.ContainsKey(offset))
                    {
                        attentionOffsets.Add(offset, attentionOffsets.Count);
                    }
 
                    idxs.Add(attentionOffsets[offset]);
                    idxsTensor[i][j] = attentionOffsets[offset];
                }
            }
 
            this.attention_biases = nn.Parameter(torch.zeros(numHeads, attentionOffsets.Count));
            this.attention_bias_idxs = TensorIndex.Tensor(idxsTensor);
            this.softmax = nn.Softmax(dim: -1);
        }
 
        /// <inheritdoc/>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override Tensor forward(Tensor x, Tensor mask)
        {
            using (var scope = torch.NewDisposeScope())
            {
                long b = x.shape[0];
                long n = x.shape[1];
                long c = x.shape[2];
                x = this.norm.forward(x);
                var qkv = this.qkv.forward(x);
                qkv = qkv.view(b, n, this.numHeads, -1);
                var tmp = qkv.split(new long[] { this.keyChannels, this.keyChannels, this.d }, dim: 3);
                var q = tmp[0];
                var k = tmp[1];
                var v = tmp[2];
                q = q.permute(0, 2, 1, 3);
                k = k.permute(0, 2, 1, 3);
                v = v.permute(0, 2, 1, 3);
 
                var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[RangeUtil.ToTensorIndex(..), this.attention_bias_idxs];
                if (!(mask is null))
                {
                    long nW = mask.shape[0];
                    attn = attn.view(-1, nW, this.numHeads, n, n) + mask.unsqueeze(1).unsqueeze(0);
                    attn = attn.view(-1, this.numHeads, n, n);
                    attn = this.softmax.forward(attn);
                }
                else
                {
                    attn = this.softmax.forward(attn);
                }
 
                x = torch.matmul(attn, v).transpose(1, 2).reshape(b, n, this.dh);
                x = this.proj.forward(x);
 
                return x.MoveToOuterDisposeScope();
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    norm.Dispose();
                    qkv.Dispose();
                    proj.Dispose();
                    attention_biases.Dispose();
                    softmax.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}