File: SweepableEstimator\MultiModelPipeline.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
 
namespace Microsoft.ML.AutoML
{
    [JsonConverter(typeof(MultiModelPipelineConverter))]
    internal class MultiModelPipeline
    {
        private static readonly StringEntity _nilStringEntity = new StringEntity("Nil");
        private static readonly EstimatorEntity _nilSweepableEntity = new EstimatorEntity(null);
        private readonly Dictionary<string, SweepableEstimator> _estimators;
        private readonly Entity _schema;
 
        internal MultiModelPipeline()
        {
            _estimators = new Dictionary<string, SweepableEstimator>();
            _schema = null;
        }
 
        internal MultiModelPipeline(Dictionary<string, SweepableEstimator> estimators, Entity schema)
        {
            _estimators = estimators;
            _schema = schema;
        }
 
        public Dictionary<string, SweepableEstimator> Estimators { get => _estimators; }
 
        internal Entity Schema { get => _schema; }
 
        /// <summary>
        /// Get the schema of all single model pipelines in the form of strings.
        /// the pipeline id can be used to create a single model pipeline through <see cref="MultiModelPipeline.BuildSweepableEstimatorPipeline(string)"/>.
        /// </summary>
        internal string[] PipelineIds { get => Schema.ToTerms().Select(t => t.ToString()).ToArray(); }
 
        public MultiModelPipeline Append(params SweepableEstimator[] estimators)
        {
            Entity entity = null;
            foreach (var estimator in estimators)
            {
                if (entity == null)
                {
                    entity = new EstimatorEntity(estimator);
                    continue;
                }
 
                entity += estimator;
            }
 
            return Append(entity);
        }
 
        public MultiModelPipeline AppendOrSkip(params SweepableEstimator[] estimators)
        {
            Entity entity = null;
            foreach (var estimator in estimators)
            {
                if (entity == null)
                {
                    entity = new EstimatorEntity(estimator);
                    continue;
                }
 
                entity += estimator;
            }
 
            return AppendOrSkip(entity);
        }
 
        public SweepableEstimatorPipeline BuildSweepableEstimatorPipeline(string schema)
        {
            var pipelineNodes = Entity.FromExpression(schema)
                                      .ValueEntities()
                                      .Where(e => e is StringEntity se && se.Value != "Nil")
                                      .Select((se) => _estimators[((StringEntity)se).Value]);
 
            return new SweepableEstimatorPipeline(pipelineNodes);
        }
 
        internal MultiModelPipeline Append(Entity entity)
        {
            return AppendEntity(false, entity);
        }
 
        internal MultiModelPipeline AppendOrSkip(Entity entity)
        {
            return AppendEntity(true, entity);
        }
 
        internal MultiModelPipeline AppendOrSkip(MultiModelPipeline pipeline)
        {
            return AppendPipeline(true, pipeline);
        }
 
        internal MultiModelPipeline Append(MultiModelPipeline pipeline)
        {
            return AppendPipeline(false, pipeline);
        }
 
        private MultiModelPipeline AppendPipeline(bool allowSkip, MultiModelPipeline pipeline)
        {
            var sweepableEntity = CreateSweepableEntityFromEntity(pipeline.Schema, pipeline.Estimators);
            return AppendEntity(allowSkip, sweepableEntity);
        }
 
        private MultiModelPipeline AppendEntity(bool allowSkip, Entity entity)
        {
            var estimators = _estimators.ToDictionary(x => x.Key, x => x.Value);
            var stringEntity = VisitAndReplaceSweepableEntityWithStringEntity(entity, ref estimators);
            if (allowSkip)
            {
                stringEntity += _nilStringEntity;
            }
 
            var schema = _schema;
            if (schema == null)
            {
                schema = stringEntity;
            }
            else
            {
                schema *= stringEntity;
            }
 
            return new MultiModelPipeline(estimators, schema);
        }
 
        private Entity CreateSweepableEntityFromEntity(Entity entity, Dictionary<string, SweepableEstimator> lookupTable)
        {
            if (entity is null)
            {
                return null;
            }
 
            if (entity is StringEntity stringEntity)
            {
                if (stringEntity == _nilStringEntity)
                {
                    return _nilSweepableEntity;
                }
 
                return new EstimatorEntity(lookupTable[stringEntity.Value]);
            }
            else if (entity is ConcatenateEntity concatenateEntity)
            {
                return new ConcatenateEntity()
                {
                    Left = CreateSweepableEntityFromEntity(concatenateEntity.Left, lookupTable),
                    Right = CreateSweepableEntityFromEntity(concatenateEntity.Right, lookupTable),
                };
            }
            else if (entity is OneOfEntity oneOfEntity)
            {
                return new OneOfEntity()
                {
                    Left = CreateSweepableEntityFromEntity(oneOfEntity.Left, lookupTable),
                    Right = CreateSweepableEntityFromEntity(oneOfEntity.Right, lookupTable),
                };
            }
 
            throw new ArgumentException();
        }
 
        private Entity VisitAndReplaceSweepableEntityWithStringEntity(Entity e, ref Dictionary<string, SweepableEstimator> estimators)
        {
            if (e is null)
            {
                return null;
            }
 
            if (e is EstimatorEntity sweepableEntity0)
            {
                if (sweepableEntity0 == _nilSweepableEntity)
                {
                    return _nilStringEntity;
                }
 
                var id = GetNextId(estimators);
                estimators[id] = (SweepableEstimator)sweepableEntity0.Estimator;
                return new StringEntity(id);
            }
 
            e.Left = VisitAndReplaceSweepableEntityWithStringEntity(e.Left, ref estimators);
            e.Right = VisitAndReplaceSweepableEntityWithStringEntity(e.Right, ref estimators);
 
            return e;
        }
 
        private string GetNextId(Dictionary<string, SweepableEstimator> estimators)
        {
            var count = estimators.Count();
            return "e" + count.ToString();
        }
    }
}