Training\TrainerUtils.cs (82)
224private static IEnumerable<DataViewSchema.Column> CreatePredicate(RoleMappedData data, CursOpt opt, IEnumerable<int> extraCols)
233if ((opt & CursOpt.Label) != 0 && data.Schema.Label.HasValue)
235if ((opt & CursOpt.Features) != 0 && data.Schema.Feature.HasValue)
237if ((opt & CursOpt.Weight) != 0 && data.Schema.Weight.HasValue)
239if ((opt & CursOpt.Group) != 0 && data.Schema.Group.HasValue)
248public static DataViewRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, Random rand, IEnumerable<int> extraCols = null)
256CursOpt opt, int n, Random rand, IEnumerable<int> extraCols = null)
496private readonly Action<CursOpt> _signal;
506protected TrainingCursorBase(DataViewRowCursor input, Action<CursOpt> signal)
514protected static DataViewRowCursor CreateCursor(RoleMappedData data, CursOpt opt, Random rand, params int[] extraCols)
529/// return the default <see cref="CursOpt"/>, in which case the flags will not ever change.
534protected virtual CursOpt CursoringCompleteFlags()
536return SkippedRowCount == 0 ? CursOpt.AllowBadEverything : default(CursOpt);
594private readonly CursOpt _initOpts;
597private CursOpt _opts;
601protected FactoryBase(RoleMappedData data, CursOpt opt)
610private void SignalCore(CursOpt opt)
625CursOpt opt;
646CursOpt opt;
655Action<CursOpt> signal;
680protected abstract TCurs CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal);
689private readonly Action<CursOpt> _signal;
692private CursOpt _opts;
694public AndAccumulator(Action<CursOpt> signal, int lim)
700_opts = ~default(CursOpt);
703public void Signal(CursOpt opt)
736public StandardScalarCursor(RoleMappedData data, CursOpt opt, Random rand = null, params int[] extraCols)
741protected StandardScalarCursor(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
746if ((opt & CursOpt.Weight) != 0)
749_keepBadWeight = (opt & CursOpt.AllowBadWeights) != 0;
751if ((opt & CursOpt.Group) != 0)
754_keepBadGroup = (opt & CursOpt.AllowBadGroups) != 0;
756if ((opt & CursOpt.Id) != 0)
762protected override CursOpt CursoringCompleteFlags()
764CursOpt opt = base.CursoringCompleteFlags();
766opt |= CursOpt.AllowBadWeights;
768opt |= CursOpt.AllowBadGroups;
801public Factory(RoleMappedData data, CursOpt opt)
806protected override StandardScalarCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
825public FeatureFloatVectorCursor(RoleMappedData data, CursOpt opt = CursOpt.Features,
831protected FeatureFloatVectorCursor(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
834if ((opt & CursOpt.Features) != 0 && data.Schema.Feature != null)
837_keepBad = (opt & CursOpt.AllowBadFeatures) != 0;
841protected override CursOpt CursoringCompleteFlags()
843var opt = base.CursoringCompleteFlags();
845opt |= CursOpt.AllowBadFeatures;
867public Factory(RoleMappedData data, CursOpt opt = CursOpt.Features)
872protected override FeatureFloatVectorCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
892public FloatLabelCursor(RoleMappedData data, CursOpt opt = CursOpt.Label,
898protected FloatLabelCursor(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
901if ((opt & CursOpt.Label) != 0 && data.Schema.Label != null)
904_keepBad = (opt & CursOpt.AllowBadLabels) != 0;
908protected override CursOpt CursoringCompleteFlags()
910var opt = base.CursoringCompleteFlags();
912opt |= CursOpt.AllowBadLabels;
933public Factory(RoleMappedData data, CursOpt opt = CursOpt.Label)
938protected override FloatLabelCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
961public MulticlassLabelCursor(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label,
967protected MulticlassLabelCursor(int classCount, DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
973if ((opt & CursOpt.Label) != 0 && data.Schema.Label != null)
976_keepBad = (opt & CursOpt.AllowBadLabels) != 0;
980protected override CursOpt CursoringCompleteFlags()
982var opt = base.CursoringCompleteFlags();
984opt |= CursOpt.AllowBadLabels;
1008public Factory(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label)
1016protected override MulticlassLabelCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)