|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#nullable disable
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics.Analyzers.NamingStyles;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.ExtractMethod;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Simplification;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.ExtractMethod;
using static CSharpSyntaxTokens;
using static SyntaxFactory;
internal sealed partial class CSharpMethodExtractor
{
private abstract partial class CSharpCodeGenerator : CodeGenerator<StatementSyntax, SyntaxNode, CSharpCodeGenerationOptions>
{
private readonly SyntaxToken _methodName;
private const string NewMethodPascalCaseStr = "NewMethod";
private const string NewMethodCamelCaseStr = "newMethod";
public static Task<GeneratedCode> GenerateAsync(
InsertionPoint insertionPoint,
CSharpSelectionResult selectionResult,
AnalyzerResult analyzerResult,
CSharpCodeGenerationOptions options,
bool localFunction,
CancellationToken cancellationToken)
{
var codeGenerator = Create(selectionResult, analyzerResult, options, localFunction);
return codeGenerator.GenerateAsync(insertionPoint, cancellationToken);
}
public static CSharpCodeGenerator Create(
CSharpSelectionResult selectionResult,
AnalyzerResult analyzerResult,
CSharpCodeGenerationOptions options,
bool localFunction)
{
if (selectionResult.SelectionInExpression)
return new ExpressionCodeGenerator(selectionResult, analyzerResult, options, localFunction);
if (selectionResult.IsExtractMethodOnSingleStatement())
return new SingleStatementCodeGenerator(selectionResult, analyzerResult, options, localFunction);
if (selectionResult.IsExtractMethodOnMultipleStatements())
return new MultipleStatementsCodeGenerator(selectionResult, analyzerResult, options, localFunction);
throw ExceptionUtilities.UnexpectedValue(selectionResult);
}
protected CSharpCodeGenerator(
CSharpSelectionResult selectionResult,
AnalyzerResult analyzerResult,
CSharpCodeGenerationOptions options,
bool localFunction)
: base(selectionResult, analyzerResult, options, localFunction)
{
Contract.ThrowIfFalse(SemanticDocument == selectionResult.SemanticDocument);
var nameToken = CreateMethodName();
_methodName = nameToken.WithAdditionalAnnotations(MethodNameAnnotation);
}
public override OperationStatus<ImmutableArray<SyntaxNode>> GetNewMethodStatements(SyntaxNode insertionPointNode, CancellationToken cancellationToken)
{
var statements = CreateMethodBody(insertionPointNode, cancellationToken);
var status = CheckActiveStatements(statements);
return status.With(statements.CastArray<SyntaxNode>());
}
protected override IMethodSymbol GenerateMethodDefinition(
SyntaxNode insertionPointNode, CancellationToken cancellationToken)
{
var statements = CreateMethodBody(insertionPointNode, cancellationToken);
statements = WrapInCheckStatementIfNeeded(statements);
var methodSymbol = CodeGenerationSymbolFactory.CreateMethodSymbol(
attributes: [],
accessibility: Accessibility.Private,
modifiers: CreateMethodModifiers(),
returnType: AnalyzerResult.ReturnType,
refKind: AnalyzerResult.ReturnsByRef ? RefKind.Ref : RefKind.None,
explicitInterfaceImplementations: default,
name: _methodName.ToString(),
typeParameters: CreateMethodTypeParameters(),
parameters: CreateMethodParameters(),
statements: statements.CastArray<SyntaxNode>(),
methodKind: this.LocalFunction ? MethodKind.LocalFunction : MethodKind.Ordinary);
return MethodDefinitionAnnotation.AddAnnotationToSymbol(
Formatter.Annotation.AddAnnotationToSymbol(methodSymbol));
}
protected override async Task<SyntaxNode> GenerateBodyForCallSiteContainerAsync(
SyntaxNode insertionPointNode,
SyntaxNode container,
CancellationToken cancellationToken)
{
var variableMapToRemove = CreateVariableDeclarationToRemoveMap(
AnalyzerResult.GetVariablesToMoveIntoMethodDefinition(cancellationToken), cancellationToken);
var firstStatementToRemove = GetFirstStatementOrInitializerSelectedAtCallSite();
var lastStatementToRemove = GetLastStatementOrInitializerSelectedAtCallSite();
Contract.ThrowIfFalse(firstStatementToRemove.Parent == lastStatementToRemove.Parent
|| CSharpSyntaxFacts.Instance.AreStatementsInSameContainer(firstStatementToRemove, lastStatementToRemove));
var statementsToInsert = await CreateStatementsOrInitializerToInsertAtCallSiteAsync(
insertionPointNode, cancellationToken).ConfigureAwait(false);
var callSiteGenerator = new CallSiteContainerRewriter(
container,
variableMapToRemove,
firstStatementToRemove,
lastStatementToRemove,
statementsToInsert);
return container.CopyAnnotationsTo(callSiteGenerator.Generate()).WithAdditionalAnnotations(Formatter.Annotation);
}
private async Task<ImmutableArray<SyntaxNode>> CreateStatementsOrInitializerToInsertAtCallSiteAsync(
SyntaxNode insertionPointNode, CancellationToken cancellationToken)
{
var selectedNode = GetFirstStatementOrInitializerSelectedAtCallSite();
// field initializer, constructor initializer, expression bodied member case
if (selectedNode is ConstructorInitializerSyntax or FieldDeclarationSyntax ||
IsExpressionBodiedMember(selectedNode) ||
IsExpressionBodiedAccessor(selectedNode))
{
var statement = await GetStatementOrInitializerContainingInvocationToExtractedMethodAsync(cancellationToken).ConfigureAwait(false);
return [statement];
}
// regular case
var semanticModel = SemanticDocument.SemanticModel;
var postProcessor = new PostProcessor(semanticModel, insertionPointNode.SpanStart);
var statements = AddSplitOrMoveDeclarationOutStatementsToCallSite(cancellationToken);
statements = postProcessor.MergeDeclarationStatements(statements);
statements = AddAssignmentStatementToCallSite(statements, cancellationToken);
statements = await AddInvocationAtCallSiteAsync(statements, cancellationToken).ConfigureAwait(false);
statements = AddReturnIfUnreachable(statements);
return statements.CastArray<SyntaxNode>();
}
protected override bool ShouldLocalFunctionCaptureParameter(SyntaxNode node)
=> node.SyntaxTree.Options.LanguageVersion() < LanguageVersion.CSharp8;
private static bool IsExpressionBodiedMember(SyntaxNode node)
=> node is MemberDeclarationSyntax member && member.GetExpressionBody() != null;
private static bool IsExpressionBodiedAccessor(SyntaxNode node)
=> node is AccessorDeclarationSyntax accessor && accessor.ExpressionBody != null;
private SimpleNameSyntax CreateMethodNameForInvocation()
{
return AnalyzerResult.MethodTypeParametersInDeclaration.Count == 0
? IdentifierName(_methodName)
: GenericName(_methodName, TypeArgumentList(CreateMethodCallTypeVariables()));
}
private SeparatedSyntaxList<TypeSyntax> CreateMethodCallTypeVariables()
{
Contract.ThrowIfTrue(AnalyzerResult.MethodTypeParametersInDeclaration.Count == 0);
// propagate any type variable used in extracted code
return [.. AnalyzerResult.MethodTypeParametersInDeclaration.Select(m => SyntaxFactory.ParseTypeName(m.Name))];
}
protected override SyntaxNode GetCallSiteContainerFromOutermostMoveInVariable(CancellationToken cancellationToken)
{
var outmostVariable = GetOutermostVariableToMoveIntoMethodDefinition(cancellationToken);
if (outmostVariable == null)
return null;
var idToken = outmostVariable.GetIdentifierTokenAtDeclaration(SemanticDocument);
var declStatement = idToken.GetAncestor<LocalDeclarationStatementSyntax>();
Contract.ThrowIfNull(declStatement);
Contract.ThrowIfFalse(declStatement.Parent.IsStatementContainerNode());
return declStatement.Parent;
}
private DeclarationModifiers CreateMethodModifiers()
{
var isUnsafe = this.SelectionResult.ShouldPutUnsafeModifier();
var isAsync = this.SelectionResult.ShouldPutAsyncModifier();
var isStatic = !AnalyzerResult.UseInstanceMember;
var isReadOnly = AnalyzerResult.ShouldBeReadOnly;
// Static local functions are only supported in C# 8.0 and later
var languageVersion = SemanticDocument.SyntaxTree.Options.LanguageVersion();
if (LocalFunction && (!Options.PreferStaticLocalFunction.Value || languageVersion < LanguageVersion.CSharp8))
{
isStatic = false;
}
// UseInstanceMember will be false for interface members, but extracting a non-static
// member to a static member has a very different meaning for interfaces so we need
// an extra check here.
if (!LocalFunction && IsNonStaticInterfaceMember())
{
isStatic = false;
}
return new DeclarationModifiers(
isUnsafe: isUnsafe,
isAsync: isAsync,
isStatic: isStatic,
isReadOnly: isReadOnly);
}
private bool IsNonStaticInterfaceMember()
{
var typeDecl = SelectionResult.GetContainingScopeOf<BaseTypeDeclarationSyntax>();
if (typeDecl is null)
return false;
if (!typeDecl.IsKind(SyntaxKind.InterfaceDeclaration))
return false;
var memberDecl = SelectionResult.GetContainingScopeOf<MemberDeclarationSyntax>();
if (memberDecl is null)
return false;
return !memberDecl.Modifiers.Any(SyntaxKind.StaticKeyword);
}
private static SyntaxKind GetParameterRefSyntaxKind(ParameterBehavior parameterBehavior)
{
return parameterBehavior == ParameterBehavior.Ref
? SyntaxKind.RefKeyword
: parameterBehavior == ParameterBehavior.Out ?
SyntaxKind.OutKeyword : SyntaxKind.None;
}
private ImmutableArray<StatementSyntax> CreateMethodBody(
SyntaxNode insertionPoint, CancellationToken cancellationToken)
{
var statements = GetInitialStatementsForMethodDefinitions();
statements = SplitOrMoveDeclarationIntoMethodDefinition(insertionPoint, statements, cancellationToken);
statements = MoveDeclarationOutFromMethodDefinition(statements, cancellationToken);
statements = AppendReturnStatementIfNeeded(statements);
statements = CleanupCode(statements);
return statements;
}
private ImmutableArray<StatementSyntax> WrapInCheckStatementIfNeeded(ImmutableArray<StatementSyntax> statements)
{
var kind = this.SelectionResult.UnderCheckedStatementContext();
if (kind == SyntaxKind.None)
return statements;
return statements is [BlockSyntax block]
? [CheckedStatement(kind, block)]
: [CheckedStatement(kind, Block(statements))];
}
private static ImmutableArray<StatementSyntax> CleanupCode(ImmutableArray<StatementSyntax> statements)
{
statements = PostProcessor.RemoveRedundantBlock(statements);
statements = PostProcessor.RemoveDeclarationAssignmentPattern(statements);
statements = PostProcessor.RemoveInitializedDeclarationAndReturnPattern(statements);
return statements;
}
private static OperationStatus CheckActiveStatements(ImmutableArray<StatementSyntax> statements)
{
if (statements.IsEmpty)
return OperationStatus.NoActiveStatement;
if (statements is [ReturnStatementSyntax { Expression: null }])
return OperationStatus.NoActiveStatement;
// Look for at least one non local-variable-decl statement, or at least one local variable with an initializer.
foreach (var statement in statements)
{
if (statement is not LocalDeclarationStatementSyntax declStatement)
return OperationStatus.SucceededStatus;
foreach (var variable in declStatement.Declaration.Variables)
{
if (variable.Initializer != null)
return OperationStatus.SucceededStatus;
}
}
return OperationStatus.NoActiveStatement;
}
private ImmutableArray<StatementSyntax> MoveDeclarationOutFromMethodDefinition(
ImmutableArray<StatementSyntax> statements, CancellationToken cancellationToken)
{
using var _ = ArrayBuilder<StatementSyntax>.GetInstance(out var result);
var variableToRemoveMap = CreateVariableDeclarationToRemoveMap(
AnalyzerResult.GetVariablesToMoveOutToCallSiteOrDelete(cancellationToken), cancellationToken);
statements = statements.SelectAsArray(s => FixDeclarationExpressionsAndDeclarationPatterns(s, variableToRemoveMap));
foreach (var statement in statements)
{
if (statement is not LocalDeclarationStatementSyntax declarationStatement || declarationStatement.Declaration.Variables.FullSpan.IsEmpty)
{
// if given statement is not decl statement.
result.Add(statement);
continue;
}
var expressionStatements = new List<StatementSyntax>();
var list = new List<VariableDeclaratorSyntax>();
var triviaList = new List<SyntaxTrivia>();
// When we modify the declaration to an initialization we have to preserve the leading trivia
var firstVariableToAttachTrivia = true;
// go through each var decls in decl statement, and create new assignment if
// variable is initialized at decl.
foreach (var variableDeclaration in declarationStatement.Declaration.Variables)
{
if (variableToRemoveMap.HasSyntaxAnnotation(variableDeclaration))
{
if (variableDeclaration.Initializer != null)
{
var identifier = ApplyTriviaFromDeclarationToAssignmentIdentifier(declarationStatement, firstVariableToAttachTrivia, variableDeclaration);
// move comments with the variable here
expressionStatements.Add(CreateAssignmentExpressionStatement(identifier, variableDeclaration.Initializer.Value));
}
else
{
// we don't remove trivia around tokens we remove
triviaList.AddRange(variableDeclaration.GetLeadingTrivia());
triviaList.AddRange(variableDeclaration.GetTrailingTrivia());
}
firstVariableToAttachTrivia = false;
continue;
}
// Prepend the trivia from the declarations without initialization to the next persisting variable declaration
if (triviaList.Count > 0)
{
list.Add(variableDeclaration.WithPrependedLeadingTrivia(triviaList));
triviaList.Clear();
firstVariableToAttachTrivia = false;
continue;
}
firstVariableToAttachTrivia = false;
list.Add(variableDeclaration);
}
if (list.Count == 0 && triviaList.Count > 0)
{
// well, there are trivia associated with the node.
// we can't just delete the node since then, we will lose
// the trivia. unfortunately, it is not easy to attach the trivia
// to next token. for now, create an empty statement and associate the
// trivia to the statement
// TODO : think about a way to trivia attached to next token
result.Add(EmptyStatement(Token([.. triviaList], SyntaxKind.SemicolonToken, [ElasticMarker])));
triviaList.Clear();
}
// return survived var decls
if (list.Count > 0)
{
result.Add(LocalDeclarationStatement(
declarationStatement.Modifiers,
VariableDeclaration(
declarationStatement.Declaration.Type,
[.. list]),
declarationStatement.SemicolonToken.WithPrependedLeadingTrivia(triviaList)));
triviaList.Clear();
}
// return any expression statement if there was any
result.AddRange(expressionStatements);
}
return result.ToImmutableAndClear();
}
/// <summary>
/// If the statement has an <c>out var</c> declaration expression for a variable which
/// needs to be removed, we need to turn it into a plain <c>out</c> parameter, so that
/// it doesn't declare a duplicate variable.
/// If the statement has a pattern declaration (such as <c>3 is int i</c>) for a variable
/// which needs to be removed, we will annotate it as a conflict, since we don't have
/// a better refactoring.
/// </summary>
private static StatementSyntax FixDeclarationExpressionsAndDeclarationPatterns(StatementSyntax statement,
HashSet<SyntaxAnnotation> variablesToRemove)
{
var replacements = new Dictionary<SyntaxNode, SyntaxNode>();
var declarations = statement.DescendantNodes()
.Where(n => n.Kind() is SyntaxKind.DeclarationExpression or SyntaxKind.DeclarationPattern);
foreach (var node in declarations)
{
switch (node.Kind())
{
case SyntaxKind.DeclarationExpression:
{
var declaration = (DeclarationExpressionSyntax)node;
if (declaration.Designation.Kind() != SyntaxKind.SingleVariableDesignation)
{
break;
}
var designation = (SingleVariableDesignationSyntax)declaration.Designation;
var name = designation.Identifier.ValueText;
if (variablesToRemove.HasSyntaxAnnotation(designation))
{
var newLeadingTrivia = new SyntaxTriviaList();
newLeadingTrivia = newLeadingTrivia.AddRange(declaration.Type.GetLeadingTrivia());
newLeadingTrivia = newLeadingTrivia.AddRange(declaration.Type.GetTrailingTrivia());
newLeadingTrivia = newLeadingTrivia.AddRange(designation.GetLeadingTrivia());
replacements.Add(declaration, IdentifierName(designation.Identifier)
.WithLeadingTrivia(newLeadingTrivia));
}
break;
}
case SyntaxKind.DeclarationPattern:
{
var pattern = (DeclarationPatternSyntax)node;
if (!variablesToRemove.HasSyntaxAnnotation(pattern))
{
break;
}
// We don't have a good refactoring for this, so we just annotate the conflict
// For instance, when a local declared by a pattern declaration (`3 is int i`) is
// used outside the block we're trying to extract.
if (pattern.Designation is not SingleVariableDesignationSyntax designation)
{
break;
}
var identifier = designation.Identifier;
var annotation = ConflictAnnotation.Create(CSharpFeaturesResources.Conflict_s_detected);
var newIdentifier = identifier.WithAdditionalAnnotations(annotation);
var newDesignation = designation.WithIdentifier(newIdentifier);
replacements.Add(pattern, pattern.WithDesignation(newDesignation));
break;
}
}
}
return statement.ReplaceNodes(replacements.Keys, (orig, partiallyReplaced) => replacements[orig]);
}
private static SyntaxToken ApplyTriviaFromDeclarationToAssignmentIdentifier(LocalDeclarationStatementSyntax declarationStatement, bool firstVariableToAttachTrivia, VariableDeclaratorSyntax variable)
{
var identifier = variable.Identifier;
var typeSyntax = declarationStatement.Declaration.Type;
if (firstVariableToAttachTrivia && typeSyntax != null)
{
var identifierLeadingTrivia = new SyntaxTriviaList();
if (typeSyntax.HasLeadingTrivia)
{
identifierLeadingTrivia = identifierLeadingTrivia.AddRange(typeSyntax.GetLeadingTrivia());
}
identifierLeadingTrivia = identifierLeadingTrivia.AddRange(identifier.LeadingTrivia);
identifier = identifier.WithLeadingTrivia(identifierLeadingTrivia);
}
return identifier;
}
private ImmutableArray<StatementSyntax> SplitOrMoveDeclarationIntoMethodDefinition(
SyntaxNode insertionPointNode,
ImmutableArray<StatementSyntax> statements,
CancellationToken cancellationToken)
{
var semanticModel = SemanticDocument.SemanticModel;
var postProcessor = new PostProcessor(semanticModel, insertionPointNode.SpanStart);
var declStatements = CreateDeclarationStatements(AnalyzerResult.GetVariablesToSplitOrMoveIntoMethodDefinition(cancellationToken), cancellationToken);
declStatements = postProcessor.MergeDeclarationStatements(declStatements);
return declStatements.Concat(statements);
}
private static ExpressionSyntax CreateAssignmentExpression(SyntaxToken identifier, ExpressionSyntax rvalue)
{
return AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(identifier),
rvalue);
}
protected override bool LastStatementOrHasReturnStatementInReturnableConstruct()
{
var lastStatement = GetLastStatementOrInitializerSelectedAtCallSite();
var container = lastStatement.GetAncestorsOrThis<SyntaxNode>().FirstOrDefault(n => n.IsReturnableConstruct());
if (container == null)
{
// case such as field initializer
return false;
}
var blockBody = container.GetBlockBody();
if (blockBody == null)
{
// such as expression lambda. there is no statement
return false;
}
// check whether it is last statement except return statement
var statements = blockBody.Statements;
if (statements.Last() == lastStatement)
{
return true;
}
var index = statements.IndexOf((StatementSyntax)lastStatement);
return statements[index + 1].Kind() == SyntaxKind.ReturnStatement;
}
protected override SyntaxToken CreateIdentifier(string name)
=> Identifier(name);
protected override StatementSyntax CreateReturnStatement(string identifierName = null)
{
return string.IsNullOrEmpty(identifierName)
? ReturnStatement()
: ReturnStatement(IdentifierName(identifierName));
}
protected override ExpressionSyntax CreateCallSignature()
{
var methodName = CreateMethodNameForInvocation().WithAdditionalAnnotations(Simplifier.Annotation);
var isLocalFunction = LocalFunction && ShouldLocalFunctionCaptureParameter(SemanticDocument.Root);
using var _ = ArrayBuilder<ArgumentSyntax>.GetInstance(out var arguments);
foreach (var argument in AnalyzerResult.MethodParameters)
{
if (!isLocalFunction || !argument.CanBeCapturedByLocalFunction)
{
var modifier = GetParameterRefSyntaxKind(argument.ParameterModifier);
var refOrOut = modifier == SyntaxKind.None ? default : Token(modifier);
arguments.Add(Argument(IdentifierName(argument.Name)).WithRefOrOutKeyword(refOrOut));
}
}
var invocation = (ExpressionSyntax)InvocationExpression(methodName, ArgumentList([.. arguments]));
if (this.SelectionResult.ShouldPutAsyncModifier())
{
if (this.SelectionResult.ShouldCallConfigureAwaitFalse())
{
if (AnalyzerResult.ReturnType.GetMembers().Any(static x => x is IMethodSymbol
{
Name: nameof(Task.ConfigureAwait),
Parameters: [{ Type.SpecialType: SpecialType.System_Boolean }],
}))
{
invocation = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
invocation,
IdentifierName(nameof(Task.ConfigureAwait))),
ArgumentList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))]));
}
}
invocation = AwaitExpression(invocation);
}
if (AnalyzerResult.ReturnsByRef)
invocation = RefExpression(invocation);
return invocation;
}
protected override StatementSyntax CreateAssignmentExpressionStatement(SyntaxToken identifier, ExpressionSyntax rvalue)
=> ExpressionStatement(CreateAssignmentExpression(identifier, rvalue));
protected override StatementSyntax CreateDeclarationStatement(
VariableInfo variable,
ExpressionSyntax initialValue,
CancellationToken cancellationToken)
{
var type = variable.GetVariableType();
var typeNode = type.GenerateTypeSyntax();
var originalIdentifierToken = variable.GetOriginalIdentifierToken(cancellationToken);
// Hierarchy being checked for to see if a using keyword is needed is
// Token -> VariableDeclarator -> VariableDeclaration -> LocalDeclaration
var usingKeyword = originalIdentifierToken.Parent?.Parent?.Parent is LocalDeclarationStatementSyntax { UsingKeyword.FullSpan.IsEmpty: false }
? UsingKeyword
: default;
var equalsValueClause = initialValue == null ? null : EqualsValueClause(value: initialValue);
return LocalDeclarationStatement(
VariableDeclaration(typeNode)
.AddVariables(VariableDeclarator(Identifier(variable.Name))
.WithInitializer(equalsValueClause)))
.WithUsingKeyword(usingKeyword);
}
protected override async Task<GeneratedCode> CreateGeneratedCodeAsync(
SemanticDocument newDocument, CancellationToken cancellationToken)
{
// in hybrid code cases such as extract method, formatter will have some difficulties on where it breaks lines in two.
// here, we explicitly insert newline at the end of "{" of auto generated method decl so that anchor knows how to find out
// indentation of inserted statements (from users code) with user code style preserved
var root = newDocument.Root;
var methodDefinition = root.GetAnnotatedNodes<SyntaxNode>(MethodDefinitionAnnotation).First();
SyntaxNode newMethodDefinition = methodDefinition switch
{
MethodDeclarationSyntax method => TweakNewLinesInMethod(method),
LocalFunctionStatementSyntax localFunction => TweakNewLinesInMethod(localFunction),
_ => throw new NotSupportedException("SyntaxNode expected to be MethodDeclarationSyntax or LocalFunctionStatementSyntax."),
};
newDocument = await newDocument.WithSyntaxRootAsync(
root.ReplaceNode(methodDefinition, newMethodDefinition), cancellationToken).ConfigureAwait(false);
return await base.CreateGeneratedCodeAsync(newDocument, cancellationToken).ConfigureAwait(false);
}
private static MethodDeclarationSyntax TweakNewLinesInMethod(MethodDeclarationSyntax method)
=> TweakNewLinesInMethod(method, method.Body, method.ExpressionBody);
private static LocalFunctionStatementSyntax TweakNewLinesInMethod(LocalFunctionStatementSyntax method)
=> TweakNewLinesInMethod(method, method.Body, method.ExpressionBody);
private static TDeclarationNode TweakNewLinesInMethod<TDeclarationNode>(TDeclarationNode method, BlockSyntax body, ArrowExpressionClauseSyntax expressionBody) where TDeclarationNode : SyntaxNode
{
if (body != null)
{
return method.ReplaceToken(
body.OpenBraceToken,
body.OpenBraceToken.WithAppendedTrailingTrivia(
ElasticCarriageReturnLineFeed));
}
else if (expressionBody != null)
{
return method.ReplaceToken(
expressionBody.ArrowToken,
expressionBody.ArrowToken.WithPrependedLeadingTrivia(
ElasticCarriageReturnLineFeed));
}
else
{
return method;
}
}
protected StatementSyntax GetStatementContainingInvocationToExtractedMethodWorker()
{
var callSignature = CreateCallSignature();
if (AnalyzerResult.HasReturnType)
{
Contract.ThrowIfTrue(AnalyzerResult.HasVariableToUseAsReturnValue);
return ReturnStatement(callSignature);
}
return ExpressionStatement(callSignature);
}
protected override async Task<SemanticDocument> UpdateMethodAfterGenerationAsync(
SemanticDocument originalDocument,
IMethodSymbol methodSymbol,
CancellationToken cancellationToken)
{
// Only need to update for nullable reference types in return
if (methodSymbol.ReturnType.NullableAnnotation != NullableAnnotation.Annotated)
return originalDocument;
var syntaxNode = originalDocument.Root.GetAnnotatedNodesAndTokens(MethodDefinitionAnnotation).FirstOrDefault().AsNode();
var nodeIsMethodOrLocalFunction = syntaxNode is MethodDeclarationSyntax or LocalFunctionStatementSyntax;
if (!nodeIsMethodOrLocalFunction)
return originalDocument;
var nullableReturnOperations = CheckReturnOperations(syntaxNode, originalDocument, cancellationToken);
if (nullableReturnOperations is not null)
return nullableReturnOperations;
var returnType = syntaxNode is MethodDeclarationSyntax method ? method.ReturnType : ((LocalFunctionStatementSyntax)syntaxNode).ReturnType;
var newDocument = await GenerateNewDocumentAsync(methodSymbol, returnType, originalDocument, cancellationToken).ConfigureAwait(false);
return await SemanticDocument.CreateAsync(newDocument, cancellationToken).ConfigureAwait(false);
static bool ReturnOperationBelongsToMethod(SyntaxNode returnOperationSyntax, SyntaxNode methodSyntax)
{
var enclosingMethod = returnOperationSyntax.FirstAncestorOrSelf<SyntaxNode>(n => n switch
{
BaseMethodDeclarationSyntax _ => true,
AnonymousFunctionExpressionSyntax _ => true,
LocalFunctionStatementSyntax _ => true,
_ => false
});
return enclosingMethod == methodSyntax;
}
static SemanticDocument CheckReturnOperations(
SyntaxNode node,
SemanticDocument originalDocument,
CancellationToken cancellationToken)
{
var semanticModel = originalDocument.SemanticModel;
var methodOperation = semanticModel.GetOperation(node, cancellationToken);
var returnOperations = methodOperation.DescendantsAndSelf().OfType<IReturnOperation>();
foreach (var returnOperation in returnOperations)
{
// If the return statement is located in a nested local function or lambda it
// shouldn't contribute to the nullability of the extracted method's return type
if (!ReturnOperationBelongsToMethod(returnOperation.Syntax, methodOperation.Syntax))
continue;
var syntax = returnOperation.ReturnedValue?.Syntax ?? returnOperation.Syntax;
var returnTypeInfo = semanticModel.GetTypeInfo(syntax, cancellationToken);
if (returnTypeInfo.Nullability.FlowState == NullableFlowState.MaybeNull)
{
// Flow state shows that return is correctly nullable
return originalDocument;
}
}
return null;
}
static async Task<Document> GenerateNewDocumentAsync(
IMethodSymbol methodSymbol,
TypeSyntax returnType,
SemanticDocument originalDocument,
CancellationToken cancellationToken)
{
// Return type can be updated to not be null
var newType = methodSymbol.ReturnType.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
var oldRoot = await originalDocument.Document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var newRoot = oldRoot.ReplaceNode(returnType, newType.GenerateTypeSyntax());
return originalDocument.Document.WithSyntaxRoot(newRoot);
}
}
protected SyntaxToken GenerateMethodNameForStatementGenerators()
{
var semanticModel = SemanticDocument.SemanticModel;
var nameGenerator = new UniqueNameGenerator(semanticModel);
var scope = this.SelectionResult.GetContainingScope();
// If extracting a local function, we want to ensure all local variables are considered when generating a unique name.
if (LocalFunction)
{
scope = this.SelectionResult.GetFirstTokenInSelection().Parent;
}
return Identifier(nameGenerator.CreateUniqueMethodName(scope, GenerateMethodNameFromUserPreference()));
}
protected string GenerateMethodNameFromUserPreference()
{
var methodName = NewMethodPascalCaseStr;
if (!LocalFunction)
{
return methodName;
}
// For local functions, pascal case and camel case should be the most common and therefore we only consider those cases.
var localFunctionPreferences = Options.NamingStyle.SymbolSpecifications.Where(symbol => symbol.AppliesTo(new SymbolSpecification.SymbolKindOrTypeKind(MethodKind.LocalFunction), CreateMethodModifiers(), null));
var namingRules = Options.NamingStyle.Rules.NamingRules;
var localFunctionKind = new SymbolSpecification.SymbolKindOrTypeKind(MethodKind.LocalFunction);
if (LocalFunction)
{
if (namingRules.Any(static (rule, arg) => rule.NamingStyle.CapitalizationScheme.Equals(Capitalization.CamelCase) && rule.SymbolSpecification.AppliesTo(arg.localFunctionKind, arg.self.CreateMethodModifiers(), null), (self: this, localFunctionKind)))
{
methodName = NewMethodCamelCaseStr;
}
}
// We default to pascal case.
return methodName;
}
}
}
|