|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
[assembly: LoadableClass(WordEmbeddingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingTransformer), typeof(WordEmbeddingTransformer.Options),
typeof(SignatureDataTransform), WordEmbeddingTransformer.UserName, "WordEmbeddingsTransform", WordEmbeddingTransformer.ShortName, DocName = "transform/WordEmbeddingsTransform.md")]
[assembly: LoadableClass(WordEmbeddingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingTransformer), null, typeof(SignatureLoadDataTransform),
WordEmbeddingTransformer.UserName, WordEmbeddingTransformer.LoaderSignature)]
[assembly: LoadableClass(typeof(WordEmbeddingTransformer), null, typeof(SignatureLoadModel),
WordEmbeddingTransformer.UserName, WordEmbeddingTransformer.LoaderSignature)]
[assembly: LoadableClass(typeof(IRowMapper), typeof(WordEmbeddingTransformer), null, typeof(SignatureLoadRowMapper),
WordEmbeddingTransformer.UserName, WordEmbeddingTransformer.LoaderSignature)]
namespace Microsoft.ML.Transforms.Text
{
/// <summary>
/// <see cref="ITransformer"/> resulting from fitting an <see cref="WordEmbeddingEstimator"/>.
/// </summary>
public sealed class WordEmbeddingTransformer : OneToOneTransformerBase
{
internal sealed class Column : OneToOneColumn
{
internal static Column Parse(string str)
{
Contracts.AssertNonEmpty(str);
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
return TryUnparseCore(sb);
}
}
internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 0)]
public Column[] Columns;
[Argument(ArgumentType.AtMostOnce, HelpText = "Pre-trained model used to create the vocabulary", ShortName = "model", SortOrder = 1)]
public WordEmbeddingEstimator.PretrainedModelKind? ModelKind = WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Filename for custom word embedding model",
ShortName = "dataFile", SortOrder = 2)]
public string CustomLookupTable;
}
internal const string Summary = "Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence " +
"vectors using a pre-trained model";
internal const string UserName = "Word Embeddings Transform";
internal const string ShortName = "WordEmbeddings";
internal const string LoaderSignature = "WordEmbeddingsTransform";
internal static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "W2VTRANS",
verWrittenCur: 0x00010001, //Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(WordEmbeddingTransformer).Assembly.FullName);
}
private readonly WordEmbeddingEstimator.PretrainedModelKind? _modelKind;
private readonly string _modelFileNameWithPath;
private static readonly object _embeddingsLock = new object();
private readonly bool _customLookup;
private readonly int _linesToSkip;
private readonly Model _currentVocab;
private static readonly Dictionary<string, WeakReference<Model>> _vocab = new Dictionary<string, WeakReference<Model>>();
private sealed class Model
{
public readonly BigArray<float> WordVectors;
private readonly NormStr.Pool _pool;
public readonly int Dimension;
public Model(int dimension)
{
Dimension = dimension;
WordVectors = new BigArray<float>();
_pool = new NormStr.Pool();
}
public void AddWordVector(IChannel ch, string word, float[] wordVector)
{
ch.Assert(wordVector.Length == Dimension);
if (_pool.Get(word) == null)
{
_pool.Add(word);
WordVectors.AddRange(wordVector);
}
}
public bool GetWordVector(in ReadOnlyMemory<char> word, float[] wordVector)
{
NormStr str = _pool.Get(word);
if (str != null)
{
WordVectors.CopyTo(str.Id * Dimension, wordVector, Dimension);
return true;
}
return false;
}
public long GetNumWords()
{
return _pool.LongCount();
}
public List<string> GetWordLabels()
{
var labels = new List<string>();
foreach (var label in _pool)
{
labels.Add(new string(label.Value.ToArray()));
}
return labels;
}
}
private const string RegistrationName = "WordEmbeddings";
private const int Timeout = 10 * 60 * 1000;
/// <summary>
/// Instantiates <see cref="WordEmbeddingTransformer"/> using the pretrained word embedding model specified by <paramref name="modelKind"/>.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
/// <param name="modelKind">The pretrained word embedding model.</param>
internal WordEmbeddingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
WordEmbeddingEstimator.PretrainedModelKind modelKind = WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding)
: this(env, modelKind, new WordEmbeddingEstimator.ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName))
{
}
/// <summary>
/// Instantiates <see cref="WordEmbeddingTransformer"/> using the custom word embedding model by loading it from the file specified by the <paramref name="customModelFile"/>.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="customModelFile">Filename for custom word embedding model.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
internal WordEmbeddingTransformer(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null)
: this(env, customModelFile, new WordEmbeddingEstimator.ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName))
{
}
/// <summary>
/// Instantiates <see cref="WordEmbeddingTransformer"/> using the pretrained word embedding model specified by <paramref name="modelKind"/>.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="modelKind">The pretrained word embedding model.</param>
/// <param name="columns">Input/Output columns.</param>
internal WordEmbeddingTransformer(IHostEnvironment env, WordEmbeddingEstimator.PretrainedModelKind modelKind, params WordEmbeddingEstimator.ColumnOptions[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
env.CheckUserArg(Enum.IsDefined(typeof(WordEmbeddingEstimator.PretrainedModelKind), modelKind), nameof(modelKind));
_modelKind = modelKind;
_modelFileNameWithPath = EnsureModelFile(env, out _linesToSkip, (WordEmbeddingEstimator.PretrainedModelKind)_modelKind);
_currentVocab = GetVocabularyDictionary(env);
}
/// <summary>
/// Instantiates <see cref="WordEmbeddingTransformer"/> using the custom word embedding model by loading it from the file specified by the <paramref name="customModelFile"/>.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="customModelFile">Filename for custom word embedding model.</param>
/// <param name="columns">Input/Output columns.</param>
internal WordEmbeddingTransformer(IHostEnvironment env, string customModelFile, params WordEmbeddingEstimator.ColumnOptions[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
env.CheckValue(customModelFile, nameof(customModelFile));
Host.CheckNonWhiteSpace(customModelFile, nameof(customModelFile));
_modelKind = null;
_customLookup = true;
_modelFileNameWithPath = customModelFile;
_currentVocab = GetVocabularyDictionary(env);
}
private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(WordEmbeddingEstimator.ColumnOptions[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
}
// Factory method for SignatureDataTransform.
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
if (options.ModelKind == null)
options.ModelKind = WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding;
env.CheckUserArg(!options.ModelKind.HasValue || Enum.IsDefined(typeof(WordEmbeddingEstimator.PretrainedModelKind), options.ModelKind), nameof(options.ModelKind));
env.CheckValue(options.Columns, nameof(options.Columns));
var cols = new WordEmbeddingEstimator.ColumnOptions[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
var item = options.Columns[i];
cols[i] = new WordEmbeddingEstimator.ColumnOptions(
item.Name,
item.Source ?? item.Name);
}
bool customLookup = !string.IsNullOrWhiteSpace(options.CustomLookupTable);
if (customLookup)
return new WordEmbeddingTransformer(env, options.CustomLookupTable, cols).MakeDataTransform(input);
else
return new WordEmbeddingTransformer(env, options.ModelKind.Value, cols).MakeDataTransform(input);
}
private WordEmbeddingTransformer(IHost host, ModelLoadContext ctx)
: base(host, ctx)
{
Host.AssertValue(ctx);
_customLookup = ctx.Reader.ReadBoolByte();
if (_customLookup)
{
_modelFileNameWithPath = ctx.LoadNonEmptyString();
_modelKind = null;
}
else
{
_modelKind = (WordEmbeddingEstimator.PretrainedModelKind)ctx.Reader.ReadUInt32();
_modelFileNameWithPath = EnsureModelFile(Host, out _linesToSkip, (WordEmbeddingEstimator.PretrainedModelKind)_modelKind);
}
Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath));
_currentVocab = GetVocabularyDictionary(host);
}
internal static WordEmbeddingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
IHost h = env.Register(RegistrationName);
h.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new WordEmbeddingTransformer(h, ctx);
}
// Factory method for SignatureLoadDataTransform.
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> Create(env, ctx).MakeDataTransform(input);
// Factory method for SignatureLoadRowMapper.
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
base.SaveColumns(ctx);
ctx.Writer.WriteBoolByte(_customLookup);
if (_customLookup)
ctx.SaveString(_modelFileNameWithPath);
else
ctx.Writer.Write((uint)_modelKind);
}
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
private protected override void CheckInputColumn(DataViewSchema inputSchema, int col, int srcCol)
{
var colType = inputSchema[srcCol].Type;
if (!(colType is VectorDataViewType vectorType && vectorType.ItemType is TextDataViewType))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, "String", inputSchema[srcCol].Type.ToString());
}
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly WordEmbeddingTransformer _parent;
private readonly VectorDataViewType _outputType;
public Mapper(WordEmbeddingTransformer parent, DataViewSchema inputSchema)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
Host.CheckValue(parent, nameof(parent));
_parent = parent;
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
_parent.CheckInputColumn(inputSchema, i, ColMapNewToOld[i]);
}
_outputType = new VectorDataViewType(NumberDataViewType.Single, 3 * _parent._currentVocab.Dimension);
}
public bool CanSaveOnnx(OnnxContext ctx) => true;
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
=> _parent.ColumnPairs.Select(x => new DataViewSchema.DetachedColumn(x.outputColumnName, _outputType, null)).ToArray();
public void SaveAsOnnx(OnnxContext ctx)
{
foreach (var (outputColumnName, inputColumnName) in _parent.ColumnPairs)
{
var srcVariableName = ctx.GetVariableName(inputColumnName);
var schema = _parent.GetOutputSchema(InputSchema);
var dstVariableName = ctx.AddIntermediateVariable(schema[outputColumnName].Type, outputColumnName);
SaveAsOnnxCore(ctx, srcVariableName, dstVariableName);
}
}
private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
// Converts 1 column that is taken as input to the transform into one column of output
//
// Missing words are mapped to k for finding average, k + 1 for finding min, and k + 2 for finding max
// Those spots in the dictionary contain a vector of 0s, max floats, and min floats, respectively
//
// Symbols:
// j: length of latent vector of every word in the pretrained model
// n: length of input tensor (number of words)
// X: word input, a tensor with n elements.
// k: # of words in pretrained model (known when transform is created)
// S: word labels, k tensor (known when transform is created)
// D: word embeddings, (k + 3)-by-j tensor(known when transform is created). The extra three embeddings
// at the end are used for out of vocab words.
// F: location value representing missing words, equal to k
// P: output, a j * 3 tensor
//
// X [n]
// |
// nameX
// |
// LabelEncoder (classes_strings = S [k], default_int64 = k)
// |
// /----------------------- nameY -----------------------\
// / | | \
// Initialize (F)-------/----|------ nameF ------> Equal \
// / | | \
// / | nameA \
// / | / | \ \
// / '-------------| / | \ \
// / ------|-----/ | \------------------ \---------
// / / | | \ \
// | Cast (to = int64) | Cast (to = float) Not |
// | | | | | |
// | nameVMin | nameB nameQ |
// | | | | | |
// Add ------------' | Scale (scale = 2.0) Cast (to = int32) |
// | | | | |
// | | nameSMax nameZ |
// | | | | |
// | | Cast (to = int64) ReduceSum (axes = [0]) |
// namePMin | | | |
// | | nameVMax nameR |
// | | | | |
// | '-- Add --' Cast (to = float) |
// | Initialize (D [k + 3, j] | | |
// | | | | |
// | nameD namePMax nameRF |
// | | | | |
// | | | Clip (min = 1.0) |
// | | | | |
// | | | nameT |
// | |----------------|----------------------------|--------\ |
// | | | | \ |
// | /---------'-------------\ | | '----\ |
// Gather Gather | Gather
// | | | |
// nameGMin nameGMax | nameW
// | | | |
// ReduceMin (axes = [0]) ReduceMax (axes = [0]) | ReduceSum (axes = [0])
// | | | |
// | | | nameK
// | | | |
// | | '------- Div ------'
// nameJ nameL |
// | | nameE
// | | |
// '------------------- Concat (axis = 1) -------------------------------'
// |
// nameP
// |
// P [j * 3]
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
long[] axes = new long[] { 0 };
// Allocate D, a constant tensor representing word embedding weights.
var shapeD = new long[] { _parent._currentVocab.GetNumWords() + 3, _parent._currentVocab.Dimension };
var wordVectors = _parent._currentVocab.WordVectors;
var tensorD = new List<float>();
tensorD.AddRange(wordVectors);
// Out-of-vocab embedding vector for combining embeddings by mean.
tensorD.AddRange(Enumerable.Repeat(0.0f, _parent._currentVocab.Dimension));
// Out-of-vocab embedding vector for combining embeddings by element-wise min.
tensorD.AddRange(Enumerable.Repeat(float.MaxValue, _parent._currentVocab.Dimension));
// Out-of-vocab embedding vector for combining embeddings by element-wise max.
tensorD.AddRange(Enumerable.Repeat(float.MinValue, _parent._currentVocab.Dimension));
var nameD = ctx.AddInitializer(tensorD, shapeD, "WordEmbeddingWeights");
// Allocate F, a value representing an out-of-dictionary word.
var tensorF = _parent._currentVocab.GetNumWords();
var nameF = ctx.AddInitializer(tensorF, "NotFoundValueComp");
// Retrieve X, name of input.
var nameX = srcVariableName;
// Do label encoding. Out-of-vocab tokens will be mapped to the size of vocabulary. Because the index of vocabulary
// is zero-based, the size of vocabulary is just greater then the max indexes computed from in-vocab tokens by one.
var nameY = ctx.AddIntermediateVariable(null, "LabelEncodedInput", true);
var nodeY = ctx.CreateNode("LabelEncoder", nameX, nameY, ctx.GetNodeName("LabelEncoder"));
nodeY.AddAttribute("classes_strings", _parent._currentVocab.GetWordLabels());
nodeY.AddAttribute("default_int64", _parent._currentVocab.GetNumWords());
// Do steps necessary for min and max embedding vectors.
// Map to boolean vector representing missing words. The following Equal produces 1 if a token is missing and 0 otherwise.
var nameA = ctx.AddIntermediateVariable(null, "NotFoundValuesBool", true);
var nodeA = ctx.CreateNode("Equal", new[] { nameY, nameF }, new[] { nameA }, ctx.GetNodeName("Equal"), "");
// Cast the not found vector to a vector of floats.
var nameB = ctx.AddIntermediateVariable(null, "NotFoundValuesFloat", true);
var nodeB = ctx.CreateNode("Cast", nameA, nameB, ctx.GetNodeName("Cast"), "");
nodeB.AddAttribute("to", 1);
// Scale the not found vector to get the location bias for max weights.
var nameSMax = ctx.AddIntermediateVariable(null, "ScaleMax", true);
var nodeSMax = ctx.CreateNode("Scale", nameB, nameSMax, ctx.GetNodeName("Scale"), "");
nodeSMax.AddAttribute("scale", 2.0);
// Cast scaled word label locations to ints.
var nameVMin = ctx.AddIntermediateVariable(null, "CastMin", true);
var nodeVMin = ctx.CreateNode("Cast", nameA, nameVMin, ctx.GetNodeName("Cast"), "");
nodeVMin.AddAttribute("to", 7);
var nameVMax = ctx.AddIntermediateVariable(null, "CastMax", true);
var nodeVMax = ctx.CreateNode("Cast", nameSMax, nameVMax, ctx.GetNodeName("Cast"), "");
nodeVMax.AddAttribute("to", 7);
// Add the scaled options back to originals. The outputs of the following Add operators are almost identical
// the output of the previous LabelEncoder. The only difference is that out-of-vocab tokens are mapped to k+1
// for applying ReduceMin and k+2 for applying ReduceMax so that out-of-vocab tokens do not affect embedding results at all.
var namePMin = ctx.AddIntermediateVariable(null, "AddMin", true);
var nodePMin = ctx.CreateNode("Add", new[] { nameY, nameVMin }, new[] { namePMin }, ctx.GetNodeName("Add"), "");
var namePMax = ctx.AddIntermediateVariable(null, "AddMax", true);
var nodePMax = ctx.CreateNode("Add", new[] { nameY, nameVMax }, new[] { namePMax }, ctx.GetNodeName("Add"), "");
// Map encoded words to their embedding vectors, mapping missing ones to min/max.
var nameGMin = ctx.AddIntermediateVariable(null, "GatheredMin", true);
var nodeGMin = ctx.CreateNode("Gather", new[] { nameD, namePMin }, new[] { nameGMin }, ctx.GetNodeName("Gather"), "");
var nameGMax = ctx.AddIntermediateVariable(null, "GatheredMax", true);
var nodeGMax = ctx.CreateNode("Gather", new[] { nameD, namePMax }, new[] { nameGMax }, ctx.GetNodeName("Gather"), "");
// Merge all embedding vectors using element-wise min/max per embedding coordinate.
var nameJ = ctx.AddIntermediateVariable(null, "MinWeights", true);
var nodeJ = ctx.CreateNode("ReduceMin", nameGMin, nameJ, ctx.GetNodeName("ReduceMin"), "");
nodeJ.AddAttribute("axes", axes);
var nameL = ctx.AddIntermediateVariable(null, "MaxWeights", true);
var nodeL = ctx.CreateNode("ReduceMax", nameGMax, nameL, ctx.GetNodeName("ReduceMax"), "");
nodeL.AddAttribute("axes", axes);
// Do steps necessary for mean embedding vector.
// Map encoded words to their embedding vectors using Gather.
var nameW = ctx.AddIntermediateVariable(null, "GatheredMean", true);
var nodeW = ctx.CreateNode("Gather", new[] { nameD, nameY }, new[] { nameW }, ctx.GetNodeName("Gather"), "");
// Find the sum of the embedding vectors.
var nameK = ctx.AddIntermediateVariable(null, "SumWeights", true);
var nodeK = ctx.CreateNode("ReduceSum", nameW, nameK, ctx.GetNodeName("ReduceSum"), "");
nodeK.AddAttribute("axes", axes);
// Flip the boolean vector representing missing words to represent found words.
var nameQ = ctx.AddIntermediateVariable(null, "FoundValuesBool", true);
var nodeQ = ctx.CreateNode("Not", nameA, nameQ, ctx.GetNodeName("Not"), "");
// Cast the found words vector to ints.
var nameZ = ctx.AddIntermediateVariable(null, "FoundValuesInt", true);
var nodeZ = ctx.CreateNode("Cast", nameQ, nameZ, ctx.GetNodeName("Cast"), "");
nodeZ.AddAttribute("to", 6);
// Sum the number of total found words.
var nameR = ctx.AddIntermediateVariable(null, "NumWordsFoundInt", true);
var nodeR = ctx.CreateNode("ReduceSum", nameZ, nameR, ctx.GetNodeName("ReduceSum"), "");
nodeR.AddAttribute("axes", axes);
// Cast the found words to float.
var nameRF = ctx.AddIntermediateVariable(null, "NumWordsFoundFloat", true);
var nodeRF = ctx.CreateNode("Cast", nameR, nameRF, ctx.GetNodeName("Cast"), "");
nodeRF.AddAttribute("to", 1);
// Clip the number of found words to prevent division by 0.
var nameT = ctx.AddIntermediateVariable(null, "NumWordsClippedFloat", true);
var nodeT = ctx.CreateNode("Clip", nameRF, nameT, ctx.GetNodeName("Clip"), "");
nodeT.AddAttribute("min", 1.0f);
// Divide total sum by number of words found to get the average embedding vector of the input string vector.
var nameE = ctx.AddIntermediateVariable(null, "MeanWeights", true);
var nodeE = ctx.CreateNode("Div", new[] { nameK, nameT }, new[] { nameE }, ctx.GetNodeName("Div"), "");
// Concatenate the final embeddings produced by the three reduction strategies.
var nameP = dstVariableName;
var nodeP = ctx.CreateNode("Concat", new[] { nameJ, nameE, nameL }, new[] { nameP }, ctx.GetNodeName("Concat"), "");
nodeP.AddAttribute("axis", 1);
}
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
Host.AssertValue(input);
Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
disposer = null;
return GetGetterVec(input, iinfo);
}
private ValueGetter<VBuffer<float>> GetGetterVec(DataViewRow input, int iinfo)
{
Host.AssertValue(input);
Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
var column = input.Schema[ColMapNewToOld[iinfo]];
Host.Assert(column.Type is VectorDataViewType);
Host.Assert(column.Type.GetItemType() is TextDataViewType);
var srcGetter = input.GetGetter<VBuffer<ReadOnlyMemory<char>>>(column);
var src = default(VBuffer<ReadOnlyMemory<char>>);
int dimension = _parent._currentVocab.Dimension;
float[] wordVector = new float[_parent._currentVocab.Dimension];
return
(ref VBuffer<float> dst) =>
{
int deno = 0;
srcGetter(ref src);
var editor = VBufferEditor.Create(ref dst, 3 * dimension);
int offset = 2 * dimension;
for (int i = 0; i < dimension; i++)
{
editor.Values[i] = float.MaxValue;
editor.Values[i + dimension] = 0;
editor.Values[i + offset] = float.MinValue;
}
var srcValues = src.GetValues();
for (int word = 0; word < srcValues.Length; word++)
{
if (_parent._currentVocab.GetWordVector(in srcValues[word], wordVector))
{
deno++;
for (int i = 0; i < dimension; i++)
{
float currentTerm = wordVector[i];
if (editor.Values[i] > currentTerm)
editor.Values[i] = currentTerm;
editor.Values[dimension + i] += currentTerm;
if (editor.Values[offset + i] < currentTerm)
editor.Values[offset + i] = currentTerm;
}
}
}
if (deno != 0)
for (int index = 0; index < dimension; index++)
editor.Values[index + dimension] /= deno;
dst = editor.Commit();
};
}
}
private static readonly Dictionary<WordEmbeddingEstimator.PretrainedModelKind, string> _modelsMetaData = new Dictionary<WordEmbeddingEstimator.PretrainedModelKind, string>()
{
{ WordEmbeddingEstimator.PretrainedModelKind.GloVe50D, "glove.6B.50d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.GloVe100D, "glove.6B.100d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.GloVe200D, "glove.6B.200d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.GloVe300D, "glove.6B.300d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.GloVeTwitter25D, "glove.twitter.27B.25d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.GloVeTwitter50D, "glove.twitter.27B.50d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.GloVeTwitter100D, "glove.twitter.27B.100d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.GloVeTwitter200D, "glove.twitter.27B.200d.txt" },
{ WordEmbeddingEstimator.PretrainedModelKind.FastTextWikipedia300D, "wiki.en.vec" },
{ WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding, "sentiment.emd" }
};
private static readonly Dictionary<WordEmbeddingEstimator.PretrainedModelKind, int> _linesToSkipInModels = new Dictionary<WordEmbeddingEstimator.PretrainedModelKind, int>()
{ { WordEmbeddingEstimator.PretrainedModelKind.FastTextWikipedia300D, 1 } };
private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, WordEmbeddingEstimator.PretrainedModelKind kind)
{
linesToSkip = 0;
if (_modelsMetaData.ContainsKey(kind))
{
var modelFileName = _modelsMetaData[kind];
if (_linesToSkipInModels.ContainsKey(kind))
linesToSkip = _linesToSkipInModels[kind];
using (var ch = Host.Start("Ensuring resources"))
{
string dir = kind == WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding ? Path.Combine("Text", "Sswe") : "WordVectors";
var url = $"{dir}/{modelFileName}";
var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, url, modelFileName, dir, Timeout);
ensureModel.Wait();
var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result);
if (errorResult != null)
{
var directory = Path.GetDirectoryName(errorResult.FileName);
var name = Path.GetFileName(errorResult.FileName);
throw ch.Except($"{errorMessage}\nModel file for Word Embedding transform could not be found! " +
$@"Please copy the model file '{name}' from '{url}' to '{directory}'.");
}
return ensureModel.Result.FileName;
}
}
throw Host.Except($"Can't map model kind = {kind} to specific file, please refer to https://aka.ms/MLNetIssue for assistance");
}
private Model GetVocabularyDictionary(IHostEnvironment hostEnvironment)
{
int dimension = 0;
if (!File.Exists(_modelFileNameWithPath))
throw Host.Except("Custom word embedding model file '{0}' could not be found for Word Embeddings transform.", _modelFileNameWithPath);
if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null)
{
if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model model))
{
dimension = model.Dimension;
return model;
}
}
lock (_embeddingsLock)
{
if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null)
{
if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model modelObject))
{
dimension = modelObject.Dimension;
return modelObject;
}
}
using (var ch = Host.Start(LoaderSignature))
using (var pch = Host.StartProgressChannel("Building Vocabulary from Model File for Word Embeddings Transform"))
{
var parsedData = new ConcurrentBag<(string key, float[] values, long lineNumber)>();
int skippedLinesCount = Math.Max(1, _linesToSkip);
var invariantCulture = _modelKind != null;
Parallel.ForEach(File.ReadLines(_modelFileNameWithPath).Skip(skippedLinesCount), GetParallelOptions(hostEnvironment),
(line, parallelState, lineNumber) =>
{
(bool isSuccess, string key, float[] values) = LineParser.ParseKeyThenNumbers(line, invariantCulture);
if (isSuccess)
parsedData.Add((key, values, lineNumber + skippedLinesCount));
else // we use shared state here (ch) but it's not our hot path and we don't care about unhappy-path performance
ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + skippedLinesCount}");
});
Model model = null;
foreach (var parsedLine in parsedData.OrderBy(parsedLine => parsedLine.lineNumber))
{
dimension = parsedLine.values.Length;
if (model == null)
model = new Model(dimension);
if (model.Dimension != dimension)
ch.Warning($"Dimension mismatch while reading model file: '{_modelFileNameWithPath}', line number {parsedLine.lineNumber}, expected dimension = {model.Dimension}, received dimension = {dimension}");
else
model.AddWordVector(ch, parsedLine.key, parsedLine.values);
}
// Handle first line of the embedding file separately since some embedding files including fastText have a single-line header
var firstLine = File.ReadLines(_modelFileNameWithPath).First();
string[] wordsInFirstLine = firstLine.TrimEnd().Split(' ', '\t');
dimension = wordsInFirstLine.Length - 1;
if (model == null)
model = new Model(dimension);
if (model.Dimension == dimension)
{
float temp;
string firstKey = wordsInFirstLine[0];
float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray();
if (!firstValue.Contains(Single.NaN))
model.AddWordVector(ch, firstKey, firstValue);
}
_vocab[_modelFileNameWithPath] = new WeakReference<Model>(model, false);
return model;
}
}
}
private static ParallelOptions GetParallelOptions(IHostEnvironment hostEnvironment)
=> new ParallelOptions(); // we provide default options and let the Parallel decide
}
/// <summary>
/// Text featurizer which converts vectors of text tokens into a numerical vector using a pre-trained embeddings model.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
///
/// ### Estimator Characteristics
/// | | |
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | No |
/// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
/// | Output column data type | Known-sized vector of <xref:System.Single> |
/// | Exportable to ONNX | No |
///
/// The <xref:Microsoft.ML.Transforms.Text.WordEmbeddingTransformer> produces a new column,
/// named as specified in the output column name parameters, where each input vector is mapped to a numerical vector
/// with size of 3 * dimensionality of the embedding model used. Notice that this is independent of the size of the input vector.
///
/// For example, when using GloVe50D, which itself is 50 dimensional, the output column is a vector of size 150.
/// The first third of slots contains the minimum values across the embeddings corresponding to each string in the input vector.
/// The second third contains the average of the embeddings. The last third of slots contains maximum values
/// of the encountered embeddings. The min/max provides a bounding hyper-rectangle for the words in the word embedding space.
/// This can assist for longer phrases where the average of many words drowns out the useful signal.
///
/// The user can specify a custom pre-trained embeddings model or one of the available pre-trained models.
/// The available options are various versions of [GloVe Models](https://nlp.stanford.edu/projects/glove/),
/// [FastText](https://en.wikipedia.org/wiki/FastText), and [SSWE](https://anthology.aclweb.org/P/P14/P14-1146.pdf).
///
/// Check the See Also section for links to usage examples.
/// ]]></format>
/// </remarks>
/// <seealso cref="TextCatalog.ApplyWordEmbedding(TransformsCatalog.TextTransforms, string, string, PretrainedModelKind)"/>
/// <seealso cref="TextCatalog.ApplyWordEmbedding(TransformsCatalog.TextTransforms, string, string, string)"/>
public sealed class WordEmbeddingEstimator : IEstimator<WordEmbeddingTransformer>
{
private readonly IHost _host;
private readonly ColumnOptions[] _columns;
private readonly PretrainedModelKind? _modelKind;
private readonly string _customLookupTable;
/// <summary>
/// Extracts word embeddings.
/// Output three times more values than dimension of the model specified in <paramref name="modelKind"/>
/// First set of values represent minumum encountered values (for each dimension), second set represent average (for each dimension)
/// and third one represent maximum encountered values (for each dimension).
/// </summary>
/// <param name="env">The local instance of <see cref="IHostEnvironment"/></param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
/// <param name="modelKind">The embeddings <see cref="PretrainedModelKind"/> to use. </param>
internal WordEmbeddingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
PretrainedModelKind modelKind = PretrainedModelKind.SentimentSpecificWordEmbedding)
: this(env, modelKind, new ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName))
{
}
/// <summary>
/// Extracts word embeddings.
/// Output three times more values than dimension of the model specified in <paramref name="customModelFile"/>
/// First set of values represent minimum encountered values (for each dimension), second set represent average (for each dimension)
/// and third one represent maximum encountered values (for each dimension).
/// </summary>
/// <param name="env">The local instance of <see cref="IHostEnvironment"/></param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="customModelFile">The path of the pre-trained embeddings model to use. </param>
/// <param name="inputColumnName">Name of the column to transform. </param>
internal WordEmbeddingEstimator(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null)
: this(env, customModelFile, new ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName))
{
}
/// <summary>
/// Extracts word embeddings.
/// Output three times more values than dimension of the model specified in <paramref name="modelKind"/>
/// First set of values represent minimum encountered values (for each dimension), second set represent average (for each dimension)
/// and third one represent maximum encountered values (for each dimension).
/// </summary>
/// <param name="env">The local instance of <see cref="IHostEnvironment"/></param>
/// <param name="modelKind">The embeddings <see cref="PretrainedModelKind"/> to use. </param>
/// <param name="columns">The array columns, and per-column configurations to extract embeddings from.</param>
internal WordEmbeddingEstimator(IHostEnvironment env,
PretrainedModelKind modelKind = PretrainedModelKind.SentimentSpecificWordEmbedding,
params ColumnOptions[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(WordEmbeddingEstimator));
_modelKind = modelKind;
_customLookupTable = null;
_columns = columns;
}
internal WordEmbeddingEstimator(IHostEnvironment env, string customModelFile, params ColumnOptions[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(WordEmbeddingEstimator));
_modelKind = null;
_customLookupTable = customModelFile;
_columns = columns;
}
/// <summary>
/// Specifies which word embeddings to use.
/// </summary>
public enum PretrainedModelKind
{
/// <summary>
/// GloVe 50 dimensional word embeddings.
/// </summary>
[TGUI(Label = "GloVe 50D")]
GloVe50D = 0,
/// <summary>
/// GloVe 100 dimensional word embeddings.
/// </summary>
[TGUI(Label = "GloVe 100D")]
GloVe100D = 1,
/// <summary>
/// GloVe 200 dimensional word embeddings.
/// </summary>
[TGUI(Label = "GloVe 200D")]
GloVe200D = 2,
/// <summary>
/// GloVe 300 dimensional word embeddings.
/// </summary>
[TGUI(Label = "GloVe 300D")]
GloVe300D = 3,
/// <summary>
/// GloVe 25 dimensional word embeddings trained on Twitter data.
/// </summary>
[TGUI(Label = "GloVe Twitter 25D")]
GloVeTwitter25D = 4,
/// <summary>
/// GloVe 50 dimensional word embeddings trained on Twitter data.
/// </summary>
[TGUI(Label = "GloVe Twitter 50D")]
GloVeTwitter50D = 5,
/// <summary>
/// GloVe 100 dimensional word embeddings trained on Twitter data.
/// </summary>
[TGUI(Label = "GloVe Twitter 100D")]
GloVeTwitter100D = 6,
/// <summary>
/// GloVe 200 dimensional word embeddings trained on Twitter data.
/// </summary>
[TGUI(Label = "GloVe Twitter 200D")]
GloVeTwitter200D = 7,
/// <summary>
/// FastText 300 dimensional word embeddings trained on Wikipedia.
/// </summary>
[TGUI(Label = "fastText Wikipedia 300D")]
FastTextWikipedia300D = 8,
/// <summary>
/// Word embeddings trained on sentiment analysis tasks.
/// </summary>
[TGUI(Label = "Sentiment-Specific Word Embedding")]
SentimentSpecificWordEmbedding = 9
}
/// <summary>
/// Information for each column pair.
/// </summary>
[BestFriend]
internal sealed class ColumnOptions
{
/// <summary>
/// Name of the column resulting from the transformation of <cref see="InputColumnName"/>.
/// </summary>
public readonly string Name;
/// <summary>
/// Name of column to transform.
/// </summary>
public readonly string InputColumnName;
/// <summary>
/// Describes how the transformer handles one column pair.
/// </summary>
/// <param name="name">Name of the column resulting from the transformation of <cref see="inputColumnName"/>. </param>
/// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/> <cref see="name"/> will be used as source.</param>
public ColumnOptions(string name, string inputColumnName = null)
{
Contracts.CheckNonEmpty(name, nameof(name));
Name = name;
InputColumnName = inputColumnName ?? name;
}
}
/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
if (!(col.ItemType is TextDataViewType) || (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.Kind != SchemaShape.Column.VectorKind.Vector))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, new VectorDataViewType(TextDataViewType.Instance).ToString(), col.GetTypeString());
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
}
return new SchemaShape(result.Values);
}
/// <summary>
/// Trains and returns a <see cref="WordEmbeddingTransformer"/>.
/// </summary>
public WordEmbeddingTransformer Fit(IDataView input)
{
bool customLookup = !string.IsNullOrWhiteSpace(_customLookupTable);
WordEmbeddingTransformer transformer;
if (customLookup)
transformer = new WordEmbeddingTransformer(_host, _customLookupTable, _columns);
else
transformer = new WordEmbeddingTransformer(_host, _modelKind.Value, _columns);
// Validate input schema.
transformer.GetOutputSchema(input.Schema);
return transformer;
}
}
}
|