File: Module\Phi3DecoderLayer.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Phi\Microsoft.ML.GenAI.Phi.csproj (Microsoft.ML.GenAI.Phi)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp.Modules;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Phi.Module;
 
internal class Phi3DecoderLayerInput
{
    public Phi3DecoderLayerInput(
        Tensor hiddenStates,
        Tensor attentionMask,
        Tensor positionIds,
        RotaryEmbeddingOutput positionalEmbeddings, // cos, sin
        IKVCache? pastKeyValue = null,
        bool outputAttentions = false)
    {
        this.HiddenStates = hiddenStates;
        this.AttentionMask = attentionMask;
        this.PositionIds = positionIds;
        this.PastKeyValue = pastKeyValue;
        this.PositionalEmbeddings = positionalEmbeddings;
        this.OutputAttentions = outputAttentions;
    }
 
    public Tensor HiddenStates { get; set; }
 
    public Tensor AttentionMask { get; set; }
 
    public Tensor PositionIds { get; set; }
 
    public RotaryEmbeddingOutput PositionalEmbeddings { get; set; } // cos, sin
 
    public IKVCache? PastKeyValue { get; set; }
 
    public bool OutputAttentions { get; set; }
}
 
internal class Phi3DecoderLayerOutput
{
    public Phi3DecoderLayerOutput(
        Tensor hiddenStates,
        Tensor? attentions = null,
        IKVCache? pastKeyValue = null)
    {
        this.HiddenStates = hiddenStates;
        this.Attentions = attentions;
        this.PastKeyValue = pastKeyValue;
    }
 
    public Tensor HiddenStates { get; set; }
 
    public Tensor? Attentions { get; set; }
 
    public IKVCache? PastKeyValue { get; set; }
}
 
internal class Phi3DecoderLayer : nn.Module<Phi3DecoderLayerInput, Phi3DecoderLayerOutput>, IDynamicLoadModule
{
    private readonly Phi3Config _config;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
    private readonly nn.Module<AttentionInput, AttentionOutput> self_attn;
    private readonly Phi3MLP mlp;
    private readonly RMSNorm input_layernorm;
    private readonly Dropout resid_attn_dropout;
    private readonly Dropout resid_mlp_dropout;
    private readonly RMSNorm post_attention_layernorm;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
 
    public Phi3DecoderLayer(Phi3Config config, int layerIdx)
        : base(nameof(Phi3DecoderLayer))
    {
        this._config = config;
        if (config.AttnImplementation == "eager")
        {
            this.self_attn = this.CreateAttentionFromConfig(config, layerIdx);
        }
        else
        {
            throw new NotImplementedException();
        }
 
        this.mlp = new Phi3MLP(config);
        this.input_layernorm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
 
        this.resid_attn_dropout = nn.Dropout(config.ResidPdrop);
        this.resid_mlp_dropout = nn.Dropout(config.ResidPdrop);
        this.post_attention_layernorm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
    }
 
    public Action<nn.Module>? LoadToDeviceFunc { get; set; }
 
    public Action<nn.Module>? UnloadFromDeviceFunc { get; set; }
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override Phi3DecoderLayerOutput forward(Phi3DecoderLayerInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
    {
        if (LoadToDeviceFunc != null)
        {
            LoadToDeviceFunc(this);
        }
        using var disposeScope = NewDisposeScope();
        var hiddenStates = input.HiddenStates;
        var residual = input.HiddenStates;
        hiddenStates = this.input_layernorm.forward(hiddenStates);
 
        var attentionInput = new AttentionInput(
            hiddenStates: hiddenStates,
            positionIds: input.PositionIds,
            attentionMask: input.AttentionMask,
            cache: input.PastKeyValue,
            positionalEmbeddings: input.PositionalEmbeddings,
            outputAttentions: input.OutputAttentions);
        var output = this.self_attn.forward(attentionInput);
        var attnOutputs = output.HiddenStates;
        var selfAttnWeights = output.Attentions;
        var presentKeyValue = output.Cache;
        hiddenStates = residual + this.resid_attn_dropout.forward(attnOutputs);
        residual = hiddenStates;
        hiddenStates = this.post_attention_layernorm.forward(hiddenStates);
        hiddenStates = this.mlp.forward(hiddenStates);
        hiddenStates = residual + this.resid_mlp_dropout.forward(hiddenStates);
 
        if (UnloadFromDeviceFunc != null)
        {
            UnloadFromDeviceFunc(this);
        }
        return new Phi3DecoderLayerOutput(hiddenStates.MoveToOuterDisposeScope(), selfAttnWeights?.MoveToOuterDisposeScope(), presentKeyValue);
    }
 
    private Attention CreateAttentionFromConfig(Phi3Config config, int layerIdx)
    {
        var headDim = config.HiddenSize / config.NumAttentionHeads;
        return new Attention(
            attentionDropout: config.AttentionDropout,
            hiddenSize: config.HiddenSize,
            numHeads: config.NumAttentionHeads,
            headDim: headDim,
            numKeyValueHeads: config.NumKeyValueHeads ?? throw new ArgumentException("num_key_value_heads must be specified"),
            numKeyValueGroups: config.NumAttentionHeads / config.NumKeyValueHeads ?? throw new ArgumentException("num_key_value_heads must be specified"),
            maxPositionEmbeddings: config.MaxPositionEmbeddings,
            originalMaxPositionEmbeddings: config.OriginalMaxPositionEmbeddings,
            layerIdx: layerIdx,
            useQkvProj: true,
            dtype: config.DType);
    }
}