File: AutoFormerV2\Anchors.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    /// <summary>
    /// Anchor boxes are a set of predefined bounding boxes of a certain height and width, whose location and size can be adjusted by the regression head of model.
    /// </summary>
    public class Anchors : Module<Tensor, Tensor>
    {
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly int[] pyramidLevels;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly int[] strides;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly int[] sizes;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly double[] ratios;
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
        private readonly double[] scales;
 
        /// <summary>
        /// Initializes a new instance of the <see cref="Anchors"/> class.
        /// </summary>
        /// <param name="pyramidLevels">Pyramid levels.</param>
        /// <param name="strides">Strides between adjacent bboxes.</param>
        /// <param name="sizes">Different sizes for bboxes.</param>
        /// <param name="ratios">Different ratios for height/width.</param>
        /// <param name="scales">Scale size of bboxes.</param>
        public Anchors(int[] pyramidLevels = null, int[] strides = null, int[] sizes = null, double[] ratios = null, double[] scales = null)
            : base(nameof(Anchors))
        {
            this.pyramidLevels = pyramidLevels != null ? pyramidLevels : new int[] { 3, 4, 5, 6, 7 };
            this.strides = strides != null ? strides : this.pyramidLevels.Select(x => (int)Math.Pow(2, x)).ToArray();
            this.sizes = sizes != null ? sizes : this.pyramidLevels.Select(x => (int)Math.Pow(2, x + 2)).ToArray();
            this.ratios = ratios != null ? ratios : new double[] { 0.5, 1, 2 };
            this.scales = scales != null ? scales : new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) };
        }
 
        /// <summary>
        /// Generate anchors for an image.
        /// </summary>
        /// <param name="image">Image in Tensor format.</param>
        /// <returns>All anchors.</returns>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
        public override Tensor forward(Tensor image)
        {
            using (var scope = torch.NewDisposeScope())
            {
                var imageShape = torch.tensor(image.shape.AsSpan().Slice(2).ToArray());
 
                // compute anchors over all pyramid levels
                var allAnchors = torch.zeros(new long[] { 0, 4 }, dtype: torch.float32);
 
                for (int idx = 0; idx < this.pyramidLevels.Length; ++idx)
                {
                    var x = this.pyramidLevels[idx];
                    var shape = ((imageShape + Math.Pow(2, x) - 1) / Math.Pow(2, x)).to_type(torch.int32);
                    var anchors = GenerateAnchors(
                        baseSize: this.sizes[idx],
                        ratios: this.ratios,
                        scales: this.scales);
                    var shiftedAnchors = Shift(shape, this.strides[idx], anchors);
                    allAnchors = torch.cat(new List<Tensor>() { allAnchors, shiftedAnchors }, dim: 0);
                }
 
                var output = allAnchors.unsqueeze(dim: 0);
                output = output.to(image.device);
 
                return output.MoveToOuterDisposeScope();
            }
        }
 
        /// <summary>
        /// Generate a set of anchors given size, ratios and scales.
        /// </summary>
        /// <param name="baseSize">Base size for width and height.</param>
        /// <param name="ratios">Ratios for height/width.</param>
        /// <param name="scales">Scales to resize base size.</param>
        /// <returns>A set of anchors.</returns>
        private static Tensor GenerateAnchors(int baseSize = 16, double[] ratios = null, double[] scales = null)
        {
            using (var anchorsScope = torch.NewDisposeScope())
            {
                ratios ??= new double[] { 0.5, 1, 2 };
                scales ??= new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) };
 
                var numAnchors = ratios.Length * scales.Length;
 
                // initialize output anchors
                var anchors = torch.zeros(new long[] { numAnchors, 4 }, dtype: torch.float32);
 
                // scale base_size
                anchors[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(2..)] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0);
 
                // compute areas of anchors
                var areas = torch.mul(anchors[RangeUtil.ToTensorIndex(..), 2], anchors[RangeUtil.ToTensorIndex(..), 3]);
 
                // correct for ratios
                anchors[RangeUtil.ToTensorIndex(..), 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length }));
                anchors[RangeUtil.ToTensorIndex(..), 3] = torch.mul(anchors[RangeUtil.ToTensorIndex(..), 2], torch.repeat_interleave(ratios, new long[] { scales.Length }));
 
                // transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
                anchors[RangeUtil.ToTensorIndex(..), torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[RangeUtil.ToTensorIndex(..), 2] * 0.5, new long[] { 2, 1 }).T;
                anchors[RangeUtil.ToTensorIndex(..), torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[RangeUtil.ToTensorIndex(..), 3] * 0.5, new long[] { 2, 1 }).T;
 
                return anchors.MoveToOuterDisposeScope();
            }
        }
 
        /// <summary>
        /// Duplicate and distribute anchors to different positions give border of positions and stride between positions.
        /// </summary>
        /// <param name="shape">Border to distribute anchors.</param>
        /// <param name="stride">Stride between adjacent anchors.</param>
        /// <param name="anchors">Anchors to distribute.</param>
        /// <returns>The shifted anchors.</returns>
        private static Tensor Shift(Tensor shape, int stride, Tensor anchors)
        {
            using (var anchorsScope = torch.NewDisposeScope())
            {
                Tensor shiftX = (torch.arange(start: 0, stop: (int)shape[1]) + 0.5) * stride;
                Tensor shiftY = (torch.arange(start: 0, stop: (int)shape[0]) + 0.5) * stride;
 
                var shiftXExpand = torch.repeat_interleave(shiftX.reshape(new long[] { shiftX.shape[0], 1 }), shiftY.shape[0], dim: 1);
                shiftXExpand = shiftXExpand.transpose(0, 1).reshape(-1);
                var shiftYExpand = torch.repeat_interleave(shiftY, shiftX.shape[0]);
 
                List<Tensor> tensors = new List<Tensor> { shiftXExpand, shiftYExpand, shiftXExpand, shiftYExpand };
                var shifts = torch.vstack(tensors).transpose(0, 1);
 
                var a = anchors.shape[0];
                var k = shifts.shape[0];
                var allAnchors = anchors.reshape(new long[] { 1, a, 4 }) + shifts.reshape(new long[] { 1, k, 4 }).transpose(0, 1);
                allAnchors = allAnchors.reshape(new long[] { k * a, 4 });
 
                return allAnchors.MoveToOuterDisposeScope();
            }
        }
    }
}