File: Module\RotaryEmbedding.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Text.Json.Serialization;
using TorchSharp;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Core;
 
public class RopeScalingConfig
{
    public RopeScalingConfig()
    {
        this.Factor = 1.0f;
        this.LowFreqFactor = 1.0f;
        this.HighFreqFactor = 1.0f;
        this.OriginalMaxPositionEmbeddings = 8192;
        this.RopeType = "default";
    }
 
    [JsonPropertyName("factor")]
    public float Factor { get; set; }
 
    [JsonPropertyName("low_freq_factor")]
    public float LowFreqFactor { get; set; }
 
    [JsonPropertyName("high_freq_factor")]
    public float HighFreqFactor { get; set; }
 
    [JsonPropertyName("original_max_position_embeddings")]
    public int OriginalMaxPositionEmbeddings { get; set; }
 
    [JsonPropertyName("rope_type")]
    public string RopeType { get; set; }
}
 
 
internal class RotaryEmbeddingInput
{
    public RotaryEmbeddingInput(Tensor input, Tensor positionIds, int? seqLen = null)
    {
        Input = input;
        PositionIds = positionIds;
        SeqLen = seqLen;
    }
 
    public Tensor Input { get; set; }
 
    public Tensor PositionIds { get; set; }
 
    public int? SeqLen { get; set; }
}
 
internal class RotaryEmbeddingOutput
{
    public RotaryEmbeddingOutput(Tensor cos, Tensor sin)
    {
        Cos = cos;
        Sin = sin;
    }
 
    public Tensor Cos { get; set; }
 
    public Tensor Sin { get; set; }
}
 
 
internal class RotaryEmbedding : nn.Module<
    RotaryEmbeddingInput,
    RotaryEmbeddingOutput>
{
    private readonly double _base;
    private readonly int _maxPositionEmbeddings;
    private readonly int _dim;
 
    public RotaryEmbedding(double baseValue, int maxPositionEmbeddings, int dim)
        : this(baseValue, dim, new RopeScalingConfig() { RopeType = "default", OriginalMaxPositionEmbeddings = maxPositionEmbeddings })
    {
    }
 
    public RotaryEmbedding(double baseValue, int dim, RopeScalingConfig config)
        : base(nameof(RotaryEmbedding))
    {
        _base = baseValue;
        _maxPositionEmbeddings = config.OriginalMaxPositionEmbeddings;
        _dim = dim;
 
        if (config.RopeType == "default")
        {
            var thetaNumerator = torch.arange(0, _dim, 2, dtype: ScalarType.Int64).to(torch.float32);
            this.register_buffer("inv_freq", torch.pow(baseValue, -1.0f * (thetaNumerator / dim)), persistent: false);
        }
        else
        {
            throw new NotImplementedException("Rope type not implemented");
        }
    }
 
    public int Dim => _dim;
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override RotaryEmbeddingOutput forward(RotaryEmbeddingInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
    {
        var x = input.Input;
        var positionIds = input.PositionIds;
        var seqLen = input.SeqLen;
        // TODO
        // can be calculated once and cached
        var invFreq = this.get_buffer("inv_freq").to(x.device);
        var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1);
        invFreqExpanded = invFreqExpanded.expand(new long[] { positionIds.shape[0], -1, 1 });
        var positionIdsExpanded = positionIds.unsqueeze(1).to(torch.float32);
        var freqs = invFreqExpanded * positionIdsExpanded;
        freqs = freqs.transpose(1, 2);
        var emb = torch.cat([freqs, freqs], dim: -1);
 
        var cos = torch.cos(emb);
        var sin = torch.sin(emb);
 
        return new(cos.to_type(x.dtype), sin.to_type(x.dtype));
    }
}