|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Identity.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection.Extensions;
namespace Microsoft.Extensions.DependencyInjection;
/// <summary>
/// Contains extension methods to <see cref="IdentityBuilder"/> for adding entity framework stores.
/// </summary>
public static class IdentityEntityFrameworkBuilderExtensions
{
/// <summary>
/// Adds an Entity Framework implementation of identity information stores.
/// </summary>
/// <typeparam name="TContext">The Entity Framework database context to use.</typeparam>
/// <param name="builder">The <see cref="IdentityBuilder"/> instance this method extends.</param>
/// <returns>The <see cref="IdentityBuilder"/> instance this method extends.</returns>
public static IdentityBuilder AddEntityFrameworkStores<TContext>(this IdentityBuilder builder)
where TContext : DbContext
{
AddStores(builder.Services, builder.UserType, builder.RoleType, typeof(TContext));
return builder;
}
private static void AddStores(IServiceCollection services, Type userType, Type? roleType, Type contextType)
{
var identityUserType = FindGenericBaseType(userType, typeof(IdentityUser<>));
if (identityUserType == null)
{
throw new InvalidOperationException(Resources.NotIdentityUser);
}
var keyType = identityUserType.GenericTypeArguments[0];
if (roleType != null)
{
var identityRoleType = FindGenericBaseType(roleType, typeof(IdentityRole<>));
if (identityRoleType == null)
{
throw new InvalidOperationException(Resources.NotIdentityRole);
}
Type userStoreType;
Type roleStoreType;
var identityContext = FindGenericBaseType(contextType, typeof(IdentityDbContext<,,,,,,,>));
if (identityContext == null)
{
// If its a custom DbContext, we can only add the default POCOs
userStoreType = typeof(UserStore<,,,>).MakeGenericType(userType, roleType, contextType, keyType);
roleStoreType = typeof(RoleStore<,,>).MakeGenericType(roleType, contextType, keyType);
}
else
{
userStoreType = typeof(UserStore<,,,,,,,,>).MakeGenericType(userType, roleType, contextType,
identityContext.GenericTypeArguments[2],
identityContext.GenericTypeArguments[3],
identityContext.GenericTypeArguments[4],
identityContext.GenericTypeArguments[5],
identityContext.GenericTypeArguments[7],
identityContext.GenericTypeArguments[6]);
roleStoreType = typeof(RoleStore<,,,,>).MakeGenericType(roleType, contextType,
identityContext.GenericTypeArguments[2],
identityContext.GenericTypeArguments[4],
identityContext.GenericTypeArguments[6]);
}
services.TryAddScoped(typeof(IUserStore<>).MakeGenericType(userType), userStoreType);
services.TryAddScoped(typeof(IRoleStore<>).MakeGenericType(roleType), roleStoreType);
}
else
{ // No Roles
Type userStoreType;
var identityContext = FindGenericBaseType(contextType, typeof(IdentityUserContext<,,,,>));
if (identityContext == null)
{
// If its a custom DbContext, we can only add the default POCOs
userStoreType = typeof(UserOnlyStore<,,>).MakeGenericType(userType, contextType, keyType);
}
else
{
userStoreType = typeof(UserOnlyStore<,,,,,>).MakeGenericType(userType, contextType,
identityContext.GenericTypeArguments[1],
identityContext.GenericTypeArguments[2],
identityContext.GenericTypeArguments[3],
identityContext.GenericTypeArguments[4]);
}
services.TryAddScoped(typeof(IUserStore<>).MakeGenericType(userType), userStoreType);
}
}
private static Type? FindGenericBaseType(Type currentType, Type genericBaseType)
{
Type? type = currentType;
while (type != null)
{
var genericType = type.IsGenericType ? type.GetGenericTypeDefinition() : null;
if (genericType != null && genericType == genericBaseType)
{
return type;
}
type = type.BaseType;
}
return null;
}
}
|