1 write to _numHeads
Microsoft.ML.TorchSharp (1)
NasBert\Modules\MultiHeadAttention.cs (1)
71
_numHeads
= numHeads;
15 references to _numHeads
Microsoft.ML.TorchSharp (15)
NasBert\Modules\MultiHeadAttention.cs (15)
73
_headDim = _embeddingDim /
_numHeads
;
75
if (_headDim *
_numHeads
!= _embeddingDim)
212
q = q.view(tgtLen, batchSize *
_numHeads
, _headDim).transpose_(0, 1);
213
k = k?.view(-1, batchSize *
_numHeads
, _headDim).transpose_(0, 1);
214
v = v?.view(-1, batchSize *
_numHeads
, _headDim).transpose_(0, 1);
221
var prevKey = savedState[PrevKeyKey].view(batchSize *
_numHeads
, -1, _headDim);
229
var prevValue = savedState[PrevValueKey].view(batchSize *
_numHeads
, -1, _headDim);
236
savedState[PrevKeyKey] = k?.view(batchSize,
_numHeads
, -1, _headDim);
238
savedState[PrevValueKey] = v?.view(batchSize,
_numHeads
, -1, _headDim);
265
Debug.Assert(attentionWeights.size().SequenceEqual(new[] { batchSize *
_numHeads
, tgtLen, srcLen }));
278
.view(batchSize,
_numHeads
, tgtLen, srcLen)
280
.view(batchSize *
_numHeads
, tgtLen, srcLen);
289
var weightsView = attentionWeights.view(batchSize,
_numHeads
, tgtLen, srcLen);
290
outAttentionWeights = weightsView.sum(dim: 1).div_(
_numHeads
);
294
Debug.Assert(attention.size().SequenceEqual(new[] { batchSize *
_numHeads
, tgtLen, _headDim }));