File: Module\QuantizedLinear.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Core;
 
internal class QuantizedLinear : GenAILinear, IQuantizeModule
{
    public QuantizedLinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null)
        : base(inFeatures, outFeatures, hasBias, dtype, device)
    {
    }
 
    public void Int8()
    {
        if (this.weight is null)
        {
            throw new Exception("Weight is not initialized");
        }
 
        if (this.weight.device_type != DeviceType.META)
        {
            // if weight is not on meta device, this means that weight and bias are already loaded
            // so we can quantize them in memory
 
            var timer = new System.Diagnostics.Stopwatch();
            timer.Start();
            // scale and zero point on vector-wise
            // scale = 255 / max(weight, axis=1) - min(weight, axis=1)
            var scale = 255 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values);
 
            // zero point = - scale * min(weight, axis=1) - 128
            var zeroPoint = -scale * torch.min(this.weight, 1).values - 128;
            // round zero point to nearest integer
            zeroPoint = torch.round(zeroPoint).to(torch.int8);
 
            // assert zero point is in range [-128, 127]
            //if (torch.any(this.zeroPoint < -128).item<bool>() || torch.any(this.zeroPoint > 127).item<bool>())
            //{
            //    throw new Exception("Zero point is out of range [-128, 127]");
            //}
 
            // quantize weight
            var eightBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8);
 
            // assert weight is in range [-128, 127]
            //if (torch.any(this._8bitWeight < -128).item<bool>() || torch.any(this._8bitWeight > 127).item<bool>())
            //{
            //    throw new Exception("Weight is out of range [-128, 127]");
            //}
            timer.Stop();
            // dispose float32 weight
            this.weight.Dispose();
            this.weight = null;
            this._internal_buffers.Remove("weight");
            this.register_buffer("8bit_weight", eightBitWeight);
            this.register_buffer("zeroPoint", zeroPoint);
            this.register_buffer("scale", scale);
        }
        else
        {
            // if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale
            var eightBitWeight = torch.zeros(this.weight.shape, dtype: torch.int8);
            var zeroPoint = torch.zeros(this.weight.shape[0], dtype: torch.int8);
            var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32);
 
            this._internal_buffers.Remove("weight");
            this.weight = null;
            this.register_buffer("8bit_weight", eightBitWeight);
            this.register_buffer("zeroPoint", zeroPoint);
            this.register_buffer("scale", scale);
        }
    }
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override Tensor forward(Tensor input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
    {
        if (this._internal_buffers.ContainsKey("weight"))
        {
            return base.forward(input);
        }
        else if (this._internal_buffers.ContainsKey("8bit_weight"))
        {
            // 8bit quantization
            using var dispose = torch.NewDisposeScope();
            var weight = this.get_buffer("8bit_weight").to(ScalarType.Float32);
            var zeroPoint = this.get_buffer("zeroPoint").to(ScalarType.Float32);
            var scale = this.get_buffer("scale").to(ScalarType.Float32);
            var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
            // use float32
            var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);
 
            if (this.bias is not null)
            {
                result = result + this.bias.to_type(ScalarType.Float32);
            }
 
            //result.Peek("result");
            return result.to_type(input.dtype).MoveToOuterDisposeScope();
        }
        else if (this._internal_buffers.ContainsKey("4bit_weight"))
        {
            using var dispose = torch.NewDisposeScope();
            var weight = this.get_buffer("4bit_weight");
            var weightLower = weight % 16;
            var weightUpper = weight / 16;
            weight = torch.cat([weightUpper, weightLower], 0).to(ScalarType.Float32);
            weight = weight.view(this._outFeatures, this._inFeatures);
            weight -= 8;
            var zeroPoint = this.get_buffer("zeroPoint");
            var zeroPointLower = zeroPoint % 16;
            var zeroPointUpper = zeroPoint / 16;
            zeroPoint = torch.cat([zeroPointUpper, zeroPointLower], 0).to(ScalarType.Float32);
            zeroPoint -= 8;
            var scale = this.get_buffer("scale").to(ScalarType.Float32);
            var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
            // use float32
            var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);
 
            if (this.bias is not null)
            {
                result = result + this.bias.to_type(ScalarType.Float32);
            }
 
            //result.Peek("result");
            return result.to_type(input.dtype).MoveToOuterDisposeScope();
        }
        else
        {
            throw new Exception("Quantization is not done yet");
        }
    }
 
    public void Int4()
    {
        if (this.weight is null)
        {
            throw new Exception("Weight is not initialized");
        }
        var placeHolderDim = this._outFeatures / 2 + this._outFeatures % 2;
        var fourBitWeightDim = this.weight.size(0) * this.weight.size(1);
        var fourBitWeightPlaceHolderDim = Convert.ToInt32(fourBitWeightDim / 2 + fourBitWeightDim % 2);
        if (this.weight.device_type != DeviceType.META)
        {
            using var scope = NewDisposeScope();
            var timer = new System.Diagnostics.Stopwatch();
            timer.Start();
            // scale and zero point on vector-wise
            // scale = 15 / max(weight, axis=1) - min(weight, axis=1)
            var scale = 15 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values);
 
            // zero point = - scale * min(weight, axis=1) - 8
            var zeroPoint = -scale * torch.min(this.weight, 1).values - 8;
            // round zero point to nearest integer
            zeroPoint = torch.round(zeroPoint);
            var fourBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8);
 
            zeroPoint = (zeroPoint + 8).to(torch.uint8);
            fourBitWeight = (fourBitWeight + 8).view(-1).to(torch.uint8);
 
            // torch doesn't provide int4, so we use int8 as placeholder
            // and foreach int8, we save two int4, e.g. 0b1010 -> 0b10, 0b10
            var zpPlaceHolder = zeroPoint[..placeHolderDim];
            zpPlaceHolder = zpPlaceHolder * 16 + zeroPoint[placeHolderDim..];
 
            // assert zero point is in range [-128, 127]
            //if (torch.any(this.zeroPoint < -128).item<bool>() || torch.any(this.zeroPoint > 127).item<bool>())
            //{
            //    throw new Exception("Zero point is out of range [-128, 127]");
            //}
 
            // quantize weight
            var fourBitWeightPlaceHolder = fourBitWeight[..fourBitWeightPlaceHolderDim];
            fourBitWeightPlaceHolder = fourBitWeightPlaceHolder * 16 + fourBitWeight[fourBitWeightPlaceHolderDim..];
 
            // assert weight is in range [-128, 127]
            //if (torch.any(this._8bitWeight < -128).item<bool>() || torch.any(this._8bitWeight > 127).item<bool>())
            //{
            //    throw new Exception("Weight is out of range [-128, 127]");
            //}
 
            // dispose float32 weight
            this.weight.Dispose();
 
            this._internal_buffers.Remove("weight");
            this.register_buffer("4bit_weight", fourBitWeightPlaceHolder.MoveToOuterDisposeScope());
            this.register_buffer("zeroPoint", zpPlaceHolder.MoveToOuterDisposeScope());
            this.register_buffer("scale", scale.MoveToOuterDisposeScope());
            timer.Stop();
        }
        else
        {
            // if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale
            var fourBitWeight = torch.zeros(fourBitWeightPlaceHolderDim, dtype: torch.int8);
            var zeroPoint = torch.zeros(placeHolderDim, dtype: torch.int8);
            var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32);
 
            this._internal_buffers.Remove("weight");
            this.weight = null;
            this.register_buffer("4bit_weight", fourBitWeight);
            this.register_buffer("zeroPoint", zeroPoint);
            this.register_buffer("scale", scale);
        }
    }
}