|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.Tokenizers;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Core;
public interface ICausalLMPipeline<out TTokenizer, out TModel> : ICausalLMPipeline
where TTokenizer : Tokenizer
where TModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
TTokenizer TypedTokenizer { get; }
TModel TypedModel { get; }
}
public interface ICausalLMPipeline
{
Tokenizer Tokenizer { get; }
nn.Module<CausalLMModelInput, CausalLMModelOutput> Model { get; }
string Generate(
string prompt,
int maxLen = CausalLMPipeline.Defaults.MaxLen,
float temperature = CausalLMPipeline.Defaults.Temperature,
float topP = CausalLMPipeline.Defaults.TopP,
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);
/// <summary>
/// Generate the embedding(last hidden state of the last token) for the prompt. The embedding is normalized by L2 norm.
/// </summary>
float[] GenerateEmbeddingFromLastTokenPool(string prompt);
IEnumerable<string> GenerateStreaming(
string prompt,
int maxLen = CausalLMPipeline.Defaults.MaxLen,
float temperature = CausalLMPipeline.Defaults.Temperature,
float topP = CausalLMPipeline.Defaults.TopP,
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);
(Tensor, Tensor) Generate(
Tensor inputIds,
Tensor attentionMask,
int[][] stopTokenSequence,
float temperature = CausalLMPipeline.Defaults.Temperature,
float topP = CausalLMPipeline.Defaults.TopP,
int maxLen = CausalLMPipeline.Defaults.MaxLen);
IEnumerable<(Tensor, Tensor)> GenerateStreaming(
Tensor inputIds,
Tensor attentionMask,
int[][] stopTokenSequence,
float temperature = CausalLMPipeline.Defaults.Temperature,
float topP = CausalLMPipeline.Defaults.TopP,
int maxLen = CausalLMPipeline.Defaults.MaxLen);
}
public class CausalLMPipeline<TTokenizer, TModel> : CausalLMPipeline, ICausalLMPipeline<TTokenizer, TModel>
where TTokenizer : Tokenizer
where TModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
public CausalLMPipeline(
TTokenizer tokenizer,
TModel model,
string device = Defaults.Device)
: base(tokenizer, model, device)
{
}
public TTokenizer TypedTokenizer { get => (TTokenizer)base.Tokenizer; }
public TModel TypedModel { get => (TModel)base.Model; }
}
public class CausalLMPipeline : ICausalLMPipeline
{
internal static class Defaults
{
internal const string Device = "cpu";
internal const float Temperature = 0.7F;
internal const float TopP = 0.9F;
internal const int MaxLen = 128;
internal const string[]? StopSequence = null;
}
public CausalLMPipeline(
Tokenizer tokenizer,
nn.Module<CausalLMModelInput, CausalLMModelOutput> model,
string device = Defaults.Device)
{
this.Tokenizer = tokenizer;
this.Model = model;
this.Device = device;
}
/// <summary>
/// For moq purpose
/// </summary>
private protected CausalLMPipeline()
{
this.Tokenizer = default!;
this.Model = default!;
this.Device = default!;
}
public Tokenizer Tokenizer { get; }
public nn.Module<CausalLMModelInput, CausalLMModelOutput> Model { get; }
public Device Device { get; }
public IEnumerable<(
Tensor, // output token ids [batch_size, 1]
Tensor // output logits [batch_size, 1, vocab_size]
)> GenerateStreaming(
Tensor inputIds,
Tensor attentionMask,
int[][] stopTokenSequence,
float temperature = Defaults.Temperature,
float topP = Defaults.TopP,
int maxLen = Defaults.MaxLen)
{
using var scope = NewDisposeScope();
using var noGrad = torch.no_grad();
var batch = inputIds.shape[0];
var device = inputIds.device;
var promptLength = (int)inputIds.shape[1];
var totalLen = promptLength + maxLen;
var prevPos = 0;
var eosReached = torch.tensor(new bool[batch], device: device);
torch.Tensor? logits = default;
var cache = new DynamicKVCache();
if (promptLength == totalLen)
{
var input = new CausalLMModelInput(inputIds, attentionMask, pastKeyValuesLength: 0)
{
OverrideCache = cache,
};
var output = this.Model.forward(input);
logits = output.Logits;
}
for (var curPos = promptLength; curPos != totalLen; curPos++)
{
var input = new CausalLMModelInput(inputIds[.., prevPos..curPos], attentionMask[.., prevPos..curPos], pastKeyValuesLength: prevPos)
{
OverrideCache = cache,
};
var output = this.Model.forward(input);
logits = output.Logits?.MoveToOtherDisposeScope(inputIds) ?? throw new InvalidOperationException("Logits is null");
torch.Tensor nextToken;
if (temperature > 0)
{
var probs = torch.softmax(logits[.., -1] / temperature, dim: -1);
nextToken = this.SampleTopP(probs, topP);
}
else
{
nextToken = torch.argmax(logits[.., -1], dim: -1);
}
nextToken = nextToken.reshape(-1);
inputIds = torch.cat([inputIds, nextToken.unsqueeze(1)], dim: -1).MoveToOtherDisposeScope(inputIds);
attentionMask = torch.cat([attentionMask, attentionMask.new_ones(attentionMask.shape[0], 1)], dim: -1);
foreach (var stopSequence in stopTokenSequence)
{
// determine if the last n tokens are the stop sequence
var lastN = inputIds[.., ^stopSequence.Length..];
var lastNMatch = lastN == torch.tensor(stopSequence, device: device);
eosReached |= lastNMatch.all(dim: -1);
}
if (eosReached.all().item<bool>())
{
break;
}
yield return (nextToken.MoveToOuterDisposeScope(), logits[.., ^1].MoveToOuterDisposeScope());
prevPos = curPos;
}
}
public virtual (
Tensor, // output token ids [batch_size, sequence_length]
Tensor // output logits [batch_size, sequence_length, vocab_size]
) Generate(
Tensor inputIds, // input token ids [batch_size, sequence_length]
Tensor attentionMask, // attention mask [batch_size, sequence_length]
int[][] stopTokenSequence,
float temperature = 0.7f,
float topP = 0.9f,
int maxLen = 128)
{
using var scope = NewDisposeScope();
Tensor? logits = null;
foreach (var (token, _logits) in this.GenerateStreaming(inputIds, attentionMask, stopTokenSequence, temperature, topP, maxLen))
{
inputIds = torch.cat([inputIds, token.unsqueeze(1)], dim: -1).MoveToOtherDisposeScope(inputIds);
if (logits is null)
{
logits = _logits;
}
else
{
logits = torch.cat([logits, _logits], dim: -1).MoveToOtherDisposeScope(inputIds);
}
}
return (inputIds, logits ?? throw new InvalidOperationException("Logits is null"));
}
public virtual string Generate(
string prompt,
int maxLen = 128,
float temperature = 0.7f,
float topP = 0.9f,
string[]? stopSequences = null)
{
var chunks = new List<string>();
foreach (var chunk in this.GenerateStreaming(prompt, maxLen, temperature, topP, stopSequences))
{
chunks.Add(chunk);
}
return string.Join(string.Empty, chunks);
}
public virtual IEnumerable<string> GenerateStreaming(
string prompt,
int maxLen = 128,
float temperature = 0.7F,
float topP = 0.9F,
string[]? stopSequences = Defaults.StopSequence)
{
using var newScope = NewDisposeScope();
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
// set up stop token ids
// stop token ids: [[eosId], [stopSequence1], [stopSequence2], ...]
// when causal language model generates tokens, it will stop when it generates any token in stopSequences
List<int[]> stopTokenIds = [[]];
if (stopSequences != null)
{
stopTokenIds.AddRange(stopSequences.Select(x =>
{
var tokens = this.Tokenizer.EncodeToTokens(x, out var _, false, false);
return tokens
// Skip the first _ token automatically added by tokenizer
.Where(t => !t.Offset.Equals(new Range(0, 0)))
.Select(t => t.Id)
.ToArray();
}));
}
stopTokenIds = stopTokenIds.Where(ids => ids.Count() > 0).ToList();
foreach (var (token, _) in this.GenerateStreaming(inputTensor, attentionMask, stopTokenIds.ToArray(), temperature: temperature, maxLen: maxLen))
{
var tokenIds = token[0].to_type(ScalarType.Int32).data<int>().ToArray();
var duplicateTokenString = this.Tokenizer switch
{
SentencePieceTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"),
};
var tokenString = this.Tokenizer switch
{
SentencePieceTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"),
};
// replace the first occurrence of the token with the duplicate token
tokenString = duplicateTokenString.Substring(tokenString.Length);
yield return tokenString;
}
}
protected torch.Tensor SampleTopP(torch.Tensor logits, float topP)
{
(var probsSort, var probsIndex) = torch.sort(logits, dim: -1, descending: true);
var cumSum = torch.cumsum(probsSort, dim: -1);
var mask = cumSum - probsSort > topP;
probsSort[mask] = 0f;
probsSort /= probsSort.sum(dim: -1, keepdim: true);
var nextToken = torch.multinomial(probsSort, num_samples: 1);
nextToken = torch.gather(probsIndex, dim: -1, index: nextToken);
return nextToken;
}
public float[] GenerateEmbeddingFromLastTokenPool(string prompt)
{
using var scope = NewDisposeScope();
using var noGrad = torch.no_grad();
var inputIds = this.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0)
{
OverrideCache = new DynamicKVCache(),
};
var output = this.Model.forward(input);
var lastTokenHiddenState = output.LastHiddenState[0, ^1];
// shape of lastTokenHiddenState: [hidden_size]
// L2 norm
var norm = lastTokenHiddenState.norm();
var normalized = lastTokenHiddenState / norm;
return normalized.to_type(ScalarType.Float32).data<float>().ToArray();
}
}
|