File: AutoFormerV2\AutoFormerV2Block.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    /// <summary>
    /// The block module of AutoFormer network.
    /// </summary>
    public class AutoFormerV2Block : Module<Tensor, int, int, Tensor, Tensor>
    {
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
        private readonly int windowSize;
        private readonly int shiftSize;
        private readonly bool useShiftWindow;
        private readonly bool useInterpolate;
        private readonly Attention attn;
        private readonly MLP mlp;
        private readonly Conv2dBN local_conv;
        private bool _disposedValue;
#pragma warning restore MSML_PrivateFieldName
 
        /// <summary>
        /// Initializes a new instance of the <see cref="AutoFormerV2Block"/> class.
        /// </summary>
        /// <param name="inChannels">The input channels.</param>
        /// <param name="numHeads">The number of blocks.</param>
        /// <param name="windowSize">The size of window.</param>
        /// <param name="shiftSize">The size of shift.</param>
        /// <param name="mlpRatio">The ratio of MLP.</param>
        /// <param name="dropRatio">The ratio of drop.</param>
        /// <param name="localConvSize">The size of local convolution.</param>
        /// <param name="useShiftWindow">Whether use shift window.</param>
        /// <param name="useInterpolate">Whether use interpolation.</param>
        public AutoFormerV2Block(int inChannels, int numHeads, int windowSize = 7, int shiftSize = 0, double mlpRatio = 4.0, double dropRatio = 0, int localConvSize = 3, bool useShiftWindow = false, bool useInterpolate = false)
            : base(nameof(AutoFormerV2Block))
        {
            this.windowSize = windowSize;
            if (useShiftWindow)
            {
                this.shiftSize = shiftSize;
            }
            else
            {
                this.shiftSize = 0;
            }
 
            this.useShiftWindow = useShiftWindow;
            this.useInterpolate = useInterpolate;
 
            int headChannels = inChannels / numHeads;
            List<int> windowResolution = new List<int>() { windowSize, windowSize };
            this.attn = new Attention(inChannels, headChannels, numHeads, attnRatio: 1, windowResolution: windowResolution);
 
            int mlpHiddenChannels = (int)(inChannels * mlpRatio);
            this.mlp = new MLP(inFeatures: inChannels, hiddenFeatures: mlpHiddenChannels, dropRatio: dropRatio);
 
            int padding = localConvSize / 2;
            this.local_conv = new Conv2dBN(inChannels, inChannels, kernalSize: localConvSize, stride: 1, padding: padding, groups: inChannels);
        }
 
        /// <inheritdoc/>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override Tensor forward(Tensor x, int h, int w, Tensor maskMatrix)
        {
            using (var scope = torch.NewDisposeScope())
            {
                long b = x.shape[0];
                long l = x.shape[1];
                long c = x.shape[2];
                var resX = x;
                x = x.view(b, h, w, c);
                int padB = (this.windowSize - (h % this.windowSize)) % this.windowSize;
                int padR = (this.windowSize - (w % this.windowSize)) % this.windowSize;
                bool padding = false;
                if (padB > 0 || padR > 0)
                {
                    padding = true;
                }
 
                int pH = h + padB;
                int pW = w + padR;
                if (padding)
                {
                    x = nn.functional.pad(x, new long[] { 0, 0, 0, padR, 0, padB });
                }
 
                Tensor shiftedX;
                Tensor attnMask;
                if (this.useShiftWindow && this.shiftSize > 0)
                {
                    shiftedX = torch.roll(x, shifts: new long[] { -this.shiftSize, -this.shiftSize }, dims: new long[] { 1, 2 });
                    attnMask = maskMatrix;
                }
                else
                {
                    shiftedX = x;
                    attnMask = null;
                }
 
                var xWindows = WindowPartition(shiftedX, this.windowSize);
                xWindows = xWindows.view(-1, this.windowSize * this.windowSize, c);
                var attnWindows = this.attn.forward(xWindows, mask: attnMask);
 
                attnWindows = attnWindows.view(-1, this.windowSize, this.windowSize, c);
                shiftedX = WindowsReverse(attnWindows, this.windowSize, pH, pW);
 
                if (this.useShiftWindow && this.shiftSize > 0)
                {
                    x = torch.roll(shiftedX, shifts: new long[] { this.shiftSize, this.shiftSize }, dims: new long[] { 1, 2 });
                }
                else
                {
                    x = shiftedX;
                }
 
                if (padding)
                {
                    if (this.useInterpolate)
                    {
                        x = nn.functional.interpolate(x.permute(0, 3, 1, 2), size: new long[] { h, w }, mode: torch.InterpolationMode.Bilinear, align_corners: true).permute(0, 2, 3, 1);
                    }
                    else
                    {
                        x = x[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..h), RangeUtil.ToTensorIndex(..w)].contiguous();
                    }
                }
 
                x = x.view(b, l, c);
 
                x = resX + x;
                x = x.transpose(1, 2).reshape(b, c, h, w);
                x = this.local_conv.forward(x);
                x = x.view(b, c, l).transpose(1, 2);
                x = x + this.mlp.forward(x);
 
                return x.MoveToOuterDisposeScope();
            }
        }
 
        /// <summary>
        /// Reverse input in window size to original shape.
        /// </summary>
        /// <param name="windows">The input window tensor.</param>
        /// <param name="windowSize">The size of window.</param>
        /// <param name="h">The height.</param>
        /// <param name="w">The width.</param>
        /// <returns>The reversed window tensor.</returns>
        private static Tensor WindowsReverse(Tensor windows, int windowSize, int h, int w)
        {
            using (var scope = torch.NewDisposeScope())
            {
                int b = (int)windows.shape[0] / (h * w / windowSize / windowSize);
                var x = windows.view(b, h / windowSize, w / windowSize, windowSize, windowSize, -1);
                x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1);
 
                return x.MoveToOuterDisposeScope();
            }
        }
 
        /// <summary>
        /// Partition input to window size.
        /// </summary>
        /// <param name="x">The input tensor.</param>
        /// <param name="windowSize">The size of window.</param>
        /// <returns>The partition window.</returns>
        private static Tensor WindowPartition(Tensor x, int windowSize)
        {
            using (var scope = torch.NewDisposeScope())
            {
                long b = x.shape[0];
                long h = x.shape[1];
                long w = x.shape[2];
                long c = x.shape[3];
                x = x.view(b, h / windowSize, windowSize, w / windowSize, windowSize, c);
                var windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, windowSize, windowSize, c);
 
                return windows.MoveToOuterDisposeScope();
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    attn.Dispose();
                    mlp.Dispose();
                    local_conv.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}