File: ConvertAnonymousType\AbstractConvertAnonymousTypeToClassCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.GenerateEqualsAndGetHashCodeFromMembers;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.ConvertAnonymousType;
 
internal abstract class AbstractConvertAnonymousTypeToClassCodeRefactoringProvider<
    TExpressionSyntax,
    TNameSyntax,
    TIdentifierNameSyntax,
    TObjectCreationExpressionSyntax,
    TAnonymousObjectCreationExpressionSyntax,
    TNamespaceDeclarationSyntax>
    : AbstractConvertAnonymousTypeCodeRefactoringProvider<TAnonymousObjectCreationExpressionSyntax>
    where TExpressionSyntax : SyntaxNode
    where TNameSyntax : TExpressionSyntax
    where TIdentifierNameSyntax : TNameSyntax
    where TObjectCreationExpressionSyntax : TExpressionSyntax
    where TAnonymousObjectCreationExpressionSyntax : TExpressionSyntax
    where TNamespaceDeclarationSyntax : SyntaxNode
{
    protected abstract TObjectCreationExpressionSyntax CreateObjectCreationExpression(TNameSyntax nameNode, TAnonymousObjectCreationExpressionSyntax currentAnonymousObject);
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, textSpan, cancellationToken) = context;
        var (anonymousObject, anonymousType) = await TryGetAnonymousObjectAsync(document, textSpan, cancellationToken).ConfigureAwait(false);
 
        if (anonymousObject == null || anonymousType == null)
            return;
 
        // Check if the anonymous type actually references another anonymous type inside of it.
        // If it does, we can't convert this.  There is no way to describe this anonymous type
        // in the concrete type we create.
        var containsAnonymousType = anonymousType.GetMembers()
                                                 .OfType<IPropertySymbol>()
                                                 .Any(p => p.Type.ContainsAnonymousType());
        if (containsAnonymousType)
            return;
 
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        if (syntaxFacts.SupportsRecord(anonymousObject.SyntaxTree.Options))
        {
            context.RegisterRefactoring(
                CodeAction.Create(
                    FeaturesResources.Convert_to_record,
                    c => ConvertAsync(document, textSpan, isRecord: true, c),
                    nameof(FeaturesResources.Convert_to_record)),
                anonymousObject.Span);
        }
 
        context.RegisterRefactoring(
            CodeAction.Create(
                FeaturesResources.Convert_to_class,
                c => ConvertAsync(document, textSpan, isRecord: false, c),
                nameof(FeaturesResources.Convert_to_class)),
            anonymousObject.Span);
    }
 
    private async Task<Document> ConvertAsync(Document document, TextSpan span, bool isRecord, CancellationToken cancellationToken)
    {
        var (anonymousObject, anonymousType) = await TryGetAnonymousObjectAsync(document, span, cancellationToken).ConfigureAwait(false);
 
        Debug.Assert(anonymousObject != null);
        Debug.Assert(anonymousType != null);
 
        var position = span.Start;
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        // Generate a unique name for the class we're creating.  We'll also add a rename
        // annotation so the user can pick the right name for the type afterwards.
        var className = NameGenerator.GenerateUniqueName(
            isRecord ? "NewRecord" : "NewClass",
            n => semanticModel.LookupSymbols(position, name: n).IsEmpty);
 
        // First, create the set of properties this class will have based on the properties the
        // anonymous type has itself.  Also, get a mapping of the original anonymous type's
        // properties to the new name we generated for it (if we converted camelCase to
        // PascalCase).
        var (properties, propertyMap) = GenerateProperties(document, anonymousType);
 
        // Next, generate the full class that will be used to replace all instances of this
        // anonymous type.
        var namedTypeSymbol = await GenerateFinalNamedTypeAsync(
            document, className, isRecord, properties, cancellationToken).ConfigureAwait(false);
 
        var generator = SyntaxGenerator.GetGenerator(document);
        var editor = new SyntaxEditor(root, generator);
 
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var containingMember = anonymousObject.FirstAncestorOrSelf<SyntaxNode, ISyntaxFactsService>((node, syntaxFacts) => syntaxFacts.IsMethodLevelMember(node), syntaxFacts) ?? anonymousObject;
 
        // Next, go and update any references to these anonymous type properties to match
        // the new PascalCased name we've picked for the new properties that will go in
        // the named type.
        await ReplacePropertyReferencesAsync(
            document, editor, containingMember,
            propertyMap, cancellationToken).ConfigureAwait(false);
 
        // Next, go through and replace all matching anonymous types in this method with a call
        // to construct the new named type we've generated.  
        await ReplaceMatchingAnonymousTypesAsync(
            document, editor, namedTypeSymbol,
            containingMember, anonymousObject,
            anonymousType, cancellationToken).ConfigureAwait(false);
 
        var context = new CodeGenerationContext(
            generateMembers: true,
            sortMembers: false,
            autoInsertionLocation: false);
 
        var info = await document.GetCodeGenerationInfoAsync(context, cancellationToken).ConfigureAwait(false);
        var formattingOptions = await document.GetSyntaxFormattingOptionsAsync(cancellationToken).ConfigureAwait(false);
 
        // Then, actually insert the new class in the appropriate container.
        var container = anonymousObject.GetAncestor<TNamespaceDeclarationSyntax>() ?? root;
        editor.ReplaceNode(container, (currentContainer, _) =>
            info.Service.AddNamedType(currentContainer, namedTypeSymbol, info, cancellationToken));
 
        var updatedDocument = document.WithSyntaxRoot(editor.GetChangedRoot());
 
        // Finally, format using the equals+getHashCode service so that our generated methods
        // follow any special formatting rules specific to them.
        var equalsAndGetHashCodeService = document.GetRequiredLanguageService<IGenerateEqualsAndGetHashCodeService>();
        return await equalsAndGetHashCodeService.FormatDocumentAsync(
            updatedDocument, formattingOptions, cancellationToken).ConfigureAwait(false);
    }
 
    private static async Task ReplacePropertyReferencesAsync(
        Document document, SyntaxEditor editor, SyntaxNode containingMember,
        ImmutableDictionary<IPropertySymbol, string> propertyMap, CancellationToken cancellationToken)
    {
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var identifiers = containingMember.DescendantNodes().OfType<TIdentifierNameSyntax>();
 
        foreach (var identifier in identifiers)
        {
            if (!syntaxFacts.IsNameOfSimpleMemberAccessExpression(identifier) &&
                !syntaxFacts.IsNameOfMemberBindingExpression(identifier))
            {
                continue;
            }
 
            if (semanticModel.GetSymbolInfo(identifier, cancellationToken).GetAnySymbol() is not IPropertySymbol symbol)
                continue;
 
            if (propertyMap.TryGetValue(symbol, out var newName))
            {
                editor.ReplaceNode(
                    identifier,
                    (currentId, g) => g.IdentifierName(newName).WithTriviaFrom(currentId));
            }
        }
    }
 
    private async Task ReplaceMatchingAnonymousTypesAsync(
        Document document, SyntaxEditor editor, INamedTypeSymbol classSymbol,
        SyntaxNode containingMember, TAnonymousObjectCreationExpressionSyntax creationNode,
        INamedTypeSymbol anonymousType, CancellationToken cancellationToken)
    {
        // When invoked we want to fixup all creations of the "same" anonymous type within the
        // containing method.  We define same-ness as meaning "they have the type symbol".  this
        // means both have the same member names, in the same order, with the same member types.
        // We fix all these up in the method because the user may be creating several instances
        // of this anonymous type in that method and then combining them in interesting ways
        // (i.e. checking them for equality, using them in collections, etc.).  The language
        // guarantees within a method boundary that these will be the same type and can be used
        // together in this fashion.
        //
        // Note: we could consider expanding this in the future (potentially with another
        // lightbulb action).  Specifically, we could look in the containing type and replace
        // any matches in any methods.
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        var childCreationNodes = containingMember.DescendantNodesAndSelf()
                                                 .OfType<TAnonymousObjectCreationExpressionSyntax>();
 
        foreach (var childCreation in childCreationNodes)
        {
            var childType = semanticModel.GetTypeInfo(childCreation, cancellationToken).Type;
            if (childType == null)
            {
                Debug.Fail("We should always be able to get an anonymous type for any anonymous creation node.");
                continue;
            }
 
            if (anonymousType.Equals(childType))
                ReplaceWithObjectCreation(editor, classSymbol, creationNode, childCreation);
        }
    }
 
    private void ReplaceWithObjectCreation(
        SyntaxEditor editor, INamedTypeSymbol classSymbol,
        TAnonymousObjectCreationExpressionSyntax startingCreationNode,
        TAnonymousObjectCreationExpressionSyntax childCreation)
    {
        // Use the callback form as anonymous types may be nested, and we want to
        // properly replace them even in that case.
        editor.ReplaceNode(
            childCreation,
            (currentNode, g) =>
            {
                var currentAnonymousObject = (TAnonymousObjectCreationExpressionSyntax)currentNode;
 
                // If we hit the node the user started on, then add the rename annotation here.
                var className = classSymbol.Name;
                var classNameToken = startingCreationNode == childCreation
                    ? g.Identifier(className).WithAdditionalAnnotations(RenameAnnotation.Create())
                    : g.Identifier(className);
 
                var classNameNode = classSymbol.TypeParameters.Length == 0
                    ? (TNameSyntax)g.IdentifierName(classNameToken)
                    : (TNameSyntax)g.GenericName(classNameToken,
                        classSymbol.TypeParameters.Select(tp => g.IdentifierName(tp.Name)));
 
                return CreateObjectCreationExpression(classNameNode, currentAnonymousObject)
                    .WithAdditionalAnnotations(Formatter.Annotation);
            });
    }
 
    private static async Task<INamedTypeSymbol> GenerateFinalNamedTypeAsync(
        Document document, string typeName, bool isRecord,
        ImmutableArray<IPropertySymbol> properties,
        CancellationToken cancellationToken)
    {
        var generator = SyntaxGenerator.GetGenerator(document);
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        // Next, see if any of the properties ended up using any type parameters from the
        // containing method/named-type.  If so, we'll need to generate a generic type so we can
        // properly pass these along.
        var capturedTypeParameters =
            properties.Select(p => p.Type)
                      .SelectMany(t => t.GetReferencedTypeParameters())
                      .Distinct()
                      .ToImmutableArray();
 
        using var _ = ArrayBuilder<ISymbol>.GetInstance(out var members);
 
        if (isRecord)
        {
            // Create a record with a single primary constructor, containing a parameter for all the properties
            // we're generating.
            var constructor = CodeGenerationSymbolFactory.CreateConstructorSymbol(
                attributes: default,
                Accessibility.Public,
                modifiers: default,
                typeName,
                properties.SelectAsArray(prop => CodeGenerationSymbolFactory.CreateParameterSymbol(prop.Type, prop.Name)),
                isPrimaryConstructor: true);
 
            members.Add(constructor);
        }
        else
        {
            // Now try to generate all the members that will go in the new class. This is a bit
            // circular.  In order to generate some of the members, we need to know about the type.
            // But in order to create the type, we need the members.  To address this we do two
            // passes. First, we create an empty version of the class.  This can then be used to
            // help create members like Equals/GetHashCode.  Then, once we have all the members we
            // create the final type.
            var namedTypeWithoutMembers = CreateNamedType(typeName, isRecord: false, capturedTypeParameters, members: default);
 
            var constructor = CreateClassConstructor(semanticModel, typeName, properties, generator);
 
            // Generate Equals/GetHashCode.  Only readonly properties are suitable for these
            // methods.  We can defer to our existing language service for this so that we
            // generate the same Equals/GetHashCode that our other IDE features generate.
            var readonlyProperties = ImmutableArray<ISymbol>.CastUp(
                properties.WhereAsArray(p => p.SetMethod == null));
 
            var equalsAndGetHashCodeService = document.GetRequiredLanguageService<IGenerateEqualsAndGetHashCodeService>();
 
            var equalsMethod = await equalsAndGetHashCodeService.GenerateEqualsMethodAsync(
                document, namedTypeWithoutMembers, readonlyProperties,
                localNameOpt: SyntaxGeneratorExtensions.OtherName, cancellationToken).ConfigureAwait(false);
            var getHashCodeMethod = await equalsAndGetHashCodeService.GenerateGetHashCodeMethodAsync(
                document, namedTypeWithoutMembers,
                readonlyProperties, cancellationToken).ConfigureAwait(false);
 
            members.AddRange(properties);
            members.Add(constructor);
            members.Add(equalsMethod);
            members.Add(getHashCodeMethod);
 
        }
 
        return CreateNamedType(typeName, isRecord, capturedTypeParameters, members.ToImmutable());
    }
 
    private static INamedTypeSymbol CreateNamedType(
        string className, bool isRecord, ImmutableArray<ITypeParameterSymbol> capturedTypeParameters, ImmutableArray<ISymbol> members)
    {
        return CodeGenerationSymbolFactory.CreateNamedTypeSymbol(
            attributes: default, Accessibility.Internal, modifiers: default,
            isRecord, TypeKind.Class, className, capturedTypeParameters, members: members);
    }
 
    private static (ImmutableArray<IPropertySymbol> properties, ImmutableDictionary<IPropertySymbol, string> propertyMap) GenerateProperties(
        Document document, INamedTypeSymbol anonymousType)
    {
        var originalProperties = anonymousType.GetMembers().OfType<IPropertySymbol>().ToImmutableArray();
        var newProperties = originalProperties.SelectAsArray(p => GenerateProperty(document, p));
 
        // If we changed the names of any properties, record that name mapping.  We'll
        // use this to update reference to the old anonymous-type properties to the new
        // names.
        var builder = ImmutableDictionary.CreateBuilder<IPropertySymbol, string>();
        for (var i = 0; i < originalProperties.Length; i++)
        {
            var originalProperty = originalProperties[i];
            var newProperty = newProperties[i];
 
            if (originalProperty.Name != newProperty.Name)
                builder[originalProperty] = newProperty.Name;
        }
 
        return (newProperties, builder.ToImmutable());
    }
 
    private static IPropertySymbol GenerateProperty(Document document, IPropertySymbol prop)
    {
        // The actual properties generated by C#/VB are not what we want.  For example, they
        // think of themselves as having real accessors that will read/write into a real field
        // in the type.  Instead, we just want to generate auto-props. So we effectively clone
        // the property, just throwing aways anything we don't need for that purpose.
        // 
        // We also want to follow general .NET naming.  So that means converting to pascal
        // case from camel-case.
 
        var getMethod = prop.GetMethod != null ? CreateAccessorSymbol(prop, MethodKind.PropertyGet) : null;
        var setMethod = prop.SetMethod != null ? CreateAccessorSymbol(prop, MethodKind.PropertySet) : null;
 
        return CodeGenerationSymbolFactory.CreatePropertySymbol(
            attributes: default, Accessibility.Public, modifiers: default,
            prop.Type, refKind: default, explicitInterfaceImplementations: default,
            GetLegalName(prop.Name.ToPascalCase(trimLeadingTypePrefix: false), document),
            parameters: default, getMethod: getMethod, setMethod: setMethod);
    }
 
    private static string GetLegalName(string name, Document document)
    {
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        return syntaxFacts.IsLegalIdentifier(name)
            ? name
            : "Item"; // Just a dummy name for the property.  Does not need to be localized.
    }
 
    private static IMethodSymbol CreateAccessorSymbol(IPropertySymbol prop, MethodKind kind)
        => CodeGenerationSymbolFactory.CreateMethodSymbol(
               attributes: default, Accessibility.Public, DeclarationModifiers.Abstract,
               prop.Type, refKind: default, explicitInterfaceImplementations: default,
               name: "", typeParameters: default, parameters: default, methodKind: kind);
 
    private static IMethodSymbol CreateClassConstructor(
        SemanticModel semanticModel, string className,
        ImmutableArray<IPropertySymbol> properties, SyntaxGenerator generator)
    {
        // For every property, create a corresponding parameter, as well as an assignment
        // statement from that parameter to the property.
        var parameterToPropMap = new Dictionary<string, ISymbol>();
        var parameters = properties.SelectAsArray(prop =>
        {
            var parameter = CodeGenerationSymbolFactory.CreateParameterSymbol(
                prop.Type, prop.Name.ToCamelCase(trimLeadingTypePrefix: false));
 
            parameterToPropMap[parameter.Name] = prop;
 
            return parameter;
        });
 
        var assignmentStatements = generator.CreateAssignmentStatements(
            generator.SyntaxGeneratorInternal,
            semanticModel, parameters, parameterToPropMap, ImmutableDictionary<string, string>.Empty,
            addNullChecks: false, preferThrowExpression: false);
 
        var constructor = CodeGenerationSymbolFactory.CreateConstructorSymbol(
            attributes: default, Accessibility.Public, modifiers: default,
            className, parameters, assignmentStatements);
 
        return constructor;
    }
}