File: Utils\TorchUtils.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime;
using TorchSharp;
using static TorchSharp.torch;
 
namespace Microsoft.ML.TorchSharp.Utils
{
    internal static class TorchUtils
    {
        public static void DisposeDictionaryWithTensor<TKey, TResult>(Dictionary<TKey, TResult> dictionary)
        {
            if (dictionary == null)
                return;
 
            foreach (var kvp in dictionary)
            {
                if (kvp.Value is torch.Tensor tensor)
                    tensor.Dispose();
 
                else if (kvp.Value is Dictionary<dynamic, dynamic> subDictionary)
                    DisposeDictionaryWithTensor(subDictionary);
 
                if (kvp.Key is torch.Tensor keyTensor)
                    keyTensor.Dispose();
 
                else if (kvp.Key is Dictionary<dynamic, dynamic> subDictionary)
                    DisposeDictionaryWithTensor(subDictionary);
            }
        }
 
        public static torch.Device InitializeDevice(IHostEnvironment env)
        {
            var device = ((IHostEnvironmentInternal)env).GpuDeviceId != null && cuda.is_available() ? CUDA : CPU;
            if (((IHostEnvironmentInternal)env).FallbackToCpu == false && device == CPU && ((IHostEnvironmentInternal)env).GpuDeviceId != null)
                throw new Exception("Fallback to CPU is false but no GPU detected");
 
            return device;
        }
    }
}