File: DefaultUserPrincipalProviderTests.cs
Web Access
Project: src\tests\Aspire.Hosting.Azure.Tests\Aspire.Hosting.Azure.Tests.csproj (Aspire.Hosting.Azure.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Text;
using System.Text.Json;
using Aspire.Hosting.Azure.Provisioning.Internal;
using Azure.Core;
 
namespace Aspire.Hosting.Azure.Tests;
 
public class DefaultUserPrincipalProviderTests
{
    [Fact]
    public async Task GetUserPrincipalAsync_ReturnsValidUserPrincipal()
    {
        // Arrange
        var tokenCredentialProvider = ProvisioningTestHelpers.CreateTokenCredentialProvider();
        var provider = new DefaultUserPrincipalProvider(tokenCredentialProvider);
 
        // Act
        var principal = await provider.GetUserPrincipalAsync();
 
        // Assert
        Assert.NotNull(principal);
        Assert.Equal(Guid.Parse("11111111-2222-3333-4444-555555555555"), principal.Id);
        Assert.Equal("test@example.com", principal.Name);
    }
 
    [Fact]
    public async Task GetUserPrincipalAsync_ParsesTokenCorrectly()
    {
        // Arrange
        var expectedOid = Guid.NewGuid();
        var expectedUpn = "test@example.com";
        var token = CreateTestToken(expectedOid, expectedUpn);
        var tokenCredentialProvider = new TestTokenCredentialProviderWithCustomToken(token);
        var provider = new DefaultUserPrincipalProvider(tokenCredentialProvider);
 
        // Act
        var principal = await provider.GetUserPrincipalAsync();
 
        // Assert
        Assert.Equal(expectedOid, principal.Id);
        Assert.Equal(expectedUpn, principal.Name);
    }
 
    [Fact]
    public async Task GetUserPrincipalAsync_ParsesTokenWithEmail()
    {
        // Arrange
        var expectedOid = Guid.NewGuid();
        var expectedEmail = "user@company.com";
        var token = CreateTestTokenWithEmail(expectedOid, expectedEmail);
        var tokenCredentialProvider = new TestTokenCredentialProviderWithCustomToken(token);
        var provider = new DefaultUserPrincipalProvider(tokenCredentialProvider);
 
        // Act
        var principal = await provider.GetUserPrincipalAsync();
 
        // Assert
        Assert.Equal(expectedOid, principal.Id);
        Assert.Equal(expectedEmail, principal.Name);
    }
 
    [Fact]
    public async Task GetUserPrincipalAsync_HandlesCancellation()
    {
        // Arrange
        var tokenCredentialProvider = ProvisioningTestHelpers.CreateTokenCredentialProvider();
        var provider = new DefaultUserPrincipalProvider(tokenCredentialProvider);
        using var cts = new CancellationTokenSource();
        cts.Cancel();
 
        // Act & Assert
        await Assert.ThrowsAsync<OperationCanceledException>(
            () => provider.GetUserPrincipalAsync(cts.Token));
    }
 
    [Fact]
    public async Task GetUserPrincipalAsync_RespectsTokenCredentialProviderDependency()
    {
        // Arrange
        var customTokenCredential = new TestTokenCredential();
        var tokenCredentialProvider = new TestTokenCredentialProviderWithCredential(customTokenCredential);
        var provider = new DefaultUserPrincipalProvider(tokenCredentialProvider);
 
        // Act
        var principal = await provider.GetUserPrincipalAsync();
 
        // Assert
        Assert.NotNull(principal);
        // Verify the provider used the injected token credential
        Assert.True(customTokenCredential.GetTokenCalled);
    }
 
    private static string CreateTestToken(Guid oid, string upn)
    {
        var header = Convert.ToBase64String(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new { alg = "RS256", typ = "JWT" })));
        
        var payload = new
        {
            oid = oid.ToString(),
            upn = upn,
            exp = DateTimeOffset.UtcNow.AddHours(1).ToUnixTimeSeconds(),
            iat = DateTimeOffset.UtcNow.ToUnixTimeSeconds()
        };
        
        var payloadJson = JsonSerializer.Serialize(payload);
        var payloadBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(payloadJson))
            .TrimEnd('=')
            .Replace('+', '-')
            .Replace('/', '_');
 
        var signature = Convert.ToBase64String(Encoding.UTF8.GetBytes("test-signature"));
 
        return $"{header}.{payloadBase64}.{signature}";
    }
 
    private static string CreateTestTokenWithEmail(Guid oid, string email)
    {
        var header = Convert.ToBase64String(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new { alg = "RS256", typ = "JWT" })));
        
        var payload = new
        {
            oid = oid.ToString(),
            email = email,
            exp = DateTimeOffset.UtcNow.AddHours(1).ToUnixTimeSeconds(),
            iat = DateTimeOffset.UtcNow.ToUnixTimeSeconds()
        };
        
        var payloadJson = JsonSerializer.Serialize(payload);
        var payloadBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(payloadJson))
            .TrimEnd('=')
            .Replace('+', '-')
            .Replace('/', '_');
 
        var signature = Convert.ToBase64String(Encoding.UTF8.GetBytes("test-signature"));
 
        return $"{header}.{payloadBase64}.{signature}";
    }
 
    private sealed class TestTokenCredentialProviderWithCustomToken(string token) : ITokenCredentialProvider
    {
        public TokenCredential TokenCredential => new TestTokenCredentialWithCustomToken(token);
    }
 
    private sealed class TestTokenCredentialWithCustomToken(string token) : TokenCredential
    {
        public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
        {
            return new AccessToken(token, DateTimeOffset.UtcNow.AddHours(1));
        }
 
        public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
        {
            cancellationToken.ThrowIfCancellationRequested();
            return ValueTask.FromResult(new AccessToken(token, DateTimeOffset.UtcNow.AddHours(1)));
        }
    }
 
    private sealed class TestTokenCredentialProviderWithCredential(TestTokenCredential tokenCredential) : ITokenCredentialProvider
    {
        public TokenCredential TokenCredential => tokenCredential;
    }
 
    private sealed class TestTokenCredential : TokenCredential
    {
        public bool GetTokenCalled { get; private set; }
 
        public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
        {
            GetTokenCalled = true;
            var oid = Guid.NewGuid();
            var upn = "test@example.com";
            var token = CreateTestToken(oid, upn);
            return new AccessToken(token, DateTimeOffset.UtcNow.AddHours(1));
        }
 
        public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
        {
            GetTokenCalled = true;
            cancellationToken.ThrowIfCancellationRequested();
            var oid = Guid.NewGuid();
            var upn = "test@example.com";
            var token = CreateTestToken(oid, upn);
            return ValueTask.FromResult(new AccessToken(token, DateTimeOffset.UtcNow.AddHours(1)));
        }
 
        private static string CreateTestToken(Guid oid, string upn)
        {
            var header = Convert.ToBase64String(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new { alg = "RS256", typ = "JWT" })));
            
            var payload = new
            {
                oid = oid.ToString(),
                upn = upn,
                exp = DateTimeOffset.UtcNow.AddHours(1).ToUnixTimeSeconds(),
                iat = DateTimeOffset.UtcNow.ToUnixTimeSeconds()
            };
            
            var payloadJson = JsonSerializer.Serialize(payload);
            var payloadBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(payloadJson))
                .TrimEnd('=')
                .Replace('+', '-')
                .Replace('/', '_');
 
            var signature = Convert.ToBase64String(Encoding.UTF8.GetBytes("test-signature"));
 
            return $"{header}.{payloadBase64}.{signature}";
        }
    }
}