1 write to _numHeads
Microsoft.ML.TorchSharp (1)
NasBert\Modules\MultiHeadAttention.cs (1)
71_numHeads = numHeads;
15 references to _numHeads
Microsoft.ML.TorchSharp (15)
NasBert\Modules\MultiHeadAttention.cs (15)
73_headDim = _embeddingDim / _numHeads; 75if (_headDim * _numHeads != _embeddingDim) 212q = q.view(tgtLen, batchSize * _numHeads, _headDim).transpose_(0, 1); 213k = k?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1); 214v = v?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1); 221var prevKey = savedState[PrevKeyKey].view(batchSize * _numHeads, -1, _headDim); 229var prevValue = savedState[PrevValueKey].view(batchSize * _numHeads, -1, _headDim); 236savedState[PrevKeyKey] = k?.view(batchSize, _numHeads, -1, _headDim); 238savedState[PrevValueKey] = v?.view(batchSize, _numHeads, -1, _headDim); 265Debug.Assert(attentionWeights.size().SequenceEqual(new[] { batchSize * _numHeads, tgtLen, srcLen })); 278.view(batchSize, _numHeads, tgtLen, srcLen) 280.view(batchSize * _numHeads, tgtLen, srcLen); 289var weightsView = attentionWeights.view(batchSize, _numHeads, tgtLen, srcLen); 290outAttentionWeights = weightsView.sum(dim: 1).div_(_numHeads); 294Debug.Assert(attention.size().SequenceEqual(new[] { batchSize * _numHeads, tgtLen, _headDim }));