File: EnsembleUtils.cs
Web Access
Project: src\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj (Microsoft.ML.Ensemble)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.Ensemble
{
    internal static class EnsembleUtils
    {
        /// <summary>
        /// Return a dataset with non-selected features zeroed out.
        /// </summary>
        public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, BitArray features)
        {
            Contracts.AssertValue(host);
            Contracts.AssertValue(data);
            Contracts.Assert(data.Schema.Feature.HasValue);
            Contracts.AssertValue(features);
            var featCol = data.Schema.Feature.Value;
 
            var type = featCol.Type;
            var typeVectorSize = type.GetVectorSize();
            Contracts.Assert(features.Length == typeVectorSize);
            int card = Utils.GetCardinality(features);
            if (card == typeVectorSize)
                return data;
 
            // REVIEW: This doesn't preserve metadata on the features column. Should it?
            var name = featCol.Name;
            var view = LambdaColumnMapper.Create(
                host, "FeatureSelector", data.Data, name, name, type, type,
                (in VBuffer<Single> src, ref VBuffer<Single> dst) => SelectFeatures(in src, features, card, ref dst));
 
            var res = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
            return res;
        }
 
        /// <summary>
        /// Fill dst with values selected from src if the indices of the src values are set in includedIndices,
        /// otherwise assign default(T). The length of dst will be equal to src.Length.
        /// </summary>
        public static void SelectFeatures<T>(in VBuffer<T> src, BitArray includedIndices, int cardinality, ref VBuffer<T> dst)
        {
            Contracts.Assert(Utils.Size(includedIndices) == src.Length);
            Contracts.Assert(cardinality == Utils.GetCardinality(includedIndices));
            Contracts.Assert(cardinality < src.Length);
 
            var srcValues = src.GetValues();
            if (src.IsDense)
            {
                if (cardinality >= src.Length / 2)
                {
                    T defaultValue = default;
                    var editor = VBufferEditor.Create(ref dst, src.Length);
                    for (int i = 0; i < srcValues.Length; i++)
                        editor.Values[i] = !includedIndices[i] ? defaultValue : srcValues[i];
                    dst = editor.Commit();
                }
                else
                {
                    var editor = VBufferEditor.Create(ref dst, src.Length, cardinality);
 
                    int count = 0;
                    for (int i = 0; i < srcValues.Length; i++)
                    {
                        if (includedIndices[i])
                        {
                            Contracts.Assert(count < cardinality);
                            editor.Values[count] = srcValues[i];
                            editor.Indices[count] = i;
                            count++;
                        }
                    }
 
                    Contracts.Assert(count == cardinality);
                    dst = editor.Commit();
                }
            }
            else
            {
                var editor = VBufferEditor.Create(ref dst, src.Length, cardinality);
 
                int count = 0;
                var srcIndices = src.GetIndices();
                for (int i = 0; i < srcValues.Length; i++)
                {
                    if (includedIndices[srcIndices[i]])
                    {
                        editor.Values[count] = srcValues[i];
                        editor.Indices[count] = srcIndices[i];
                        count++;
                    }
                }
 
                dst = editor.CommitTruncated(count);
            }
        }
    }
}