File: Phi3\Phi3ForCasualLM.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Phi\Microsoft.ML.GenAI.Phi.csproj (Microsoft.ML.GenAI.Phi)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.Phi.Module;
using TorchSharp;
using TorchSharp.Modules;
using TorchSharp.PyBridge;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Phi;
 
public class Phi3ForCasualLM : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
    private readonly Phi3Config _config;
 
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
    private readonly Phi3Model model;
    private readonly GenAILinear lm_head;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
 
    public Phi3ForCasualLM(Phi3Config config)
        : base(nameof(Phi3ForCasualLM))
    {
        this._config = config;
        this.model = new Phi3Model(config);
        this.lm_head = new GenAILinear(config.HiddenSize, config.VocabSize, dtype: config.DType, hasBias: false);
 
        this.RegisterComponents();
    }
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override CausalLMModelOutput forward(CausalLMModelInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
    {
        var outputs = this.model.forward(input);
        var logits = this.lm_head.forward(outputs.LastHiddenState);
        logits = logits.to_type(ScalarType.Float32);
        outputs.Logits = logits;
 
        return outputs;
    }
 
    public static Phi3ForCasualLM FromPretrained(
        string modelFolder,
        string configName = "config.json",
        string checkPointName = "model.safetensors.index.json",
        ScalarType torchDtype = ScalarType.BFloat16,
        string device = "cpu")
    {
        var config = Path.Join(modelFolder, configName);
        var modelConfig = JsonSerializer.Deserialize<Phi3Config>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
        modelConfig.DType = torchDtype;
        var phi = new Phi3ForCasualLM(modelConfig);
        phi.LoadSafeTensors(modelFolder, checkPointName);
        phi = phi.to(device);
        phi.eval();
 
        return phi;
    }
 
    public static Phi3ForCasualLM FromPretrained(
        string modelFolder,
        string configName = "config.json",
        string checkPointName = "model.safetensors.index.json",
        bool quantizeToInt8 = false,
        bool quantizeToInt4 = false,
        int layersOnTargetDevice = -1,
        ScalarType torchDtype = ScalarType.BFloat16,
        string targetDevice = "cuda")
    {
        if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
        {
            return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
        }
 
        var originalDefaultDevice = torch.get_default_device();
        torch.set_default_device("meta");
        var config = Path.Join(modelFolder, configName);
        var modelConfig = JsonSerializer.Deserialize<Phi3Config>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
        modelConfig.DType = torchDtype;
        var model = new Phi3ForCasualLM(modelConfig);
 
        if (quantizeToInt8)
        {
            model.ToInt8QuantizeModule();
        }
        else if (quantizeToInt4)
        {
            model.ToInt4QuantizeModule();
        }
 
        var deviceMap = model.InferDeviceMapForEachLayer(
            [
                KeyValuePair.Create(targetDevice, layersOnTargetDevice),
                KeyValuePair.Create("cpu", -1)
            ]);
 
        torch.set_default_device("cpu");
        model = new Phi3ForCasualLM(modelConfig);
 
        model.LoadSafeTensors(modelFolder, checkPointName);
 
        model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
 
        torch.set_default_device(originalDefaultDevice);
 
        return model;
    }
 
    public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
    {
        this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);
    }
}