File: Utils\ThreadTaskManager.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    internal static class ThreadTaskManager
    {
        private static readonly object _lockObject = new object();
 
        // REVIEW: Should this bother with number of threads? What should it do?
        public static int NumThreads { get; private set; }
 
        public static void Initialize(int numThreads)
        {
            lock (_lockObject)
            {
                if (NumThreads == 0)
                {
                    Contracts.Assert(numThreads > 0);
                    Contracts.Assert(NumThreads == 0);
                    NumThreads = numThreads;
                    BlockingThreadPool.Initialize(numThreads);
                }
            }
        }
 
        private class ThreadTask : IThreadTask
        {
            private readonly int _num;
            private readonly Action[] _actions;
 
            public ThreadTask(int num, IEnumerable<Action> actions)
            {
                _actions = actions.ToArray();
                _num = num;
            }
 
            public void RunTask()
            {
                // Special case one thread. We treat all other values the same - let the Task
                // library determine the best plan.
                if (_num == 1)
                {
                    for (int i = 0; i < _actions.Length; i++)
                    {
                        // REVIEW: Should this simply invoke the action on this thread?
                        var task = Task.Run(_actions[i]);
                        task.Wait();
                    }
                }
                else
                    Parallel.Invoke(new ParallelOptions() { MaxDegreeOfParallelism = _num }, _actions);
            }
        }
 
        /// <summary>
        /// Makes a new task using the subtasks
        /// </summary>
        /// <param name="subTasks">subtasks composing the task</param>
        /// <returns>An IThreadTask to run the tasks</returns>
        public static IThreadTask MakeTask(IEnumerable<Action> subTasks)
        {
            return new ThreadTask(NumThreads, subTasks);
        }
 
        /// <summary>
        /// Makes a new task from the supplied action that takes an integer argument, from 0...max
        /// </summary>
        /// <param name="subTaskAction">Action to run</param>
        /// <param name="maxArgument">The max range of the argument</param>
        /// <returns>A task that runs the action using each value of the argument from 0...max</returns>
        public static IThreadTask MakeTask(Action<int> subTaskAction, int maxArgument)
        {
            IEnumerable<Action> subTasks =
                Enumerable.Range(0, maxArgument)
                .Select<int, Action>(arg => (() => subTaskAction(arg)));
            return MakeTask(subTasks);
        }
    }
 
    /// <summary>
    /// Interface for a decomposable task that runs on many threads
    /// </summary>
    internal interface IThreadTask
    {
        void RunTask();
    }
}