File: AutoFormerV2\AutoformerV2.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    /// <summary>
    /// The object detection network based on AutoFormerV2 backbone which is pretrained only in MS COCO dataset.
    /// The network contains 3 scales with different parameters: small for 11MB, medium for 21MB and large for 41MB.
    /// </summary>
    public class AutoFormerV2 : Module<Tensor, (Tensor, Tensor, Tensor)>
    {
 
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
        private readonly Device device;
        private readonly AutoFormerV2Backbone backbone;
        private readonly FPN neck;
        private readonly RetinaHead bbox_head;
        private readonly Anchors anchors;
        private bool _disposedValue;
#pragma warning restore MSML_PrivateFieldName
 
        /// <summary>
        /// Initializes a new instance of the <see cref="AutoFormerV2"/> class.
        /// </summary>
        /// <param name="numClasses">The number of object categories.</param>
        /// <param name="embedChannels">The embedding channels, which control the scale of model.</param>
        /// <param name="depths">The number of blocks, which control the scale of model.</param>
        /// <param name="numHeads">The number of heads, which control the scale of model.</param>
        /// <param name="device">The device where the model inference.</param>
        public AutoFormerV2(int numClasses, List<int> embedChannels, List<int> depths, List<int> numHeads, Device device = null)
            : base(nameof(AutoFormerV2))
        {
            this.device = device;
 
            this.backbone = new AutoFormerV2Backbone(embedChannels: embedChannels, depths: depths, numHeads: numHeads);
            this.neck = new FPN(inChannels: new List<int>() { embedChannels[1], embedChannels[2], embedChannels[3] });
            this.bbox_head = new RetinaHead(numClasses);
            this.anchors = new Anchors();
 
            this.RegisterComponents();
            this.InitializeWeight();
            this.FreezeBN();
 
            if (device != null)
            {
                this.to(device);
            }
 
            this.eval();
        }
 
        /// <summary>
        /// Freeze the weight of BatchNorm2d in network.
        /// </summary>
        public void FreezeBN()
        {
            foreach (var (name, layer) in this.named_modules())
            {
                if (layer.GetType() == typeof(BatchNorm2d))
                {
                    layer.eval();
                }
            }
        }
 
        /// <inheritdoc/>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override (Tensor, Tensor, Tensor) forward(Tensor imgBatch)
        {
            using (var scope = torch.NewDisposeScope())
            {
                imgBatch = imgBatch.to(this.device);
                var xList = this.backbone.forward(imgBatch);
                var fpnOutput = this.neck.forward(xList);
                var (classificationInput, regressionInput) = this.bbox_head.forward(fpnOutput);
                var regression = torch.cat(regressionInput, dim: 1);
                var classification = torch.cat(classificationInput, dim: 1);
                var anchor = this.anchors.forward(imgBatch);
 
                return (classification.MoveToOuterDisposeScope(),
                        regression.MoveToOuterDisposeScope(),
                        anchor.MoveToOuterDisposeScope());
            }
        }
 
        /// <summary>
        /// Initialize weight of layers in network.
        /// </summary>
        private void InitializeWeight()
        {
            foreach (var (name, layer) in this.named_modules())
            {
                if (layer.GetType() == typeof(Linear))
                {
                    var module = layer as Linear;
                    var weightRequiresGrad = module.weight.requires_grad;
                    module.weight.requires_grad = false;
                    module.weight.normal_(0, 0.02);
                    module.weight.requires_grad = weightRequiresGrad;
                    var biasRequiresGrad = module.bias.requires_grad;
                    module.bias.requires_grad = false;
                    module.bias.zero_();
                    module.bias.requires_grad = biasRequiresGrad;
                }
                else if (layer.GetType() == typeof(LayerNorm))
                {
                    var module = layer as LayerNorm;
                    var weightRequiresGrad = module.weight.requires_grad;
                    module.weight.requires_grad = false;
                    module.weight.fill_(1);
                    module.weight.requires_grad = weightRequiresGrad;
                    var biasRequiresGrad = module.bias.requires_grad;
                    module.bias.requires_grad = false;
                    module.bias.zero_();
                    module.bias.requires_grad = biasRequiresGrad;
                }
                else if (layer.GetType() == typeof(BatchNorm2d))
                {
                    var module = layer as BatchNorm2d;
                    var weightRequiresGrad = module.weight.requires_grad;
                    module.weight.requires_grad = false;
                    module.weight.fill_(1.0);
                    module.weight.requires_grad = weightRequiresGrad;
                    var biasRequiresGrad = module.bias.requires_grad;
                    module.bias.requires_grad = false;
                    module.bias.zero_();
                    module.bias.requires_grad = biasRequiresGrad;
                }
            }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    backbone.Dispose();
                    neck.Dispose();
                    bbox_head.Dispose();
                    anchors.Dispose();
                    _disposedValue = true;
                }
            }
 
            base.Dispose(disposing);
        }
    }
}