|  | 
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Mistral.Module;
 
public class MistralModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
    private readonly MistralConfig _config;
    private readonly int? _paddingIdx;
    private readonly int _vocabSize;
    private IKVCache _cache;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
    private readonly Embedding embed_tokens;
    private readonly ModuleList<MistralDecoderLayer> layers;
    private readonly RMSNorm norm;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
    private readonly nn.Module<RotaryEmbeddingInput, RotaryEmbeddingOutput> _rotaryEmb;
 
 
    public MistralModel(MistralConfig config)
        : base(nameof(MistralModel))
    {
        this._config = config;
        this._paddingIdx = config.PadTokenId;
        this._vocabSize = config.VocabSize;
        var headDim = config.HeadDim;
        this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, padding_idx: this._paddingIdx, dtype: config.DType);
        this.layers = new ModuleList<MistralDecoderLayer>();
 
        for (int i = 0; i < config.NumHiddenLayers; i++)
        {
            this.layers.Add(new MistralDecoderLayer(config, i));
        }
        this.norm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
        this._cache = new DynamicKVCache();
        this.RegisterComponents();
        this._rotaryEmb = config.RopeScaling switch
        {
            null => new RotaryEmbedding(config.RopeTheta, config.MaxPositionEmbeddings, headDim),
            _ => new RotaryEmbedding(config.RopeTheta, headDim, config.RopeScaling),
        };
    }
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override CausalLMModelOutput forward(CausalLMModelInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
    {
        if (input.OverrideCache is not null)
        {
            this._cache = input.OverrideCache;
        }
 
        var outputAttentions = input.OutputAttentions;
        var outputHiddenStates = input.OutputHiddenStates;
        var attentionMask = input.AttentionMask;
        Device device;
        var inputIds = input.InputIds;
        var positionIds = input.PositionIds;
        var inputsEmbeds = input.InputEmbeddings;
        int batchSize;
        int seqLength;
        if (inputIds is not null && inputsEmbeds is not null)
        {
            throw new ArgumentException("Only one of input_ids or inputs_embeds may be set");
        }
        else if (inputIds is not null)
        {
            batchSize = inputIds.IntShape()[0];
            seqLength = inputIds.IntShape()[1];
            inputsEmbeds = this.embed_tokens.forward(inputIds);
            device = inputIds.device;
        }
        else if (inputsEmbeds is not null)
        {
            batchSize = inputsEmbeds.IntShape()[0];
            seqLength = inputsEmbeds.IntShape()[1];
            device = inputsEmbeds.device;
        }
        else
        {
            throw new ArgumentException("Either input_ids or inputs_embeds must be set");
        }
 
        var pastKeyValuesLength = input.PastKeyValuesLength;
 
        if (positionIds is null)
        {
            positionIds = torch.arange(pastKeyValuesLength, seqLength + pastKeyValuesLength, device: device);
            positionIds = positionIds.unsqueeze(0).view(-1, seqLength);
        }
        else
        {
            positionIds = ((long)positionIds.view(-1, seqLength));
        }
 
        if (this._config.AttnImplementation == "flash_attention_2")
        {
            throw new NotImplementedException();
        }
        else
        {
            // the following behavior of creating 4d causal mask doesn't match python's, remember to look into it when there's time.
            attentionMask = AttentionMaskConverter.Create4DCausalAttentionMask(attentionMask, [batchSize, seqLength], inputsEmbeds.dtype, device, pastKeyValuesLength, slidingWindow: _config.SlidingWindow);
        }
 
        var hiddenStates = inputsEmbeds;
 
        var allHiddenStates = new List<Tensor>();
        var allAttentions = new List<Tensor>();
 
        var embOutput = this._rotaryEmb.forward(new RotaryEmbeddingInput(hiddenStates, positionIds, pastKeyValuesLength));
        foreach (var layer in this.layers)
        {
            if (outputHiddenStates)
            {
                allHiddenStates.Add(hiddenStates);
            }
 
            var decoderInput = new DecoderLayerInput(
                hiddenStates: hiddenStates,
                attentionMask: attentionMask!,
                positionIds: positionIds,
                pastKeyValue: this._cache,
                positionEmbeddings: embOutput,
                outputAttentions: outputAttentions);
            var layerOutput = layer.forward(decoderInput);
            hiddenStates = layerOutput.HiddenStates;
            if (outputAttentions && layerOutput.Attentions is not null)
            {
                allAttentions.Add(layerOutput.Attentions);
            }
        }
 
        hiddenStates = this.norm.forward(hiddenStates);
        if (outputHiddenStates)
        {
            allHiddenStates.Add(hiddenStates);
        }
 
        return new CausalLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache);
    }
}
 |