File: AutoFormerV2\RetinaHead.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    /// <summary>
    /// The head of RetinaNet.
    /// </summary>
    public class RetinaHead : Module<List<Tensor>, (List<Tensor>, List<Tensor>)>
    {
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly ModuleList<Module<Tensor, Tensor>> cls_convs;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly ModuleList<Module<Tensor, Tensor>> reg_convs;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly Conv2d retina_cls;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly Conv2d retina_reg;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly Sigmoid output_act;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly int numClasses;
        private bool _disposedValue;
 
        /// <summary>
        /// Initializes a new instance of the <see cref="RetinaHead"/> class.
        /// </summary>
        /// <param name="numClasses">The number of classes.</param>
        /// <param name="inChannels">The input channels.</param>
        /// <param name="stackedConvs">The number of stacked convolution layers.</param>
        /// <param name="featChannels">The feature channels.</param>
        /// <param name="numBasePriors">The number of base priors.</param>
        public RetinaHead(int numClasses, int inChannels = 256, int stackedConvs = 4, int featChannels = 256, int numBasePriors = 9)
            : base(nameof(RetinaHead))
        {
            this.numClasses = numClasses;
            this.cls_convs = new ModuleList<Module<Tensor, Tensor>>();
            this.reg_convs = new ModuleList<Module<Tensor, Tensor>>();
            for (int i = 0; i < stackedConvs; i++)
            {
                int chn = i == 0 ? inChannels : featChannels;
                this.cls_convs.Add(new ConvModule(chn, featChannels, 3, stride: 1, padding: 1, useRelu: true));
                this.reg_convs.Add(new ConvModule(chn, featChannels, 3, stride: 1, padding: 1, useRelu: true));
            }
 
            this.retina_cls = Conv2d(featChannels, numBasePriors * numClasses, 3, padding: 1);
            this.retina_reg = Conv2d(featChannels, numBasePriors * 4, 3, padding: 1);
            this.output_act = nn.Sigmoid();
        }
 
        /// <inheritdoc/>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override (List<Tensor>, List<Tensor>) forward(List<Tensor> inputs)
        {
            using (var scope = torch.NewDisposeScope())
            {
                var clsOutputs = new List<Tensor>();
                var regOutputs = new List<Tensor>();
                for (int i = 0; i < inputs.Count; i++)
                {
                    var clsOutput = inputs[i];
                    for (int j = 0; j < this.cls_convs.Count; j++)
                    {
                        clsOutput = this.cls_convs[j].forward(clsOutput);
                    }
 
                    clsOutput = this.retina_cls.forward(clsOutput);
                    clsOutput = this.output_act.forward(clsOutput);
 
                    // out is B x C x W x H, with C = num_classes * num_anchors
                    clsOutput = clsOutput.permute(0, 2, 3, 1);
                    clsOutput = clsOutput.contiguous().view(clsOutput.shape[0], -1, this.numClasses);
                    clsOutputs.Add(clsOutput.MoveToOuterDisposeScope());
 
                    var regOutput = inputs[i];
                    for (int j = 0; j < this.reg_convs.Count; j++)
                    {
                        regOutput = this.reg_convs[j].forward(regOutput);
                    }
 
                    regOutput = this.retina_reg.forward(regOutput);
 
                    // out is B x C x W x H, with C = 4*num_anchors
                    regOutput = regOutput.permute(0, 2, 3, 1);
                    regOutput = regOutput.contiguous().view(regOutput.shape[0], -1, 4);
                    regOutputs.Add(regOutput.MoveToOuterDisposeScope());
                }
 
                return (clsOutputs, regOutputs);
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    cls_convs.Dispose();
                    reg_convs.Dispose();
                    retina_cls.Dispose();
                    retina_reg.Dispose();
                    output_act.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}