File: Resources\VulnerabilityInfoResourceV3.cs
Web Access
Project: src\src\nuget-client\src\NuGet.Core\NuGet.Protocol\NuGet.Protocol.csproj (NuGet.Protocol)
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using NuGet.Common;
using NuGet.Protocol.Core.Types;
using NuGet.Protocol.Model;
using NuGet.Protocol.Utility;

namespace NuGet.Protocol.Resources
{
    /// <summary>Implementation of <see cref="IVulnerabilityInfoResource"/> for NuGet V3 HTTP feeds.</summary>
    /// <remarks>Not intended to be created directly. Use <see cref="SourceRepository.GetResourceAsync{T}(CancellationToken)"/>
    /// with <see cref="IVulnerabilityInfoResource"/> for T, and typecast to this class.
    /// <para>Implements the <a href="https://learn.microsoft.com/nuget/api/vulnerability-info-resource">VulnerabilityInfo server API resource</a>.</para></remarks>
    public sealed class VulnerabilityInfoResourceV3 : IVulnerabilityInfoResource
    {
        private readonly SourceRepository _sourceRepository;

        internal VulnerabilityInfoResourceV3(SourceRepository sourceRepository)
        {
            _sourceRepository = sourceRepository ?? throw new ArgumentNullException(nameof(sourceRepository));
        }

        /// <summary>Get the vulnerability pages the server contains.</summary>
        /// <param name="cacheContext">The cache settings to use when making HTTP requests.</param>
        /// <param name="log">The logger for any messages.</param>
        /// <param name="cancellationToken">The cancellation token to cancel operation.</param>
        /// <returns>The list of vulnerability data pages the server claims to have.</returns>
        /// <exception cref="FatalProtocolException">When various HTTP or deserialization exceptions occur.</exception>
        public async Task<IReadOnlyList<V3VulnerabilityIndexEntry>> GetVulnerabilityFilesAsync(SourceCacheContext cacheContext, ILogger log, CancellationToken cancellationToken)
        {
            Uri vulnerabilityIndexUrl = await GetIndexUrlAsync(cancellationToken);
            HttpSourceResource httpSourceResource = await _sourceRepository.GetResourceAsync<HttpSourceResource>(cancellationToken);

            HttpSourceCacheContext httpSourceCacheContext = HttpSourceCacheContext.Create(cacheContext, isFirstAttempt: true);
            var request = new HttpSourceCachedRequest(vulnerabilityIndexUrl.OriginalString, "vuln_index", httpSourceCacheContext);
            IReadOnlyList<V3VulnerabilityIndexEntry>? vulnFiles;
            try
            {
                vulnFiles = await httpSourceResource.HttpSource.GetAsync(request,
                    async result =>
                    {
                        IReadOnlyList<V3VulnerabilityIndexEntry>? parsed =
                            await JsonSerializer.DeserializeAsync(result.Stream!, JsonContext.Default.VulnerabilityIndex);
                        return parsed;
                    },
                    log,
                    cancellationToken);
            }
            catch (Exception ex) when (ex is not FatalProtocolException)
            {
                // Deserialization errors might throw Newtonsoft.Json's JsonSerializationException, but the library we deserialize with
                // should not be a leaky abstraction, so it should be just a little easier to maybe move to System.Text.Json one day.
                throw new FatalProtocolException(message: ex.Message, innerException: ex);
            }

            if (vulnFiles == null)
            {
                string message = string.Format(Strings.VulnerabilityPage_CouldNotLoad, _sourceRepository.PackageSource.Name, vulnerabilityIndexUrl.OriginalString);
                throw new FatalProtocolException(message);
            }

            return vulnFiles;

            async Task<Uri> GetIndexUrlAsync(CancellationToken cancellationToken)
            {
                ServiceIndexResourceV3 serviceIndex = await _sourceRepository.GetResourceAsync<ServiceIndexResourceV3>(cancellationToken);
                return serviceIndex.GetServiceEntryUri(ServiceTypes.VulnerabilityInfo);
            }
        }

        /// <summary>Get the known vulnerability data for a single page from the server.</summary>
        /// <param name="vulnerabilityPage">The page to get data from</param>
        /// <param name="cacheContext">The cache settings if HTTP requests are made.</param>
        /// <param name="logger">The logger for messages.</param>
        /// <param name="cancellationToken">The cancelation token to cancel operation.</param>
        /// <returns>The known vulnerabilities defined in the file.</returns>
        /// <exception cref="FatalProtocolException">If various HTTP or deserialization exceptions occur.</exception>
        public async Task<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> GetVulnerabilityDataAsync(
            V3VulnerabilityIndexEntry vulnerabilityPage,
            SourceCacheContext cacheContext,
            ILogger logger,
            CancellationToken cancellationToken)
        {
            HttpSourceResource httpSourceResource = await _sourceRepository.GetResourceAsync<HttpSourceResource>(cancellationToken);

            HttpSourceCacheContext httpSourceCacheContext = HttpSourceCacheContext.Create(cacheContext, isFirstAttempt: true);
            var request = new HttpSourceCachedRequest(vulnerabilityPage.Url.OriginalString, "vuln_data_" + vulnerabilityPage.Name, httpSourceCacheContext);
            IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>? data;
            try
            {
                data = await httpSourceResource.HttpSource.GetAsync(request,
                    async result =>
                    {
                        CaseInsensitiveDictionary<IReadOnlyList<PackageVulnerabilityInfo>>? parsed =
                            await JsonSerializer.DeserializeAsync(result.Stream!, JsonContext.Default.VulnerabilityPage);
                        return parsed;
                    },
                    logger,
                    cancellationToken);
            }
            catch (Exception ex) when (ex is not FatalProtocolException)
            {
                // Deserialization errors might throw Newtonsoft.Json's JsonSerializationException, but the library we deserialize with
                // should not be a leaky abstraction, so it should be just a little easier to maybe move to System.Text.Json one day.
                throw new FatalProtocolException(message: ex.Message, innerException: ex);
            }

            if (data == null)
            {
                string message = string.Format(Strings.VulnerabilityPage_CouldNotLoad, _sourceRepository.PackageSource.Name, vulnerabilityPage.Url);
                throw new FatalProtocolException(message);
            }

            return data;
        }

        /// <inheritdoc cref="IVulnerabilityInfoResource.GetVulnerabilityInfoAsync(SourceCacheContext, ILogger, CancellationToken)"/>
        public async Task<GetVulnerabilityInfoResult> GetVulnerabilityInfoAsync(SourceCacheContext cacheContext, ILogger logger, CancellationToken cancellationToken)
        {
            List<Exception>? exceptions = null;
            List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>? knownVulnerabilities = null;

            try
            {
                IReadOnlyList<V3VulnerabilityIndexEntry> indexEntries = await GetVulnerabilityFilesAsync(cacheContext, logger, cancellationToken);

                if (indexEntries.Count == 0)
                {
                    knownVulnerabilities = new();
                    return new GetVulnerabilityInfoResult(knownVulnerabilities, ToAggregateException(exceptions));
                }

                const int maxPages = 16;
                if (indexEntries.Count > maxPages)
                {
                    string message = string.Format(Strings.Vulnerability_TooManyPages, _sourceRepository.PackageSource.Name, indexEntries.Count, maxPages);
                    AddException(new FatalProtocolException(message), ref exceptions);
                    indexEntries = indexEntries.Take(maxPages).ToList();
                }

                var tasks = new Task<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>[indexEntries.Count];
                for (int i = 0; i < tasks.Length; i++)
                {
                    V3VulnerabilityIndexEntry indexEntry = indexEntries[i];
                    tasks[i] = GetVulnerabilityDataAsync(indexEntry, cacheContext, logger, cancellationToken);
                }

                indexEntries = GetValidIndexEntries(indexEntries, ref exceptions);

                await Task.WhenAll(tasks);

                for (int i = 0; i < tasks.Length; i++)
                {
                    try
                    {
                        IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>? data = await tasks[i];
                        if (data != null)
                        {
                            if (knownVulnerabilities == null)
                            {
                                knownVulnerabilities = new(indexEntries.Count);
                            }
                            knownVulnerabilities.Add(data);
                        }
                    }
                    catch (Exception ex)
                    {
                        AddException(ex, ref exceptions);
                    }
                }
            }
            catch (Exception ex)
            {
                AddException(ex, ref exceptions);
            }

            GetVulnerabilityInfoResult result = new(knownVulnerabilities, ToAggregateException(exceptions));
            return result;

            static AggregateException? ToAggregateException(IEnumerable<Exception>? exceptions)
            {
                AggregateException? aggregateException =
                    exceptions == null
                    ? null
                    : new AggregateException(exceptions);
                return aggregateException;
            }
        }

        private IReadOnlyList<V3VulnerabilityIndexEntry> GetValidIndexEntries(IReadOnlyList<V3VulnerabilityIndexEntry> indexEntries, ref List<Exception>? exceptions)
        {
            List<V3VulnerabilityIndexEntry> validIndexEntries = new(indexEntries.Count);
            HashSet<string> pageNames =
#if NETSTANDARD
                new(comparer: StringComparer.InvariantCultureIgnoreCase);
#else
                new(indexEntries.Count, StringComparer.InvariantCultureIgnoreCase);
#endif

            for (int i = 0; i < indexEntries.Count; i++)
            {
                V3VulnerabilityIndexEntry entry = indexEntries[i];

                string name = entry.Name;
                if (string.IsNullOrWhiteSpace(name))
                {
                    string message = string.Format(Strings.VulnerabilityPage_HasNoName, i, _sourceRepository.PackageSource.Name);
                    AddException(new FatalProtocolException(message), ref exceptions);
                    continue;
                }

                const int maxNameLength = 32;
                if (name.Length > maxNameLength)
                {
                    string message = string.Format(Strings.VulnerabilityPage_NameTooLong, i, maxNameLength, _sourceRepository.PackageSource.Name);
                    AddException(new FatalProtocolException(message), ref exceptions);
                    continue;
                }

                bool hasOnlyValidChars = true;
                for (int j = 0; j < name.Length; j++)
                {
                    char c = name[j];
                    if (!(c >= 'A' && c <= 'Z')
                        && !(c >= 'a' && c <= 'z')
                        && !(c >= '0' && c <= '9')
                        && !(c == '-' || c == '_'))
                    {
                        hasOnlyValidChars = false;
                        break;
                    }
                }
                if (!hasOnlyValidChars)
                {
                    string message = string.Format(Strings.VulnerabilityPage_NameHasInvalidCharacters, i, _sourceRepository.PackageSource.Name);
                    AddException(new FatalProtocolException(message), ref exceptions);
                    continue;
                }

                if (entry.Url == null)
                {
                    string message = string.Format(Strings.VulnerabilityPage_NoUrl, entry.Name, _sourceRepository.PackageSource.Name);
                    AddException(new FatalProtocolException(message), ref exceptions);
                    continue;
                }

                if (entry.Url.Scheme != "https" && entry.Url.Scheme != "http")
                {
                    string message = string.Format(Strings.VulnerabilityPage_UrlNotHttp, entry.Name, _sourceRepository.PackageSource.Name);
                    AddException(new FatalProtocolException(message), ref exceptions);
                    continue;
                }

                if (!pageNames.Add(entry.Name))
                {
                    string message = string.Format(Strings.VulnerabilityPage_NameNotUnique, entry.Name, _sourceRepository.PackageSource.Name);
                    AddException(new FatalProtocolException(message), ref exceptions);
                    continue;
                }

                validIndexEntries.Add(entry);
            }

            return validIndexEntries;
        }

        private static void AddException(Exception exception, ref List<Exception>? exceptions)
        {
            if (exceptions == null)
            {
                exceptions = new();
            }
            exceptions.Add(exception);
        }
    }
}