File: src\Analyzers\CSharp\CodeFixes\UseLocalFunction\CSharpUseLocalFunctionCodeFixProvider.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Host;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.UseLocalFunction;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
[ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.UseLocalFunction), Shared]
[method: ImportingConstructor]
[method: SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
internal sealed class CSharpUseLocalFunctionCodeFixProvider() : SyntaxEditorBasedCodeFixProvider
{
    private static readonly TypeSyntax s_objectType = PredefinedType(ObjectKeyword);
 
    public override ImmutableArray<string> FixableDiagnosticIds
        => [IDEDiagnosticIds.UseLocalFunctionDiagnosticId];
 
    protected override bool IncludeDiagnosticDuringFixAll(Diagnostic diagnostic)
        => !diagnostic.IsSuppressed;
 
    public override Task RegisterCodeFixesAsync(CodeFixContext context)
    {
        RegisterCodeFix(context, CSharpAnalyzersResources.Use_local_function, nameof(CSharpAnalyzersResources.Use_local_function));
        return Task.CompletedTask;
    }
 
    protected override async Task FixAllAsync(
        Document document, ImmutableArray<Diagnostic> diagnostics,
        SyntaxEditor editor, CancellationToken cancellationToken)
    {
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        var nodesFromDiagnostics = new List<(
            LocalDeclarationStatementSyntax declaration,
            AnonymousFunctionExpressionSyntax function,
            List<ExpressionSyntax> references)>(diagnostics.Length);
 
        var nodesToTrack = new HashSet<SyntaxNode>();
 
        foreach (var diagnostic in diagnostics)
        {
            var localDeclaration = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(cancellationToken);
            var anonymousFunction = (AnonymousFunctionExpressionSyntax)diagnostic.AdditionalLocations[1].FindNode(cancellationToken);
 
            var references = new List<ExpressionSyntax>(diagnostic.AdditionalLocations.Count - 2);
 
            for (var i = 2; i < diagnostic.AdditionalLocations.Count; i++)
            {
                references.Add((ExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(getInnermostNodeForTie: true, cancellationToken));
            }
 
            nodesFromDiagnostics.Add((localDeclaration, anonymousFunction, references));
 
            nodesToTrack.Add(localDeclaration);
            nodesToTrack.Add(anonymousFunction);
            nodesToTrack.AddRange(references);
        }
 
        var root = editor.OriginalRoot;
        var currentRoot = root.TrackNodes(nodesToTrack);
 
        var languageVersion = semanticModel.SyntaxTree.Options.LanguageVersion();
        var makeStaticIfPossible = false;
 
        if (languageVersion >= LanguageVersion.CSharp8)
        {
            var info = await document.GetCodeGenerationInfoAsync(CodeGenerationContext.Default, cancellationToken).ConfigureAwait(false);
 
            var options = (CSharpCodeGenerationOptions)info.Options;
            makeStaticIfPossible = options.PreferStaticLocalFunction.Value;
        }
 
        // Process declarations in reverse order so that we see the effects of nested
        // declarations before processing the outer decls.
        foreach (var (localDeclaration, anonymousFunction, references) in nodesFromDiagnostics.OrderByDescending(nodes => nodes.function.SpanStart))
        {
            var delegateType = (INamedTypeSymbol)semanticModel.GetTypeInfo(anonymousFunction, cancellationToken).ConvertedType;
            var parameterList = GenerateParameterList(anonymousFunction, delegateType.DelegateInvokeMethod);
            var makeStatic = MakeStatic(semanticModel, makeStaticIfPossible, localDeclaration, cancellationToken);
 
            var currentLocalDeclaration = currentRoot.GetCurrentNode(localDeclaration);
            var currentAnonymousFunction = currentRoot.GetCurrentNode(anonymousFunction);
 
            currentRoot = ReplaceAnonymousWithLocalFunction(
                document.Project.Solution.Services, currentRoot,
                currentLocalDeclaration, currentAnonymousFunction,
                delegateType.DelegateInvokeMethod, parameterList, makeStatic);
 
            // these invocations might actually be inside the local function! so we have to do this separately
            currentRoot = ReplaceReferences(
                document, currentRoot,
                delegateType, parameterList,
                [.. references.Select(node => currentRoot.GetCurrentNode(node))]);
        }
 
        editor.ReplaceNode(root, currentRoot);
    }
 
    private static bool MakeStatic(
        SemanticModel semanticModel,
        bool makeStaticIfPossible,
        LocalDeclarationStatementSyntax localDeclaration,
        CancellationToken cancellationToken)
    {
        // Determines if we can make the local function 'static'.  We can make it static
        // if the original lambda did not capture any variables (other than the local 
        // variable itself).  it's ok for the lambda to capture itself as a static-local
        // function can reference itself without any problems.
        if (makeStaticIfPossible)
        {
            var localSymbol = semanticModel.GetDeclaredSymbol(
                localDeclaration.Declaration.Variables[0], cancellationToken);
 
            var dataFlow = semanticModel.AnalyzeDataFlow(localDeclaration);
            if (dataFlow.Succeeded)
            {
                var capturedVariables = dataFlow.Captured.Remove(localSymbol);
                if (capturedVariables.IsEmpty)
                {
                    return true;
                }
            }
        }
 
        return false;
    }
 
    private static SyntaxNode ReplaceAnonymousWithLocalFunction(
        SolutionServices services, SyntaxNode currentRoot,
        LocalDeclarationStatementSyntax localDeclaration, AnonymousFunctionExpressionSyntax anonymousFunction,
        IMethodSymbol delegateMethod, ParameterListSyntax parameterList, bool makeStatic)
    {
        var newLocalFunctionStatement = CreateLocalFunctionStatement(localDeclaration, anonymousFunction, delegateMethod, parameterList, makeStatic)
            .WithTriviaFrom(localDeclaration)
            .WithAdditionalAnnotations(Formatter.Annotation);
 
        var editor = new SyntaxEditor(currentRoot, services);
        editor.ReplaceNode(localDeclaration, newLocalFunctionStatement);
 
        var anonymousFunctionStatement = anonymousFunction.GetAncestor<StatementSyntax>();
        if (anonymousFunctionStatement != localDeclaration)
        {
            // This is the split decl+init form.  Remove the second statement as we're
            // merging into the first one.
            editor.RemoveNode(anonymousFunctionStatement);
        }
 
        return editor.GetChangedRoot();
    }
 
    private static SyntaxNode ReplaceReferences(
        Document document, SyntaxNode currentRoot,
        INamedTypeSymbol delegateType, ParameterListSyntax parameterList,
        ImmutableArray<ExpressionSyntax> references)
    {
        return currentRoot.ReplaceNodes(references, (_ /* nested invocations! */, reference) =>
        {
            if (reference is InvocationExpressionSyntax invocation)
            {
                var directInvocation = invocation.Expression is MemberAccessExpressionSyntax memberAccess // it's a .Invoke call
                    ? invocation.WithExpression(memberAccess.Expression).WithTriviaFrom(invocation) // remove it
                    : invocation;
 
                return WithNewParameterNames(directInvocation, delegateType.DelegateInvokeMethod, parameterList);
            }
 
            // It's not an invocation. Wrap the identifier in a cast (which will be remove by the simplifier if unnecessary)
            // to ensure we preserve semantics in cases like overload resolution or generic type inference.
            return SyntaxGenerator.GetGenerator(document).CastExpression(delegateType, reference);
        });
    }
 
    private static LocalFunctionStatementSyntax CreateLocalFunctionStatement(
        LocalDeclarationStatementSyntax localDeclaration,
        AnonymousFunctionExpressionSyntax anonymousFunction,
        IMethodSymbol delegateMethod,
        ParameterListSyntax parameterList,
        bool makeStatic)
    {
        var modifiers = new SyntaxTokenList();
        if (makeStatic)
        {
            modifiers = modifiers.Add(StaticKeyword);
        }
 
        if (anonymousFunction.AsyncKeyword.IsKind(SyntaxKind.AsyncKeyword))
        {
            modifiers = modifiers.Add(anonymousFunction.AsyncKeyword);
        }
 
        var returnType = delegateMethod.GenerateReturnTypeSyntax();
 
        var identifier = localDeclaration.Declaration.Variables[0].Identifier;
        var typeParameterList = (TypeParameterListSyntax)null;
 
        var constraintClauses = default(SyntaxList<TypeParameterConstraintClauseSyntax>);
 
        var body = anonymousFunction.Body is BlockSyntax block
            ? block
            : null;
 
        var expressionBody = anonymousFunction.Body is ExpressionSyntax expression
            ? ArrowExpressionClause(((LambdaExpressionSyntax)anonymousFunction).ArrowToken, expression)
            : null;
 
        var semicolonToken = anonymousFunction.Body is ExpressionSyntax
            ? localDeclaration.SemicolonToken
            : default;
 
        return LocalFunctionStatement(
            modifiers, returnType, identifier, typeParameterList, parameterList,
            constraintClauses, body, expressionBody, semicolonToken);
    }
 
    private static ParameterListSyntax GenerateParameterList(
        AnonymousFunctionExpressionSyntax anonymousFunction, IMethodSymbol delegateMethod)
    {
        var parameterList = TryGetOrCreateParameterList(anonymousFunction);
        var i = 0;
 
        return parameterList != null
            ? parameterList.ReplaceNodes(parameterList.Parameters, (parameterNode, _) => PromoteParameter(parameterNode, delegateMethod.Parameters.ElementAtOrDefault(i++)))
            : ParameterList([.. delegateMethod.Parameters.Select(parameter =>
                PromoteParameter(Parameter(parameter.Name.ToIdentifierToken()), parameter))]);
 
        static ParameterSyntax PromoteParameter(ParameterSyntax parameterNode, IParameterSymbol delegateParameter)
        {
            // delegateParameter may be null, consider this case: Action x = (a, b) => { };
            // we will still fall back to object
 
            if (parameterNode.Type == null)
            {
                parameterNode = parameterNode.WithType(delegateParameter?.Type.GenerateTypeSyntax() ?? s_objectType);
            }
 
            if (delegateParameter?.HasExplicitDefaultValue == true)
            {
                parameterNode = parameterNode.WithDefault(GetDefaultValue(delegateParameter));
            }
 
            return parameterNode;
        }
    }
 
    private static ParameterListSyntax TryGetOrCreateParameterList(AnonymousFunctionExpressionSyntax anonymousFunction)
    {
        switch (anonymousFunction)
        {
            case SimpleLambdaExpressionSyntax simpleLambda:
                return ParameterList([simpleLambda.Parameter]);
            case ParenthesizedLambdaExpressionSyntax parenthesizedLambda:
                return parenthesizedLambda.ParameterList;
            case AnonymousMethodExpressionSyntax anonymousMethod:
                return anonymousMethod.ParameterList; // may be null!
            default:
                throw ExceptionUtilities.UnexpectedValue(anonymousFunction);
        }
    }
 
    private static InvocationExpressionSyntax WithNewParameterNames(InvocationExpressionSyntax invocation, IMethodSymbol method, ParameterListSyntax newParameterList)
    {
        return invocation.ReplaceNodes(invocation.ArgumentList.Arguments, (argumentNode, _) =>
        {
            if (argumentNode.NameColon == null)
            {
                return argumentNode;
            }
 
            var parameterIndex = TryDetermineParameterIndex(argumentNode.NameColon, method);
            if (parameterIndex == -1)
            {
                return argumentNode;
            }
 
            var newParameter = newParameterList.Parameters.ElementAtOrDefault(parameterIndex);
            if (newParameter == null || newParameter.Identifier.IsMissing)
            {
                return argumentNode;
            }
 
            return argumentNode.WithNameColon(argumentNode.NameColon.WithName(IdentifierName(newParameter.Identifier)));
        });
    }
 
    private static int TryDetermineParameterIndex(NameColonSyntax argumentNameColon, IMethodSymbol method)
    {
        var name = argumentNameColon.Name.Identifier.ValueText;
        return method.Parameters.IndexOf(p => p.Name == name);
    }
 
    private static EqualsValueClauseSyntax GetDefaultValue(IParameterSymbol parameter)
        => EqualsValueClause(ExpressionGenerator.GenerateExpression(parameter.Type, parameter.ExplicitDefaultValue, canUseFieldReference: true));
}