File: AcrLoginService.cs
Web Access
Project: src\src\Aspire.Hosting.Azure\Aspire.Hosting.Azure.csproj (Aspire.Hosting.Azure)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#pragma warning disable ASPIRECONTAINERRUNTIME001
 
using System.Text.Json;
using System.Text.Json.Serialization;
using Aspire.Hosting.Publishing;
using Azure.Core;
using Microsoft.Extensions.Logging;
 
namespace Aspire.Hosting.Azure;
 
/// <summary>
/// Default implementation of <see cref="IAcrLoginService"/> that handles ACR authentication
/// using Azure credentials and OAuth2 token exchange.
/// </summary>
internal sealed class AcrLoginService : IAcrLoginService
{
    private const string AcrUsername = "00000000-0000-0000-0000-000000000000";
    private const string AcrScope = "https://containerregistry.azure.net/.default";
 
    private static readonly JsonSerializerOptions s_jsonOptions = new(JsonSerializerDefaults.Web)
    {
        PropertyNameCaseInsensitive = true
    };
 
    private readonly IHttpClientFactory _httpClientFactory;
    private readonly IContainerRuntime _containerRuntime;
    private readonly ILogger<AcrLoginService> _logger;
 
    private sealed class AcrRefreshTokenResponse
    {
        [JsonPropertyName("refresh_token")]
        public string? RefreshToken { get; set; }
 
        [JsonPropertyName("expires_in")]
        public int? ExpiresIn { get; set; }
    }
 
    /// <summary>
    /// Initializes a new instance of the <see cref="AcrLoginService"/> class.
    /// </summary>
    /// <param name="httpClientFactory">The HTTP client factory for making OAuth2 exchange requests.</param>
    /// <param name="containerRuntime">The container runtime for performing registry login.</param>
    /// <param name="logger">The logger for diagnostic output.</param>
    public AcrLoginService(IHttpClientFactory httpClientFactory, IContainerRuntime containerRuntime, ILogger<AcrLoginService> logger)
    {
        _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
        _containerRuntime = containerRuntime ?? throw new ArgumentNullException(nameof(containerRuntime));
        _logger = logger ?? throw new ArgumentNullException(nameof(logger));
    }
 
    /// <inheritdoc/>
    public async Task LoginAsync(
        string registryEndpoint,
        string tenantId,
        TokenCredential credential,
        CancellationToken cancellationToken = default)
    {
        ArgumentException.ThrowIfNullOrWhiteSpace(registryEndpoint);
        ArgumentException.ThrowIfNullOrWhiteSpace(tenantId);
        ArgumentNullException.ThrowIfNull(credential);
 
        // Step 1: Acquire AAD access token for ACR audience
        var tokenRequestContext = new TokenRequestContext([AcrScope]);
        var aadToken = await credential.GetTokenAsync(tokenRequestContext, cancellationToken).ConfigureAwait(false);
 
        _logger.LogDebug("AAD access token acquired for ACR audience, registry: {RegistryEndpoint}, token length: {TokenLength}",
            registryEndpoint, aadToken.Token.Length);
 
        // Step 2: Exchange AAD token for ACR refresh token
        var refreshToken = await ExchangeAadTokenForAcrRefreshTokenAsync(
            registryEndpoint, tenantId, aadToken.Token, cancellationToken).ConfigureAwait(false);
 
        _logger.LogDebug("ACR refresh token acquired, length: {TokenLength}", refreshToken.Length);
 
        // Step 3: Login to the registry using container runtime
        await _containerRuntime.LoginToRegistryAsync(registryEndpoint, AcrUsername, refreshToken, cancellationToken).ConfigureAwait(false);
    }
 
    private async Task<string> ExchangeAadTokenForAcrRefreshTokenAsync(
        string registryEndpoint,
        string tenantId,
        string aadAccessToken,
        CancellationToken cancellationToken)
    {
        // Use named HTTP client "AcrLogin" which can be configured for debug-level logging
        // via configuration: "Logging": { "LogLevel": { "System.Net.Http.HttpClient.AcrLogin": "Debug" } }
        var httpClient = _httpClientFactory.CreateClient("AcrLogin");
        httpClient.Timeout = TimeSpan.FromSeconds(30);
 
        // ACR OAuth2 exchange endpoint
        var exchangeUrl = $"https://{registryEndpoint}/oauth2/exchange";
 
        _logger.LogDebug("Exchanging AAD token for ACR refresh token at {ExchangeUrl} (tenant: {TenantId})",
            exchangeUrl,
            tenantId);
 
        var formData = new Dictionary<string, string>
        {
            ["grant_type"] = "access_token",
            ["service"] = registryEndpoint,
            ["tenant"] = tenantId,
            ["access_token"] = aadAccessToken
        };
 
        using var content = new FormUrlEncodedContent(formData);
        var response = await httpClient.PostAsync(exchangeUrl, content, cancellationToken).ConfigureAwait(false);
 
        // Read response body as string once
        var responseBody = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
 
        if (!response.IsSuccessStatusCode)
        {
            var truncatedBody = responseBody.Length <= 1000 ? responseBody : responseBody[..1000] + "…";
            throw new HttpRequestException(
                $"POST /oauth2/exchange failed {(int)response.StatusCode} {response.ReasonPhrase}. Body: {truncatedBody}",
                null,
                response.StatusCode);
        }
 
        // Deserialize from the string we already read
        var tokenResponse = JsonSerializer.Deserialize<AcrRefreshTokenResponse>(responseBody, s_jsonOptions);
 
        if (string.IsNullOrEmpty(tokenResponse?.RefreshToken))
        {
            throw new InvalidOperationException($"Response missing refresh_token.");
        }
 
        return tokenResponse.RefreshToken;
    }
}