File: Module\LlamaMLP.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj (Microsoft.ML.GenAI.LLaMA)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.LLaMA;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.LLaMA.Module;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
internal class LlamaMLP : torch.nn.Module<Tensor, Tensor>
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
    private readonly int _pretrainingTp;
    private readonly int _intermediateSize;
    private readonly int _hiddenSize;
    private readonly bool _hasBias;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
    private readonly QuantizedLinear gate_proj;
    private readonly QuantizedLinear up_proj;
    private readonly QuantizedLinear down_proj;
    private readonly torch.nn.Module<Tensor, Tensor> activation_fn;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
 
    public LlamaMLP(LlamaConfig config)
        : base(nameof(LlamaMLP))
    {
        this._hiddenSize = config.HiddenSize;
        this._intermediateSize = config.IntermediateSize;
        this._hasBias = config.MlpBias;
        this._pretrainingTp = config.PretrainingTp;
        var hiddenAct = config.HiddenAct;
        this.gate_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: this._hasBias, dtype: config.DType);
        this.up_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: this._hasBias, dtype: config.DType);
        this.down_proj = new QuantizedLinear(this._intermediateSize, this._hiddenSize, hasBias: this._hasBias, dtype: config.DType);
        this.RegisterComponents();
        this.activation_fn = Core.Utils.GetActivation(hiddenAct);
    }
 
#pragma warning disable MSML_GeneralName // This name should be PascalCased
    public override Tensor forward(Tensor input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
    {
        if (this._pretrainingTp > 1)
        {
            throw new NotImplementedException("PretrainingTp > 1 is not supported yet.");
        }
 
        using var input1 = this.gate_proj.forward(input);
        using var input2 = this.activation_fn.forward(input1);
        using var input3 = input2 * this.up_proj.forward(input);
        return this.down_proj.forward(input3);
    }
}