File: ExtractMethod\MethodExtractor.CodeGenerator.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.ExtractMethod;
 
internal abstract partial class MethodExtractor<TSelectionResult, TStatementSyntax, TExpressionSyntax>
{
    protected abstract class CodeGenerator
    {
        /// <summary>
        /// Used to produced the set of statements that will go into the generated method.
        /// </summary>
        public abstract OperationStatus<ImmutableArray<SyntaxNode>> GetNewMethodStatements(
            SyntaxNode insertionPointNode, CancellationToken cancellationToken);
    }
 
#pragma warning disable CS0693 // Intentionally hiding the outer TStatementSyntax
    protected abstract partial class CodeGenerator<TStatementSyntax, TNodeUnderContainer, TCodeGenerationOptions> : CodeGenerator
#pragma warning restore CS0693
        where TStatementSyntax : SyntaxNode
        where TNodeUnderContainer : SyntaxNode
        where TCodeGenerationOptions : CodeGenerationOptions
    {
        protected readonly SyntaxAnnotation MethodNameAnnotation;
        protected readonly SyntaxAnnotation MethodDefinitionAnnotation;
        protected readonly SyntaxAnnotation CallSiteAnnotation;
 
        protected readonly TSelectionResult SelectionResult;
        protected readonly AnalyzerResult AnalyzerResult;
 
        protected readonly TCodeGenerationOptions Options;
        protected readonly bool LocalFunction;
 
        protected CodeGenerator(TSelectionResult selectionResult, AnalyzerResult analyzerResult, TCodeGenerationOptions options, bool localFunction)
        {
            SelectionResult = selectionResult;
            AnalyzerResult = analyzerResult;
 
            Options = options;
            LocalFunction = localFunction;
 
            MethodNameAnnotation = new SyntaxAnnotation();
            CallSiteAnnotation = new SyntaxAnnotation();
            MethodDefinitionAnnotation = new SyntaxAnnotation();
        }
 
        protected SemanticDocument SemanticDocument => SelectionResult.SemanticDocument;
 
        #region method to be implemented in sub classes
 
        protected abstract SyntaxNode GetCallSiteContainerFromOutermostMoveInVariable(CancellationToken cancellationToken);
 
        protected abstract Task<SyntaxNode> GenerateBodyForCallSiteContainerAsync(SyntaxNode insertionPointNode, SyntaxNode outermostCallSiteContainer, CancellationToken cancellationToken);
        protected abstract IMethodSymbol GenerateMethodDefinition(SyntaxNode insertionPointNode, CancellationToken cancellationToken);
        protected abstract bool ShouldLocalFunctionCaptureParameter(SyntaxNode node);
 
        protected abstract SyntaxToken CreateIdentifier(string name);
        protected abstract SyntaxToken CreateMethodName();
        protected abstract bool LastStatementOrHasReturnStatementInReturnableConstruct();
 
        protected abstract TNodeUnderContainer GetFirstStatementOrInitializerSelectedAtCallSite();
        protected abstract TNodeUnderContainer GetLastStatementOrInitializerSelectedAtCallSite();
        protected abstract Task<TNodeUnderContainer> GetStatementOrInitializerContainingInvocationToExtractedMethodAsync(CancellationToken cancellationToken);
 
        protected abstract TExpressionSyntax CreateCallSignature();
        protected abstract TStatementSyntax CreateDeclarationStatement(VariableInfo variable, TExpressionSyntax initialValue, CancellationToken cancellationToken);
        protected abstract TStatementSyntax CreateAssignmentExpressionStatement(SyntaxToken identifier, TExpressionSyntax rvalue);
        protected abstract TStatementSyntax CreateReturnStatement(string identifierName = null);
 
        protected abstract ImmutableArray<TStatementSyntax> GetInitialStatementsForMethodDefinitions();
 
        protected abstract Task<SemanticDocument> UpdateMethodAfterGenerationAsync(
            SemanticDocument originalDocument, IMethodSymbol methodSymbolResult, CancellationToken cancellationToken);
 
        #endregion
 
        public async Task<GeneratedCode> GenerateAsync(InsertionPoint insertionPoint, CancellationToken cancellationToken)
        {
            var newMethodDefinition = GenerateMethodDefinition(insertionPoint.GetContext(), cancellationToken);
            var callSiteDocument = await InsertMethodAndUpdateCallSiteAsync(insertionPoint, newMethodDefinition, cancellationToken).ConfigureAwait(false);
 
            // For nullable reference types, we can provide a better experience by reducing use of nullable
            // reference types after a method is done being generated. If we can determine that the method never
            // returns null, for example, then we can make the signature into a non-null reference type even though
            // the original type was nullable. This allows our code generation to follow our recommendation of only
            // using nullable when necessary. This is done after method generation instead of at analyzer time
            // because it's purely based on the resulting code, which the generator can modify as needed. If return
            // statements are added, the flow analysis could change to indicate something different. It's cleaner to
            // rely on flow analysis of the final resulting code than to try and predict from the analyzer what will
            // happen in the generator. 
            var finalDocument = await UpdateMethodAfterGenerationAsync(callSiteDocument, newMethodDefinition, cancellationToken).ConfigureAwait(false);
 
            return await CreateGeneratedCodeAsync(finalDocument, cancellationToken).ConfigureAwait(false);
        }
 
        private async Task<SemanticDocument> InsertMethodAndUpdateCallSiteAsync(
            InsertionPoint insertionPoint, IMethodSymbol newMethodDefinition, CancellationToken cancellationToken)
        {
            var document = this.SemanticDocument.Document;
            var codeGenerationService = document.GetLanguageService<ICodeGenerationService>();
 
            // First, update the callsite with the call to the new method.
            var outermostCallSiteContainer = GetOutermostCallSiteContainerToProcess(cancellationToken);
 
            var rootWithUpdatedCallSite = this.SemanticDocument.Root.ReplaceNode(
                outermostCallSiteContainer,
                await GenerateBodyForCallSiteContainerAsync(
                    insertionPoint.GetContext(), outermostCallSiteContainer, cancellationToken).ConfigureAwait(false));
 
            // Then insert the local-function/method into the updated document that contains the updated callsite.
            var documentWithUpdatedCallSite = await this.SemanticDocument.WithSyntaxRootAsync(rootWithUpdatedCallSite, cancellationToken).ConfigureAwait(false);
            var finalRoot = LocalFunction
                ? InsertLocalFunction()
                : InsertNormalMethod();
 
            return await documentWithUpdatedCallSite.WithSyntaxRootAsync(finalRoot, cancellationToken).ConfigureAwait(false);
 
            SyntaxNode InsertLocalFunction()
            {
                // Now, insert the local function.
                var info = codeGenerationService.GetInfo(
                    new CodeGenerationContext(generateDefaultAccessibility: false),
                    Options,
                    document.Project.ParseOptions);
 
                var localMethod = codeGenerationService.CreateMethodDeclaration(newMethodDefinition, CodeGenerationDestination.Unspecified, info, cancellationToken);
 
                // Find the destination for the local function after the callsite has been fixed up.
                var destination = insertionPoint.With(documentWithUpdatedCallSite).GetContext();
                var updatedDestination = codeGenerationService.AddStatements(destination, [localMethod], info, cancellationToken);
 
                var finalRoot = documentWithUpdatedCallSite.Root.ReplaceNode(destination, updatedDestination);
                return finalRoot;
            }
 
            SyntaxNode InsertNormalMethod()
            {
                var syntaxKinds = document.GetLanguageService<ISyntaxKindsService>();
 
                // Find the destination for the new method after the callsite has been fixed up.
                var mappedMember = insertionPoint.With(documentWithUpdatedCallSite).GetContext();
                mappedMember = mappedMember.Parent?.RawKind == syntaxKinds.GlobalStatement
                    ? mappedMember.Parent
                    : mappedMember;
 
                // it is possible in a script file case where there is no previous member. in that case, insert new text into top level script
                var destination = mappedMember.Parent ?? mappedMember;
 
                var info = codeGenerationService.GetInfo(
                    new CodeGenerationContext(
                        afterThisLocation: mappedMember.GetLocation(),
                        generateDefaultAccessibility: true,
                        generateMethodBodies: true),
                    Options,
                    documentWithUpdatedCallSite.Project.ParseOptions);
 
                var newContainer = codeGenerationService.AddMethod(destination, newMethodDefinition, info, cancellationToken);
                var finalRoot = documentWithUpdatedCallSite.Root.ReplaceNode(destination, newContainer);
                return finalRoot;
            }
        }
 
        private SyntaxNode GetOutermostCallSiteContainerToProcess(CancellationToken cancellationToken)
        {
            var callSiteContainer = GetCallSiteContainerFromOutermostMoveInVariable(cancellationToken);
            if (callSiteContainer != null)
            {
                return callSiteContainer;
            }
            else
            {
                return this.SelectionResult.GetOutermostCallSiteContainerToProcess(cancellationToken);
            }
        }
 
        protected virtual Task<GeneratedCode> CreateGeneratedCodeAsync(SemanticDocument newDocument, CancellationToken cancellationToken)
        {
            return Task.FromResult(new GeneratedCode(
                newDocument,
                MethodNameAnnotation,
                CallSiteAnnotation,
                MethodDefinitionAnnotation));
        }
 
        protected VariableInfo GetOutermostVariableToMoveIntoMethodDefinition(CancellationToken cancellationToken)
        {
            return this.AnalyzerResult.GetOutermostVariableToMoveIntoMethodDefinition(cancellationToken);
        }
 
        protected ImmutableArray<TStatementSyntax> AddReturnIfUnreachable(ImmutableArray<TStatementSyntax> statements)
        {
            if (AnalyzerResult.EndOfSelectionReachable)
            {
                return statements;
            }
 
            var type = SelectionResult.GetContainingScopeType();
            if (type != null && type.SpecialType != SpecialType.System_Void)
            {
                return statements;
            }
 
            // no return type + end of selection not reachable
            if (LastStatementOrHasReturnStatementInReturnableConstruct())
            {
                return statements;
            }
 
            return statements.Concat(CreateReturnStatement());
        }
 
        protected async Task<ImmutableArray<TStatementSyntax>> AddInvocationAtCallSiteAsync(
            ImmutableArray<TStatementSyntax> statements, CancellationToken cancellationToken)
        {
            if (AnalyzerResult.HasVariableToUseAsReturnValue)
            {
                return statements;
            }
 
            Contract.ThrowIfTrue(AnalyzerResult.GetVariablesToSplitOrMoveOutToCallSite(cancellationToken).Any(v => v.UseAsReturnValue));
 
            // add invocation expression
            return statements.Concat(
                (TStatementSyntax)(SyntaxNode)await GetStatementOrInitializerContainingInvocationToExtractedMethodAsync(cancellationToken).ConfigureAwait(false));
        }
 
        protected ImmutableArray<TStatementSyntax> AddAssignmentStatementToCallSite(
            ImmutableArray<TStatementSyntax> statements,
            CancellationToken cancellationToken)
        {
            if (!AnalyzerResult.HasVariableToUseAsReturnValue)
            {
                return statements;
            }
 
            var variable = AnalyzerResult.VariableToUseAsReturnValue;
            if (variable.ReturnBehavior == ReturnBehavior.Initialization)
            {
                // there must be one decl behavior when there is "return value and initialize" variable
                Contract.ThrowIfFalse(AnalyzerResult.GetVariablesToSplitOrMoveOutToCallSite(cancellationToken).Single(v => v.ReturnBehavior == ReturnBehavior.Initialization) != null);
 
                var declarationStatement = CreateDeclarationStatement(
                    variable, CreateCallSignature(), cancellationToken);
                declarationStatement = declarationStatement.WithAdditionalAnnotations(CallSiteAnnotation);
 
                return statements.Concat(declarationStatement);
            }
 
            Contract.ThrowIfFalse(variable.ReturnBehavior == ReturnBehavior.Assignment);
            return statements.Concat(
                CreateAssignmentExpressionStatement(CreateIdentifier(variable.Name), CreateCallSignature()).WithAdditionalAnnotations(CallSiteAnnotation));
        }
 
        protected ImmutableArray<TStatementSyntax> CreateDeclarationStatements(
            ImmutableArray<VariableInfo> variables, CancellationToken cancellationToken)
        {
            return variables.SelectAsArray(v => CreateDeclarationStatement(v, initialValue: null, cancellationToken));
        }
 
        protected ImmutableArray<TStatementSyntax> AddSplitOrMoveDeclarationOutStatementsToCallSite(
            CancellationToken cancellationToken)
        {
            using var _ = ArrayBuilder<TStatementSyntax>.GetInstance(out var list);
 
            foreach (var variable in AnalyzerResult.GetVariablesToSplitOrMoveOutToCallSite(cancellationToken))
            {
                if (variable.UseAsReturnValue)
                    continue;
 
                var declaration = CreateDeclarationStatement(
                    variable, initialValue: null, cancellationToken: cancellationToken);
                list.Add(declaration);
            }
 
            return list.ToImmutableAndClear();
        }
 
        protected ImmutableArray<TStatementSyntax> AppendReturnStatementIfNeeded(ImmutableArray<TStatementSyntax> statements)
        {
            if (!AnalyzerResult.HasVariableToUseAsReturnValue)
            {
                return statements;
            }
 
            var variableToUseAsReturnValue = AnalyzerResult.VariableToUseAsReturnValue;
 
            Contract.ThrowIfFalse(variableToUseAsReturnValue.ReturnBehavior is ReturnBehavior.Assignment or
                                  ReturnBehavior.Initialization);
 
            return statements.Concat(CreateReturnStatement(AnalyzerResult.VariableToUseAsReturnValue.Name));
        }
 
        protected static HashSet<SyntaxAnnotation> CreateVariableDeclarationToRemoveMap(
            IEnumerable<VariableInfo> variables, CancellationToken cancellationToken)
        {
            var annotations = new List<(SyntaxToken, SyntaxAnnotation)>();
 
            foreach (var variable in variables)
            {
                Contract.ThrowIfFalse(variable.GetDeclarationBehavior(cancellationToken) is DeclarationBehavior.MoveOut or
                                      DeclarationBehavior.MoveIn or
                                      DeclarationBehavior.Delete);
 
                variable.AddIdentifierTokenAnnotationPair(annotations, cancellationToken);
            }
 
            return new HashSet<SyntaxAnnotation>(annotations.Select(t => t.Item2));
        }
 
        protected ImmutableArray<ITypeParameterSymbol> CreateMethodTypeParameters()
        {
            if (AnalyzerResult.MethodTypeParametersInDeclaration.Count == 0)
            {
                return [];
            }
 
            var set = new HashSet<ITypeParameterSymbol>(AnalyzerResult.MethodTypeParametersInConstraintList);
 
            var typeParameters = ArrayBuilder<ITypeParameterSymbol>.GetInstance();
            foreach (var parameter in AnalyzerResult.MethodTypeParametersInDeclaration)
            {
                if (parameter != null && set.Contains(parameter))
                {
                    typeParameters.Add(parameter);
                    continue;
                }
 
                typeParameters.Add(CodeGenerationSymbolFactory.CreateTypeParameter(
                    parameter.GetAttributes(), parameter.Variance, parameter.Name, [], parameter.NullableAnnotation,
                    parameter.HasConstructorConstraint, parameter.HasReferenceTypeConstraint, parameter.HasUnmanagedTypeConstraint,
                    parameter.HasValueTypeConstraint, parameter.HasNotNullConstraint, parameter.AllowsRefLikeType, parameter.Ordinal));
            }
 
            return typeParameters.ToImmutableAndFree();
        }
 
        protected ImmutableArray<IParameterSymbol> CreateMethodParameters()
        {
            var parameters = ArrayBuilder<IParameterSymbol>.GetInstance();
            var isLocalFunction = LocalFunction && ShouldLocalFunctionCaptureParameter(SemanticDocument.Root);
            foreach (var parameter in AnalyzerResult.MethodParameters)
            {
                if (!isLocalFunction || !parameter.CanBeCapturedByLocalFunction)
                {
                    var refKind = GetRefKind(parameter.ParameterModifier);
                    var type = parameter.GetVariableType();
 
                    parameters.Add(
                        CodeGenerationSymbolFactory.CreateParameterSymbol(
                            attributes: [],
                            refKind: refKind,
                            isParams: false,
                            type: type,
                            name: parameter.Name));
                }
            }
 
            return parameters.ToImmutableAndFree();
        }
 
        private static RefKind GetRefKind(ParameterBehavior parameterBehavior)
        {
            return parameterBehavior == ParameterBehavior.Ref ? RefKind.Ref :
                        parameterBehavior == ParameterBehavior.Out ? RefKind.Out : RefKind.None;
        }
    }
}