File: Transforms\SkipTakeFilter.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(SkipTakeFilter.SkipTakeFilterSummary, typeof(SkipTakeFilter), typeof(SkipTakeFilter.Options), typeof(SignatureDataTransform),
    SkipTakeFilter.SkipTakeFilterUserName, "SkipTakeFilter", SkipTakeFilter.SkipTakeFilterShortName)]
 
[assembly: LoadableClass(SkipTakeFilter.SkipFilterSummary, typeof(SkipTakeFilter), typeof(SkipTakeFilter.SkipOptions), typeof(SignatureDataTransform),
    SkipTakeFilter.SkipFilterUserName, "SkipFilter", SkipTakeFilter.SkipFilterShortName)]
 
[assembly: LoadableClass(SkipTakeFilter.TakeFilterSummary, typeof(SkipTakeFilter), typeof(SkipTakeFilter.TakeOptions), typeof(SignatureDataTransform),
    SkipTakeFilter.TakeFilterUserName, "TakeFilter", SkipTakeFilter.TakeFilterShortName)]
 
[assembly: LoadableClass(SkipTakeFilter.SkipTakeFilterSummary, typeof(SkipTakeFilter), null, typeof(SignatureLoadDataTransform),
    "Skip and Take Filter", SkipTakeFilter.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// Allows limiting input to a subset of row at an optional offset.  Can be used to implement data paging.
    /// </summary>
    [BestFriend]
    internal sealed class SkipTakeFilter : FilterBase, ITransformTemplate
    {
        public const string LoaderSignature = "SkipTakeFilter";
        private const string ModelSignature = "SKIPTKFL";
        private const string RegistrationName = "SkipTakeFilter";
 
        public const string SkipTakeFilterSummary = "Allows limiting input to a subset of rows at an optional offset.  Can be used to implement data paging.";
        public const string TakeFilterSummary = "Allows limiting input to a subset of rows by taking N first rows.";
        public const string SkipFilterSummary = "Allows limiting input to a subset of rows by skipping a number of rows.";
        public const string SkipTakeFilterUserName = "Skip and Take Filter";
        public const string SkipTakeFilterShortName = "SkipTake";
        public const string SkipFilterUserName = "Skip Filter";
        public const string SkipFilterShortName = "Skip";
        public const string TakeFilterUserName = "Take Filter";
        public const string TakeFilterShortName = "Take";
 
        public sealed class Options : TransformInputBase
        {
            internal const string SkipHelp = "Number of items to skip";
            internal const string TakeHelp = "Number of items to take";
            internal const long DefaultSkip = 0;
            internal const long DefaultTake = long.MaxValue;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = SkipHelp, ShortName = "s", SortOrder = 1)]
            public long? Skip;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = TakeHelp, ShortName = "t", SortOrder = 2)]
            public long? Take;
        }
 
        public sealed class TakeOptions : TransformInputBase
        {
            [Argument(ArgumentType.Required, HelpText = Options.TakeHelp, ShortName = "c,n,t", SortOrder = 1)]
            public long Count = Options.DefaultTake;
        }
 
        public sealed class SkipOptions : TransformInputBase
        {
            [Argument(ArgumentType.Required, HelpText = Options.SkipHelp, ShortName = "c,n,s", SortOrder = 1)]
            public long Count = Options.DefaultSkip;
        }
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: ModelSignature,
                verWrittenCur: 0x00010001,          // initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(SkipTakeFilter).Assembly.FullName);
        }
 
        private readonly long _skip;
        private readonly long _take;
 
        private SkipTakeFilter(long skip, long take, IHostEnvironment env, IDataView input)
            : base(env, RegistrationName, input)
        {
            Host.Assert(skip >= 0);
            Host.Assert(take >= 0);
 
            _skip = skip;
            _take = take;
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="SkipTakeFilter"/>.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="options">Options for the skip operation.</param>
        /// <param name="input">Input <see cref="IDataView"/>.</param>
        internal SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input)
            : this(options.Count, Options.DefaultTake, env, input)
        {
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="SkipTakeFilter"/>.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="options">Options for the take operation.</param>
        /// <param name="input">Input <see cref="IDataView"/>.</param>
        internal SkipTakeFilter(IHostEnvironment env, TakeOptions options, IDataView input)
            : this(Options.DefaultSkip, options.Count, env, input)
        {
        }
 
        IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
            => new SkipTakeFilter(_skip, _take, env, newSource);
 
        public static SkipTakeFilter Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            long skip = options.Skip ?? Options.DefaultSkip;
            long take = options.Take ?? Options.DefaultTake;
            env.CheckUserArg(skip >= 0, nameof(options.Skip), "should be non-negative");
            env.CheckUserArg(take >= 0, nameof(options.Take), "should be non-negative");
            return new SkipTakeFilter(skip, take, env, input);
        }
 
        public static SkipTakeFilter Create(IHostEnvironment env, SkipOptions options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckUserArg(options.Count >= 0, nameof(options.Count), "should be non-negative");
            return new SkipTakeFilter(options.Count, Options.DefaultTake, env, input);
        }
 
        public static SkipTakeFilter Create(IHostEnvironment env, TakeOptions options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckUserArg(options.Count >= 0, nameof(options.Count), "should be non-negative");
            return new SkipTakeFilter(Options.DefaultSkip, options.Count, env, input);
        }
 
        /// <summary>Creates instance of class from context.</summary>
        public static SkipTakeFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // long: skip
            // long: take
            long skip = ctx.Reader.ReadInt64();
            h.CheckDecode(skip >= 0);
            long take = ctx.Reader.ReadInt64();
            h.CheckDecode(take >= 0);
            return h.Apply("Loading Model", ch => new SkipTakeFilter(skip, take, h, input));
        }
 
        ///<summary>Saves class data to context</summary>
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // long: skip
            // long: take
            Host.Assert(_skip >= 0);
            ctx.Writer.Write(_skip);
            Host.Assert(_take >= 0);
            ctx.Writer.Write(_take);
        }
 
        /// <summary>
        /// This filter can not shuffle
        /// </summary>
        public override bool CanShuffle { get { return false; } }
 
        /// <summary>
        /// Returns the computed count of rows remaining after skip and take operation.
        /// Returns null if count is unknown.
        /// </summary>
        public override long? GetRowCount()
        {
            if (_take == 0)
                return 0;
            long? count = Source.GetRowCount();
            if (count == null)
                return null;
 
            long afterSkip = count.GetValueOrDefault() - _skip;
            return Math.Min(Math.Max(0, afterSkip), _take);
        }
 
        protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
        {
            Host.AssertValue(predicate, "predicate");
            return false;
        }
 
        protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
 
            var input = Source.GetRowCursor(columnsNeeded);
            var activeColumns = Utils.BuildArray(OutputSchema.Count, columnsNeeded);
            return new Cursor(Host, input, OutputSchema, activeColumns, _skip, _take);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);
            return new DataViewRowCursor[] { GetRowCursorCore(columnsNeeded) };
        }
 
        private sealed class Cursor : LinkedRowRootCursorBase
        {
            private readonly long _skip;
            private readonly long _take;
            private long _rowsTaken;
            private bool _started;
 
            /// <summary>
            /// SkipTakeFilter does not support cursor sets, so this can always be zero.
            /// </summary>
            public override long Batch => 0;
 
            public Cursor(IChannelProvider provider, DataViewRowCursor input, DataViewSchema schema, bool[] active, long skip, long take)
                : base(provider, input, schema, active)
            {
                Ch.Assert(skip >= 0);
                Ch.Assert(take >= 0);
 
                _skip = skip;
                _take = take;
            }
 
            public override ValueGetter<DataViewRowId> GetIdGetter()
            {
                return Input.GetIdGetter();
            }
 
            protected override bool MoveNextCore()
            {
                // Exit if 1 + _rowsTaken will overflow, or if we already have taken enough rows.
                if (1 > _take - _rowsTaken)
                {
                    _rowsTaken = _take;
                    return false;
                }
 
                ++_rowsTaken;
 
                if (!_started)
                {
                    _started = true;
 
                    // Exit if 1 + _skip will overflow.
                    if (1 > long.MaxValue - _skip)
                    {
                        _rowsTaken = _take;
                        return false;
                    }
 
                    // Move forward _skip + 1 rows to get to the "first" row of the input.
                    for (long i = 0; i <= _skip; ++i)
                    {
                        if (!Root.MoveNext())
                            return false;
                    }
                    return true;
                }
 
                return Root.MoveNext();
            }
        }
    }
}