File: AutoMLExperiment\ITrialResultManager.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.AutoML
{
    internal interface ITrialResultManager
    {
        IEnumerable<TrialResult> GetAllTrialResults();
 
        void AddOrUpdateTrialResult(TrialResult result);
 
        void Save();
    }
 
    /// <summary>
    /// trialResult Manager that saves and loads trial result as csv format.
    /// </summary>
    internal class CsvTrialResultManager : ITrialResultManager
    {
        private readonly string _filePath;
        private readonly IChannel _channel;
        private readonly HashSet<TrialResult> _trialResultsHistory;
        private readonly SearchSpace.SearchSpace _searchSpace;
        private readonly DataViewSchema _schema;
        private readonly HashSet<TrialResult> _newTrialResults;
        public CsvTrialResultManager(string filePath, SearchSpace.SearchSpace searchSpace, IChannel channel = null)
        {
            _filePath = filePath;
            _channel = channel;
            _searchSpace = searchSpace;
            var schemaBuilder = new DataViewSchema.Builder();
            schemaBuilder.AddColumn("id", NumberDataViewType.Int32);
            schemaBuilder.AddColumn("loss", NumberDataViewType.Single);
            schemaBuilder.AddColumn("durationInMilliseconds", NumberDataViewType.Single);
            schemaBuilder.AddColumn("peakCpu", NumberDataViewType.Single);
            schemaBuilder.AddColumn("peakMemoryInMegaByte", NumberDataViewType.Single);
            schemaBuilder.AddColumn("parameter", new VectorDataViewType(NumberDataViewType.Double));
            _schema = schemaBuilder.ToSchema();
 
            // load from csv file.
            var trialResults = LoadFromCsvFile(filePath);
            _trialResultsHistory = new HashSet<TrialResult>(trialResults, new TrialResult());
            _newTrialResults = new HashSet<TrialResult>(new TrialResult());
        }
 
        public void AddOrUpdateTrialResult(TrialResult result)
        {
            if (_trialResultsHistory.Contains(result))
            {
                throw new ArgumentException("can't add or update result that already save to csv");
            }
            _newTrialResults.Remove(result);
            _newTrialResults.Add(result);
        }
 
        public IEnumerable<TrialResult> GetAllTrialResults()
        {
            return _trialResultsHistory.Concat(_newTrialResults);
        }
 
        /// <summary>
        /// save trial result to csv. This will not overwrite any existing records that already written in csv.
        /// </summary>
        public void Save()
        {
            // header (type)
            // | id (int) | loss (float) | durationInMilliseconds (float) | peakCpu (float) | peakMemoryInMegaByte (float) | parameter_i (float) |
            using (var fileStream = new FileStream(_filePath, FileMode.Append, FileAccess.Write))
            using (var writeStream = new StreamWriter(fileStream))
            {
                var sep = ",";
 
                if (_trialResultsHistory.Count == 0)
                {
                    // write header
                    var header = new string[]
                    {
                        "id",
                        "loss",
                        "durationInMilliseconds",
                        "peakCpu",
                        "peakMemoryInMegaByte"
                    }.Concat(Enumerable.Range(0, _searchSpace.FeatureSpaceDim).Select(i => $"parameter_{i}"));
                    writeStream.WriteLine(string.Join(sep, header));
                }
 
                foreach (var trialResult in _newTrialResults.OrderBy(res => res.TrialSettings.TrialId))
                {
                    var parameter = _searchSpace.MappingToFeatureSpace(trialResult.TrialSettings.Parameter);
                    var csvLine = string.Join(
                        sep,
                        new string[]
                        {
                            trialResult.TrialSettings.TrialId.ToString(CultureInfo.InvariantCulture),
                            trialResult.Loss.ToString("F3", CultureInfo.InvariantCulture),
                            trialResult.DurationInMilliseconds.ToString("F3", CultureInfo.InvariantCulture),
                            trialResult.PeakCpu?.ToString("F3", CultureInfo.InvariantCulture),
                            trialResult.PeakMemoryInMegaByte?.ToString("F3", CultureInfo.InvariantCulture),
                        }.Concat(parameter.Select(p => p.ToString("F3", CultureInfo.InvariantCulture))));
                    writeStream.WriteLine(csvLine);
                }
 
                writeStream.Flush();
                writeStream.Close();
            }
 
            foreach (var result in _newTrialResults)
            {
                _trialResultsHistory.Add(result);
            }
 
            _newTrialResults.Clear();
        }
 
        private IEnumerable<TrialResult> LoadFromCsvFile(string filePath)
        {
            if (!File.Exists(filePath))
            {
                return Array.Empty<TrialResult>();
            }
 
            // header (type)
            // | id (int) | loss (float) | durationInMilliseconds (float) | peakCpu (float) | peakMemoryInMegaByte (float) | parameter_i (float) |
            var context = new MLContext();
            var textLoaderColumns = new TextLoader.Column[]
            {
                new TextLoader.Column("id", DataKind.Int32, 0),
                new TextLoader.Column("loss", DataKind.Single, 1),
                new TextLoader.Column("durationInMilliseconds", DataKind.Single, 2),
                new TextLoader.Column("peakCpu", DataKind.Single, 3),
                new TextLoader.Column("peakMemoryInMegaByte", DataKind.Single, 4),
                new TextLoader.Column("parameter", DataKind.Double, 5, 5 + _searchSpace.FeatureSpaceDim - 1),
            };
            var res = new List<TrialResult>();
            var dataView = context.Data.LoadFromTextFile(filePath, textLoaderColumns, separatorChar: ',', hasHeader: true, allowQuoting: true);
            var rowCursor = dataView.GetRowCursor(_schema);
 
            var idGetter = rowCursor.GetGetter<int>(_schema["id"]);
            var lossGetter = rowCursor.GetGetter<float>(_schema["loss"]);
            var durationGetter = rowCursor.GetGetter<float>(_schema["durationInMilliseconds"]);
            var peakCpuGetter = rowCursor.GetGetter<float>(_schema["peakCpu"]);
            var peakMemoryGetter = rowCursor.GetGetter<float>(_schema["peakMemoryInMegaByte"]);
            var parameterGetter = rowCursor.GetGetter<VBuffer<double>>(_schema["parameter"]);
 
            while (rowCursor.MoveNext())
            {
                int id = 0;
                float loss = 0;
                float duration = 0;
                float peakCpu = 0;
                float peakMemory = 0;
                VBuffer<double> parameter = default;
 
                idGetter(ref id);
                lossGetter(ref loss);
                durationGetter(ref duration);
                peakCpuGetter(ref peakCpu);
                peakMemoryGetter(ref peakMemory);
                parameterGetter(ref parameter);
                var feature = parameter.DenseValues().ToArray();
                var trialResult = new TrialResult
                {
                    TrialSettings = new TrialSettings
                    {
                        TrialId = id,
                        Parameter = _searchSpace.SampleFromFeatureSpace(feature),
                    },
                    DurationInMilliseconds = duration,
                    Loss = loss,
                    PeakCpu = peakCpu,
                    PeakMemoryInMegaByte = peakMemory,
                };
 
                res.Add(trialResult);
            }
 
            _channel?.Trace($"load trial history from {filePath} successfully with {res.Count()} pieces of data");
            return res;
        }
    }
}