File: src\Components\Common\ManagedIdentityTokenCredentialHelpers.cs
Web Access
Project: src\src\Components\Aspire.Azure.Npgsql.EntityFrameworkCore.PostgreSQL\Aspire.Azure.Npgsql.EntityFrameworkCore.PostgreSQL.csproj (Aspire.Azure.Npgsql.EntityFrameworkCore.PostgreSQL)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Text.Json;
using Azure.Core;
using Azure.Identity;
using Npgsql;
 
namespace Aspire;
 
internal static class ManagedIdentityTokenCredentialHelpers
{
    private const string AzureDatabaseForPostgresSqlScope = "https://ossrdbms-aad.database.windows.net/.default";
    private const string AzureManagementScope = "https://management.azure.com/.default";
 
    private static readonly TokenRequestContext s_databaseForPostgresSqlTokenRequestContext = new([AzureDatabaseForPostgresSqlScope]);
    private static readonly TokenRequestContext s_managementTokenRequestContext = new([AzureManagementScope]);
 
    public static bool ConfigureEntraIdAuthentication(this NpgsqlDataSourceBuilder dataSourceBuilder, TokenCredential? credential)
    {
        credential ??= new DefaultAzureCredential();
        var configuredAuth = false;
 
        // The connection string requires the username to be provided. Since it will depend on the Managed Identity that is used
        // we attempt to get the username from the access token if it's not defined.
 
        if (string.IsNullOrEmpty(dataSourceBuilder.ConnectionStringBuilder.Username))
        {
            // Ensure to use the management scope, so the token contains user names for all managed identity types - e.g. user and service principal
            var token = credential.GetToken(s_managementTokenRequestContext, default);
 
            if (TryGetUsernameFromToken(token.Token, out var username))
            {
                dataSourceBuilder.ConnectionStringBuilder.Username = username;
                configuredAuth = true;
            }
            else
            {
                // Otherwise check using the PostgresSql scope
                token = credential.GetToken(s_databaseForPostgresSqlTokenRequestContext, default);
 
                if (TryGetUsernameFromToken(token.Token, out username))
                {
                    dataSourceBuilder.ConnectionStringBuilder.Username = username;
                    configuredAuth = true;
                }
            }
 
            // If we still don't have a username, we let Npgsql handle the error when trying to connect.
            // The user will be hinted to provide a username by using the configureDataSourceBuilder callback.
        }
 
        if (string.IsNullOrEmpty(dataSourceBuilder.ConnectionStringBuilder.Password))
        {
            // The token is not cached since it is refreshed for each new physical connection, or when it has expired.
 
            dataSourceBuilder.UsePasswordProvider(
                passwordProvider: _ => credential.GetToken(s_databaseForPostgresSqlTokenRequestContext, default).Token,
                passwordProviderAsync: async (_, ct) => (await credential.GetTokenAsync(s_databaseForPostgresSqlTokenRequestContext, default).ConfigureAwait(false)).Token
            );
 
            configuredAuth = true;
        }
 
        return configuredAuth;
    }
 
    private static bool TryGetUsernameFromToken(string jwtToken, out string? username)
    {
        username = null;
 
        // Split the token into its parts (Header, Payload, Signature)
        var tokenParts = jwtToken.Split('.');
        if (tokenParts.Length != 3)
        {
            return false;
        }
 
        // The payload is the second part, Base64Url encoded
        var payload = tokenParts[1];
 
        // Add padding if necessary
        payload = AddBase64Padding(payload);
 
        // Decode the payload from Base64Url
        var decodedBytes = Convert.FromBase64String(payload);
 
        // Parse the decoded payload as JSON
        var reader = new Utf8JsonReader(decodedBytes);
        var payloadJson = JsonElement.ParseValue(ref reader);
 
        // Try to get the username from 'xms_mirid', 'upn', 'preferred_username', or 'unique_name' claims
        if (payloadJson.TryGetProperty("xms_mirid", out var xms_mirid) &&
            xms_mirid.GetString() is string xms_miridString &&
            ParsePrincipalName(xms_miridString) is string principalName)
        {
            username = principalName;
        }
        else if (payloadJson.TryGetProperty("upn", out var upn))
        {
            username = upn.GetString();
        }
        else if (payloadJson.TryGetProperty("preferred_username", out var preferredUsername))
        {
            username = preferredUsername.GetString();
        }
        else if (payloadJson.TryGetProperty("unique_name", out var uniqueName))
        {
            username = uniqueName.GetString();
        }
 
        return username != null;
    }
 
    // parse the xms_mirid claim which look like
    // /subscriptions/{subId}/resourcegroups/{resourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{principalName}
    private static string? ParsePrincipalName(string xms_mirid)
    {
        var lastSlashIndex = xms_mirid.LastIndexOf('/');
        if (lastSlashIndex == -1)
        {
            return null;
        }
 
        var beginning = xms_mirid.AsSpan(0, lastSlashIndex);
        var principalName = xms_mirid.AsSpan(lastSlashIndex + 1);
 
        if (principalName.IsEmpty || !beginning.EndsWith("providers/Microsoft.ManagedIdentity/userAssignedIdentities", StringComparison.OrdinalIgnoreCase))
        {
            return null;
        }
 
        return principalName.ToString();
    }
 
    private static string AddBase64Padding(string base64) => (base64.Length % 4) switch
    {
        2 => base64 + "==",
        3 => base64 + "=",
        _ => base64,
    };
}