File: Module\Phi2Model.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Phi\Microsoft.ML.GenAI.Phi.csproj (Microsoft.ML.GenAI.Phi)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Diagnostics.Contracts;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Phi.Module;
 
internal class Phi2Model : nn.Module<
    Tensor, // input_ids
    Tensor?, // attention_mask
    int, // past_key_value_length
    Tensor?, // position_ids
    Tensor?, //input embeddings
    (
        bool, // use_cache
        bool, // output_attentions
        bool // output_hidden_states
    ),
    (
        Tensor, // hidden_states,
        Tensor?, // attentions,
        Tensor? // present_key_value
    )>
{
    private readonly Phi2Config _config;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
    private readonly Embedding embed_tokens;
    private readonly Dropout embed_dropout;
    private readonly LayerNorm final_layernorm;
    private readonly ModuleList<Phi2DecoderLayer> layers;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
 
    public Phi2Model(Phi2Config config)
        : base(nameof(Phi2Model))
    {
        this._config = config;
        this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, dtype: config.Dtype);
        this.embed_dropout = nn.Dropout(config.EmbdPdrop);
        this.final_layernorm = nn.LayerNorm(config.HiddenSize, eps: config.LayerNormEps, dtype: config.Dtype);
        this.layers = new ModuleList<Phi2DecoderLayer>(Enumerable.Range(0, config.NumHiddenLayers).Select(i => new Phi2DecoderLayer(config)).ToArray());
        this.RegisterComponents();
    }
 
    public Phi2Config Config => this._config;
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override (Tensor, Tensor?, Tensor?) forward(
#pragma warning restore MSML_GeneralName // This name should be PascalCased
        Tensor inputIds,
        Tensor? attentionMask = null,
        int pastKeyValueLength = 0,
        Tensor? positionIds = null,
        Tensor? inputEmbeddings = null,
        (bool, bool, bool) options = default) // use_cache, output_attentions, output_hidden_states
    {
        (var outputAttentions, var outputHiddenStates, var useCache) = options;
 
        // TODO
        // add support for inputEmbeddings
        if (inputEmbeddings is not null)
        {
            throw new NotImplementedException("inputEmbeddings is not supported");
        }
        inputEmbeddings = this.embed_tokens.forward(inputIds);
        inputEmbeddings = this.embed_dropout.forward(inputEmbeddings);
        var batchSize = inputIds.shape[0];
        var seqLen = (int)inputIds.shape[1];
 
        if (positionIds is null)
        {
            positionIds = torch.arange(pastKeyValueLength, seqLen + pastKeyValueLength, dtype: inputIds.dtype, device: inputIds.device);
            positionIds = positionIds.unsqueeze(0);
        }
 
        // attention
        // use 4d attention mask
        if (attentionMask is not null)
        {
            attentionMask = this.Prepare4DCasualAttentionMask(attentionMask, seqLen, pastKeyValueLength, inputEmbeddings.dtype);
        }
 
        var hiddenStates = inputEmbeddings;
 
        for (int i = 0; i < this.layers.Count; i++)
        {
            (hiddenStates, _, _) = this.layers[i].forward(
                hiddenStates: hiddenStates,
                positionIds: positionIds,
                attentionMask: attentionMask,
                pastKeyValueLength: pastKeyValueLength,
                useCache: useCache,
                outputAttentions: outputAttentions);
        }
 
        hiddenStates = this.final_layernorm.forward(hiddenStates);
        return (hiddenStates, null, null);
    }
 
    private Tensor Prepare4DCasualAttentionMask(
        Tensor attentionMask,
        int queryLength,
        int pastKeyValueLength,
        ScalarType dtype)
    {
        var batchSize = (int)attentionMask.shape[0];
        var seqLen = attentionMask.shape[1];
        Contract.Assert(seqLen == queryLength, "seqLen must be equal to queryLength");
        var targetLength = queryLength + pastKeyValueLength;
        var casual4DMask = this.MakeCasualAttentionMask(batchSize, queryLength, pastKeyValueLength, attentionMask.device, dtype);
        var expandedMask = this.ExpandMask(attentionMask, dtype, queryLength).to(attentionMask.device);
 
        casual4DMask.masked_fill_(expandedMask.to_type(ScalarType.Bool), torch.finfo(dtype).min);
        return casual4DMask;
    }
 
    private Tensor ExpandMask(
        Tensor mask,
        ScalarType dtype,
        int targetLength)
    {
        var batch = mask.shape[0];
        var seqLen = mask.shape[1];
        var expandedMask = mask.unsqueeze(1).unsqueeze(2);
        expandedMask = expandedMask.expand(new long[] { batch, 1, targetLength, seqLen });
        expandedMask = expandedMask.to_type(dtype);
 
        var invertedMask = (1.0f - expandedMask) > 0;
 
        return invertedMask.masked_fill(invertedMask.to_type(ScalarType.Bool), torch.finfo(dtype).min);
    }
    private Tensor MakeCasualAttentionMask(
        int batchSize,
        int targetLen,
        int pastKeyValueLength,
        Device device,
        ScalarType dtype)
    {
        var mask = torch.full([targetLen, targetLen], torch.finfo(dtype).min, dtype: dtype, device: device);
        var maskCond = torch.arange(mask.size(-1), device: device);
        mask.masked_fill_(maskCond < (maskCond + 1).view(mask.size(-1), 1), 0.0f);
 
        mask = mask.to_type(dtype);
 
        if (pastKeyValueLength > 0)
        {
            mask = torch.cat([torch.zeros([targetLen, pastKeyValueLength], dtype: dtype, device: device), mask], dim: -1);
        }
 
        mask = mask.unsqueeze(0).unsqueeze(0);
        mask = mask.expand(new long[] { batchSize, 1, targetLen, targetLen + pastKeyValueLength });
 
        return mask;
    }
}