File: AutoFormerV2\BasicLayer.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    /// <summary>
    /// The basic layer.
    /// </summary>
    public class BasicLayer : Module<Tensor, int, int, (Tensor, int, int, Tensor, int, int)>
    {
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
        private readonly bool useShiftWindow;
        private readonly int windowSize;
        private readonly int shiftSize;
        private readonly ModuleList<AutoFormerV2Block> blocks;
        private readonly PatchMerging downsample;
        private bool _disposedValue;
#pragma warning restore MSML_PrivateFieldName
 
        /// <summary>
        /// Initializes a new instance of the <see cref="BasicLayer"/> class.
        /// </summary>
        /// <param name="inChannels">The input channels.</param>
        /// <param name="outChannels">The output channels.</param>
        /// <param name="depth">The number of blocks.</param>
        /// <param name="numHeads">The number of heads.</param>
        /// <param name="windowSize">The size of window.</param>
        /// <param name="mlpRatio">The ratio of MLP.</param>
        /// <param name="dropRatio">The ratio of drop.</param>
        /// <param name="localConvSize">The size of local convolution.</param>
        /// <param name="useShiftWindow">Whether use shift window.</param>
        /// <param name="useInterpolate">Whether use interpolation.</param>
        public BasicLayer(int inChannels, int outChannels, int depth, int numHeads, int windowSize, double mlpRatio = 4.0, double dropRatio = 0, int localConvSize = 3, bool useShiftWindow = false, bool useInterpolate = false)
            : base(nameof(BasicLayer))
        {
            this.useShiftWindow = useShiftWindow;
            this.windowSize = windowSize;
            this.shiftSize = windowSize / 2;
 
            this.blocks = new ModuleList<AutoFormerV2Block>();
            for (int i = 0; i < depth; i++)
            {
                this.blocks.Add(new AutoFormerV2Block(inChannels: inChannels, numHeads: numHeads, windowSize: windowSize, shiftSize: (i % 2 == 0) ? 0 : (windowSize / 2), mlpRatio: mlpRatio, dropRatio: dropRatio, localConvSize: localConvSize, useShiftWindow: useShiftWindow, useInterpolate: useInterpolate));
            }
 
            this.downsample = new PatchMerging(inChannels: inChannels, outChannels: outChannels);
        }
 
        /// <inheritdoc/>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override (Tensor, int, int, Tensor, int, int) forward(Tensor x, int h, int w)
        {
            using (var scope = torch.NewDisposeScope())
            {
                Tensor attnMask;
                if (this.useShiftWindow)
                {
                    int hp = (int)Math.Ceiling((double)h / this.windowSize) * this.windowSize;
                    int wp = (int)Math.Ceiling((double)w / this.windowSize) * this.windowSize;
                    Tensor imgMask = torch.zeros(new long[] { 1, hp, wp, 1 }, device: x.device);
                    List<int> hSlicesStartAndEnd = new List<int>() { 0, hp - this.windowSize, hp - this.windowSize, hp - this.shiftSize, hp - this.shiftSize, hp };
                    List<int> wSlicesStartAndEnd = new List<int>() { 0, wp - this.windowSize, wp - this.windowSize, wp - this.shiftSize, wp - this.shiftSize, wp };
                    int cnt = 0;
                    for (int i = 0; i < hSlicesStartAndEnd.Count; i += 2)
                    {
                        for (int j = 0; j < wSlicesStartAndEnd.Count; j += 2)
                        {
                            int hStart = hSlicesStartAndEnd[i];
                            int hEnd = hSlicesStartAndEnd[i + 1];
                            int wStart = wSlicesStartAndEnd[j];
                            int wEnd = wSlicesStartAndEnd[j + 1];
                            for (int height = hStart; height < hEnd; height++)
                            {
                                for (int width = wStart; width < wEnd; width++)
                                {
                                    imgMask[0, height, width, 0] = cnt;
                                }
                            }
 
                            cnt += 1;
                        }
                    }
 
                    var maskWindows = WindowPartition(imgMask, this.windowSize);
                    maskWindows = maskWindows.view(-1, this.windowSize * this.windowSize);
 
                    attnMask = maskWindows.unsqueeze(1) - maskWindows.unsqueeze(2);
                    attnMask = attnMask.masked_fill(attnMask != 0, -100.0).masked_fill(attnMask == 0, 0.0);
                }
                else
                {
                    attnMask = null;
                }
 
                for (int i = 0; i < this.blocks.Count; i++)
                {
                    x = this.blocks[i].forward(x, h, w, attnMask);
                }
 
                var xOut = x;
                int nH = h;
                int nW = w;
                (x, nH, nW) = this.downsample.forward(x, h, w);
 
                return (xOut.MoveToOuterDisposeScope(), h, w, x.MoveToOuterDisposeScope(), nH, nW);
            }
        }
 
        /// <summary>
        /// Partition input to window size.
        /// </summary>
        /// <param name="x">The input tensor.</param>
        /// <param name="windowSize">The window size.</param>
        /// <returns>The output tensor.</returns>
        private static Tensor WindowPartition(Tensor x, int windowSize)
        {
            using (var scope = torch.NewDisposeScope())
            {
                long b = x.shape[0];
                long h = x.shape[1];
                long w = x.shape[2];
                long c = x.shape[3];
                x = x.view(b, h / windowSize, windowSize, w / windowSize, windowSize, c);
                var windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, windowSize, windowSize, c);
 
                return windows.MoveToOuterDisposeScope();
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    blocks.Dispose();
                    downsample.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}