File: Transforms\BootstrapSamplingTransformer.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), typeof(BootstrapSamplingTransformer.Options), typeof(SignatureDataTransform),
    BootstrapSamplingTransformer.UserName, "BootstrapSampleTransform", "BootstrapSample")]
 
[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), null, typeof(SignatureLoadDataTransform),
    BootstrapSamplingTransformer.UserName, BootstrapSamplingTransformer.LoaderSignature)]
 
[assembly: EntryPointModule(typeof(BootstrapSample))]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// This class approximates bootstrap sampling of a dataview.
    /// </summary>
    [BestFriend]
    internal sealed class BootstrapSamplingTransformer : FilterBase
    {
        internal static class Defaults
        {
            public const bool Complement = false;
            public const bool ShuffleInput = true;
            public const int PoolSize = 1000;
        }
 
        public sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.",
                ShortName = "comp")]
            public bool Complement = Defaults.Complement;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The random seed. If unspecified random state will be instead derived from the environment.")]
            public uint? Seed;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to shuffle the source data. By default on, but can be turned off for efficiency.", ShortName = "si")]
            public bool ShuffleInput = Defaults.ShuffleInput;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "When shuffling the output, the number of output rows to keep in that pool. Note that shuffling of output is completely distinct from shuffling of input.", ShortName = "pool")]
            public int PoolSize = Defaults.PoolSize;
        }
 
        internal const string Summary = "Approximate bootstrap sampling.";
        internal const string UserName = "Bootstrap Sample Transform";
 
        public const string LoaderSignature = "BootstrapSampleTransform";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "BTSAMPXF",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(BootstrapSamplingTransformer).Assembly.FullName);
        }
 
        internal const string RegistrationName = "BootstrapSample";
 
        public override bool CanShuffle { get { return false; } }
 
        private readonly bool _complement;
        private readonly TauswortheHybrid.State _state;
        private readonly bool _shuffleInput;
        private readonly int _poolSize;
 
        public BootstrapSamplingTransformer(IHostEnvironment env, Options options, IDataView input)
            : base(env, RegistrationName, input)
        {
            Host.CheckValue(options, nameof(options));
            Host.CheckUserArg(options.PoolSize >= 0, nameof(options.PoolSize), "Cannot be negative");
 
            _complement = options.Complement;
            _state = new TauswortheHybrid.State(options.Seed ?? (uint)Host.Rand.Next());
            _shuffleInput = options.ShuffleInput;
            _poolSize = options.PoolSize;
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="BootstrapSamplingTransformer"/>.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
        /// <param name="complement">Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.</param>
        /// <param name="seed">The random seed. If unspecified random state will be instead derived from the environment.</param>
        /// <param name="shuffleInput">Whether we should attempt to shuffle the source data. By default on, but can be turned off for efficiency.</param>
        /// <param name="poolSize">When shuffling the output, the number of output rows to keep in that pool. Note that shuffling of output is completely distinct from shuffling of input.</param>
        public BootstrapSamplingTransformer(IHostEnvironment env,
            IDataView input,
            bool complement = Defaults.Complement,
            uint? seed = null,
            bool shuffleInput = Defaults.ShuffleInput,
            int poolSize = Defaults.PoolSize)
            : this(env, new Options() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
        {
        }
 
        private BootstrapSamplingTransformer(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            host.AssertValue(ctx);
            host.AssertValue(input);
 
            // *** Binary format ***
            // byte: is the compliment sample, that is, an out-of-bag sample
            // uint: seed0
            // uint: seed1
            // uint: seed2
            // uint: seed3
            // byte: input source should be shuffled
            // int: size of the output pool size
 
            _complement = ctx.Reader.ReadBoolByte();
            _state = TauswortheHybrid.State.Load(ctx.Reader);
            _shuffleInput = ctx.Reader.ReadBoolByte();
            _poolSize = ctx.Reader.ReadInt32();
            Host.CheckDecode(_poolSize >= 0);
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // byte: is the compliment sample, that is, an out-of-bag sample
            // uint: seed0
            // uint: seed1
            // uint: seed2
            // uint: seed3
            // byte: input source should be shuffled
            // int: size of the output pool size
 
            ctx.Writer.WriteBoolByte(_complement);
            _state.Save(ctx.Writer);
            ctx.Writer.WriteBoolByte(_shuffleInput);
            Host.Assert(_poolSize >= 0);
            ctx.Writer.Write(_poolSize);
        }
 
        public static BootstrapSamplingTransformer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
            return h.Apply("Loading Model", ch => new BootstrapSamplingTransformer(h, ctx, input));
        }
 
        protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
        {
            return false;
        }
 
        protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            // We do not use the input random because this cursor does not support shuffling.
            var rgen = new TauswortheHybrid(_state);
            var input = Source.GetRowCursor(columnsNeeded, _shuffleInput ? new TauswortheHybrid(rgen) : null);
            DataViewRowCursor cursor = new Cursor(this, input, rgen);
            if (_poolSize > 1)
                cursor = RowShufflingTransformer.GetShuffledCursor(Host, _poolSize, cursor, new TauswortheHybrid(rgen));
            return cursor;
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            var cursor = GetRowCursorCore(columnsNeeded, rand);
            return new DataViewRowCursor[] { cursor };
        }
 
        private sealed class Cursor : LinkedRootCursorBase
        {
            private int _remaining;
            private readonly BootstrapSamplingTransformer _parent;
            private readonly Random _rgen;
 
            public override long Batch => 0;
 
            public override DataViewSchema Schema => Input.Schema;
 
            public Cursor(BootstrapSamplingTransformer parent, DataViewRowCursor input, Random rgen)
                : base(parent.Host, input)
            {
                Ch.AssertValue(rgen);
                _parent = parent;
                _rgen = rgen;
            }
 
            public override ValueGetter<DataViewRowId> GetIdGetter()
            {
                var inputIdGetter = Input.GetIdGetter();
                return
                    (ref DataViewRowId val) =>
                    {
                        inputIdGetter(ref val);
                        val = val.Combine(new DataViewRowId((ulong)_remaining, 0));
                    };
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                return Input.GetGetter<TValue>(column);
            }
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                return Input.IsColumnActive(column);
            }
 
            protected override bool MoveNextCore()
            {
                Ch.Assert(_remaining >= 0);
                while (_remaining == 0 && Input.MoveNext())
                {
                    _remaining = Stats.SampleFromPoisson(_rgen, 1);
                    if (_parent._complement)
                        _remaining = _remaining == 0 ? 1 : 0;
                }
                return _remaining-- > 0;
            }
        }
    }
 
    /// <summary>
    /// Entry point methods for bootstrap sampling.
    /// </summary>
    internal static class BootstrapSample
    {
        [TlcModule.EntryPoint(Name = "Transforms.ApproximateBootstrapSampler", Desc = BootstrapSamplingTransformer.Summary, UserName = BootstrapSamplingTransformer.UserName, ShortName = BootstrapSamplingTransformer.RegistrationName)]
        public static CommonOutputs.TransformOutput GetSample(IHostEnvironment env, BootstrapSamplingTransformer.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
 
            var h = EntryPointUtils.CheckArgsAndCreateHost(env, "BootstrapSample", input);
            var view = new BootstrapSamplingTransformer(h, input, input.Data);
            return new CommonOutputs.TransformOutput()
            {
                Model = new TransformModelImpl(h, view, input.Data),
                OutputData = view
            };
        }
    }
}