File: Commands\Workload\Restore\WorkloadRestoreCommand.cs
Web Access
Project: ..\..\..\src\Cli\dotnet\dotnet.csproj (dotnet)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#nullable disable
 
using System.CommandLine;
using Microsoft.Build.Execution;
using Microsoft.Build.Logging;
using Microsoft.DotNet.Cli.Commands.Restore;
using Microsoft.DotNet.Cli.Commands.Workload.Install;
using Microsoft.DotNet.Cli.Commands.Workload.Update;
using Microsoft.DotNet.Cli.Extensions;
using Microsoft.DotNet.Cli.Utils;
using Microsoft.NET.Sdk.WorkloadManifestReader;
 
namespace Microsoft.DotNet.Cli.Commands.Workload.Restore;
 
internal class WorkloadRestoreCommand(
    ParseResult result,
    IReporter reporter = null) : WorkloadCommandBase(result, reporter: reporter)
{
    private readonly ParseResult _result = result;
    private readonly IEnumerable<string> _slnOrProjectArgument =
            result.GetValue(WorkloadRestoreCommandParser.SlnOrProjectArgument);
 
    public override int Execute()
    {
        var workloadResolverFactory = new WorkloadResolverFactory();
        var creationResult = workloadResolverFactory.Create();
        var workloadInstaller = WorkloadInstallerFactory.GetWorkloadInstaller(NullReporter.Instance, new SdkFeatureBand(creationResult.SdkVersion),
                                    creationResult.WorkloadResolver, Verbosity, creationResult.UserProfileDir, VerifySignatures, PackageDownloader,
                                    creationResult.DotnetPath, TempDirectoryPath, null, RestoreActionConfiguration, elevationRequired: true);
        var recorder = new WorkloadHistoryRecorder(
                           creationResult.WorkloadResolver,
                           workloadInstaller,
                           () => workloadResolverFactory.CreateForWorkloadSet(
                               creationResult.DotnetPath,
                               creationResult.SdkVersion.ToString(),
                               creationResult.UserProfileDir,
                               null));
        recorder.HistoryRecord.CommandName = "restore";
 
        recorder.Run(() =>
        {
            // First discover projects. This may return an error if no projects are found, and we shouldn't delay until after Update if that's the case.
            var allProjects = DiscoverAllProjects(Directory.GetCurrentDirectory(), _slnOrProjectArgument).Distinct();
 
            // Then update manifests and install a workload set as necessary
            new WorkloadUpdateCommand(_result, recorder: recorder, isRestoring: true).Execute();
 
            List<WorkloadId> allWorkloadId = RunTargetToGetWorkloadIds(allProjects);
            Reporter.WriteLine(string.Format(CliCommandStrings.InstallingWorkloads, string.Join(" ", allWorkloadId)));
 
            new WorkloadInstallCommand(_result,
                workloadIds: allWorkloadId.Select(a => a.ToString()).ToList().AsReadOnly(),
                skipWorkloadManifestUpdate: true)
            {
                IsRunningRestore = true
            }.Execute();
        });
 
        workloadInstaller.Shutdown();
        
        return 0;
    }
 
    private static readonly string GetRequiredWorkloadsTargetName = "_GetRequiredWorkloads";
 
    private List<WorkloadId> RunTargetToGetWorkloadIds(IEnumerable<string> allProjects)
    {
        var globalProperties = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase)
        {
            {"SkipResolvePackageAssets", "true"}
        };
 
        var allWorkloadId = new List<WorkloadId>();
        foreach (string projectFile in allProjects)
        {
            var project = new ProjectInstance(projectFile, globalProperties, null);
            if (!project.Targets.ContainsKey(GetRequiredWorkloadsTargetName))
            {
                continue;
            }
 
            bool buildResult = project.Build([GetRequiredWorkloadsTargetName],
                loggers: [
                    new ConsoleLogger(Verbosity.ToLoggerVerbosity())
                ],
                remoteLoggers: [],
                targetOutputs: out var targetOutputs);
 
            if (buildResult == false)
            {
                throw new GracefulException(
                    string.Format(
                        CliCommandStrings.FailedToRunTarget,
                        projectFile),
                    isUserError: false);
            }
 
            var targetResult = targetOutputs[GetRequiredWorkloadsTargetName];
            allWorkloadId.AddRange(targetResult.Items.Select(item => new WorkloadId(item.ItemSpec)));
        }
 
        allWorkloadId = [.. allWorkloadId.Distinct()];
        return allWorkloadId;
    }
 
 
    internal static List<string> DiscoverAllProjects(string currentDirectory,
        IEnumerable<string> slnOrProjectArgument = null)
    {
        var slnFiles = new List<string>();
        var projectFiles = new List<string>();
        if (slnOrProjectArgument == null || !slnOrProjectArgument.Any())
        {
            slnFiles = [.. SlnFileFactory.ListSolutionFilesInDirectory(currentDirectory, false)];
            projectFiles.AddRange(Directory.GetFiles(currentDirectory, "*.*proj"));
        }
        else
        {
            slnFiles = [.. slnOrProjectArgument
                .Where(s => Path.GetExtension(s).Equals(".sln", StringComparison.OrdinalIgnoreCase) || Path.GetExtension(s).Equals(".slnx", StringComparison.OrdinalIgnoreCase))
                .Select(Path.GetFullPath)];
            projectFiles = [.. slnOrProjectArgument
                .Where(s => Path.GetExtension(s).EndsWith("proj", StringComparison.OrdinalIgnoreCase))
                .Select(Path.GetFullPath)];
        }
 
        foreach (string solutionFilePath in slnFiles)
        {
            var solutionFile = SlnFileFactory.CreateFromFileOrDirectory(solutionFilePath);
            projectFiles.AddRange(solutionFile.SolutionProjects.Select(
                p => Path.GetFullPath(p.FilePath, Path.GetDirectoryName(solutionFilePath))));
        }
 
        if (projectFiles.Count == 0)
        {
            throw new GracefulException(
                CliCommandStrings.CouldNotFindAProject,
                currentDirectory, "--project");
        }
 
        return projectFiles;
    }
}