|  | 
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.InitializeParameter;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.CSharp.InitializeParameter;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
internal static class InitializeParameterHelpers
{
    public static Argument<ExpressionSyntax> GetArgument(ArgumentSyntax argument)
        => new(argument.GetRefKind(), argument.NameColon?.Name.Identifier.ValueText, argument.Expression);
 
    public static async Task<Solution> AddAssignmentForPrimaryConstructorAsync(
        Document document,
        IParameterSymbol parameter,
        ISymbol fieldOrProperty,
        CancellationToken cancellationToken)
    {
        var project = document.Project;
        var solution = project.Solution;
 
        var solutionEditor = new SolutionEditor(solution);
        var initializer = EqualsValueClause(IdentifierName(parameter.Name.EscapeIdentifier()));
 
        // We're assigning the parameter to a field/prop.  Convert all existing references to this primary constructor
        // parameter (within this type) to refer to the field/prop now instead.
        await UpdateParameterReferencesAsync(
            solutionEditor, parameter, fieldOrProperty, cancellationToken).ConfigureAwait(false);
 
        // We're updating an exiting field/prop.
        if (fieldOrProperty is IPropertySymbol property)
        {
            var compilation = await project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
            var initializeParameterService = document.GetRequiredLanguageService<IInitializeParameterService>();
            var isThrowNotImplementedProperty = initializeParameterService.IsThrowNotImplementedProperty(
                compilation, property, cancellationToken);
 
            foreach (var syntaxRef in property.DeclaringSyntaxReferences)
            {
                if (syntaxRef.GetSyntax(cancellationToken) is PropertyDeclarationSyntax propertyDeclaration)
                {
                    var editingDocument = solution.GetRequiredDocument(propertyDeclaration.SyntaxTree);
                    var editor = await solutionEditor.GetDocumentEditorAsync(editingDocument.Id, cancellationToken).ConfigureAwait(false);
 
                    // If the user had a property that has 'throw NotImplementedException' in it, then remove those throws.
                    var newPropertyDeclaration = isThrowNotImplementedProperty ? RemoveThrowNotImplemented(propertyDeclaration) : propertyDeclaration;
                    editor.ReplaceNode(
                        propertyDeclaration,
                        newPropertyDeclaration.WithoutTrailingTrivia()
                            .WithSemicolonToken(SemicolonToken.WithTrailingTrivia(newPropertyDeclaration.GetTrailingTrivia()))
                            .WithInitializer(initializer));
                    break;
                }
            }
        }
        else if (fieldOrProperty is IFieldSymbol field)
        {
            foreach (var syntaxRef in field.DeclaringSyntaxReferences)
            {
                if (syntaxRef.GetSyntax(cancellationToken) is VariableDeclaratorSyntax variableDeclarator)
                {
                    var editingDocument = solution.GetRequiredDocument(variableDeclarator.SyntaxTree);
                    var editor = await solutionEditor.GetDocumentEditorAsync(editingDocument.Id, cancellationToken).ConfigureAwait(false);
                    editor.ReplaceNode(
                        variableDeclarator,
                        variableDeclarator.WithInitializer(initializer));
                    break;
                }
            }
        }
 
        return solutionEditor.GetChangedSolution();
    }
 
    public static async Task UpdateParameterReferencesAsync(
        SolutionEditor solutionEditor,
        IParameterSymbol parameter,
        ISymbol fieldOrProperty,
        CancellationToken cancellationToken)
    {
        var solution = solutionEditor.OriginalSolution;
        var namedType = parameter.ContainingType;
        var documents = namedType.DeclaringSyntaxReferences
            .Select(r => solution.GetRequiredDocument(r.SyntaxTree))
            .ToImmutableHashSet();
 
        var references = await SymbolFinder.FindReferencesAsync(parameter, solution, documents, cancellationToken).ConfigureAwait(false);
        var groups = references.SelectMany(static r => r.Locations.Where(loc => !loc.IsImplicit)).GroupBy(static loc => loc.Document);
 
        foreach (var group in groups)
        {
            var editor = await solutionEditor.GetDocumentEditorAsync(group.Key.Id, cancellationToken).ConfigureAwait(false);
 
            // We may hit a location multiple times due to how we do FAR for linked symbols, but each linked symbol is
            // allowed to report the entire set of references it think it is compatible with.  So ensure we're hitting
            // each location only once.
            foreach (var location in group.Distinct(LinkedFileReferenceLocationEqualityComparer.Instance))
            {
                var node = location.Location.FindNode(getInnermostNodeForTie: true, cancellationToken);
                if (node is IdentifierNameSyntax { Parent: not NameColonSyntax } identifierName &&
                    identifierName.Identifier.ValueText == parameter.Name)
                {
                    // we may have things like `new MyType(x: ...)` we don't want to update `x` there to 'X'
                    // just because we're generating a new property 'X' for the parameter to be assigned to.
                    editor.ReplaceNode(
                        identifierName,
                        IdentifierName(fieldOrProperty.Name.EscapeIdentifier()).WithTriviaFrom(identifierName));
                }
            }
        }
    }
 
    public static bool IsFunctionDeclaration(SyntaxNode node)
        => node is BaseMethodDeclarationSyntax or LocalFunctionStatementSyntax or AnonymousFunctionExpressionSyntax;
 
    public static SyntaxNode GetBody(SyntaxNode functionDeclaration)
        => functionDeclaration switch
        {
            BaseMethodDeclarationSyntax methodDeclaration => (SyntaxNode?)methodDeclaration.Body ?? methodDeclaration.ExpressionBody!,
            LocalFunctionStatementSyntax localFunction => (SyntaxNode?)localFunction.Body ?? localFunction.ExpressionBody!,
            AnonymousFunctionExpressionSyntax anonymousFunction => anonymousFunction.Body,
            _ => throw ExceptionUtilities.UnexpectedValue(functionDeclaration),
        };
 
    private static SyntaxToken? TryGetSemicolonToken(SyntaxNode functionDeclaration)
        => functionDeclaration switch
        {
            BaseMethodDeclarationSyntax methodDeclaration => methodDeclaration.SemicolonToken,
            LocalFunctionStatementSyntax localFunction => localFunction.SemicolonToken,
            AnonymousFunctionExpressionSyntax _ => null,
            _ => throw ExceptionUtilities.UnexpectedValue(functionDeclaration),
        };
 
    public static bool IsImplicitConversion(Compilation compilation, ITypeSymbol source, ITypeSymbol destination)
        => compilation.ClassifyConversion(source: source, destination: destination).IsImplicit;
 
    public static SyntaxNode? TryGetLastStatement(IBlockOperation? blockStatement)
        => blockStatement?.Syntax is BlockSyntax block
            ? block.Statements.LastOrDefault()
            : blockStatement?.Syntax;
 
    public static void InsertStatement(
        SyntaxEditor editor,
        SyntaxNode functionDeclaration,
        bool returnsVoid,
        SyntaxNode? statementToAddAfterOpt,
        StatementSyntax statement)
    {
        var body = GetBody(functionDeclaration);
 
        if (IsExpressionBody(body))
        {
            var semicolonToken = TryGetSemicolonToken(functionDeclaration) ?? SemicolonToken;
 
            if (!TryConvertExpressionBodyToStatement(body, semicolonToken, !returnsVoid, out var convertedStatement))
            {
                return;
            }
 
            // Add the new statement as the first/last statement of the new block 
            // depending if we were asked to go after something or not.
            editor.SetStatements(functionDeclaration, statementToAddAfterOpt == null
                ? [statement, convertedStatement]
                : [convertedStatement, statement]);
        }
        else if (body is BlockSyntax block)
        {
            // Look for the statement we were asked to go after.
            var indexToAddAfter = block.Statements.IndexOf(s => s == statementToAddAfterOpt);
            if (indexToAddAfter >= 0)
            {
                // If we find it, then insert the new statement after it.
                editor.InsertAfter(block.Statements[indexToAddAfter], statement);
            }
            else if (block.Statements.Count > 0)
            {
                // Otherwise, if we have multiple statements already, then insert ourselves
                // before the first one.
                editor.InsertBefore(block.Statements[0], statement);
            }
            else
            {
                // Otherwise, we have no statements in this block.  Add the new statement
                // as the single statement the block will have.
                Debug.Assert(block.Statements.Count == 0);
                editor.ReplaceNode(block, (currentBlock, _) => ((BlockSyntax)currentBlock).AddStatements(statement));
            }
 
            // If the block was on a single line before, the format it so that the formatting
            // engine will update it to go over multiple lines. Otherwise, we can end up in
            // the strange state where the { and } tokens stay where they were originally,
            // which will look very strange like:
            //
            //          a => {
            //              if (...) {
            //              } };
            if (CSharpSyntaxFacts.Instance.IsOnSingleLine(block, fullSpan: false))
            {
                editor.ReplaceNode(
                    block,
                    (currentBlock, _) => currentBlock.WithAdditionalAnnotations(Formatter.Annotation));
            }
        }
        else
        {
            editor.SetStatements(functionDeclaration, ImmutableArray.Create(statement));
        }
    }
 
    // either from an expression lambda or expression bodied member
    public static bool IsExpressionBody(SyntaxNode body)
        => body is ExpressionSyntax or ArrowExpressionClauseSyntax;
 
    public static bool TryConvertExpressionBodyToStatement(
        SyntaxNode body,
        SyntaxToken semicolonToken,
        bool createReturnStatementForExpression,
        [NotNullWhen(true)] out StatementSyntax? statement)
    {
        Debug.Assert(IsExpressionBody(body));
 
        return body switch
        {
            // If this is a => method, then we'll have to convert the method to have a block body.
            ArrowExpressionClauseSyntax arrowClause => arrowClause.TryConvertToStatement(semicolonToken, createReturnStatementForExpression, out statement),
            // must be an expression lambda
            ExpressionSyntax expression => expression.TryConvertToStatement(semicolonToken, createReturnStatementForExpression, out statement),
            _ => throw ExceptionUtilities.UnexpectedValue(body),
        };
    }
 
    public static SyntaxNode? GetAccessorBody(IMethodSymbol accessor, CancellationToken cancellationToken)
    {
        var node = accessor.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken);
        if (node is AccessorDeclarationSyntax accessorDeclaration)
            return accessorDeclaration.ExpressionBody ?? (SyntaxNode?)accessorDeclaration.Body;
 
        // `int Age => ...;`
        if (node is ArrowExpressionClauseSyntax arrowExpression)
            return arrowExpression;
 
        return null;
    }
 
    public static SyntaxNode RemoveThrowNotImplemented(SyntaxNode node)
        => node is PropertyDeclarationSyntax propertyDeclaration ? RemoveThrowNotImplemented(propertyDeclaration) : node;
 
    public static PropertyDeclarationSyntax RemoveThrowNotImplemented(PropertyDeclarationSyntax propertyDeclaration)
    {
        if (propertyDeclaration.ExpressionBody != null)
        {
            var result = propertyDeclaration
                .WithExpressionBody(null)
                .WithSemicolonToken(default)
                .AddAccessorListAccessors(SyntaxFactory
                    .AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
                    .WithSemicolonToken(SemicolonToken))
                .WithTrailingTrivia(propertyDeclaration.SemicolonToken.TrailingTrivia)
                .WithAdditionalAnnotations(Formatter.Annotation);
            return result;
        }
 
        if (propertyDeclaration.AccessorList != null)
        {
            var accessors = propertyDeclaration.AccessorList.Accessors.Select(RemoveThrowNotImplemented);
            return propertyDeclaration.WithAccessorList(
                propertyDeclaration.AccessorList.WithAccessors([.. accessors]));
        }
 
        return propertyDeclaration;
    }
 
    private static AccessorDeclarationSyntax RemoveThrowNotImplemented(AccessorDeclarationSyntax accessorDeclaration)
    {
        var result = accessorDeclaration
            .WithExpressionBody(null)
            .WithBody(null)
            .WithSemicolonToken(SemicolonToken);
 
        return result.WithTrailingTrivia(accessorDeclaration.Body?.GetTrailingTrivia() ?? accessorDeclaration.SemicolonToken.TrailingTrivia);
    }
}
 |