File: SweepableEstimator\SweepablePipeline.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using Microsoft.ML.Data;
using Microsoft.ML.SearchSpace;
using Microsoft.ML.SearchSpace.Option;
 
namespace Microsoft.ML.AutoML
{
    [JsonConverter(typeof(SweepablePipelineConverter))]
    public class SweepablePipeline : ISweepable<EstimatorChain<ITransformer>>
    {
        private readonly Entity _schema;
        private const string SchemaOption = "_SCHEMA_";
        private readonly Dictionary<string, SweepableEstimator> _estimators = new Dictionary<string, SweepableEstimator>();
        private static readonly StringEntity _nilStringEntity = new StringEntity("Nil");
        private static readonly EstimatorEntity _nilSweepableEntity = new EstimatorEntity(null);
        private string _currentSchema;
 
        public SearchSpace.SearchSpace SearchSpace
        {
            get
            {
                var searchSpace = new SearchSpace.SearchSpace();
                foreach (var kv in _estimators)
                {
                    if (kv.Value != null)
                    {
                        searchSpace.Add(kv.Key, kv.Value.SearchSpace);
                    }
                }
 
                var schemaOptions = _schema.ToTerms().Select(t => t.ToString()).ToArray();
                var choiceOption = new ChoiceOption(schemaOptions);
                searchSpace.Add(SchemaOption, choiceOption);
 
                return searchSpace;
            }
        }
 
        public Parameter CurrentParameter
        {
            get
            {
                var parameter = Parameter.CreateNestedParameter();
                var kvPairs = _estimators.Select((e, i) => new KeyValuePair<string, Parameter>(i.ToString(), e.Value.Parameter));
                foreach (var kv in kvPairs)
                {
                    if (kv.Value != null)
                    {
                        parameter[kv.Key] = kv.Value;
                    }
                }
 
                parameter[SchemaOption] = Parameter.FromString(_currentSchema);
                return parameter;
            }
        }
 
        internal SweepablePipeline()
        {
            _estimators = new Dictionary<string, SweepableEstimator>();
            _schema = null;
        }
 
        internal SweepablePipeline(Dictionary<string, SweepableEstimator> estimators, Entity schema, string currentSchema = null)
        {
            _estimators = estimators;
            _schema = schema;
            _currentSchema = currentSchema ?? schema.ToTerms().First().ToString();
        }
 
        public Dictionary<string, SweepableEstimator> Estimators { get => _estimators; }
 
 
        internal Entity Schema { get => _schema; }
 
        public EstimatorChain<ITransformer> BuildFromOption(MLContext context, Parameter parameter)
        {
            _currentSchema = parameter[SchemaOption].AsType<string>();
            var pipeline = new EstimatorChain<ITransformer>();
            var estimatorParameterPair = Entity.FromExpression(_currentSchema)
                                   .ValueEntities()
                                   .Where(e => e is StringEntity se && se.Value != "Nil")
                                   .Select((se) =>
                                   {
                                       var key = ((StringEntity)se).Value;
                                       var estimator = _estimators[key];
                                       var param = parameter[key];
                                       return (estimator, param);
                                   });
 
            foreach (var kv in estimatorParameterPair)
            {
                pipeline = pipeline.Append(kv.estimator.BuildFromOption(context, kv.param));
            }
 
            return pipeline;
        }
 
        public SweepablePipeline BuildSweepableEstimatorPipeline(string schema)
        {
            var entity = Entity.FromExpression(schema);
            var pipelineNodes = entity.ValueEntities()
                                      .Where(e => e is StringEntity se && se.Value != "Nil")
                                      .ToDictionary((se) => se.ToString(), (se) => _estimators[((StringEntity)se).Value]);
 
            return new SweepablePipeline(pipelineNodes, entity, schema);
        }
 
        public SweepablePipeline Append(params ISweepable<IEstimator<ITransformer>>[] sweepables)
        {
            Entity entity = null;
            foreach (var sweepable in sweepables)
            {
                if (sweepable is SweepableEstimator estimator)
                {
                    if (entity == null)
                    {
                        entity = new EstimatorEntity(estimator);
                        continue;
                    }
                    else
                    {
                        entity += estimator;
                    }
                }
                else if (sweepable is SweepablePipeline pipeline)
                {
                    if (entity == null)
                    {
                        entity = CreateSweepableEntityFromEntity(pipeline._schema, pipeline._estimators);
                        continue;
                    }
                    else
                    {
                        entity += CreateSweepableEntityFromEntity(pipeline._schema, pipeline._estimators);
                    }
                }
            }
 
            return AppendEntity(false, entity);
        }
 
        public string ToString(Parameter parameter)
        {
            if (parameter.TryGetValue(AutoMLExperiment.PipelineSearchspaceName, out var pipelineParameter))
            {
                var schema = pipelineParameter["_SCHEMA_"].AsType<string>();
                var estimatorStrings = Entity.FromExpression(schema)
                                   .ValueEntities()
                                   .Where(e => e is StringEntity se && se.Value != "Nil")
                                   .Select((se) =>
                                   {
                                       var key = ((StringEntity)se).Value;
                                       var estimator = _estimators[key];
                                       return estimator.EstimatorType.ToString();
                                   });
 
                return string.Join("=>", estimatorStrings);
            }
 
            return string.Empty;
        }
 
        private SweepablePipeline AppendEntity(bool allowSkip, Entity entity)
        {
            var estimators = _estimators.ToDictionary(x => x.Key, x => x.Value);
            var stringEntity = VisitAndReplaceSweepableEntityWithStringEntity(entity, ref estimators);
            if (allowSkip)
            {
                stringEntity += _nilStringEntity;
            }
 
            var schema = _schema;
            if (schema == null)
            {
                schema = stringEntity;
            }
            else
            {
                schema *= stringEntity;
            }
 
            return new SweepablePipeline(estimators, schema);
        }
 
        private Entity CreateSweepableEntityFromEntity(Entity entity, Dictionary<string, SweepableEstimator> lookupTable)
        {
            if (entity is null)
            {
                return null;
            }
 
            if (entity is StringEntity stringEntity)
            {
                if (stringEntity == _nilStringEntity)
                {
                    return _nilSweepableEntity;
                }
 
                return new EstimatorEntity(lookupTable[stringEntity.Value]);
            }
            else if (entity is ConcatenateEntity concatenateEntity)
            {
                return new ConcatenateEntity()
                {
                    Left = CreateSweepableEntityFromEntity(concatenateEntity.Left, lookupTable),
                    Right = CreateSweepableEntityFromEntity(concatenateEntity.Right, lookupTable),
                };
            }
            else if (entity is OneOfEntity oneOfEntity)
            {
                return new OneOfEntity()
                {
                    Left = CreateSweepableEntityFromEntity(oneOfEntity.Left, lookupTable),
                    Right = CreateSweepableEntityFromEntity(oneOfEntity.Right, lookupTable),
                };
            }
 
            throw new ArgumentException();
        }
 
        private Entity VisitAndReplaceSweepableEntityWithStringEntity(Entity e, ref Dictionary<string, SweepableEstimator> estimators)
        {
            if (e is null)
            {
                return null;
            }
 
            if (e is EstimatorEntity sweepableEntity0)
            {
                if (sweepableEntity0 == _nilSweepableEntity)
                {
                    return _nilStringEntity;
                }
 
                var id = GetNextId(estimators);
                estimators[id] = (SweepableEstimator)sweepableEntity0.Estimator;
                return new StringEntity(id);
            }
 
            e.Left = VisitAndReplaceSweepableEntityWithStringEntity(e.Left, ref estimators);
            e.Right = VisitAndReplaceSweepableEntityWithStringEntity(e.Right, ref estimators);
 
            return e;
        }
 
        private string GetNextId(Dictionary<string, SweepableEstimator> estimators)
        {
            var count = estimators.Count();
            return "e" + count.ToString();
        }
    }
}