File: CodeRefactorings\ConvertLocalFunctionToMethod\CSharpConvertLocalFunctionToMethodCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Simplification;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.CodeRefactorings.ConvertLocalFunctionToMethod;
 
[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = PredefinedCodeRefactoringProviderNames.ConvertLocalFunctionToMethod), Shared]
[method: ImportingConstructor]
[method: SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
internal sealed class CSharpConvertLocalFunctionToMethodCodeRefactoringProvider() : CodeRefactoringProvider
{
    private static readonly SyntaxAnnotation s_delegateToReplaceAnnotation = new();
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, textSpan, cancellationToken) = context;
        if (document.Project.Solution.WorkspaceKind == WorkspaceKind.MiscellaneousFiles)
            return;
 
        var localFunction = await context.TryGetRelevantNodeAsync<LocalFunctionStatementSyntax>().ConfigureAwait(false);
        if (localFunction == null)
            return;
 
        if (localFunction.Parent is not BlockSyntax parentBlock)
            return;
 
        var container = localFunction.GetAncestor<MemberDeclarationSyntax>();
 
        // If the local function is defined in a block within the top-level statements context, then we can't provide the refactoring because
        // there is no class we can put the generated method in.
        if (container == null || container is GlobalStatementSyntax or FieldDeclarationSyntax or EventFieldDeclarationSyntax)
            return;
 
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var containerSymbol = semanticModel.GetDeclaredSymbol(container);
        if (containerSymbol is null)
            return;
 
        context.RegisterRefactoring(
            CodeAction.Create(
                CSharpFeaturesResources.Convert_to_method,
                c => UpdateDocumentAsync(document, parentBlock, localFunction, container, c),
                nameof(CSharpFeaturesResources.Convert_to_method)),
            localFunction.Span);
    }
 
    private static async Task<Document> UpdateDocumentAsync(
        Document document,
        BlockSyntax parentBlock,
        LocalFunctionStatementSyntax localFunction,
        MemberDeclarationSyntax container,
        CancellationToken cancellationToken)
    {
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
        var declaredSymbol = semanticModel.GetRequiredDeclaredSymbol(localFunction, cancellationToken);
 
        Contract.ThrowIfTrue(localFunction.Body is null && localFunction.ExpressionBody is null);
 
        var dataFlow = semanticModel.AnalyzeDataFlow(
            localFunction.Body ?? (SyntaxNode)localFunction.ExpressionBody!.Expression);
 
        // Exclude local function parameters in case they were captured inside the function body
        var captures = dataFlow.CapturedInside.Except(dataFlow.VariablesDeclared).Except(declaredSymbol.Parameters).ToList();
 
        // First, create a parameter per each capture so that we can pass them as arguments to the final method
        // Filter out `this` because it doesn't need a parameter, we will just make a non-static method for that
        // We also make a `ref` parameter here for each capture that is being written into inside the function
        var capturesAsParameters = captures
            .Where(capture => !capture.IsThisParameter())
            .Select(capture => CodeGenerationSymbolFactory.CreateParameterSymbol(
                attributes: default,
                refKind: dataFlow.WrittenInside.Contains(capture) ? RefKind.Ref : RefKind.None,
                isParams: false,
                type: capture.GetSymbolType() ?? semanticModel.Compilation.ObjectType,
                name: capture.Name)).ToList();
 
        // Find all enclosing type parameters e.g. from outer local functions and the containing member
        // We exclude the containing type itself which has type parameters accessible to all members
        var typeParameters = new List<ITypeParameterSymbol>();
        AddCapturedTypeParameters(declaredSymbol, typeParameters);
 
        // We're going to remove unreferenced type parameters but we explicitly preserve
        // captures' types, just in case that they were not spelt out in the function body
        var captureTypes = captures.SelectMany(capture => capture.GetSymbolType().GetReferencedTypeParameters());
        RemoveUnusedTypeParameters(localFunction, semanticModel, typeParameters, reservedTypeParameters: captureTypes);
 
        var containerSymbol = semanticModel.GetRequiredDeclaredSymbol(container, cancellationToken);
        var isStatic = containerSymbol.IsStatic || captures.All(capture => !capture.IsThisParameter());
 
        // GetSymbolModifiers actually checks if the local function needs to be unsafe, not whether
        // it is declared as such, so this check we don't need to worry about whether the containing method
        // is unsafe, this will just work regardless.
        var needsUnsafe = declaredSymbol.GetSymbolModifiers().IsUnsafe;
 
        var methodName = GenerateUniqueMethodName(declaredSymbol);
        var parameters = declaredSymbol.Parameters;
        var methodSymbol = CodeGenerationSymbolFactory.CreateMethodSymbol(
            containingType: declaredSymbol.ContainingType,
            attributes: default,
            accessibility: Accessibility.Private,
            modifiers: new DeclarationModifiers(isStatic, isAsync: declaredSymbol.IsAsync, isUnsafe: needsUnsafe),
            returnType: declaredSymbol.ReturnType,
            refKind: declaredSymbol.RefKind,
            explicitInterfaceImplementations: default,
            name: methodName,
            typeParameters: [.. typeParameters],
            parameters: parameters.AddRange(capturesAsParameters));
 
        var info = (CSharpCodeGenerationContextInfo)await document.GetCodeGenerationInfoAsync(CodeGenerationContext.Default, cancellationToken).ConfigureAwait(false);
        var method = MethodGenerator.GenerateMethodDeclaration(methodSymbol, CodeGenerationDestination.Unspecified, info, cancellationToken);
 
        if (localFunction.AttributeLists.Count > 0)
            method = method.WithoutLeadingTrivia().WithAttributeLists(localFunction.AttributeLists).WithLeadingTrivia(method.GetLeadingTrivia());
 
        var generator = CSharpSyntaxGenerator.Instance;
        var editor = new SyntaxEditor(root, generator);
 
        var needsRename = methodName != declaredSymbol.Name;
        var identifierToken = needsRename ? methodName.ToIdentifierToken() : default;
        var supportsNonTrailing = SupportsNonTrailingNamedArguments(root.SyntaxTree.Options);
        var hasAdditionalArguments = !capturesAsParameters.IsEmpty();
        var additionalTypeParameters = typeParameters.Except(declaredSymbol.TypeParameters).ToList();
        var hasAdditionalTypeArguments = !additionalTypeParameters.IsEmpty();
        var additionalTypeArguments = hasAdditionalTypeArguments
            ? additionalTypeParameters.Select(p => (TypeSyntax)p.Name.ToIdentifierName()).ToArray()
            : null;
 
        var anyDelegatesToReplace = false;
        // Update callers' name, arguments and type arguments
        foreach (var node in parentBlock.DescendantNodes())
        {
            // A local function reference can only be an identifier or a generic name.
            switch (node.Kind())
            {
                case SyntaxKind.IdentifierName:
                case SyntaxKind.GenericName:
                    break;
                default:
                    continue;
            }
 
            // Using symbol to get type arguments, since it could be inferred and not present in the source
            var symbol = semanticModel.GetSymbolInfo(node, cancellationToken).Symbol as IMethodSymbol;
            if (!Equals(symbol?.OriginalDefinition, declaredSymbol))
            {
                continue;
            }
 
            var currentNode = node;
 
            if (needsRename)
            {
                currentNode = ((SimpleNameSyntax)currentNode).WithIdentifier(identifierToken);
            }
 
            if (hasAdditionalTypeArguments)
            {
                var existingTypeArguments = symbol.TypeArguments.Select(s => s.GenerateTypeSyntax());
                // Prepend additional type arguments to preserve lexical order in which they are defined
                Contract.ThrowIfNull(additionalTypeArguments);
                var typeArguments = additionalTypeArguments.Concat(existingTypeArguments);
                currentNode = generator.WithTypeArguments(currentNode, typeArguments);
                currentNode = currentNode.WithAdditionalAnnotations(Simplifier.Annotation);
            }
 
            if (node.Parent is InvocationExpressionSyntax invocation)
            {
                if (hasAdditionalArguments)
                {
                    var shouldUseNamedArguments =
                        !supportsNonTrailing && invocation.ArgumentList.Arguments.Any(arg => arg.NameColon != null);
 
                    var additionalArguments = capturesAsParameters.Select(p =>
                        (ArgumentSyntax)GenerateArgument(p, p.Name, shouldUseNamedArguments)).ToArray();
 
                    editor.ReplaceNode(invocation.ArgumentList,
                        invocation.ArgumentList.AddArguments(additionalArguments));
                }
            }
            else if (hasAdditionalArguments || hasAdditionalTypeArguments)
            {
                // Convert local function delegates to lambda if the signature no longer matches
                currentNode = currentNode.WithAdditionalAnnotations(s_delegateToReplaceAnnotation);
                anyDelegatesToReplace = true;
            }
 
            editor.ReplaceNode(node, currentNode);
        }
 
        editor.TrackNode(localFunction);
        editor.TrackNode(container);
 
        root = editor.GetChangedRoot();
 
        localFunction = root.GetCurrentNode(localFunction) ?? throw ExceptionUtilities.Unreachable();
        container = root.GetCurrentNode(container) ?? throw ExceptionUtilities.Unreachable();
 
        method = WithBodyFrom(method, localFunction);
 
        editor = new SyntaxEditor(root, generator);
        editor.InsertAfter(container, method);
        editor.RemoveNode(localFunction, SyntaxRemoveOptions.KeepNoTrivia);
 
        if (anyDelegatesToReplace)
        {
            document = document.WithSyntaxRoot(editor.GetChangedRoot());
            semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
            editor = new SyntaxEditor(root, generator);
 
            foreach (var node in root.GetAnnotatedNodes(s_delegateToReplaceAnnotation))
            {
                var reservedNames = GetReservedNames(node, semanticModel, cancellationToken);
                var parameterNames = GenerateUniqueParameterNames(parameters, reservedNames);
                var lambdaParameters = parameters.Zip(parameterNames, (p, name) => GenerateParameter(p, name));
                var lambdaArguments = parameters.Zip(parameterNames, (p, name) => GenerateArgument(p, name));
                var additionalArguments = capturesAsParameters.Select(p => GenerateArgument(p, p.Name));
                var newNode = generator.ValueReturningLambdaExpression(lambdaParameters,
                    generator.InvocationExpression(node, lambdaArguments.Concat(additionalArguments)));
 
                newNode = newNode.WithAdditionalAnnotations(Simplifier.Annotation);
 
                if (node.IsParentKind(SyntaxKind.CastExpression))
                {
                    newNode = ((ExpressionSyntax)newNode).Parenthesize();
                }
 
                editor.ReplaceNode(node, newNode);
            }
        }
 
        return document.WithSyntaxRoot(editor.GetChangedRoot());
    }
 
    private static bool SupportsNonTrailingNamedArguments(ParseOptions options)
        => options.LanguageVersion() >= LanguageVersion.CSharp7_2;
 
    private static SyntaxNode GenerateArgument(IParameterSymbol p, string name, bool shouldUseNamedArguments = false)
        => CSharpSyntaxGenerator.Instance.Argument(shouldUseNamedArguments ? name : null, p.RefKind, name.ToIdentifierName());
 
    private static List<string> GenerateUniqueParameterNames(ImmutableArray<IParameterSymbol> parameters, List<string> reservedNames)
        => parameters.Select(p => NameGenerator.EnsureUniqueness(p.Name, reservedNames)).ToList();
 
    private static List<string> GetReservedNames(SyntaxNode node, SemanticModel semanticModel, CancellationToken cancellationToken)
        => semanticModel.GetAllDeclaredSymbols(node.GetAncestor<MemberDeclarationSyntax>(), cancellationToken).Select(s => s.Name).ToList();
 
    private static ParameterSyntax GenerateParameter(IParameterSymbol parameter, string name)
    {
        return SyntaxFactory.Parameter(name.ToIdentifierToken())
            .WithModifiers(CSharpSyntaxGeneratorInternal.GetParameterModifiers(parameter))
            .WithType(parameter.Type.GenerateTypeSyntax());
    }
 
    private static MethodDeclarationSyntax WithBodyFrom(
        MethodDeclarationSyntax method, LocalFunctionStatementSyntax localFunction)
    {
        return method
            .WithExpressionBody(localFunction.ExpressionBody)
            .WithSemicolonToken(localFunction.SemicolonToken)
            .WithBody(localFunction.Body);
    }
 
    private static void AddCapturedTypeParameters(ISymbol symbol, List<ITypeParameterSymbol> typeParameters)
    {
        var containingSymbol = symbol.ContainingSymbol;
        if (containingSymbol != null &&
            containingSymbol.Kind != SymbolKind.NamedType)
        {
            AddCapturedTypeParameters(containingSymbol, typeParameters);
        }
 
        typeParameters.AddRange(symbol.GetTypeParameters());
    }
 
    private static void RemoveUnusedTypeParameters(
        SyntaxNode localFunction,
        SemanticModel semanticModel,
        List<ITypeParameterSymbol> typeParameters,
        IEnumerable<ITypeParameterSymbol> reservedTypeParameters)
    {
        var unusedTypeParameters = typeParameters.ToList();
        foreach (var id in localFunction.DescendantNodes().OfType<IdentifierNameSyntax>())
        {
            var symbol = semanticModel.GetSymbolInfo(id).Symbol;
            if (symbol != null && symbol.OriginalDefinition is ITypeParameterSymbol typeParameter)
            {
                unusedTypeParameters.Remove(typeParameter);
            }
        }
 
        typeParameters.RemoveRange(unusedTypeParameters.Except(reservedTypeParameters));
    }
 
    private static string GenerateUniqueMethodName(ISymbol declaredSymbol)
    {
        return NameGenerator.EnsureUniqueness(
            baseName: declaredSymbol.Name,
            reservedNames: declaredSymbol.ContainingType.GetMembers().Select(m => m.Name));
    }
}