File: Commands\Package\Update\PackageUpdateIO.cs
Web Access
Project: src\src\nuget-client\src\NuGet.Core\NuGet.CommandLine.XPlat\NuGet.CommandLine.XPlat.csproj (NuGet.CommandLine.XPlat)
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

#nullable enable

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using NuGet.CommandLine.XPlat.Utility;
using NuGet.Commands;
using NuGet.Common;
using NuGet.Configuration;
using NuGet.Packaging.Core;
using NuGet.ProjectModel;
using NuGet.Protocol;
using NuGet.Protocol.Core.Types;
using NuGet.Protocol.Model;
using NuGet.Shared;
using NuGet.Versioning;
using static NuGet.CommandLine.XPlat.Commands.Package.Update.PackageUpdateCommandRunner;

namespace NuGet.CommandLine.XPlat.Commands.Package.Update;

/// <summary>
/// Implementation of IPackageUpdateIO that handles package updates by performing restore operations.
/// </summary>
internal class PackageUpdateIO : IPackageUpdateIO, IDisposable
{
    private readonly MSBuildAPIUtility _msbuildUtility;
    private readonly IEnvironmentVariableReader _environmentVariableReader;
    private readonly ISettings _settings;
    private readonly IPackageSourceProvider _sourceProvider;
    private readonly CachingSourceProvider _cachingSourceProvider;
    private readonly IReadOnlyList<PackageSource> _enabledSources;
    private readonly SourceCacheContext _sourceCacheContext;

    public PackageUpdateIO(string solutionDirectory, MSBuildAPIUtility msbuildUtility, IEnvironmentVariableReader environmentVariableReader)
    {
        _msbuildUtility = msbuildUtility;
        _environmentVariableReader = environmentVariableReader;

        // the CommandLine option validates that an existing filesystem object is provided, so we can be confident that
        // we either have a directory or a file here.
        string settingsRoot = Directory.Exists(solutionDirectory) ? solutionDirectory : Path.GetDirectoryName(solutionDirectory)!;
        _settings = Settings.LoadDefaultSettings(solutionDirectory);

        _sourceProvider = new PackageSourceProvider(_settings);
        _cachingSourceProvider = new CachingSourceProvider(_sourceProvider);
        _enabledSources = SettingsUtility.GetEnabledSources(_settings).AsList();
        _sourceCacheContext = new SourceCacheContext();
    }

    public void Dispose()
    {
        _sourceCacheContext.Dispose();
        GC.SuppressFinalize(this);
    }

    /// <inheritdoc cref="IPackageUpdateIO.GetDependencyGraphSpec(string)"/>
    public DependencyGraphSpec? GetDependencyGraphSpec(string project)
    {
        string tempFile = Path.GetTempFileName();
        try
        {
            if (!RunMsbuildTarget(project, tempFile))
            {
                return null;
            }

            DependencyGraphSpec result = DependencyGraphSpec.Load(tempFile);

            // Fixup virtual project paths.
            if (_msbuildUtility.VirtualProjectBuilder?.GetVirtualProjectPath(project) is { } virtualProjectPath)
            {
                foreach (var packageSpec in result.Projects)
                {
                    if (packageSpec.FilePath == virtualProjectPath)
                    {
                        packageSpec.FilePath = project;
                    }
                }
            }

            return result;
        }
        finally
        {
            File.Delete(tempFile);
        }

        bool RunMsbuildTarget(string project, string tempFile)
        {
            // When being run from the dotnet CLI, use the same dotnet executable, just in case the dotnet on the PATH is different
            // But when NuGet.CommandLine.XPlat is being called directly, call dotnet on the path, so this code is debuggable.
            string dotnetPath = _environmentVariableReader.GetEnvironmentVariable("DOTNET_HOST_PATH") ?? "dotnet";

            bool isFileBasedApp = _msbuildUtility.VirtualProjectBuilder?.IsValidEntryPointPath(project) == true;

            // don't redirect stdout or stderr, so errors are output. But use quiet verbosity, so that success has no output.
            ProcessStartInfo processStartInfo = new ProcessStartInfo(dotnetPath)
            {
                Arguments = (isFileBasedApp ? "build " : "msbuild ") +
                $"\"{project}\" " +
                (isFileBasedApp ? "--no-restore " : "-restore:false ") +
                "-target:GenerateRestoreGraphFile " +
                $"-property:RestoreGraphOutputPath=\"{tempFile}\" " +
                "-property:RestoreRecursive=false " +
                "-nologo " +
                "-verbosity:quiet " +
                (!isFileBasedApp ? $"-noautoresponse" : null), // currently not supported for file-based apps
                UseShellExecute = false,
                Environment =
                {
                    { "MSBUILDTERMINALLOGGER", "off" },
                },
            };

            using var process = Process.Start(processStartInfo);
            if (process is null) throw new System.Exception("Unexpected error starting child process. Process.Start returned null.");
            process.WaitForExit();

            return process.ExitCode == 0;
        }
    }

    /// <inheritdoc cref="IPackageUpdateIO.PreviewUpdatePackageReferenceAsync(DependencyGraphSpec, ILogger, CancellationToken)"/>
    public async Task<IPackageUpdateIO.RestoreResult> PreviewUpdatePackageReferenceAsync(
        DependencyGraphSpec dgSpec,
        ILogger logger,
        CancellationToken cancellationToken)
    {
        var providerCache = new RestoreCommandProvidersCache();

        // Restore outputs a lot of messages at normal verbosity, which update doesn't want.
        var restoreLogger = new RemappedLevelLogger(
            logger,
            new RemappedLevelLogger.Mapping
            {
                Information = LogLevel.Verbose,
                Minimal = LogLevel.Verbose,
            });

        // Pre-loaded request provider containing the graph file
        var providers = new List<IPreLoadedRestoreRequestProvider>
            {
                new DependencyGraphSpecRequestProvider(providerCache, dgSpec)
            };


        var restoreContext = new RestoreArgs()
        {
            CacheContext = _sourceCacheContext,
            Log = restoreLogger,
            MachineWideSettings = new XPlatMachineWideSetting(),
            PreLoadedRequestProviders = providers
            // Sources : No need to pass it, because SourceRepositories contains the already built SourceRepository objects
        };

        var restoreRequests = await RestoreRunner.GetRequests(restoreContext);
        var restoreResult = await RestoreRunner.RunWithoutCommitAsync(restoreRequests, restoreContext, cancellationToken);

        var result = new RestoreResult
        {
            RestoreResultPairs = restoreResult
        };
        return result;
    }

    /// <inheritdoc cref="IPackageUpdateIO.CommitAsync(IPackageUpdateIO.RestoreResult, CancellationToken)"/>
    public async Task CommitAsync(IPackageUpdateIO.RestoreResult restorePreviewResult, CancellationToken none)
    {
        var restoreResult = (RestoreResult)restorePreviewResult;
        foreach (var restoreResultPair in restoreResult.RestoreResultPairs)
        {
            await RestoreRunner.CommitAsync(restoreResultPair, CancellationToken.None);
        }
    }

    /// <inheritdoc cref="IPackageUpdateIO.UpdatePackageReference(PackageSpec, IPackageUpdateIO.RestoreResult, List{string}, PackageToUpdate, ILogger)"/>
    public void UpdatePackageReference(PackageSpec updatedPackageSpec, IPackageUpdateIO.RestoreResult restorePreviewResult, List<string> packageTfmAliases, PackageToUpdate packageToUpdate, ILogger logger)
    {
        PackageDependency packageDependency = new PackageDependency(packageToUpdate.Id, packageToUpdate.NewVersion);

        var restoreResult = (RestoreResult)restorePreviewResult;
        var restoreResultPair = restoreResult.RestoreResultPairs.Single(pair =>
            string.Equals(pair.SummaryRequest.Request.Project.FilePath, updatedPackageSpec.FilePath, StringComparison.OrdinalIgnoreCase));

        if (!AddPackageReferenceCommandRunner.TryFindResolvedVersion(packageTfmAliases,
            packageDependency.Id,
            restoreResultPair.Result,
            logger,
            out NuGetVersion resolvedVersion))
        {
            return;
        }

        // Generate the LibraryDependency using the same logic as AddPackageReferenceCommandRunner
        var libraryDependency = AddPackageReferenceCommandRunner.GenerateLibraryDependency(
            updatedPackageSpec,
            customPackagesPath: null,
            packageDependency,
            resolvedVersion);

        // MSBuildUtility only updated CPM Directory.Packages.props when "noVersion" is false.
        const bool noVersion = false;

        // Determine whether to add package reference conditionally or unconditionally
        if (packageTfmAliases.Count == updatedPackageSpec.TargetFrameworks.Count)
        {
            // package is used by all project TFMs (no condition)
            _msbuildUtility.AddPackageReference(updatedPackageSpec.FilePath, libraryDependency, noVersion);
        }
        else
        {
            _msbuildUtility.AddPackageReferencePerTFM(updatedPackageSpec.FilePath, libraryDependency, packageTfmAliases, noVersion);
        }
    }

    /// <inheritdoc cref="IPackageUpdateIO.GetLatestVersionAsync(string, bool, IReadOnlyList{string}?, ILogger, CancellationToken)"/>
    public async Task<NuGetVersion?> GetLatestVersionAsync(
        string packageId,
        bool includePrerelease,
        IReadOnlyList<string>? allowedSources,
        ILogger logger,
        CancellationToken cancellationToken)
    {
        var sources = GetSourcesForPackage(packageId, allowedSources);
        var lookups = new Task<NuGetVersion?>[sources.Count];
        for (int source = 0; source < sources.Count; source++)
        {
            SourceRepository sourceRepository = sources[source];
            // If package source is a local folder feed, it might not actually be async
            lookups[source] = Task.Run(() => FindHighestPackageVersionAsync(sourceRepository, packageId, includePrerelease, logger, cancellationToken));
        }

        await Task.WhenAll(lookups);

        NuGetVersion? highestVersion = null;
        foreach (var task in lookups)
        {
            if (task.Result != null)
            {
                if (highestVersion == null || task.Result > highestVersion)
                {
                    highestVersion = task.Result;
                }
            }
        }

        return highestVersion;
    }

    /// <inheritdoc cref="IPackageUpdateIO.GetKnownVulnerabilitiesAsync(ILogger, CancellationToken)"/>
    public async Task<IReadOnlyList<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>> GetKnownVulnerabilitiesAsync(ILogger logger, CancellationToken cancellationToken)
    {
        IReadOnlyList<PackageSource>? auditSources = _sourceProvider.LoadAuditSources()?.Where(s => s.IsEnabled).ToList();
        if (auditSources is null || auditSources.Count == 0)
        {
            auditSources = _enabledSources;
        }

        var tasks = new List<Task<GetVulnerabilityInfoResult?>>(auditSources.Count);
        foreach (var auditSource in auditSources)
        {
            tasks.Add(Task.Run(async () =>
            {
                var sourceRepository = Repository.Factory.GetCoreV3(auditSource.Source);
                var vulnerabilityResource = await sourceRepository.GetResourceAsync<IVulnerabilityInfoResource>(cancellationToken);
                if (vulnerabilityResource is not null)
                {
                    var vulnerabilities = await vulnerabilityResource.GetVulnerabilityInfoAsync(_sourceCacheContext, logger, cancellationToken);
                    return vulnerabilities;
                }
                return null;
            }, cancellationToken));
        }

        List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> allVulnerabilities = new();
        foreach (var task in tasks)
        {
            var result = await task;
            if (result is not null)
            {
                if (result.KnownVulnerabilities?.Count > 0)
                {
                    foreach (var vulnDict in result.KnownVulnerabilities)
                    {
                        allVulnerabilities.Add(vulnDict);
                    }
                }
            }
        }

        return allVulnerabilities;
    }

    /// <inheritdoc cref="IPackageUpdateIO.GetNonVulnerableAsync(string, IReadOnlyList{string}?, NuGetVersion, ILogger, IReadOnlyList{IReadOnlyDictionary{string, IReadOnlyList{PackageVulnerabilityInfo}}}, CancellationToken)"/>
    public async Task<NuGetVersion?> GetNonVulnerableAsync(
        string packageId,
        IReadOnlyList<string>? allowedSources,
        NuGetVersion minVersion,
        ILogger logger,
        IReadOnlyList<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> knownVulnerabilities,
        CancellationToken cancellationToken)
    {
        var sources = GetSourcesForPackage(packageId, allowedSources);
        var lookups = new Task<NuGetVersion?>[sources.Count];
        for (int source = 0; source < sources.Count; source++)
        {
            SourceRepository sourceRepository = sources[source];
            // If package source is a local folder feed, it might not actually be async
            lookups[source] = Task.Run(() => FindLowestNonVulnerablePackageVersionAsync(sourceRepository, packageId, minVersion, knownVulnerabilities, logger, cancellationToken));
        }

        await Task.WhenAll(lookups);

        NuGetVersion? lowestNonVulnerableVersion = null;
        foreach (var task in lookups)
        {
            if (task.Result != null)
            {
                if (lowestNonVulnerableVersion == null || task.Result < lowestNonVulnerableVersion)
                {
                    lowestNonVulnerableVersion = task.Result;
                }
            }
        }

        return lowestNonVulnerableVersion;
    }

    public PackageSourceMapping GetPackageSourceMapping()
    {
        return PackageSourceMapping.GetPackageSourceMapping(_settings);
    }

    private List<SourceRepository> GetSourcesForPackage(string packageId, IReadOnlyList<string>? allowedSources)
    {
        IReadOnlyList<PackageSource> packageSources;

        // Apply package source mapping if enabled
        if (allowedSources is not null)
        {
            if (allowedSources.Count == 0)
            {
                throw new ArgumentException("The allowedSources list must contain at least one source if specified.", nameof(allowedSources));
            }

            List<PackageSource> sourceMappedSources = new List<PackageSource>(allowedSources.Count);
            sourceMappedSources.AddRange(_enabledSources.Where(ps => allowedSources.Contains(ps.Name, StringComparer.OrdinalIgnoreCase)));
            packageSources = sourceMappedSources;
        }
        else
        {
            packageSources = _enabledSources;
        }

        var sources = new List<SourceRepository>(packageSources.Count);
        for (int i = 0; i < packageSources.Count; i++)
        {
            SourceRepository sourceRepository = _cachingSourceProvider.CreateRepository(packageSources[i]);
            sources.Add(sourceRepository);
        }
        return sources;
    }

    private async Task<NuGetVersion?>? FindLowestNonVulnerablePackageVersionAsync(
        SourceRepository source,
        string packageId,
        NuGetVersion minVersion,
        IReadOnlyList<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> knownVulnerabilities,
        ILogger logger,
        CancellationToken cancellationToken)
    {
        var packageMetadataResource = await source.GetResourceAsync<PackageMetadataResource>(cancellationToken);

        var packageDetails = await packageMetadataResource.GetMetadataAsync(
            packageId,
            includePrerelease: false,
            includeUnlisted: false,
            _sourceCacheContext,
            logger,
            cancellationToken);

        if (packageDetails is null || !packageDetails.Any())
        {
            return null;
        }

        var versions = packageDetails
            .Select(p => p.Identity)
            .Where(p => p.Version >= minVersion && !PackageHasKnownVulnerability(p))
            .Select(p => p.Version);

        VersionRange versionRange = new VersionRange(minVersion, includeMinVersion: true, maxVersion: null, includeMaxVersion: true);
        NuGetVersion? result = versionRange.FindBestMatch(versions);

        return result;

        bool PackageHasKnownVulnerability(PackageIdentity package)
        {
            foreach (var sourceVulnerabilities in knownVulnerabilities)
            {
                if (sourceVulnerabilities.TryGetValue(packageId, out var vulnerabilities))
                {
                    foreach (var vulnerability in vulnerabilities)
                    {
                        if (vulnerability.Versions.Satisfies(package.Version))
                        {
                            return true;
                        }
                    }
                }
            }
            return false;
        }
    }

    private async Task<NuGetVersion?> FindHighestPackageVersionAsync(
        SourceRepository source,
        string packageId,
        bool includePrerelease,
        ILogger logger,
        CancellationToken cancellationToken)
    {
        var packageMetadataResource = await source.GetResourceAsync<PackageMetadataResource>(cancellationToken);

        var packageDetails = await packageMetadataResource.GetMetadataAsync(
            packageId,
            includePrerelease: includePrerelease,
            includeUnlisted: false,
            _sourceCacheContext,
            logger,
            cancellationToken);

        if (packageDetails is null || !packageDetails.Any())
        {
            return null;
        }

        NuGetVersion highestVersion = packageDetails.Max(p => p.Identity.Version)!;
        return highestVersion;
    }

    /// <inheritdoc cref="IPackageUpdateIO.GetProjectAssetsFileAsync(DependencyGraphSpec, string, ILogger, CancellationToken)"/>
    public async Task<LockFile> GetProjectAssetsFileAsync(
        DependencyGraphSpec dgSpec,
        string projectPath,
        ILogger logger,
        CancellationToken cancellationToken)
    {
        var previewRestoreResult = (RestoreResult)await PreviewUpdatePackageReferenceAsync(dgSpec, NullLogger.Instance, cancellationToken);
        if (!previewRestoreResult.Success)
        {
            logger.LogError("Restore failed");
            throw new NotSupportedException();
        }

        var restoreResultPair = previewRestoreResult.RestoreResultPairs.Single(pair =>
            string.Equals(pair.SummaryRequest.Request.Project.FilePath, projectPath, StringComparison.OrdinalIgnoreCase));
        LockFile? assetsFile = restoreResultPair.Result.LockFile;
        if (assetsFile is null)
        {
            var packageSpec = dgSpec.GetProjectSpec(projectPath);
            var assetsFilePath = Path.Combine(packageSpec.RestoreMetadata.OutputPath, LockFileFormat.AssetsFileName);
            assetsFile = new LockFileFormat().Read(assetsFilePath);
        }

        return assetsFile;
    }

    internal class RestoreResult : IPackageUpdateIO.RestoreResult
    {
        internal required IReadOnlyList<RestoreResultPair> RestoreResultPairs { get; init; }

        public override bool Success => RestoreResultPairs.All(pair => pair.Result.Success);
    }
}