|
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json.Linq;
using NuGet.Common;
using NuGet.Configuration;
using NuGet.Protocol.Core.Types;
namespace NuGet.Protocol
{
public class HttpSource : IDisposable
{
private readonly Func<Task<HttpHandlerResource>> _messageHandlerFactory;
private readonly Uri _sourceUri;
private HttpClient? _httpClient;
private string? _httpCacheDirectory;
private readonly PackageSource _packageSource;
private readonly IThrottle _throttle;
private bool _disposed = false;
// Only one thread may re-create the http client at a time.
private readonly SemaphoreSlim _httpClientLock = new SemaphoreSlim(1, 1);
/// <summary>The retry handler to use for all HTTP requests.</summary>
/// <summary>This API is intended only for testing purposes and should not be used in product code.</summary>
public IHttpRetryHandler RetryHandler { get; set; } = new HttpRetryHandler();
public string PackageSource => _packageSource.Source;
public HttpSource(
PackageSource packageSource,
Func<Task<HttpHandlerResource>> messageHandlerFactory,
IThrottle throttle)
{
if (packageSource == null)
{
throw new ArgumentNullException(nameof(packageSource));
}
if (messageHandlerFactory == null)
{
throw new ArgumentNullException(nameof(messageHandlerFactory));
}
if (throttle == null)
{
throw new ArgumentNullException(nameof(throttle));
}
_packageSource = packageSource;
_sourceUri = packageSource.SourceUri;
_messageHandlerFactory = messageHandlerFactory;
_throttle = throttle;
}
/// <summary>
/// Caching Get request.
/// </summary>
public virtual async Task<T> GetAsync<T>(
HttpSourceCachedRequest request,
Func<HttpSourceResult, Task<T>> processAsync,
ILogger log,
CancellationToken token)
{
ThrowIfHttpUriAndInsecureConnectionsNotAllowed(request.Uri);
var cacheResult = HttpCacheUtility.InitializeHttpCacheResult(
HttpCacheDirectory,
_sourceUri,
request.CacheKey,
request.CacheContext);
return await ConcurrencyUtilities.ExecuteWithFileLockedAsync(
cacheResult.CacheFile,
action: async lockedToken =>
{
cacheResult.Stream = TryReadCacheFile(request.Uri, cacheResult.MaxAge, cacheResult.CacheFile);
try
{
if (cacheResult.Stream != null)
{
log.LogInformation(string.Format(CultureInfo.InvariantCulture, " " + Strings.Http_RequestLog, "CACHE", request.Uri));
// Validate the content fetched from the cache.
try
{
request.EnsureValidContents?.Invoke(cacheResult.Stream);
cacheResult.Stream.Seek(0, SeekOrigin.Begin);
var httpSourceResult = new HttpSourceResult(
HttpSourceResultStatus.OpenedFromDisk,
cacheResult.CacheFile,
cacheResult.Stream);
return await processAsync(httpSourceResult);
}
catch (Exception e)
{
cacheResult.Stream.Dispose();
cacheResult.Stream = null;
string message = string.Format(CultureInfo.CurrentCulture, Strings.Log_InvalidCacheEntry, request.Uri)
+ Environment.NewLine
+ ExceptionUtilities.DisplayMessage(e);
log.LogWarning(message);
}
}
Func<HttpRequestMessage> requestFactory = () =>
{
var requestMessage = HttpRequestMessageFactory.Create(HttpMethod.Get, request.Uri, log);
foreach (var acceptHeaderValue in request.AcceptHeaderValues)
{
requestMessage.Headers.Accept.Add(acceptHeaderValue);
}
return requestMessage;
};
Func<Task<ThrottledResponse>> throttledResponseFactory = () => GetThrottledResponse(
requestFactory,
request.RequestTimeout,
request.DownloadTimeout,
request.MaxTries,
request.IsRetry,
request.IsLastAttempt,
request.CacheContext.SourceCacheContext.SessionId,
log,
lockedToken);
using (var throttledResponse = await throttledResponseFactory())
{
if (request.IgnoreNotFounds && throttledResponse.Response.StatusCode == HttpStatusCode.NotFound)
{
var httpSourceResult = new HttpSourceResult(HttpSourceResultStatus.NotFound);
return await processAsync(httpSourceResult);
}
if (throttledResponse.Response.StatusCode == HttpStatusCode.NoContent)
{
// Ignore reading and caching the empty stream.
var httpSourceResult = new HttpSourceResult(HttpSourceResultStatus.NoContent);
return await processAsync(httpSourceResult);
}
throttledResponse.Response.EnsureSuccessStatusCode();
if (!request.CacheContext.DirectDownload)
{
await HttpCacheUtility.CreateCacheFileAsync(
cacheResult,
throttledResponse.Response,
request.EnsureValidContents,
lockedToken);
using (var httpSourceResult = new HttpSourceResult(
HttpSourceResultStatus.OpenedFromDisk,
cacheResult.CacheFile,
cacheResult.Stream!)) // Stream is set by CreateCacheFileAsync above
{
return await processAsync(httpSourceResult);
}
}
else
{
// Note that we do not execute the content validator on the response stream when skipping
// the cache. We cannot seek on the network stream and it is not valuable to download the
// content twice just to validate the first time (considering that the second download could
// be different from the first thus rendering the first validation meaningless).
#if NETCOREAPP2_0_OR_GREATER
using (var stream = await throttledResponse.Response.Content.ReadAsStreamAsync(lockedToken))
#else
using (var stream = await throttledResponse.Response.Content.ReadAsStreamAsync())
#endif
using (var httpSourceResult = new HttpSourceResult(
HttpSourceResultStatus.OpenedFromNetwork,
cacheFileName: null,
stream: stream))
{
return await processAsync(httpSourceResult);
}
}
}
}
finally
{
if (cacheResult.Stream != null)
{
cacheResult.Stream.Dispose();
}
}
},
token: token);
}
public Task<T> ProcessStreamAsync<T>(
HttpSourceRequest request,
Func<Stream?, Task<T>> processAsync,
ILogger log,
CancellationToken token)
{
return ProcessStreamAsync<T>(request, processAsync, cacheContext: null, log: log, token: token);
}
internal async Task<T> ProcessHttpStreamAsync<T>(
HttpSourceRequest request,
Func<HttpResponseMessage?, Task<T>> processAsync,
ILogger log,
CancellationToken token)
{
// RequestUri is always set for NuGet HTTP requests
ThrowIfHttpUriAndInsecureConnectionsNotAllowed(request.RequestFactory().RequestUri!.AbsoluteUri);
return await ProcessResponseAsync(
request,
async response =>
{
if ((request.IgnoreNotFounds && response.StatusCode == HttpStatusCode.NotFound) ||
response.StatusCode == HttpStatusCode.NoContent)
{
return await processAsync(null);
}
response.EnsureSuccessStatusCode();
return await processAsync(response);
},
cacheContext: null,
log,
token);
}
public async Task<T> ProcessStreamAsync<T>(
HttpSourceRequest request,
Func<Stream?, Task<T>> processAsync,
SourceCacheContext? cacheContext,
ILogger log,
CancellationToken token)
{
return await ProcessResponseAsync(
request,
async response =>
{
if ((request.IgnoreNotFounds && response.StatusCode == HttpStatusCode.NotFound) ||
response.StatusCode == HttpStatusCode.NoContent)
{
return await processAsync(null);
}
response.EnsureSuccessStatusCode();
var networkStream = await response.Content.ReadAsStreamAsync();
return await processAsync(networkStream);
},
cacheContext,
log,
token);
}
public Task<T> ProcessResponseAsync<T>(
HttpSourceRequest request,
Func<HttpResponseMessage, Task<T>> processAsync,
ILogger log,
CancellationToken token)
{
return ProcessResponseAsync(request, processAsync, cacheContext: null, log: log, token: token);
}
public async Task<T> ProcessResponseAsync<T>(
HttpSourceRequest request,
Func<HttpResponseMessage, Task<T>> processAsync,
SourceCacheContext? cacheContext,
ILogger log,
CancellationToken token)
{
// Generate a new session id if no cache context was provided.
var sessionId = cacheContext?.SessionId ?? Guid.NewGuid();
Task<ThrottledResponse> throttledResponseFactory() => GetThrottledResponse(
request.RequestFactory,
request.RequestTimeout,
request.DownloadTimeout,
request.MaxTries,
request.IsRetry,
request.IsLastAttempt,
sessionId,
log,
token);
using (var throttledResponse = await throttledResponseFactory())
{
return await processAsync(throttledResponse.Response);
}
}
public async Task<JObject?> GetJObjectAsync(HttpSourceRequest request, ILogger log, CancellationToken token)
{
return await ProcessStreamAsync(
request,
processAsync: stream =>
{
if (stream == null)
{
return TaskResult.Null<JObject>();
}
return stream.AsJObjectAsync(token);
},
log: log,
token: token);
}
private async Task<ThrottledResponse> GetThrottledResponse(
Func<HttpRequestMessage> requestFactory,
TimeSpan requestTimeout,
TimeSpan downloadTimeout,
int maxTries,
bool isRetry,
bool isLastAttempt,
Guid sessionId,
ILogger log,
CancellationToken cancellationToken)
{
HttpClient httpClient = await GetHttpClientAsync();
var request = new HttpRetryHandlerRequest(httpClient, requestFactory)
{
RequestTimeout = requestTimeout,
DownloadTimeout = downloadTimeout,
MaxTries = maxTries,
IsRetry = isRetry,
IsLastAttempt = isLastAttempt
};
// Add X-NuGet-Session-Id to all outgoing requests. This allows feeds to track nuget operations.
request.AddHeaders.Add(new KeyValuePair<string, IEnumerable<string>>(ProtocolConstants.SessionId, new[] { sessionId.ToString() }));
// Acquire the semaphore.
await _throttle.WaitAsync();
HttpResponseMessage response;
try
{
response = await RetryHandler.SendAsync(request, _packageSource.SourceUri.OriginalString, log, cancellationToken);
}
catch
{
// If the request fails, release the semaphore. If no exception is thrown by
// SendAsync, then the semaphore is released when the HTTP response message is
// disposed.
_throttle.Release();
throw;
}
return new ThrottledResponse(_throttle, response);
}
private async Task<HttpClient> GetHttpClientAsync()
{
// Create the http client on the first call
if (_httpClient == null)
{
await _httpClientLock.WaitAsync();
try
{
// Double check
if (_httpClient == null)
{
_httpClient = await CreateHttpClientAsync();
}
}
finally
{
_httpClientLock.Release();
}
}
return _httpClient;
}
private async Task<HttpClient> CreateHttpClientAsync()
{
var httpHandler = await _messageHandlerFactory();
var httpClient = new HttpClient(httpHandler.MessageHandler)
{
Timeout = Timeout.InfiniteTimeSpan
};
// Set user agent
UserAgent.SetUserAgent(httpClient);
// Set accept-language header
string acceptLanguage = CultureInfo.CurrentUICulture.ToString();
if (!string.IsNullOrEmpty(acceptLanguage))
{
httpClient.DefaultRequestHeaders.AcceptLanguage.ParseAdd(acceptLanguage);
}
return httpClient;
}
public string HttpCacheDirectory
{
get
{
if (_httpCacheDirectory == null)
{
_httpCacheDirectory = SettingsUtility.GetHttpCacheFolder();
}
return _httpCacheDirectory;
}
set { _httpCacheDirectory = value; }
}
protected virtual Stream? TryReadCacheFile(string uri, TimeSpan maxAge, string cacheFile)
{
// Do not need the uri here
return CachingUtility.ReadCacheFile(maxAge, cacheFile);
}
public static HttpSource Create(SourceRepository source)
{
return Create(source, NullThrottle.Instance);
}
public static HttpSource Create(SourceRepository source, IThrottle throttle)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}
if (throttle == null)
{
throw new ArgumentNullException(nameof(throttle));
}
Func<Task<HttpHandlerResource>> factory = () => source.GetResourceAsync<HttpHandlerResource>(CancellationToken.None);
return new HttpSource(source.PackageSource, factory, throttle);
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}
if (disposing)
{
if (_httpClient != null)
{
_httpClient.Dispose();
}
_httpClientLock.Dispose();
}
_disposed = true;
}
private class ThrottledResponse : IDisposable
{
private IThrottle? _throttle;
public ThrottledResponse(IThrottle throttle, HttpResponseMessage response)
{
if (throttle == null)
{
throw new ArgumentNullException(nameof(throttle));
}
if (response == null)
{
throw new ArgumentNullException(nameof(response));
}
_throttle = throttle;
Response = response;
}
public HttpResponseMessage Response { get; }
public void Dispose()
{
try
{
Response.Dispose();
}
finally
{
Interlocked.Exchange(ref _throttle, null)?.Release();
}
}
}
private void ThrowIfHttpUriAndInsecureConnectionsNotAllowed(string uri)
{
if (uri.StartsWith("http://", StringComparison.OrdinalIgnoreCase))
{
if (_packageSource.IsHttps && !_packageSource.AllowInsecureConnections)
{
throw new HttpSourceException(
string.Format(
CultureInfo.CurrentCulture,
Strings.Error_Insecure_HTTP,
_sourceUri.AbsoluteUri ?? "<unknown>",
uri));
}
}
}
}
}
|