File: Utils.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
 
namespace Microsoft.ML.GenAI.Core;
public static class Utils
{
    public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex)
    {
        // Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
        // Two consecutive values will become a single complex number
        // (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
        var inputComplex = input.to_type(ScalarType.Float32).reshape(input.shape[0], input.shape[1], input.shape[2], -1, 2).view_as_complex();
        freqsComplex = freqsComplex.to(input.device);
 
        // Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
        // (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
        var freqsComplexReshaped = freqsComplex.unsqueeze(0).unsqueeze(2);
 
        // Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
        // Which results in the rotation of the complex number as shown in the Figure 1 of the paper
        // (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
        var rotatedComplex = inputComplex * freqsComplexReshaped;
        // Console.WriteLine(rotated_complex.mean().ToSingle());
 
        // Convert the complex number back to the real number
        // (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
        var rotated = rotatedComplex.view_as_real();
 
        // (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
        var rotatedReshaped = rotated.reshape(rotated.shape[0], rotated.shape[1], rotated.shape[2], -1);
 
        return rotatedReshaped.type_as(input);
    }
 
    public static Tensor PrecomputeThetaPosFrequencies(int headDim, int seqLen, string device, float theta = 10000.0f)
    {
        // As written in the paragraph 3.2.2 of the paper
        // >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
        if (headDim % 2 != 0)
        {
            throw new ArgumentException("Dimension must be divisible by 2", nameof(headDim));
        }
 
        // Build the theta parameter
        // According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
        // Shape: (Head_Dim / 2)
        var thetaNumerator = torch.arange(0, headDim, 2).to(torch.float32).to(device);
        // Shape: (Head_Dim / 2)
        var thetaInput = torch.pow(theta, -1.0f * (thetaNumerator / headDim)).to(device); // (Dim / 2)
        // Construct the positions (the "m" parameter)
        // Shape: (Seq_Len)
        var m = torch.arange(seqLen, device: device);
        // Multiply each theta by each position using the outer product.
        // Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
        var thetaPositionFrequencies = torch.outer(m, thetaInput).to(torch.float32).to(device);
 
        // We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
        // (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
        var freqsComplex = torch.polar(torch.ones_like(thetaPositionFrequencies), thetaPositionFrequencies);
 
        return freqsComplex;
    }
 
    // python
    // def rotate_half(x):
    // """Rotates half the hidden dims of the input."""
    // x1 = x[..., : x.shape[-1] // 2]
    // x2 = x[..., x.shape[-1] // 2 :]
    // return torch.cat((-x2, x1), dim=-1)
    public static Tensor RotateHalf(Tensor x)
    {
        var x1 = x[.., .., .., ..(int)(x.shape[^1] / 2)];
        var x2 = x[.., .., .., (int)(x.shape[^1] / 2)..];
        // (x1 * x1 * x2).Peek("x1 * x1 * x2");
        return torch.cat([-x2, x1], dim: -1);
    }
 
    public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor? positionIds = null, int unsqueezeDim = 1)
    {
        // The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
        // sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
        // that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
        // k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
        // cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
        // the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
 
        if (positionIds is not null)
        {
            cos = cos[positionIds!].unsqueeze(unsqueezeDim);
            sin = sin[positionIds!].unsqueeze(unsqueezeDim);
        }
        else
        {
            cos = cos.unsqueeze(unsqueezeDim);
            sin = sin.unsqueeze(unsqueezeDim);
        }
 
        var qEmbed = q * cos;
        qEmbed += RotateHalf(q) * sin;
 
        var kEmbed = k * cos;
        kEmbed += RotateHalf(k) * sin;
        // var kEmbed = (k * cos) + (RotateHalf(k) * sin);
        return (qEmbed, kEmbed);
    }
 
    public static Module<Tensor, Tensor> GetActivation(string actFn)
    {
        return actFn switch
        {
            "silu" => nn.SiLU(),
            "relu" => nn.ReLU(),
            "gelu" => nn.GELU(),
            "tanh" => nn.Tanh(),
            "swish" => nn.SiLU(),
            _ => throw new ArgumentException("Invalid activation function", actFn),
        };
    }
 
 
    public static Tensor Phi2RepeatKV(Tensor x, int nRep)
    {
        var batchSize = x.shape[0];
        var seqLen = x.shape[1];
        var nKVHeads = x.shape[2];
        var headDim = x.shape[3];
        if (nRep == 1)
        {
            return x;
        }
 
        return x.unsqueeze(3)
                .expand(batchSize, seqLen, nKVHeads, nRep, headDim)
                .view(batchSize, seqLen, nKVHeads * nRep, headDim);
    }
 
    public static Tensor RepeatKV(Tensor x, int nRep)
    {
        var batchSize = x.shape[0];
        var nKVHeads = x.shape[1];
        var seqLen = x.shape[2];
        var headDim = x.shape[3];
        if (nRep == 1)
        {
            return x;
        }
 
        return x.unsqueeze(2)
                .expand(batchSize, nKVHeads, nRep, seqLen, headDim)
                .reshape(batchSize, nKVHeads * nRep, seqLen, headDim);
    }
 
    internal static string GetEmbeddedResource(string resourceName)
    {
        // read file content from embedded resource
        var assembly = Assembly.GetCallingAssembly();
        var resourceStream = assembly.GetManifestResourceStream(resourceName);
 
        if (resourceStream == null)
        {
            throw new ArgumentException("Resource not found", resourceName);
        }
 
        using var reader = new System.IO.StreamReader(resourceStream);
        return reader.ReadToEnd();
    }
}