|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#nullable disable
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.ExtractMethod;
internal abstract partial class AbstractExtractMethodService<
TStatementSyntax,
TExecutableStatementSyntax,
TExpressionSyntax>
{
internal abstract partial class MethodExtractor
{
public static readonly SyntaxAnnotation MethodNameAnnotation = new();
public static readonly SyntaxAnnotation MethodDefinitionAnnotation = new();
public static readonly SyntaxAnnotation CallSiteAnnotation = new();
public static readonly SyntaxAnnotation InsertionPointAnnotation = new();
/// <summary>
/// Marks nodes that cause control flow to leave the extracted selection. This is commonly constructs like <see
/// langword="return"/>, <see langword="break"/>, <see langword="continue"/> and the like. We mark these with
/// annotations at the start of the extraction process so that we can find these nodes again later after they
/// have been extracted to rewrite them as needed. Specifically, constructs like <see langword="break"/>, <see
/// langword="continue"/> cannot cross a method boundary. As such, they must be translated to a <see
/// langword="return"/> statement that returns a value indicating the flow control construct that should be
/// executed at the callsite after the extracted method is called.
/// </summary>
public static readonly SyntaxAnnotation ExitPointAnnotation = new();
protected abstract class CodeGenerator
{
/// <summary>
/// Used to produced the set of statements that will go into the generated method.
/// </summary>
public abstract OperationStatus<ImmutableArray<SyntaxNode>> GetNewMethodStatements(
SyntaxNode insertionPointNode, CancellationToken cancellationToken);
public abstract Task<SemanticDocument> GenerateAsync(CancellationToken cancellationToken);
}
protected abstract partial class CodeGenerator<TNodeUnderContainer, TCodeGenerationOptions> : CodeGenerator
where TNodeUnderContainer : SyntaxNode
where TCodeGenerationOptions : CodeGenerationOptions
{
private static readonly CodeGenerationContext s_codeGenerationContext = new(addImports: false);
// TODO: Check if these namesare already in scope and if so, generate non-colliding ones.
protected const string FlowControlName = "flowControl";
protected const string ReturnValueName = "value";
protected readonly SelectionResult SelectionResult;
protected readonly AnalyzerResult AnalyzerResult;
protected readonly ExtractMethodGenerationOptions ExtractMethodGenerationOptions;
protected readonly TCodeGenerationOptions Options;
protected readonly bool LocalFunction;
private ITypeSymbol _finalReturnType;
protected CodeGenerator(
SelectionResult selectionResult,
AnalyzerResult analyzerResult,
ExtractMethodGenerationOptions options,
bool localFunction)
{
SelectionResult = selectionResult;
AnalyzerResult = analyzerResult;
ExtractMethodGenerationOptions = options;
Options = (TCodeGenerationOptions)options.CodeGenerationOptions;
LocalFunction = localFunction;
}
protected SemanticDocument SemanticDocument => SelectionResult.SemanticDocument;
#region method to be implemented in sub classes
protected abstract SyntaxNode GetCallSiteContainerFromOutermostMoveInVariable();
protected abstract Task<SyntaxNode> GenerateBodyForCallSiteContainerAsync(SyntaxNode insertionPointNode, SyntaxNode outermostCallSiteContainer, CancellationToken cancellationToken);
protected abstract IMethodSymbol GenerateMethodDefinition(SyntaxNode insertionPointNode, CancellationToken cancellationToken);
protected abstract bool ShouldLocalFunctionCaptureParameter(SyntaxNode node);
protected abstract SyntaxToken CreateMethodName();
protected abstract bool LastStatementOrHasReturnStatementInReturnableConstruct();
protected abstract TNodeUnderContainer GetFirstStatementOrInitializerSelectedAtCallSite();
protected abstract TNodeUnderContainer GetLastStatementOrInitializerSelectedAtCallSite();
protected abstract Task<TNodeUnderContainer> GetStatementOrInitializerContainingInvocationToExtractedMethodAsync(CancellationToken cancellationToken);
protected abstract TExpressionSyntax CreateCallSignature();
/// <summary>
/// Statement we create when we are assigning variables and at least one of the variables in a new
/// declaration that is being created. <paramref name="variables"/> can be empty. This can happen
/// if we are creating a new declaration for a flow control variable.
/// </summary>
protected abstract TStatementSyntax CreateDeclarationStatement(
ImmutableArray<VariableInfo> variables, TExpressionSyntax initialValue, ExtractMethodFlowControlInformation flowControlInformation, CancellationToken cancellationToken);
/// <summary>
/// Statement we create when we are assigning variables and all of the variables already exist and are just
/// being assigned to. <paramref name="variables"/> must be non-empty.
/// </summary>
protected abstract TStatementSyntax CreateAssignmentExpressionStatement(
ImmutableArray<VariableInfo> variables, TExpressionSyntax right);
protected abstract TExecutableStatementSyntax CreateBreakStatement();
protected abstract TExecutableStatementSyntax CreateContinueStatement();
protected abstract TExpressionSyntax CreateFlowControlReturnExpression(
ExtractMethodFlowControlInformation flowControlInformation, object flowValue);
protected abstract ImmutableArray<TStatementSyntax> GetInitialStatementsForMethodDefinitions();
protected abstract Task<SemanticDocument> UpdateMethodAfterGenerationAsync(
SemanticDocument originalDocument, IMethodSymbol methodSymbolResult, CancellationToken cancellationToken);
protected abstract Task<SemanticDocument> PerformFinalTriviaFixupAsync(
SemanticDocument newDocument, CancellationToken cancellationToken);
#endregion
private static SyntaxNode GetInsertionPoint(SemanticDocument document)
=> document.Root.GetAnnotatedNodes(InsertionPointAnnotation).Single();
public sealed override async Task<SemanticDocument> GenerateAsync(CancellationToken cancellationToken)
{
var semanticDocument = SelectionResult.SemanticDocument;
var insertionPoint = GetInsertionPoint(semanticDocument);
var newMethodDefinition = GenerateMethodDefinition(insertionPoint, cancellationToken);
var callSiteDocument = await InsertMethodAndUpdateCallSiteAsync(semanticDocument, newMethodDefinition, cancellationToken).ConfigureAwait(false);
// For nullable reference types, we can provide a better experience by reducing use of nullable
// reference types after a method is done being generated. If we can determine that the method never
// returns null, for example, then we can make the signature into a non-null reference type even though
// the original type was nullable. This allows our code generation to follow our recommendation of only
// using nullable when necessary. This is done after method generation instead of at analyzer time
// because it's purely based on the resulting code, which the generator can modify as needed. If return
// statements are added, the flow analysis could change to indicate something different. It's cleaner to
// rely on flow analysis of the final resulting code than to try and predict from the analyzer what will
// happen in the generator.
var finalDocument = await UpdateMethodAfterGenerationAsync(callSiteDocument, newMethodDefinition, cancellationToken).ConfigureAwait(false);
return await PerformFinalTriviaFixupAsync(finalDocument, cancellationToken).ConfigureAwait(false);
}
private async Task<SemanticDocument> InsertMethodAndUpdateCallSiteAsync(
SemanticDocument document, IMethodSymbol newMethodDefinition, CancellationToken cancellationToken)
{
var codeGenerationService = document.GetRequiredLanguageService<ICodeGenerationService>();
// First, update the callsite with the call to the new method.
var outermostCallSiteContainer = GetOutermostCallSiteContainerToProcess(cancellationToken);
var rootWithUpdatedCallSite = this.SemanticDocument.Root.ReplaceNode(
outermostCallSiteContainer,
await GenerateBodyForCallSiteContainerAsync(
GetInsertionPoint(document), outermostCallSiteContainer, cancellationToken).ConfigureAwait(false));
// Then insert the local-function/method into the updated document that contains the updated callsite.
var documentWithUpdatedCallSite = await this.SemanticDocument.WithSyntaxRootAsync(rootWithUpdatedCallSite, cancellationToken).ConfigureAwait(false);
var finalRoot = LocalFunction
? InsertLocalFunction()
: InsertNormalMethod();
return await documentWithUpdatedCallSite.WithSyntaxRootAsync(finalRoot, cancellationToken).ConfigureAwait(false);
SyntaxNode InsertLocalFunction()
{
// Now, insert the local function.
var info = codeGenerationService.GetInfo(
s_codeGenerationContext.With(generateDefaultAccessibility: false),
Options,
document.Project.ParseOptions);
var localMethod = codeGenerationService.CreateMethodDeclaration(newMethodDefinition, CodeGenerationDestination.Unspecified, info, cancellationToken);
// Find the destination for the local function after the callsite has been fixed up.
var destination = GetInsertionPoint(documentWithUpdatedCallSite);
var updatedDestination = codeGenerationService.AddStatements(destination, [localMethod], info, cancellationToken);
var finalRoot = documentWithUpdatedCallSite.Root.ReplaceNode(destination, updatedDestination);
return finalRoot;
}
SyntaxNode InsertNormalMethod()
{
var syntaxKinds = document.GetRequiredLanguageService<ISyntaxKindsService>();
// Find the destination for the new method after the callsite has been fixed up.
var mappedMember = GetInsertionPoint(documentWithUpdatedCallSite);
mappedMember = mappedMember.Parent?.RawKind == syntaxKinds.GlobalStatement
? mappedMember.Parent
: mappedMember;
mappedMember = mappedMember.RawKind == syntaxKinds.PrimaryConstructorBaseType
? mappedMember.Parent
: mappedMember;
// it is possible in a script file case where there is no previous member. in that case, insert new text into top level script
var destination = mappedMember.Parent ?? mappedMember;
var info = codeGenerationService.GetInfo(
s_codeGenerationContext.With(
afterThisLocation: mappedMember.GetLocation(),
generateDefaultAccessibility: true,
generateMethodBodies: true),
Options,
documentWithUpdatedCallSite.Project.ParseOptions);
var newContainer = codeGenerationService.AddMethod(destination, newMethodDefinition, info, cancellationToken);
var finalRoot = documentWithUpdatedCallSite.Root.ReplaceNode(destination, newContainer);
return finalRoot;
}
}
private SyntaxNode GetOutermostCallSiteContainerToProcess(CancellationToken cancellationToken)
{
var callSiteContainer = GetCallSiteContainerFromOutermostMoveInVariable();
return callSiteContainer ?? this.SelectionResult.GetOutermostCallSiteContainerToProcess(cancellationToken);
}
protected VariableInfo GetOutermostVariableToMoveIntoMethodDefinition()
{
return this.AnalyzerResult.GetOutermostVariableToMoveIntoMethodDefinition();
}
protected ImmutableArray<TStatementSyntax> AddReturnIfUnreachable(
ImmutableArray<TStatementSyntax> statements, CancellationToken cancellationToken)
{
if (AnalyzerResult.FlowControlInformation.EndPointIsReachable)
return statements;
// All the flow control in the analyzed block is the same (for example, all breaks/continues/returns).
// In this case add a specific instance of that same flow control construct after the call to the new
// method to ensure we preserve original control flow.
if (AnalyzerResult.FlowControlInformation.HasUniformControlFlow())
{
if (AnalyzerResult.FlowControlInformation.BreakStatementCount > 0)
return statements.Concat(this.CreateBreakStatement());
else if (AnalyzerResult.FlowControlInformation.ContinueStatementCount > 0)
return statements.Concat(this.CreateContinueStatement());
}
var returnType = SelectionResult.GetReturnType(cancellationToken);
if (returnType != null && returnType.SpecialType != SpecialType.System_Void)
return statements;
// no return type + end of selection not reachable
if (LastStatementOrHasReturnStatementInReturnableConstruct())
return statements;
return statements.Concat(CreateReturnStatement([]));
}
private TExecutableStatementSyntax CreateReturnStatement(
ImmutableArray<TExpressionSyntax> expressions)
{
var generator = this.SemanticDocument.GetRequiredLanguageService<SyntaxGenerator>();
return (TExecutableStatementSyntax)generator.ReturnStatement(CreateReturnExpression(expressions));
}
private TExpressionSyntax CreateReturnExpression(ImmutableArray<TExpressionSyntax> expressions)
{
var generator = this.SemanticDocument.GetRequiredLanguageService<SyntaxGenerator>();
return
expressions.Length == 0 ? null :
expressions.Length == 1 ? expressions[0] :
(TExpressionSyntax)generator.TupleExpression(expressions.Select(generator.Argument));
}
protected async Task<ImmutableArray<TStatementSyntax>> AddInvocationAtCallSiteAsync(
ImmutableArray<TStatementSyntax> statements, CancellationToken cancellationToken)
{
// If the newly extracted method isn't returning any data, and doesn't have complex flow control, then
// we want to handle that here. The case where we do need to pass data out is in AddAssignmentStatementToCallSite.
if (AnalyzerResult.VariablesToUseAsReturnValue.IsEmpty &&
!AnalyzerResult.FlowControlInformation.NeedsControlFlowValue())
{
Contract.ThrowIfTrue(AnalyzerResult.GetVariablesToSplitOrMoveOutToCallSite().Any(v => v.UseAsReturnValue));
// add invocation expression
return statements.Concat(
(TStatementSyntax)(SyntaxNode)await GetStatementOrInitializerContainingInvocationToExtractedMethodAsync(cancellationToken).ConfigureAwait(false));
}
return statements;
}
protected ImmutableArray<TStatementSyntax> AddAssignmentStatementToCallSite(
ImmutableArray<TStatementSyntax> statements,
CancellationToken cancellationToken)
{
if (AnalyzerResult.VariablesToUseAsReturnValue.IsEmpty &&
!AnalyzerResult.FlowControlInformation.NeedsControlFlowValue())
{
return statements;
}
var flowControlInformation = AnalyzerResult.FlowControlInformation;
var variables = AnalyzerResult.VariablesToUseAsReturnValue;
if (variables.Any(v => v.ReturnBehavior == ReturnBehavior.Initialization) ||
flowControlInformation.NeedsControlFlowValue())
{
var declarationStatement = CreateDeclarationStatement(
variables, CreateCallSignature(), flowControlInformation, cancellationToken);
return statements.Concat(declarationStatement.WithAdditionalAnnotations(CallSiteAnnotation));
}
return statements.Concat(
CreateAssignmentExpressionStatement(variables, CreateCallSignature()).WithAdditionalAnnotations(CallSiteAnnotation));
}
protected ImmutableArray<TStatementSyntax> CreateDeclarationStatements(
ImmutableArray<VariableInfo> variables, CancellationToken cancellationToken)
{
return variables.SelectAsArray(
v => CreateDeclarationStatement([v], initialValue: null, flowControlInformation: null, cancellationToken));
}
protected ImmutableArray<TStatementSyntax> AddSplitOrMoveDeclarationOutStatementsToCallSite(
CancellationToken cancellationToken)
{
using var _ = ArrayBuilder<TStatementSyntax>.GetInstance(out var list);
foreach (var variable in AnalyzerResult.GetVariablesToSplitOrMoveOutToCallSite())
{
if (variable.UseAsReturnValue)
continue;
list.Add(CreateDeclarationStatement(
[variable], initialValue: null, flowControlInformation: null, cancellationToken));
}
return list.ToImmutableAndClear();
}
protected ImmutableArray<TStatementSyntax> AppendReturnStatementIfNeeded(ImmutableArray<TStatementSyntax> statements)
{
// No need to add a return statement if we already have one.
var syntaxFacts = this.SemanticDocument.GetRequiredLanguageService<ISyntaxFactsService>();
if (statements is [.., var lastStatement] &&
syntaxFacts.IsReturnStatement(lastStatement))
{
return statements;
}
var generator = this.SemanticDocument.GetRequiredLanguageService<SyntaxGenerator>();
if (this.AnalyzerResult.FlowControlInformation.TryGetFallThroughFlowValue(out var fallthroughValue))
{
return statements.Concat(CreateReturnStatement([CreateFlowControlReturnExpression(this.AnalyzerResult.FlowControlInformation, fallthroughValue)]));
}
else if (!this.AnalyzerResult.VariablesToUseAsReturnValue.IsEmpty)
{
return statements.Concat(CreateReturnStatement([
CreateReturnExpression(AnalyzerResult.VariablesToUseAsReturnValue.SelectAsArray(
static (v, generator) => (TExpressionSyntax)generator.IdentifierName(v.Name),
generator))]));
}
else
{
return statements;
}
}
protected static HashSet<SyntaxAnnotation> CreateVariableDeclarationToRemoveMap(
IEnumerable<VariableInfo> variables, CancellationToken cancellationToken)
{
var annotations = new MultiDictionary<SyntaxToken, SyntaxAnnotation>();
foreach (var variable in variables)
{
Contract.ThrowIfFalse(variable.GetDeclarationBehavior() is
DeclarationBehavior.MoveOut or
DeclarationBehavior.MoveIn);
variable.AddIdentifierTokenAnnotationPair(annotations, cancellationToken);
}
return [.. annotations.Values.SelectMany(v => v)];
}
protected ImmutableArray<ITypeParameterSymbol> CreateMethodTypeParameters()
{
if (AnalyzerResult.MethodTypeParametersInDeclaration.IsEmpty)
return [];
var set = new HashSet<ITypeParameterSymbol>(AnalyzerResult.MethodTypeParametersInConstraintList);
var typeParameters = ArrayBuilder<ITypeParameterSymbol>.GetInstance();
foreach (var parameter in AnalyzerResult.MethodTypeParametersInDeclaration)
{
if (parameter != null && set.Contains(parameter))
{
typeParameters.Add(parameter);
continue;
}
typeParameters.Add(CodeGenerationSymbolFactory.CreateTypeParameter(
parameter.GetAttributes(), parameter.Variance, parameter.Name, [], parameter.NullableAnnotation,
parameter.HasConstructorConstraint, parameter.HasReferenceTypeConstraint, parameter.HasUnmanagedTypeConstraint,
parameter.HasValueTypeConstraint, parameter.HasNotNullConstraint, parameter.AllowsRefLikeType, parameter.Ordinal));
}
return typeParameters.ToImmutableAndFree();
}
protected ImmutableArray<IParameterSymbol> CreateMethodParameters()
{
using var _ = ArrayBuilder<IParameterSymbol>.GetInstance(out var parameters);
var isLocalFunction = LocalFunction && ShouldLocalFunctionCaptureParameter(SemanticDocument.Root);
foreach (var parameter in AnalyzerResult.MethodParameters)
{
if (!isLocalFunction || !parameter.CanBeCapturedByLocalFunction)
{
var refKind = GetRefKind(parameter.ParameterModifier);
parameters.Add(CodeGenerationSymbolFactory.CreateParameterSymbol(
attributes: [],
refKind: refKind,
isParams: false,
type: parameter.SymbolType,
name: parameter.Name));
}
}
return parameters.ToImmutableAndClear();
}
private static RefKind GetRefKind(ParameterBehavior parameterBehavior)
=> parameterBehavior switch
{
ParameterBehavior.Ref => RefKind.Ref,
ParameterBehavior.Out => RefKind.Out,
_ => RefKind.None
};
protected TExecutableStatementSyntax GetStatementContainingInvocationToExtractedMethodWorker()
{
var callSignature = CreateCallSignature();
var generator = this.SemanticDocument.Document.GetRequiredLanguageService<SyntaxGenerator>();
return AnalyzerResult.CoreReturnType.SpecialType != SpecialType.System_Void
? (TExecutableStatementSyntax)generator.ReturnStatement(callSignature)
: (TExecutableStatementSyntax)generator.ExpressionStatement(callSignature);
}
public ITypeSymbol GetFinalReturnType()
{
return _finalReturnType ??= WrapWithTaskIfNecessary(AddFlowControlTypeIfNecessary(this.AnalyzerResult.CoreReturnType));
ITypeSymbol AddFlowControlTypeIfNecessary(ITypeSymbol coreReturnType)
{
var controlFlowValueType = this.AnalyzerResult.FlowControlInformation.ControlFlowValueType;
// If don't need to report complex flow control to the caller. Just return whatever the inner method wanted to iriginally return.
if (controlFlowValueType.SpecialType == SpecialType.System_Void)
return coreReturnType;
// We need to report complex flow control to the caller.
// If the method wasn't going to return any values to begin with, then all we have to do is
// return the control value value to the caller to indicate what flow control path to take.
if (coreReturnType.SpecialType == SpecialType.System_Void)
return controlFlowValueType;
// We need to report both the control flow data and the original data.
var compilation = this.SemanticDocument.SemanticModel.Compilation;
return compilation.CreateTupleTypeSymbol(
[controlFlowValueType, coreReturnType],
[FlowControlName, ReturnValueName]);
}
ITypeSymbol WrapWithTaskIfNecessary(ITypeSymbol type)
{
if (!this.SelectionResult.ContainsAwaitExpression())
return type;
// If we're awaiting, then we're going to be returning a task of some sort. Convert `void` to
// `Task` and any other T to `Task<T>`.
var compilation = this.SemanticDocument.SemanticModel.Compilation;
return type.SpecialType == SpecialType.System_Void
? compilation.TaskType()
: compilation.TaskOfTType().Construct(type);
}
}
}
}
}
|