File: Commands\Package\Download\PackageDownloadRunner.cs
Web Access
Project: src\src\nuget-client\src\NuGet.Core\NuGet.CommandLine.XPlat\NuGet.CommandLine.XPlat.csproj (NuGet.CommandLine.XPlat)
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

#nullable enable

using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using NuGet.CommandLine.XPlat.Utility;
using NuGet.Commands;
using NuGet.Configuration;
using NuGet.Credentials;
using NuGet.Packaging;
using NuGet.Packaging.Core;
using NuGet.Packaging.PackageExtraction;
using NuGet.Packaging.Signing;
using NuGet.Protocol;
using NuGet.Protocol.Core.Types;
using NuGet.Repositories;
using NuGet.Versioning;

namespace NuGet.CommandLine.XPlat.Commands.Package.PackageDownload
{
    internal static class PackageDownloadRunner
    {
        internal const int ExitCodeError = 1;
        internal const int ExitCodeSuccess = 0;

        public static async Task<int> RunAsync(PackageDownloadArgs args, CancellationToken token)
        {
            ILoggerWithColor logger = new CommandOutputLogger(args.LogLevel)
            {
                HidePrefixForInfoAndMinimal = true
            };

            XPlatUtility.ConfigureProtocol();
            DefaultCredentialServiceUtility.SetupDefaultCredentialService(logger, !args.Interactive);
            ISettings settings = Settings.LoadDefaultSettings(
                Directory.GetCurrentDirectory(),
                args.ConfigFile,
                new XPlatMachineWideSetting());
            IReadOnlyList<PackageSource> packageSources = GetPackageSources(args.Sources, new PackageSourceProvider(settings));

            return await RunAsync(args, logger, packageSources, settings, token);
        }

        public static async Task<int> RunAsync(PackageDownloadArgs args, ILoggerWithColor logger, IReadOnlyList<PackageSource> packageSources, ISettings settings, CancellationToken token)
        {
            bool hasSourcesArg = args.Sources?.Count > 0;
            PackageSourceMapping? packageSourceMapping = null;
            if (!hasSourcesArg)
            {
                packageSourceMapping = PackageSourceMapping.GetPackageSourceMapping(settings);
            }

            bool ignorePackageSourceMapping =
                hasSourcesArg
                || packageSourceMapping is null
                || !packageSourceMapping.IsEnabled;

            // When package source mapping is disabled, validate all configured sources upfront.
            // When mapping is enabled, source validation is deferred to the per-package resolution step,
            // since each package may map to a different subset of sources.
            if (ignorePackageSourceMapping && DetectAndReportInsecureSources(args.AllowInsecureConnections, packageSources, logger))
            {
                return ExitCodeError;
            }

            string outputDirectory = args.OutputDirectory ?? Directory.GetCurrentDirectory();
            var cache = new SourceCacheContext();
            IReadOnlyList<SourceRepository> allRepositories = GetSourceRepositories(packageSources);
            bool downloadedAllSuccessfully = true;

            foreach (var package in args.Packages ?? [])
            {
                logger.LogMinimal(string.Format(
                    CultureInfo.CurrentCulture,
                    Strings.PackageDownloadCommand_Starting,
                    package.Id,
                    string.IsNullOrEmpty(package.NuGetVersion?.ToNormalizedString()) ? Strings.PackageDownloadCommand_LatestVersion : package.NuGetVersion.ToNormalizedString()));

                // Resolve which repositories to use for this package
                IReadOnlyList<SourceRepository> sourceRepositories;
                if (ignorePackageSourceMapping)
                {
                    sourceRepositories = allRepositories;
                }
                else
                {
                    var mappedNames = packageSourceMapping!.GetConfiguredPackageSources(package.Id);

                    if (mappedNames.Count == 0)
                    {
                        // fail, no sources mapped for this package
                        var notConsideredSources = string.Join(
                            ", ",
                            allRepositories.Select(repository => repository.PackageSource));

                        logger.LogError(string.Format(
                            CultureInfo.CurrentCulture,
                            Strings.PackageDownloadCommand_PackageSourceMapping_NoSourcesMapped,
                            package.Id,
                            notConsideredSources));

                        downloadedAllSuccessfully &= false;
                        continue;
                    }

                    sourceRepositories = GetMappedRepositories(mappedNames, allRepositories, package.Id, logger);

                    if (DetectAndReportInsecureSources(args.AllowInsecureConnections, sourceRepositories.Select(r => r.PackageSource), logger))
                    {
                        downloadedAllSuccessfully &= false;
                        continue;
                    }
                }

                try
                {
                    (NuGetVersion? version, SourceRepository? downloadRepository) =
                        await ResolvePackageDownloadVersion(
                            package,
                            sourceRepositories,
                            cache,
                            logger,
                            args.IncludePrerelease,
                            token);

                    if (version == null)
                    {
                        // Unable to find a valid version
                        downloadedAllSuccessfully &= false;
                        continue;
                    }

                    bool success = await DownloadPackageAsync(
                        package.Id,
                        version,
                        downloadRepository!,
                        cache,
                        settings,
                        outputDirectory,
                        logger,
                        token);

                    if (success)
                    {
                        logger.LogMinimal(string.Format(
                            CultureInfo.CurrentCulture,
                            Strings.PackageDownloadCommand_Succeeded,
                            package.Id,
                            version,
                            outputDirectory));
                    }
                    else
                    {
                        logger.LogError(string.Format(
                            CultureInfo.CurrentCulture,
                            Strings.PackageDownloadCommand_Failed,
                            package.Id,
                            version));

                        downloadedAllSuccessfully &= false;
                    }
                }
#pragma warning disable CA1031 // Do not catch general exception types
                catch (Exception ex)
                {
                    logger.LogError(ex.ToString());
                    downloadedAllSuccessfully &= false;
                }
#pragma warning restore CA1031 // Do not catch general exception types
            }

            return downloadedAllSuccessfully ? ExitCodeSuccess : ExitCodeError;
        }

        internal static async Task<(NuGetVersion?, SourceRepository?)> ResolvePackageDownloadVersion(
            PackageWithNuGetVersion packageWithNuGetVersion,
            IReadOnlyList<SourceRepository> sourceRepositories,
            SourceCacheContext cache,
            ILoggerWithColor logger,
            bool includePrerelease,
            CancellationToken token)
        {
            NuGetVersion? versionToDownload = null;
            SourceRepository? downloadSourceRepository = null;
            bool versionSpecified = packageWithNuGetVersion.NuGetVersion != null;

            foreach (var repo in sourceRepositories)
            {
                var finder = await repo.GetResourceAsync<PackageMetadataResource>(token);
                var packages = await finder.GetMetadataAsync(
                    packageWithNuGetVersion.Id,
                    includePrerelease,
                    includeUnlisted: versionSpecified, // only load unlisted if an exact version is specified
                    sourceCacheContext: cache,
                    logger,
                    token);

                if (packages == null)
                {
                    continue;
                }

                if (versionSpecified)
                {
                    // If an exact version is specified, check if it exists at this source
                    foreach (var package in packages)
                    {
                        if (package?.Identity?.Version == packageWithNuGetVersion.NuGetVersion)
                        {
                            return (packageWithNuGetVersion.NuGetVersion, repo);
                        }
                    }

                    continue;
                }

                foreach (var package in packages)
                {
                    var version = package.Identity.Version;
                    if (versionToDownload == null || version > versionToDownload)
                    {
                        versionToDownload = version;
                        downloadSourceRepository = repo;
                    }
                }
            }

            if (versionToDownload == null)
            {
                logger.LogError(Strings.Error_PackageDownload_VersionNotFound);
            }

            return (versionToDownload, downloadSourceRepository);
        }

        internal static IReadOnlyList<SourceRepository> GetMappedRepositories(
            IReadOnlyList<string> mappedNames,
            IReadOnlyList<SourceRepository> allRepos,
            string packageId,
            ILoggerWithColor logger)
        {
            var mappedRepos = new List<SourceRepository>(mappedNames.Count);

            foreach (var mappedName in mappedNames)
            {
                SourceRepository? repo = FindRepositoryByName(mappedName, allRepos);

                if (repo != null)
                {
                    mappedRepos.Add(repo);
                }
                else
                {
                    logger.LogVerbose(
                        string.Format(
                            CultureInfo.CurrentCulture,
                            Strings.PackageDownloadCommand_PackageSourceMapping_NoSuchSource,
                            mappedName,
                            packageId));
                }
            }

            return mappedRepos;
        }

        private static SourceRepository? FindRepositoryByName(
            string mappedName,
            IReadOnlyList<SourceRepository> allRepos)
        {
            for (int i = 0; i < allRepos.Count; i++)
            {
                if (string.Equals(allRepos[i].PackageSource.Name, mappedName, StringComparison.OrdinalIgnoreCase))
                {
                    return allRepos[i];
                }
            }

            return null;
        }

        private static async Task<bool> DownloadPackageAsync(
            string id,
            NuGetVersion version,
            SourceRepository repo,
            SourceCacheContext cache,
            ISettings settings,
            string outputDirectory,
            Common.ILogger logger,
            CancellationToken token)
        {
            var extractionContext = new PackageExtractionContext(
                PackageSaveMode.Defaultv3,
                PackageExtractionBehavior.XmlDocFileSaveMode,
                ClientPolicyContext.GetClientPolicy(settings, logger),
                logger);

            var resolver = new VersionFolderPathResolver(outputDirectory);
            var userPackageFolder = new NuGetv3LocalRepository(outputDirectory);

            // no-op if already installed
            if (userPackageFolder.Exists(id, version))
            {
                logger.LogMinimal(string.Format(
                    CultureInfo.CurrentCulture,
                    Strings.PackageDownloadCommand_AlreadyInstalled,
                    id,
                    version.ToNormalizedString(),
                    outputDirectory));

                return true;
            }

            var packageIdentity = new PackageIdentity(id, version);
            var provider = new SourceRepositoryDependencyProvider(sourceRepository: repo, logger: logger, cacheContext: cache, ignoreFailedSources: false, ignoreWarning: false);
            using var downloader = await provider.GetPackageDownloaderAsync(packageIdentity, cache, logger, token);
            bool success = await PackageExtractor.InstallFromSourceAsync(packageIdentity, downloader, resolver, extractionContext, token);

            if (!success)
            {
                logger.LogError(string.Format(
                    CultureInfo.CurrentCulture,
                    Strings.PackageDownloadCommand_UnableToDownload,
                    id,
                    version.ToNormalizedString(),
                    repo.PackageSource.Source));
                return false;
            }

            return success;
        }

        private static IReadOnlyList<PackageSource> GetPackageSources(IList<string>? sources, IPackageSourceProvider sourceProvider)
        {
            IEnumerable<PackageSource> configuredSources = sourceProvider.LoadPackageSources()
                .Where(s => s.IsEnabled);

            if (sources != null && sources.Count > 0)
            {
                // Use sources specified on command line
                return [.. sources.Select(s => PackageSourceProviderExtensions.ResolveSource(configuredSources, s))];
            }

            return [.. configuredSources];
        }

        private static bool DetectAndReportInsecureSources(
            bool allowInsecureConnections,
            IEnumerable<PackageSource> packageSources,
            ILoggerWithColor logger)
        {
            if (!allowInsecureConnections)
            {
                var insecureSources = HttpSourcesUtility.GetDisallowedInsecureHttpSources([.. packageSources]);
                if (insecureSources.Any())
                {
                    logger.LogError(HttpSourcesUtility.BuildHttpSourceErrorMessage(insecureSources, "package download"));
                    return true;
                }
            }

            return false;
        }

        private static IReadOnlyList<SourceRepository> GetSourceRepositories(IReadOnlyList<PackageSource> packageSources)
        {
            IEnumerable<Lazy<INuGetResourceProvider>> providers = Repository.Provider.GetCoreV3();
            List<SourceRepository> sourceRepositories = [];
            foreach (var source in packageSources)
            {
                sourceRepositories.Add(Repository.CreateSource(providers, source, FeedType.Undefined));
            }

            return sourceRepositories;
        }
    }
}