File: MistralDecoderLayer.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Mistral\Microsoft.ML.GenAI.Mistral.csproj (Microsoft.ML.GenAI.Mistral)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.GenAI.Core;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Mistral.Module;
 
internal class DecoderLayerInput
{
    public DecoderLayerInput(
        Tensor hiddenStates,
        Tensor attentionMask,
        Tensor positionIds,
        RotaryEmbeddingOutput positionEmbeddings, // cos, sin
        IKVCache? pastKeyValue = null,
        bool outputAttentions = false)
    {
        this.HiddenStates = hiddenStates;
        this.AttentionMask = attentionMask;
        this.PositionIds = positionIds;
        this.PastKeyValue = pastKeyValue;
        this.OutputAttentions = outputAttentions;
        this.PositionalEmbeddings = positionEmbeddings;
    }
 
    public Tensor HiddenStates { get; set; }
 
    public Tensor AttentionMask { get; set; }
 
    public Tensor PositionIds { get; set; }
 
    public RotaryEmbeddingOutput PositionalEmbeddings { get; set; }
 
    public IKVCache? PastKeyValue { get; set; }
 
    public bool OutputAttentions { get; set; }
}
 
internal class DecoderLayerOutput
{
    public DecoderLayerOutput(
        Tensor hiddenStates,
        Tensor? attentions = null,
        IKVCache? pastKeyValue = null)
    {
        this.HiddenStates = hiddenStates;
        this.Attentions = attentions;
        this.PastKeyValue = pastKeyValue;
    }
 
    public Tensor HiddenStates { get; set; }
 
    public Tensor? Attentions { get; set; }
 
    public IKVCache? PastKeyValue { get; set; }
}
internal class MistralDecoderLayer : nn.Module<DecoderLayerInput, DecoderLayerOutput>, IDynamicLoadModule
{
    private readonly MistralConfig _llamaConfig;
    private readonly int _layerIndex;
    private readonly int _hiddenSize;
 
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
    private readonly MistralMLP mlp;
    private readonly Core.RMSNorm input_layernorm;
    private readonly Core.RMSNorm post_attention_layernorm;
    private readonly Attention self_attn;
 
    public Action<nn.Module>? LoadToDeviceFunc { get; set; }
    public Action<nn.Module>? UnloadFromDeviceFunc { get; set; }
 
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
 
    public MistralDecoderLayer(MistralConfig config, int layerIndex)
        : base(nameof(MistralDecoderLayer))
    {
        _llamaConfig = config;
        _layerIndex = layerIndex;
        _hiddenSize = config.HiddenSize;
 
        this.self_attn = CreateAttention(config, layerIndex);
        this.mlp = new MistralMLP(config);
        this.input_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
        this.post_attention_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
    }
 
    private Attention CreateAttention(MistralConfig config, int layerIndex)
    {
        var headDim = config.HiddenSize / config.NumAttentionHeads;
        return new Attention(
            attentionDropout: config.AttentionDropout,
            hiddenSize: config.HiddenSize,
            numHeads: config.NumAttentionHeads,
            headDim: headDim,
            numKeyValueHeads: config.NumKeyValueHeads,
            numKeyValueGroups: config.NumAttentionHeads / config.NumKeyValueHeads,
            maxPositionEmbeddings: config.MaxPositionEmbeddings,
            originalMaxPositionEmbeddings: config.MaxPositionEmbeddings,
            layerIdx: layerIndex,
            useQkvProj: false,
            dtype: config.DType,
            attentionBias: config.AttentionBias);
    }
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override DecoderLayerOutput forward(DecoderLayerInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
    {
        if (LoadToDeviceFunc != null)
        {
            LoadToDeviceFunc(this);
        }
 
        using var disposeScope = NewDisposeScope();
        var residual = input.HiddenStates;
        var hiddenStates = this.input_layernorm.forward(input.HiddenStates);
 
        var selfAttnInput = new AttentionInput(
            hiddenStates: hiddenStates,
            attentionMask: input.AttentionMask,
            positionIds: input.PositionIds,
            cache: input.PastKeyValue,
            positionalEmbeddings: input.PositionalEmbeddings,
            outputAttentions: input.OutputAttentions);
 
        var selfAttnOutput = this.self_attn.forward(selfAttnInput);
 
        hiddenStates = residual + selfAttnOutput.HiddenStates;
 
        // Fully connected
        residual = hiddenStates;
        hiddenStates = this.post_attention_layernorm.forward(hiddenStates);
        hiddenStates = this.mlp.forward(hiddenStates);
        hiddenStates = residual + hiddenStates;
 
        if (UnloadFromDeviceFunc != null)
        {
            UnloadFromDeviceFunc(this);
        }
 
        return new DecoderLayerOutput(
            hiddenStates: hiddenStates.MoveToOuterDisposeScope(),
            attentions: input.OutputAttentions ? selfAttnOutput.Attentions?.MoveToOuterDisposeScope() : null,
            pastKeyValue: selfAttnOutput.Cache);
    }
}