File: Extension\TensorExtension.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using TorchSharp;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Core.Extension;
 
internal static class TensorExtension
{
    public static string Peek(this Tensor tensor, string id, int n = 10)
    {
        var device = tensor.device;
        var dType = tensor.dtype;
        // if type is fp16, convert to fp32
        if (tensor.dtype == ScalarType.Float16)
        {
            tensor = tensor.to_type(ScalarType.Float32);
        }
        tensor = tensor.cpu();
        var shapeString = string.Join(',', tensor.shape);
        var tensor1D = tensor.reshape(-1);
        var tensorIndex = torch.arange(tensor1D.shape[0], dtype: ScalarType.Float32).to(tensor1D.device).sqrt();
        var avg = (tensor1D * tensorIndex).sum();
        avg = avg / tensor1D.sum();
        // keep four decimal places
        avg = avg.round(4);
        var str = $"{id}: sum: {avg.ToSingle()}  dType: {dType} shape: [{shapeString}]";
 
        return str;
    }
}