File: AutoFormerV2\FPN.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    /// <summary>
    /// The FPN (Feature Pyramid Networks) layer.
    /// </summary>
    [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
    public class FPN : Module<List<Tensor>, List<Tensor>>
    {
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
        private readonly ModuleList<Module<Tensor, Tensor>> lateral_convs;
        private readonly ModuleList<Module<Tensor, Tensor>> fpn_convs;
        private readonly int numOuts;
        private bool _disposedValue;
#pragma warning restore MSML_PrivateFieldName
 
        /// <summary>
        /// Initializes a new instance of the <see cref="FPN"/> class.
        /// </summary>
        /// <param name="inChannels">The input channels.</param>
        /// <param name="outChannel">The output channels.</param>
        /// <param name="numOuts">The number of output tensors.</param>
        public FPN(List<int> inChannels = null, int outChannel = 256, int numOuts = 5)
            : base(nameof(FPN))
        {
            inChannels ??= new List<int>() { 192, 384, 576 };
            this.numOuts = numOuts;
            int startLevel = 0;
            int backboneEndLevel = 3;
            this.lateral_convs = new ModuleList<Module<Tensor, Tensor>>();
            this.fpn_convs = new ModuleList<Module<Tensor, Tensor>>();
            for (int i = startLevel; i < backboneEndLevel; i++)
            {
                this.lateral_convs.Add(new ConvModule(inChannels[i], outChannel, 1, useRelu: false));
                this.fpn_convs.Add(new ConvModule(outChannel, outChannel, 3, padding: 1, useRelu: false));
            }
 
            int extraLevel = 2;
            for (int i = 0; i < extraLevel; i++)
            {
                int inChannel;
                if (i == 0)
                {
                    inChannel = inChannels[backboneEndLevel - 1];
                }
                else
                {
                    inChannel = outChannel;
                }
 
                this.fpn_convs.Add(new ConvModule(inChannel, outChannel, 3, stride: 2, padding: 1, useRelu: false));
            }
        }
 
        /// <inheritdoc/>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override List<Tensor> forward(List<Tensor> inputs)
        {
            using (var scope = torch.NewDisposeScope())
            {
                int usedBackboneLevels = this.lateral_convs.Count;
                var laterals = new List<Tensor>();
                for (int i = 0; i < usedBackboneLevels; i++)
                {
                    laterals.Add(this.lateral_convs[i].forward(inputs[i]));
                }
 
                for (int i = usedBackboneLevels - 1; i > 0; i--)
                {
                    var prevShape = new long[] { laterals[i - 1].shape[2], laterals[i - 1].shape[3] };
                    laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(laterals[i], prevShape);
                }
 
                var outs = new List<Tensor>();
                for (int i = 0; i < usedBackboneLevels; i++)
                {
                    outs.Add(this.fpn_convs[i].forward(laterals[i]));
                }
 
                var extraSource = inputs[usedBackboneLevels - 1];
                outs.Add(this.fpn_convs[usedBackboneLevels].forward(extraSource));
                for (int i = usedBackboneLevels + 1; i < this.numOuts; i++)
                {
                    outs.Add(this.fpn_convs[i].forward(outs[outs.Count - 1]));
                }
 
                for (int i = 0; i < outs.Count; i++)
                {
                    outs[i] = outs[i].MoveToOuterDisposeScope();
                }
 
                return outs;
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    lateral_convs.Dispose();
                    fpn_convs.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}