|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.InitializeParameter;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.InitializeParameter;
using static CSharpSyntaxTokens;
using static SyntaxFactory;
internal static class InitializeParameterHelpers
{
public static Argument<ExpressionSyntax> GetArgument(ArgumentSyntax argument)
=> new(argument.GetRefKind(), argument.NameColon?.Name.Identifier.ValueText, argument.Expression);
public static async Task<Solution> AddAssignmentForPrimaryConstructorAsync(
Document document,
IParameterSymbol parameter,
ISymbol fieldOrProperty,
CancellationToken cancellationToken)
{
var project = document.Project;
var solution = project.Solution;
var solutionEditor = new SolutionEditor(solution);
var initializer = EqualsValueClause(IdentifierName(parameter.Name.EscapeIdentifier()));
// We're assigning the parameter to a field/prop. Convert all existing references to this primary constructor
// parameter (within this type) to refer to the field/prop now instead.
await UpdateParameterReferencesAsync(
solutionEditor, parameter, fieldOrProperty, cancellationToken).ConfigureAwait(false);
// We're updating an exiting field/prop.
if (fieldOrProperty is IPropertySymbol property)
{
var compilation = await project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
var initializeParameterService = document.GetRequiredLanguageService<IInitializeParameterService>();
var isThrowNotImplementedProperty = initializeParameterService.IsThrowNotImplementedProperty(
compilation, property, cancellationToken);
foreach (var syntaxRef in property.DeclaringSyntaxReferences)
{
if (syntaxRef.GetSyntax(cancellationToken) is PropertyDeclarationSyntax propertyDeclaration)
{
var editingDocument = solution.GetRequiredDocument(propertyDeclaration.SyntaxTree);
var editor = await solutionEditor.GetDocumentEditorAsync(editingDocument.Id, cancellationToken).ConfigureAwait(false);
// If the user had a property that has 'throw NotImplementedException' in it, then remove those throws.
var newPropertyDeclaration = isThrowNotImplementedProperty ? RemoveThrowNotImplemented(propertyDeclaration) : propertyDeclaration;
editor.ReplaceNode(
propertyDeclaration,
newPropertyDeclaration.WithoutTrailingTrivia()
.WithSemicolonToken(SemicolonToken.WithTrailingTrivia(newPropertyDeclaration.GetTrailingTrivia()))
.WithInitializer(initializer));
break;
}
}
}
else if (fieldOrProperty is IFieldSymbol field)
{
foreach (var syntaxRef in field.DeclaringSyntaxReferences)
{
if (syntaxRef.GetSyntax(cancellationToken) is VariableDeclaratorSyntax variableDeclarator)
{
var editingDocument = solution.GetRequiredDocument(variableDeclarator.SyntaxTree);
var editor = await solutionEditor.GetDocumentEditorAsync(editingDocument.Id, cancellationToken).ConfigureAwait(false);
editor.ReplaceNode(
variableDeclarator,
variableDeclarator.WithInitializer(initializer));
break;
}
}
}
return solutionEditor.GetChangedSolution();
}
public static async Task UpdateParameterReferencesAsync(
SolutionEditor solutionEditor,
IParameterSymbol parameter,
ISymbol fieldOrProperty,
CancellationToken cancellationToken)
{
var solution = solutionEditor.OriginalSolution;
var namedType = parameter.ContainingType;
var documents = namedType.DeclaringSyntaxReferences
.Select(r => solution.GetRequiredDocument(r.SyntaxTree))
.ToImmutableHashSet();
var references = await SymbolFinder.FindReferencesAsync(parameter, solution, documents, cancellationToken).ConfigureAwait(false);
var groups = references.SelectMany(static r => r.Locations.Where(loc => !loc.IsImplicit)).GroupBy(static loc => loc.Document);
foreach (var group in groups)
{
var editor = await solutionEditor.GetDocumentEditorAsync(group.Key.Id, cancellationToken).ConfigureAwait(false);
// We may hit a location multiple times due to how we do FAR for linked symbols, but each linked symbol is
// allowed to report the entire set of references it think it is compatible with. So ensure we're hitting
// each location only once.
foreach (var location in group.Distinct(LinkedFileReferenceLocationEqualityComparer.Instance))
{
var node = location.Location.FindNode(getInnermostNodeForTie: true, cancellationToken);
if (node is IdentifierNameSyntax { Parent: not NameColonSyntax } identifierName &&
identifierName.Identifier.ValueText == parameter.Name)
{
// we may have things like `new MyType(x: ...)` we don't want to update `x` there to 'X'
// just because we're generating a new property 'X' for the parameter to be assigned to.
editor.ReplaceNode(
identifierName,
IdentifierName(fieldOrProperty.Name.EscapeIdentifier()).WithTriviaFrom(identifierName));
}
}
}
}
public static bool IsFunctionDeclaration(SyntaxNode node)
=> node is BaseMethodDeclarationSyntax or LocalFunctionStatementSyntax or AnonymousFunctionExpressionSyntax;
public static SyntaxNode GetBody(SyntaxNode functionDeclaration)
=> functionDeclaration switch
{
BaseMethodDeclarationSyntax methodDeclaration => (SyntaxNode?)methodDeclaration.Body ?? methodDeclaration.ExpressionBody!,
LocalFunctionStatementSyntax localFunction => (SyntaxNode?)localFunction.Body ?? localFunction.ExpressionBody!,
AnonymousFunctionExpressionSyntax anonymousFunction => anonymousFunction.Body,
_ => throw ExceptionUtilities.UnexpectedValue(functionDeclaration),
};
private static SyntaxToken? TryGetSemicolonToken(SyntaxNode functionDeclaration)
=> functionDeclaration switch
{
BaseMethodDeclarationSyntax methodDeclaration => methodDeclaration.SemicolonToken,
LocalFunctionStatementSyntax localFunction => localFunction.SemicolonToken,
AnonymousFunctionExpressionSyntax _ => null,
_ => throw ExceptionUtilities.UnexpectedValue(functionDeclaration),
};
public static bool IsImplicitConversion(Compilation compilation, ITypeSymbol source, ITypeSymbol destination)
=> compilation.ClassifyConversion(source: source, destination: destination).IsImplicit;
public static SyntaxNode? TryGetLastStatement(IBlockOperation? blockStatement)
=> blockStatement?.Syntax is BlockSyntax block
? block.Statements.LastOrDefault()
: blockStatement?.Syntax;
public static void InsertStatement(
SyntaxEditor editor,
SyntaxNode functionDeclaration,
bool returnsVoid,
SyntaxNode? statementToAddAfterOpt,
StatementSyntax statement)
{
var body = GetBody(functionDeclaration);
if (IsExpressionBody(body))
{
var semicolonToken = TryGetSemicolonToken(functionDeclaration) ?? SemicolonToken;
if (!TryConvertExpressionBodyToStatement(body, semicolonToken, !returnsVoid, out var convertedStatement))
{
return;
}
// Add the new statement as the first/last statement of the new block
// depending if we were asked to go after something or not.
editor.SetStatements(functionDeclaration, statementToAddAfterOpt == null
? [statement, convertedStatement]
: [convertedStatement, statement]);
}
else if (body is BlockSyntax block)
{
// Look for the statement we were asked to go after.
var indexToAddAfter = block.Statements.IndexOf(s => s == statementToAddAfterOpt);
if (indexToAddAfter >= 0)
{
// If we find it, then insert the new statement after it.
editor.InsertAfter(block.Statements[indexToAddAfter], statement);
}
else if (block.Statements.Count > 0)
{
// Otherwise, if we have multiple statements already, then insert ourselves
// before the first one.
editor.InsertBefore(block.Statements[0], statement);
}
else
{
// Otherwise, we have no statements in this block. Add the new statement
// as the single statement the block will have.
Debug.Assert(block.Statements.Count == 0);
editor.ReplaceNode(block, (currentBlock, _) => ((BlockSyntax)currentBlock).AddStatements(statement));
}
// If the block was on a single line before, the format it so that the formatting
// engine will update it to go over multiple lines. Otherwise, we can end up in
// the strange state where the { and } tokens stay where they were originally,
// which will look very strange like:
//
// a => {
// if (...) {
// } };
if (CSharpSyntaxFacts.Instance.IsOnSingleLine(block, fullSpan: false))
{
editor.ReplaceNode(
block,
(currentBlock, _) => currentBlock.WithAdditionalAnnotations(Formatter.Annotation));
}
}
else
{
editor.SetStatements(functionDeclaration, ImmutableArray.Create(statement));
}
}
// either from an expression lambda or expression bodied member
public static bool IsExpressionBody(SyntaxNode body)
=> body is ExpressionSyntax or ArrowExpressionClauseSyntax;
public static bool TryConvertExpressionBodyToStatement(
SyntaxNode body,
SyntaxToken semicolonToken,
bool createReturnStatementForExpression,
[NotNullWhen(true)] out StatementSyntax? statement)
{
Debug.Assert(IsExpressionBody(body));
return body switch
{
// If this is a => method, then we'll have to convert the method to have a block body.
ArrowExpressionClauseSyntax arrowClause => arrowClause.TryConvertToStatement(semicolonToken, createReturnStatementForExpression, out statement),
// must be an expression lambda
ExpressionSyntax expression => expression.TryConvertToStatement(semicolonToken, createReturnStatementForExpression, out statement),
_ => throw ExceptionUtilities.UnexpectedValue(body),
};
}
public static SyntaxNode? GetAccessorBody(IMethodSymbol accessor, CancellationToken cancellationToken)
{
var node = accessor.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken);
if (node is AccessorDeclarationSyntax accessorDeclaration)
return accessorDeclaration.ExpressionBody ?? (SyntaxNode?)accessorDeclaration.Body;
// `int Age => ...;`
if (node is ArrowExpressionClauseSyntax arrowExpression)
return arrowExpression;
return null;
}
public static SyntaxNode RemoveThrowNotImplemented(SyntaxNode node)
=> node is PropertyDeclarationSyntax propertyDeclaration ? RemoveThrowNotImplemented(propertyDeclaration) : node;
public static PropertyDeclarationSyntax RemoveThrowNotImplemented(PropertyDeclarationSyntax propertyDeclaration)
{
if (propertyDeclaration.ExpressionBody != null)
{
var result = propertyDeclaration
.WithExpressionBody(null)
.WithSemicolonToken(default)
.AddAccessorListAccessors(SyntaxFactory
.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
.WithSemicolonToken(SemicolonToken))
.WithTrailingTrivia(propertyDeclaration.SemicolonToken.TrailingTrivia)
.WithAdditionalAnnotations(Formatter.Annotation);
return result;
}
if (propertyDeclaration.AccessorList != null)
{
var accessors = propertyDeclaration.AccessorList.Accessors.Select(RemoveThrowNotImplemented);
return propertyDeclaration.WithAccessorList(
propertyDeclaration.AccessorList.WithAccessors([.. accessors]));
}
return propertyDeclaration;
}
private static AccessorDeclarationSyntax RemoveThrowNotImplemented(AccessorDeclarationSyntax accessorDeclaration)
{
var result = accessorDeclaration
.WithExpressionBody(null)
.WithBody(null)
.WithSemicolonToken(SemicolonToken);
return result.WithTrailingTrivia(accessorDeclaration.Body?.GetTrailingTrivia() ?? accessorDeclaration.SemicolonToken.TrailingTrivia);
}
}
|