File: InitializeParameter\CSharpInitializeMemberFromPrimaryConstructorParameterCodeRefactoringProvider_Update.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.InitializeParameter;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.InitializeParameter;
 
using static CSharpSyntaxTokens;
using static InitializeParameterHelpers;
using static InitializeParameterHelpersCore;
using static SyntaxFactory;
 
internal sealed partial class CSharpInitializeMemberFromPrimaryConstructorParameterCodeRefactoringProvider
{
    private static async Task<Solution> AddMultipleMembersAsync(
        Document document,
        TypeDeclarationSyntax typeDeclaration,
        ImmutableArray<IParameterSymbol> parameters,
        ImmutableArray<ISymbol> fieldsOrProperties,
        CancellationToken cancellationToken)
    {
        Debug.Assert(parameters.Length >= 1);
        Debug.Assert(fieldsOrProperties.Length >= 1);
        Debug.Assert(parameters.Length == fieldsOrProperties.Length);
 
        // Process each param+field/prop in order.  Apply the pair to the document getting the updated document.
        // Then find all the current data in that updated document and move onto the next pair.
 
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
 
        var trackedRoot = root.TrackNodes(typeDeclaration);
        var currentSolution = document.WithSyntaxRoot(trackedRoot).Project.Solution;
 
        foreach (var (parameter, fieldOrProperty) in parameters.Zip(fieldsOrProperties, static (a, b) => (a, b)))
        {
            var currentDocument = currentSolution.GetRequiredDocument(document.Id);
            var currentCompilation = await currentDocument.Project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
            var currentRoot = await currentDocument.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
 
            var currentTypeDeclaration = currentRoot.GetCurrentNode(typeDeclaration);
            if (currentTypeDeclaration == null)
                continue;
 
            var currentParameter = (IParameterSymbol?)parameter.GetSymbolKey(cancellationToken).Resolve(currentCompilation, cancellationToken: cancellationToken).GetAnySymbol();
            if (currentParameter == null)
                continue;
 
            // fieldOrProperty is a new member.  So we don't have to track it to this edit we're making.
 
            currentSolution = await AddSingleMemberAsync(
                currentDocument,
                currentTypeDeclaration,
                currentParameter,
                fieldOrProperty,
                cancellationToken).ConfigureAwait(false);
        }
 
        return currentSolution;
 
        // Intentionally static so that we do not capture outer state.  This function is called in a loop with a fresh
        // fork of the solution after each change that has been made.  This ensures we don't accidentally refer to the
        // original state in any way.
 
        static async Task<Solution> AddSingleMemberAsync(
            Document document,
            TypeDeclarationSyntax typeDeclaration,
            IParameterSymbol parameter,
            ISymbol fieldOrProperty,
            CancellationToken cancellationToken)
        {
            var project = document.Project;
            var solution = project.Solution;
            var services = solution.Services;
 
            var compilation = await project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
            var parseOptions = document.DocumentState.ParseOptions!;
 
            var solutionEditor = new SolutionEditor(solution);
            var options = await document.GetCodeGenerationOptionsAsync(cancellationToken).ConfigureAwait(false);
            var codeGenerator = document.GetRequiredLanguageService<ICodeGenerationService>();
 
            // We're assigning the parameter to a new field/prop .  Convert all existing references to this primary
            // constructor parameter (within this type) to refer to the field/prop now instead.
            await UpdateParameterReferencesAsync(
                solutionEditor, parameter, fieldOrProperty, cancellationToken).ConfigureAwait(false);
 
            // We're generating a new field/property.  Place into the containing type, ideally before/after a
            // relevant existing member.
            var (sibling, siblingSyntax, addContext) = fieldOrProperty switch
            {
                IPropertySymbol => GetAddContext<IPropertySymbol>(compilation, parameter, cancellationToken),
                IFieldSymbol => GetAddContext<IFieldSymbol>(compilation, parameter, cancellationToken),
                _ => throw ExceptionUtilities.UnexpectedValue(fieldOrProperty),
            };
 
            var preferredTypeDeclaration = siblingSyntax?.GetAncestorOrThis<TypeDeclarationSyntax>() ?? typeDeclaration;
 
            var editingDocument = solution.GetRequiredDocument(preferredTypeDeclaration.SyntaxTree);
            var editor = await solutionEditor.GetDocumentEditorAsync(editingDocument.Id, cancellationToken).ConfigureAwait(false);
            editor.ReplaceNode(
                preferredTypeDeclaration,
                (currentTypeDecl, _) =>
                {
                    if (fieldOrProperty is IPropertySymbol property)
                    {
                        return codeGenerator.AddProperty(
                            currentTypeDecl, property,
                            codeGenerator.GetInfo(addContext, options, parseOptions),
                            cancellationToken);
                    }
                    else if (fieldOrProperty is IFieldSymbol field)
                    {
                        return codeGenerator.AddField(
                            currentTypeDecl, field,
                            codeGenerator.GetInfo(addContext, options, parseOptions),
                            cancellationToken);
                    }
                    else
                    {
                        throw ExceptionUtilities.Unreachable();
                    }
                });
 
            return solutionEditor.GetChangedSolution();
        }
 
        static (ISymbol? symbol, SyntaxNode? syntax, CodeGenerationContext context) GetAddContext<TSymbol>(
            Compilation compilation, IParameterSymbol parameter, CancellationToken cancellationToken) where TSymbol : class, ISymbol
        {
            foreach (var (sibling, before) in GetSiblingParameters(parameter))
            {
                var (initializer, fieldOrProperty) = TryFindFieldOrPropertyInitializerValue(
                    compilation, sibling, cancellationToken);
 
                if (initializer != null &&
                    fieldOrProperty is TSymbol { DeclaringSyntaxReferences: [var syntaxReference, ..] } symbol)
                {
                    var syntax = syntaxReference.GetSyntax(cancellationToken);
                    return (symbol, syntax, before
                        ? new CodeGenerationContext(afterThisLocation: syntax.GetLocation())
                        : new CodeGenerationContext(beforeThisLocation: syntax.GetLocation()));
                }
            }
 
            return (symbol: null, syntax: null, CodeGenerationContext.Default);
        }
    }
 
    private static async Task UpdateParameterReferencesAsync(
        SolutionEditor solutionEditor,
        IParameterSymbol parameter,
        ISymbol fieldOrProperty,
        CancellationToken cancellationToken)
    {
        var solution = solutionEditor.OriginalSolution;
        var namedType = parameter.ContainingType;
        var documents = namedType.DeclaringSyntaxReferences
            .Select(r => solution.GetRequiredDocument(r.SyntaxTree))
            .ToImmutableHashSet();
 
        var references = await SymbolFinder.FindReferencesAsync(parameter, solution, documents, cancellationToken).ConfigureAwait(false);
        var groups = references.SelectMany(static r => r.Locations.Where(loc => !loc.IsImplicit)).GroupBy(static loc => loc.Document);
 
        foreach (var group in groups)
        {
            var editor = await solutionEditor.GetDocumentEditorAsync(group.Key.Id, cancellationToken).ConfigureAwait(false);
 
            // We may hit a location multiple times due to how we do FAR for linked symbols, but each linked symbol is
            // allowed to report the entire set of references it think it is compatible with.  So ensure we're hitting
            // each location only once.
            foreach (var location in group.Distinct(LinkedFileReferenceLocationEqualityComparer.Instance))
            {
                var node = location.Location.FindNode(getInnermostNodeForTie: true, cancellationToken);
                if (node is IdentifierNameSyntax { Parent: not NameColonSyntax } identifierName &&
                    identifierName.Identifier.ValueText == parameter.Name)
                {
                    // we may have things like `new MyType(x: ...)` we don't want to update `x` there to 'X'
                    // just because we're generating a new property 'X' for the parameter to be assigned to.
                    editor.ReplaceNode(
                        identifierName,
                        IdentifierName(fieldOrProperty.Name.EscapeIdentifier()).WithTriviaFrom(identifierName));
                }
            }
        }
    }
 
    private static async Task<Solution> UpdateExistingMemberAsync(
        Document document,
        IParameterSymbol parameter,
        ISymbol fieldOrProperty,
        CancellationToken cancellationToken)
    {
        var project = document.Project;
        var solution = project.Solution;
 
        var solutionEditor = new SolutionEditor(solution);
        var initializer = EqualsValueClause(IdentifierName(parameter.Name.EscapeIdentifier()));
 
        // We're assigning the parameter to a field/prop.  Convert all existing references to this primary constructor
        // parameter (within this type) to refer to the field/prop now instead.
        await UpdateParameterReferencesAsync(
            solutionEditor, parameter, fieldOrProperty, cancellationToken).ConfigureAwait(false);
 
        // We're updating an exiting field/prop.
        if (fieldOrProperty is IPropertySymbol property)
        {
            var compilation = await project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
            var isThrowNotImplementedProperty = IsThrowNotImplementedProperty(compilation, property, cancellationToken);
 
            foreach (var syntaxRef in property.DeclaringSyntaxReferences)
            {
                if (syntaxRef.GetSyntax(cancellationToken) is PropertyDeclarationSyntax propertyDeclaration)
                {
                    var editingDocument = solution.GetRequiredDocument(propertyDeclaration.SyntaxTree);
                    var editor = await solutionEditor.GetDocumentEditorAsync(editingDocument.Id, cancellationToken).ConfigureAwait(false);
 
                    // If the user had a property that has 'throw NotImplementedException' in it, then remove those throws.
                    var newPropertyDeclaration = isThrowNotImplementedProperty ? RemoveThrowNotImplemented(propertyDeclaration) : propertyDeclaration;
                    editor.ReplaceNode(
                        propertyDeclaration,
                        newPropertyDeclaration.WithoutTrailingTrivia()
                            .WithSemicolonToken(SemicolonToken.WithTrailingTrivia(newPropertyDeclaration.GetTrailingTrivia()))
                            .WithInitializer(initializer));
                    break;
                }
            }
        }
        else if (fieldOrProperty is IFieldSymbol field)
        {
            foreach (var syntaxRef in field.DeclaringSyntaxReferences)
            {
                if (syntaxRef.GetSyntax(cancellationToken) is VariableDeclaratorSyntax variableDeclarator)
                {
                    var editingDocument = solution.GetRequiredDocument(variableDeclarator.SyntaxTree);
                    var editor = await solutionEditor.GetDocumentEditorAsync(editingDocument.Id, cancellationToken).ConfigureAwait(false);
                    editor.ReplaceNode(
                        variableDeclarator,
                        variableDeclarator.WithInitializer(initializer));
                    break;
                }
            }
        }
 
        return solutionEditor.GetChangedSolution();
    }
}