File: Dynamic\DataOperations\FilterRowsByKeyColumnFraction.cs
Web Access
Project: src\docs\samples\Microsoft.ML.Samples\Microsoft.ML.Samples.csproj (Microsoft.ML.Samples)
using System;
using System.Collections.Generic;
using Microsoft.ML;
 
namespace Samples.Dynamic
{
    /// <summary>
    /// Sample class showing how to use FilterRowsByKeyColumnFraction.
    /// </summary>
    public static class FilterRowsByKeyColumnFraction
    {
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, 
            // as a catalog of available operations and as the source of randomness.
            var mlContext = new MLContext();
 
            // Create a small dataset as an IEnumerable.
            var samples = new List<DataPoint>()
            {
                new DataPoint(){ Age = 21 },
                new DataPoint(){ Age = 40 },
                new DataPoint(){ Age = 38 },
                new DataPoint(){ Age = 22 },
                new DataPoint(){ Age = 40 },
                new DataPoint(){ Age = 40 },
                new DataPoint(){ Age = 22 },
                new DataPoint(){ Age = 21 }
            };
 
            // Convert training data to IDataView.
            var data = mlContext.Data.LoadFromEnumerable(samples);
 
            // Convert the uint values to KeyDataViewData
            var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Age");
            var transformedData = pipeline.Fit(data).Transform(data);
 
            // Before we apply a filter, examine all the records in the dataset.
            var enumerable = mlContext.Data
                .CreateEnumerable<DataPoint>(transformedData, reuseRowObject: true);
 
            Console.WriteLine($"Age");
            foreach (var row in enumerable)
            {
                Console.WriteLine($"{row.Age}");
            }
 
            // Age
            // 
            //  1
            //  2
            //  3
            //  4
            //  2
            //  2
            //  4
            //  1
 
            // Now filter down to half the keys, choosing the lower half of values. 
            // For the keys we have the sorted values: 1 1 2 2 2 3 4 4.
            // Projected in the [0, 1[ interval as per: (key - 0.5)/(Count of Keys)
            // the values of the keys for our data would be:
            // 0.125 0.125 0.375 0.375 0.375 0.625 0.875 0.875
            // so the keys resulting from filtering in the [0, 0.5 [ interval are
            // the ones with normalized values 0.125 and 0.375, respectively keys 
            // with values 1 and 2.
            var filteredHalfData = mlContext.Data
                .FilterRowsByKeyColumnFraction(transformedData, columnName: "Age",
                lowerBound: 0, upperBound: 0.5);
 
            var filteredHalfEnumerable = mlContext.Data
                .CreateEnumerable<DataPoint>(filteredHalfData,
                reuseRowObject: true);
 
            Console.WriteLine($"Age");
            foreach (var row in filteredHalfEnumerable)
            {
                Console.WriteLine($"{row.Age}");
            }
 
            // Age
            // 
            //  1
            //  2
            //  2
            //  2
            //  1
 
            // As mentioned above, the normalized keys are:
            // 0.125 0.125 0.375 0.375 0.375 0.625 0.875 0.875
            // so the keys resulting from filtering in the [0.3, 0.6 [ interval are
            // the ones with normalized value 0.375, respectively key with
            // value = 2.
            var filteredMiddleData = mlContext.Data
                .FilterRowsByKeyColumnFraction(transformedData, columnName: "Age",
                lowerBound: 0.3, upperBound: 0.6);
 
            // Look at the data and observe that values above 2 have been filtered
            // out
            var filteredMiddleEnumerable = mlContext.Data
                .CreateEnumerable<DataPoint>(filteredMiddleData,
                reuseRowObject: true);
 
            Console.WriteLine($"Age");
            foreach (var row in filteredMiddleEnumerable)
            {
                Console.WriteLine($"{row.Age}");
            }
 
            // Age
            //
            //  2
            //  2
            //  2
        }
 
        private class DataPoint
        {
            public uint Age { get; set; }
        }
    }
}