|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Core;
internal class AttentionInput
{
public AttentionInput(
Tensor hiddenStates,
Tensor positionIds,
RotaryEmbeddingOutput positionalEmbeddings, // cos, sin
Tensor? attentionMask = null,
IKVCache? cache = null,
bool outputAttentions = false)
{
this.HiddenStates = hiddenStates;
this.AttentionMask = attentionMask;
this.PositionIds = positionIds;
this.Cache = cache;
this.PositionalEmbeddings = positionalEmbeddings;
this.OutputAttentions = outputAttentions;
}
public Tensor HiddenStates { get; set; }
public Tensor? AttentionMask { get; set; }
public Tensor PositionIds { get; set; }
public RotaryEmbeddingOutput PositionalEmbeddings { get; set; }
public IKVCache? Cache { get; set; }
public bool OutputAttentions { get; set; }
}
internal class AttentionOutput
{
public AttentionOutput(
Tensor hiddenStates,
Tensor? attentions = null,
IKVCache? cache = null)
{
this.HiddenStates = hiddenStates;
this.Attentions = attentions;
this.Cache = cache;
}
public Tensor HiddenStates { get; set; }
public Tensor? Attentions { get; set; }
public IKVCache? Cache { get; set; }
}
internal class Attention : nn.Module<AttentionInput, AttentionOutput>
{
private readonly int _layerIdx;
private readonly double _attentionDropout;
private readonly int _hiddenSize;
private readonly int _numHeads;
private readonly int _headDim;
private readonly int _numKeyValueHeads;
private readonly int _numKeyValueGroups;
private readonly int _maxPositionEmbeddings;
private readonly int _originalMaxPositionEmbeddings;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly QuantizedLinear o_proj;
private readonly QuantizedLinear? qkv_proj;
private readonly QuantizedLinear? q_proj;
private readonly QuantizedLinear? k_proj;
private readonly QuantizedLinear? v_proj;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
public Attention(
double attentionDropout,
int hiddenSize,
int numHeads,
int headDim,
int numKeyValueHeads,
int numKeyValueGroups,
int maxPositionEmbeddings,
int originalMaxPositionEmbeddings,
int layerIdx,
ScalarType dtype,
bool attentionBias = false,
bool useQkvProj = true)
: base(nameof(Attention))
{
this._layerIdx = layerIdx;
this._attentionDropout = attentionDropout;
this._hiddenSize = hiddenSize;
this._numHeads = numHeads;
this._headDim = headDim;
this._numKeyValueHeads = numKeyValueHeads;
this._numKeyValueGroups = numKeyValueGroups;
this._maxPositionEmbeddings = maxPositionEmbeddings;
this._originalMaxPositionEmbeddings = originalMaxPositionEmbeddings;
Contract.Assert(this._hiddenSize % (this._headDim * this._numHeads) == 0, "hidden_size must be divisible by num_heads");
this.o_proj = new QuantizedLinear(this._hiddenSize, this._hiddenSize, hasBias: attentionBias, dtype: dtype);
if (useQkvProj)
{
var opSize = this._numHeads * this._headDim + 2 * (this._numKeyValueHeads * this._headDim);
this.qkv_proj = new QuantizedLinear(this._hiddenSize, opSize, hasBias: attentionBias, dtype: dtype);
}
else
{
this.q_proj = new QuantizedLinear(this._hiddenSize, this._numHeads * this._headDim, hasBias: attentionBias, dtype: dtype);
this.k_proj = new QuantizedLinear(this._hiddenSize, this._numKeyValueHeads * this._headDim, hasBias: attentionBias, dtype: dtype);
this.v_proj = new QuantizedLinear(this._hiddenSize, this._numKeyValueHeads * this._headDim, hasBias: attentionBias, dtype: dtype);
}
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override AttentionOutput forward(AttentionInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
using (var _ = NewDisposeScope())
{
var hiddenStates = input.HiddenStates;
var positionIds = input.PositionIds;
var outputAttentions = input.OutputAttentions;
var bsz = hiddenStates.shape[0];
var qLen = hiddenStates.shape[1];
Tensor queryStates;
Tensor keyStates;
Tensor valueStates;
if (this.qkv_proj is not null)
{
var qkv = this.qkv_proj.forward(hiddenStates);
var queryPos = this._numHeads * this._headDim;
queryStates = qkv[.., .., ..queryPos];
keyStates = qkv[.., .., queryPos..(queryPos + this._numKeyValueHeads * this._headDim)];
valueStates = qkv[.., .., (queryPos + this._numKeyValueHeads * this._headDim)..];
}
else if (this.q_proj is not null && this.k_proj is not null && this.v_proj is not null)
{
queryStates = this.q_proj.forward(hiddenStates);
keyStates = this.k_proj.forward(hiddenStates);
valueStates = this.v_proj.forward(hiddenStates);
}
else
{
throw new InvalidOperationException("Invalid state, either qkv_proj or q_proj, k_proj, v_proj should be initialized");
}
queryStates = queryStates.view(bsz, qLen, this._numHeads, this._headDim).transpose(1, 2);
keyStates = keyStates.view(bsz, qLen, this._numKeyValueHeads, this._headDim).transpose(1, 2);
valueStates = valueStates.view(bsz, qLen, this._numKeyValueHeads, this._headDim).transpose(1, 2);
var kvSeqLen = keyStates.IntShape()[^2];
var pastKeyValue = input.Cache;
if (pastKeyValue is not null)
{
kvSeqLen += pastKeyValue.GetUsableLength(kvSeqLen, this._layerIdx);
}
(queryStates, keyStates) = Utils.ApplyRotaryPosEmb(queryStates, keyStates, input.PositionalEmbeddings.Cos, input.PositionalEmbeddings.Sin);
if (pastKeyValue is not null)
{
(keyStates, valueStates) = pastKeyValue.UpdateKVCache(keyStates, valueStates, this._layerIdx);
}
// repeat k/v heads if n_kv_heads < n_heads
keyStates = Utils.RepeatKV(keyStates, this._numKeyValueGroups);
valueStates = Utils.RepeatKV(valueStates, this._numKeyValueGroups);
// to fp32 to avoid overflow
var attnWeights = torch.matmul(queryStates, keyStates.transpose(2, 3));
attnWeights = attnWeights / Math.Sqrt(this._headDim);
// attnWeight's shape should be [bsz, this._numHeads, qLen, kvSeqLen]
Contract.Assert(attnWeights.shape.Length == 4);
Contract.Assert(attnWeights.shape[0] == bsz);
Contract.Assert(attnWeights.shape[1] == this._numHeads);
Contract.Assert(attnWeights.shape[2] == qLen);
Contract.Assert(attnWeights.shape[3] == kvSeqLen);
var attentionMask = input.AttentionMask;
if (attentionMask is not null)
{
Contract.Assert(attentionMask.shape.Length == 4);
Contract.Assert(attentionMask.shape[0] == bsz);
Contract.Assert(attentionMask.shape[1] == 1);
Contract.Assert(attentionMask.shape[2] == qLen);
//Contract.Assert(attentionMask.shape[3] == kvSeqLen);
attnWeights = attnWeights + attentionMask;
}
// upscale attention to fp32 to avoid overflow
attnWeights = nn.functional.softmax(attnWeights, dim: -1, dtype: ScalarType.Float32).to(valueStates.dtype);
attnWeights = nn.functional.dropout(attnWeights, this._attentionDropout, this.training);
var attnOutput = torch.matmul(attnWeights, valueStates);
attnOutput = attnOutput.transpose(1, 2).contiguous();
attnOutput = attnOutput.reshape(bsz, qLen, this._hiddenSize);
attnOutput = this.o_proj.forward(attnOutput);
return new(attnOutput.MoveToOuterDisposeScope(), outputAttentions ? attnWeights.MoveToOuterDisposeScope() : null, pastKeyValue);
}
}
}
|