|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Runtime.InteropServices;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Trainers.LightGbm
{
/// <summary>
/// Wrapper of Dataset object of LightGBM.
/// </summary>
internal sealed class Dataset : IDisposable
{
private WrappedLightGbmInterface.SafeDataSetHandle _handle;
private int _lastPushedRowID;
public WrappedLightGbmInterface.SafeDataSetHandle Handle => _handle;
/// <summary>
/// Create a <see cref="Dataset"/> for storing training and prediction data under LightGBM framework. The main goal of this function
/// is not marshaling ML.NET data set into LightGBM format but just creates a (unmanaged) container where examples can be pushed into by calling
/// <see cref="PushRows(float[], int, int, int)"/>. It also pre-allocates memory so the actual size (number of examples and number of features)
/// of the data set is required. A sub-sampled version of the original data set is passed in to compute some statistics needed by the training
/// procedure. Note that we use "original" to indicate a property from the unsampled data set.
/// </summary>
/// <param name="sampleValuePerColumn">A 2-D array which encodes the sub-sampled data matrix. sampleValuePerColumn[i] stores
/// all the non-zero values of the i-th feature. sampleValuePerColumn[i][j] is the j-th non-zero value of i-th feature encountered when scanning
/// the values row-by-row (i.e., example-by-example) in the matrix and column-by-column (i.e., feature-by-feature) within one row. It is similar
/// to CSC format for storing sparse matrix.</param>
/// <param name="sampleIndicesPerColumn">A 2-D array which encodes sub-sampled example indexes of non-zero features stored in sampleValuePerColumn.
/// The sampleIndicesPerColumn[i][j]-th example has a non-zero i-th feature whose value is sampleValuePerColumn[i][j].</param>
/// <param name="numCol">Total number of features in the original data.</param>
/// <param name="sampleNonZeroCntPerColumn">sampleNonZeroCntPerColumn[i] is the size of sampleValuePerColumn[i].</param>
/// <param name="numSampleRow">The number of sampled examples in the sub-sampled data matrix.</param>
/// <param name="numTotalRow">The number of original examples added using <see cref="PushRows(float[], int, int, int)"/>.</param>
/// <param name="param">LightGBM parameter used in https://github.com/Microsoft/LightGBM/blob/c920e6345bcb41fc1ec6ac338f5437034b9f0d38/src/c_api.cpp#L421. </param>
/// <param name="labels">Labels of the original data. labels[i] is the label of the i-th original example.</param>
/// <param name="weights">Example weights of the original data. weights[i] is the weight of the i-th original example.</param>
/// <param name="groups">Group identifiers of the original data. groups[i] is the group ID of the i-th original example.</param>
public unsafe Dataset(double[][] sampleValuePerColumn,
int[][] sampleIndicesPerColumn,
int numCol,
int[] sampleNonZeroCntPerColumn,
int numSampleRow,
int numTotalRow,
string param, float[] labels, float[] weights = null, int[] groups = null)
{
_handle = null;
// Use GCHandle to pin the memory, avoid the memory relocation.
GCHandle[] gcValues = new GCHandle[numCol];
GCHandle[] gcIndices = new GCHandle[numCol];
try
{
double*[] ptrArrayValues = new double*[numCol];
int*[] ptrArrayIndices = new int*[numCol];
for (int i = 0; i < numCol; i++)
{
gcValues[i] = GCHandle.Alloc(sampleValuePerColumn[i], GCHandleType.Pinned);
ptrArrayValues[i] = (double*)gcValues[i].AddrOfPinnedObject().ToPointer();
gcIndices[i] = GCHandle.Alloc(sampleIndicesPerColumn[i], GCHandleType.Pinned);
ptrArrayIndices[i] = (int*)gcIndices[i].AddrOfPinnedObject().ToPointer();
}
fixed (double** ptrValues = ptrArrayValues)
fixed (int** ptrIndices = ptrArrayIndices)
{
// Create container. Examples will pushed in later.
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateFromSampledColumn(
(IntPtr)ptrValues, (IntPtr)ptrIndices, numCol, sampleNonZeroCntPerColumn, numSampleRow, numTotalRow,
param, out _handle));
}
}
finally
{
for (int i = 0; i < numCol; i++)
{
if (gcValues[i].IsAllocated)
gcValues[i].Free();
if (gcIndices[i].IsAllocated)
gcIndices[i].Free();
}
}
// Before adding examples (i.e., feature vectors of the original data set), the original labels, weights, and groups are added.
SetLabel(labels);
SetWeights(weights);
SetGroup(groups);
Contracts.Assert(GetNumCols() == numCol);
Contracts.Assert(GetNumRows() == numTotalRow);
}
public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weights = null, int[] groups = null)
{
WrappedLightGbmInterface.SafeDataSetHandle refHandle = reference?.Handle;
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, out _handle));
SetLabel(labels);
SetWeights(weights);
SetGroup(groups);
}
public void Dispose()
{
_handle?.Dispose();
_handle = null;
}
/// <summary>
/// Append examples to LightGBM dataset.
/// </summary>
/// <param name="data">Dense (# of rows)-by-(# of columns) matrix flattened in a row-major format. One row per example.
/// The value at the i-th row and j-th column is stored in data[j + i * (# of columns)].</param>
/// <param name="numRow"># of rows of the data matrix.</param>
/// <param name="numCol"># of columns of the data matrix.</param>
/// <param name="startRowIdx">The actual row index of the first row pushed in. If it's 36, the first row in data would be the 37th row in <see cref="Dataset"/>.</param>
public void PushRows(float[] data, int numRow, int numCol, int startRowIdx)
{
Contracts.Assert(startRowIdx == _lastPushedRowID);
Contracts.Assert(numCol == GetNumCols());
Contracts.Assert(numRow > 0);
Contracts.Assert(startRowIdx <= GetNumRows() - numRow);
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetPushRows(_handle, data, numRow, numCol, startRowIdx));
_lastPushedRowID = startRowIdx + numRow;
}
public void PushRows(int[] indPtr, int[] indices, float[] data, int nIndptr,
long numElem, int numCol, int startRowIdx)
{
Contracts.Assert(startRowIdx == _lastPushedRowID);
Contracts.Assert(numCol == GetNumCols());
Contracts.Assert(startRowIdx < GetNumRows());
LightGbmInterfaceUtils.Check(
WrappedLightGbmInterface.DatasetPushRowsByCsr(
_handle, indPtr, indices, data, nIndptr, numElem, numCol, startRowIdx));
_lastPushedRowID = startRowIdx + nIndptr - 1;
}
public int GetNumRows()
{
int res = 0;
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetGetNumData(_handle, ref res));
return res;
}
public int GetNumCols()
{
int res = 0;
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetGetNumFeature(_handle, ref res));
return res;
}
public unsafe void SetLabel(float[] labels)
{
Contracts.AssertValue(labels);
Contracts.Assert(labels.Length == GetNumRows());
fixed (float* ptr = labels)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "label", (IntPtr)ptr, labels.Length,
WrappedLightGbmInterface.CApiDType.Float32));
}
public unsafe void SetWeights(float[] weights)
{
if (weights != null)
{
Contracts.Assert(weights.Length == GetNumRows());
// Skip SetWeights if all weights are same.
bool allSame = true;
for (int i = 1; i < weights.Length; ++i)
{
if (weights[i] != weights[0])
{
allSame = false;
break;
}
}
if (!allSame)
{
fixed (float* ptr = weights)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "weight", (IntPtr)ptr, weights.Length,
WrappedLightGbmInterface.CApiDType.Float32));
}
}
}
public unsafe void SetGroup(int[] groups)
{
if (groups != null)
{
fixed (int* ptr = groups)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "group", (IntPtr)ptr, groups.Length,
WrappedLightGbmInterface.CApiDType.Int32));
}
}
// Not used now. Can use for the continued train.
public unsafe void SetInitScore(double[] initScores)
{
if (initScores != null)
{
Contracts.Assert(initScores.Length % GetNumRows() == 0);
fixed (double* ptr = initScores)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "init_score", (IntPtr)ptr, initScores.Length,
WrappedLightGbmInterface.CApiDType.Float64));
}
}
}
}
|