|
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using NuGet.Common;
using NuGet.Protocol.Core.Types;
using NuGet.Protocol.Model;
using NuGet.Protocol.Utility;
namespace NuGet.Protocol.Resources
{
/// <summary>Implementation of <see cref="IVulnerabilityInfoResource"/> for NuGet V3 HTTP feeds.</summary>
/// <remarks>Not intended to be created directly. Use <see cref="SourceRepository.GetResourceAsync{T}(CancellationToken)"/>
/// with <see cref="IVulnerabilityInfoResource"/> for T, and typecast to this class.
/// <para>Implements the <a href="https://learn.microsoft.com/nuget/api/vulnerability-info-resource">VulnerabilityInfo server API resource</a>.</para></remarks>
public sealed class VulnerabilityInfoResourceV3 : IVulnerabilityInfoResource
{
private readonly SourceRepository _sourceRepository;
internal VulnerabilityInfoResourceV3(SourceRepository sourceRepository)
{
_sourceRepository = sourceRepository ?? throw new ArgumentNullException(nameof(sourceRepository));
}
/// <summary>Get the vulnerability pages the server contains.</summary>
/// <param name="cacheContext">The cache settings to use when making HTTP requests.</param>
/// <param name="log">The logger for any messages.</param>
/// <param name="cancellationToken">The cancellation token to cancel operation.</param>
/// <returns>The list of vulnerability data pages the server claims to have.</returns>
/// <exception cref="FatalProtocolException">When various HTTP or deserialization exceptions occur.</exception>
public async Task<IReadOnlyList<V3VulnerabilityIndexEntry>> GetVulnerabilityFilesAsync(SourceCacheContext cacheContext, ILogger log, CancellationToken cancellationToken)
{
Uri vulnerabilityIndexUrl = await GetIndexUrlAsync(cancellationToken);
HttpSourceResource httpSourceResource = await _sourceRepository.GetResourceAsync<HttpSourceResource>(cancellationToken);
HttpSourceCacheContext httpSourceCacheContext = HttpSourceCacheContext.Create(cacheContext, isFirstAttempt: true);
var request = new HttpSourceCachedRequest(vulnerabilityIndexUrl.OriginalString, "vuln_index", httpSourceCacheContext);
IReadOnlyList<V3VulnerabilityIndexEntry>? vulnFiles;
try
{
vulnFiles = await httpSourceResource.HttpSource.GetAsync(request,
async result =>
{
IReadOnlyList<V3VulnerabilityIndexEntry>? parsed =
await JsonSerializer.DeserializeAsync(result.Stream!, JsonContext.Default.VulnerabilityIndex);
return parsed;
},
log,
cancellationToken);
}
catch (Exception ex) when (ex is not FatalProtocolException)
{
// Deserialization errors might throw Newtonsoft.Json's JsonSerializationException, but the library we deserialize with
// should not be a leaky abstraction, so it should be just a little easier to maybe move to System.Text.Json one day.
throw new FatalProtocolException(message: ex.Message, innerException: ex);
}
if (vulnFiles == null)
{
string message = string.Format(Strings.VulnerabilityPage_CouldNotLoad, _sourceRepository.PackageSource.Name, vulnerabilityIndexUrl.OriginalString);
throw new FatalProtocolException(message);
}
return vulnFiles;
async Task<Uri> GetIndexUrlAsync(CancellationToken cancellationToken)
{
ServiceIndexResourceV3 serviceIndex = await _sourceRepository.GetResourceAsync<ServiceIndexResourceV3>(cancellationToken);
return serviceIndex.GetServiceEntryUri(ServiceTypes.VulnerabilityInfo);
}
}
/// <summary>Get the known vulnerability data for a single page from the server.</summary>
/// <param name="vulnerabilityPage">The page to get data from</param>
/// <param name="cacheContext">The cache settings if HTTP requests are made.</param>
/// <param name="logger">The logger for messages.</param>
/// <param name="cancellationToken">The cancelation token to cancel operation.</param>
/// <returns>The known vulnerabilities defined in the file.</returns>
/// <exception cref="FatalProtocolException">If various HTTP or deserialization exceptions occur.</exception>
public async Task<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> GetVulnerabilityDataAsync(
V3VulnerabilityIndexEntry vulnerabilityPage,
SourceCacheContext cacheContext,
ILogger logger,
CancellationToken cancellationToken)
{
HttpSourceResource httpSourceResource = await _sourceRepository.GetResourceAsync<HttpSourceResource>(cancellationToken);
HttpSourceCacheContext httpSourceCacheContext = HttpSourceCacheContext.Create(cacheContext, isFirstAttempt: true);
var request = new HttpSourceCachedRequest(vulnerabilityPage.Url.OriginalString, "vuln_data_" + vulnerabilityPage.Name, httpSourceCacheContext);
IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>? data;
try
{
data = await httpSourceResource.HttpSource.GetAsync(request,
async result =>
{
CaseInsensitiveDictionary<IReadOnlyList<PackageVulnerabilityInfo>>? parsed =
await JsonSerializer.DeserializeAsync(result.Stream!, JsonContext.Default.VulnerabilityPage);
return parsed;
},
logger,
cancellationToken);
}
catch (Exception ex) when (ex is not FatalProtocolException)
{
// Deserialization errors might throw Newtonsoft.Json's JsonSerializationException, but the library we deserialize with
// should not be a leaky abstraction, so it should be just a little easier to maybe move to System.Text.Json one day.
throw new FatalProtocolException(message: ex.Message, innerException: ex);
}
if (data == null)
{
string message = string.Format(Strings.VulnerabilityPage_CouldNotLoad, _sourceRepository.PackageSource.Name, vulnerabilityPage.Url);
throw new FatalProtocolException(message);
}
return data;
}
/// <inheritdoc cref="IVulnerabilityInfoResource.GetVulnerabilityInfoAsync(SourceCacheContext, ILogger, CancellationToken)"/>
public async Task<GetVulnerabilityInfoResult> GetVulnerabilityInfoAsync(SourceCacheContext cacheContext, ILogger logger, CancellationToken cancellationToken)
{
List<Exception>? exceptions = null;
List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>? knownVulnerabilities = null;
try
{
IReadOnlyList<V3VulnerabilityIndexEntry> indexEntries = await GetVulnerabilityFilesAsync(cacheContext, logger, cancellationToken);
if (indexEntries.Count == 0)
{
knownVulnerabilities = new();
return new GetVulnerabilityInfoResult(knownVulnerabilities, ToAggregateException(exceptions));
}
const int maxPages = 16;
if (indexEntries.Count > maxPages)
{
string message = string.Format(Strings.Vulnerability_TooManyPages, _sourceRepository.PackageSource.Name, indexEntries.Count, maxPages);
AddException(new FatalProtocolException(message), ref exceptions);
indexEntries = indexEntries.Take(maxPages).ToList();
}
var tasks = new Task<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>[indexEntries.Count];
for (int i = 0; i < tasks.Length; i++)
{
V3VulnerabilityIndexEntry indexEntry = indexEntries[i];
tasks[i] = GetVulnerabilityDataAsync(indexEntry, cacheContext, logger, cancellationToken);
}
indexEntries = GetValidIndexEntries(indexEntries, ref exceptions);
await Task.WhenAll(tasks);
for (int i = 0; i < tasks.Length; i++)
{
try
{
IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>? data = await tasks[i];
if (data != null)
{
if (knownVulnerabilities == null)
{
knownVulnerabilities = new(indexEntries.Count);
}
knownVulnerabilities.Add(data);
}
}
catch (Exception ex)
{
AddException(ex, ref exceptions);
}
}
}
catch (Exception ex)
{
AddException(ex, ref exceptions);
}
GetVulnerabilityInfoResult result = new(knownVulnerabilities, ToAggregateException(exceptions));
return result;
static AggregateException? ToAggregateException(IEnumerable<Exception>? exceptions)
{
AggregateException? aggregateException =
exceptions == null
? null
: new AggregateException(exceptions);
return aggregateException;
}
}
private IReadOnlyList<V3VulnerabilityIndexEntry> GetValidIndexEntries(IReadOnlyList<V3VulnerabilityIndexEntry> indexEntries, ref List<Exception>? exceptions)
{
List<V3VulnerabilityIndexEntry> validIndexEntries = new(indexEntries.Count);
HashSet<string> pageNames =
#if NETSTANDARD
new(comparer: StringComparer.InvariantCultureIgnoreCase);
#else
new(indexEntries.Count, StringComparer.InvariantCultureIgnoreCase);
#endif
for (int i = 0; i < indexEntries.Count; i++)
{
V3VulnerabilityIndexEntry entry = indexEntries[i];
string name = entry.Name;
if (string.IsNullOrWhiteSpace(name))
{
string message = string.Format(Strings.VulnerabilityPage_HasNoName, i, _sourceRepository.PackageSource.Name);
AddException(new FatalProtocolException(message), ref exceptions);
continue;
}
const int maxNameLength = 32;
if (name.Length > maxNameLength)
{
string message = string.Format(Strings.VulnerabilityPage_NameTooLong, i, maxNameLength, _sourceRepository.PackageSource.Name);
AddException(new FatalProtocolException(message), ref exceptions);
continue;
}
bool hasOnlyValidChars = true;
for (int j = 0; j < name.Length; j++)
{
char c = name[j];
if (!(c >= 'A' && c <= 'Z')
&& !(c >= 'a' && c <= 'z')
&& !(c >= '0' && c <= '9')
&& !(c == '-' || c == '_'))
{
hasOnlyValidChars = false;
break;
}
}
if (!hasOnlyValidChars)
{
string message = string.Format(Strings.VulnerabilityPage_NameHasInvalidCharacters, i, _sourceRepository.PackageSource.Name);
AddException(new FatalProtocolException(message), ref exceptions);
continue;
}
if (entry.Url == null)
{
string message = string.Format(Strings.VulnerabilityPage_NoUrl, entry.Name, _sourceRepository.PackageSource.Name);
AddException(new FatalProtocolException(message), ref exceptions);
continue;
}
if (entry.Url.Scheme != "https" && entry.Url.Scheme != "http")
{
string message = string.Format(Strings.VulnerabilityPage_UrlNotHttp, entry.Name, _sourceRepository.PackageSource.Name);
AddException(new FatalProtocolException(message), ref exceptions);
continue;
}
if (!pageNames.Add(entry.Name))
{
string message = string.Format(Strings.VulnerabilityPage_NameNotUnique, entry.Name, _sourceRepository.PackageSource.Name);
AddException(new FatalProtocolException(message), ref exceptions);
continue;
}
validIndexEntries.Add(entry);
}
return validIndexEntries;
}
private static void AddException(Exception exception, ref List<Exception>? exceptions)
{
if (exceptions == null)
{
exceptions = new();
}
exceptions.Add(exception);
}
}
}
|