|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Runtime;
[assembly: LoadableClass(SvmLightSaver.Summary, typeof(SvmLightSaver), typeof(SvmLightSaver.Arguments), typeof(SignatureDataSaver),
"SVM-Light Saver", SvmLightSaver.LoadName, "SvmLight", "Svm")]
namespace Microsoft.ML.Data
{
/// <summary>
/// The SVM-light saver is a saver class that is capable of saving the label,
/// features, group ID and weight columns of a dataset in SVM-light format. It is a bit
/// idiosyncratic in that unlike <see cref="TextSaver"/> and <see cref="BinarySaver"/>, there is no
/// attempt to save all columns, just those specific columns, with other columns being dropped on
/// the floor.
/// </summary>
[BestFriend]
internal sealed class SvmLightSaver : IDataSaver
{
public sealed class Arguments
{
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Write the variant of SVM-light format where feature indices start from 0, not 1", ShortName = "z")]
public bool Zero;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Format output labels for a binary classification problem (-1 for negative, 1 for positive)", ShortName = "b")]
public bool Binary;
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string FeatureColumnName = DefaultColumnNames.Features;
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string LabelColumnName = DefaultColumnNames.Label;
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string ExampleWeightColumnName = null;
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string RowGroupColumnName = null;
}
internal const string LoadName = "SvmLightSaver";
internal const string Summary =
"Writes Label/Features/Weight/GroupId columns into a data file in SVM-light format. " +
"Label and Features are required, but the others are optional.";
private readonly IHost _host;
private readonly bool _zero;
private readonly bool _binary;
private readonly string _featureCol;
private readonly string _labelCol;
private readonly string _groupCol;
private readonly string _weightCol;
public SvmLightSaver(IHostEnvironment env, Arguments args)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(SvmLightSaver.LoadName);
_host.CheckValue(args, nameof(args));
_zero = args.Zero;
_binary = args.Binary;
_featureCol = args.FeatureColumnName;
_labelCol = args.LabelColumnName;
_groupCol = args.RowGroupColumnName;
_weightCol = args.ExampleWeightColumnName;
}
public bool IsColumnSavable(DataViewType type)
{
// REVIEW: The SVM-light saver is a bit peculiar in that it does not
// save all columns, just some columns, and the determination of whether it will
// save a column or not is not dependent only on its type, but rather its name
// and other factors. This will claim to save all columns, but it will just
// ignore a bunch depending not on the type, but on the name.
return true;
}
public void SaveData(Stream stream, IDataView data, params int[] cols)
{
_host.CheckValue(stream, nameof(stream));
_host.CheckValue(data, nameof(data));
_host.CheckValueOrNull(cols);
if (cols == null)
cols = new int[0];
using (var ch = _host.Start("Saving"))
{
var labelCol = data.Schema.GetColumnOrNull(_labelCol);
if (!labelCol.HasValue)
throw ch.Except($"Column {_labelCol} not found in data");
var featureCol = data.Schema.GetColumnOrNull(_featureCol);
if (!featureCol.HasValue)
throw ch.Except($"Column {_featureCol} not found in data");
var groupCol = !string.IsNullOrWhiteSpace(_groupCol) ? data.Schema.GetColumnOrNull(_groupCol) : default;
if (!string.IsNullOrWhiteSpace(_groupCol) && !groupCol.HasValue)
throw ch.Except($"Column {_groupCol} not found in data");
var weightCol = !string.IsNullOrWhiteSpace(_weightCol) ? data.Schema.GetColumnOrNull(_weightCol) : default;
if (!string.IsNullOrWhiteSpace(_weightCol) && !weightCol.HasValue)
throw ch.Except($"Column {_weightCol} not found in data");
foreach (var col in cols)
{
_host.Check(col < data.Schema.Count);
var column = data.Schema[col];
if (column.Name != _labelCol && column.Name != _featureCol && column.Name != _groupCol && column.Name != _weightCol)
ch.Warning($"Column {column.Name} will not be saved. SVM-light saver saves the label column, feature column, optional group column and optional weight column.");
}
var columns = new List<DataViewSchema.Column>() { labelCol.Value, featureCol.Value };
if (groupCol.HasValue)
columns.Add(groupCol.Value);
if (weightCol.HasValue)
columns.Add(weightCol.Value);
using (var writer = new StreamWriter(stream))
using (var cursor = data.GetRowCursor(columns))
{
// Getting the getters will fail with type errors if the types are not correct,
// so we rely on those messages.
var labelGetter = cursor.GetGetter<float>(labelCol.Value);
var featuresGetter = cursor.GetGetter<VBuffer<float>>(featureCol.Value);
var groupGetter = groupCol.HasValue ? cursor.GetGetter<ulong>(groupCol.Value) : null;
var weightGetter = weightCol.HasValue ? cursor.GetGetter<float>(weightCol.Value) : null;
VBuffer<float> features = default;
while (cursor.MoveNext())
{
float lab = default;
labelGetter(ref lab);
if (_binary)
writer.Write(float.IsNaN(lab) ? 0 : (lab > 0 ? 1 : -1));
else
writer.Write("{0:R}", lab);
if (groupGetter != null)
{
ulong groupId = default;
groupGetter(ref groupId);
if (groupId > 0)
writer.Write(" qid:{0}", groupId - 1);
}
if (weightGetter != null)
{
float weight = default;
weightGetter(ref weight);
if (weight != 1)
writer.Write(" cost:{0:R}", weight);
}
featuresGetter(ref features);
bool any = false;
foreach (var pair in features.Items().Where(p => p.Value != 0))
{
writer.Write(" {0}:{1}", _zero ? pair.Key : (pair.Key + 1), pair.Value);
any = true;
}
// If there were no non-zero items, write a dummy item. Some parsers can handle
// empty arrays correctly, but some assume there is at least one defined item.
if (!any)
writer.Write(" {0}:0", _zero ? 0 : 1);
writer.WriteLine();
}
}
}
}
}
}
|