|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics.Analyzers.NamingStyles;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.Indentation;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Simplification;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.ConvertPrimaryToRegularConstructor;
using static CSharpSyntaxTokens;
using static SyntaxFactory;
[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = PredefinedCodeRefactoringProviderNames.ConvertPrimaryToRegularConstructor), Shared]
[method: ImportingConstructor]
[method: Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
internal sealed partial class ConvertPrimaryToRegularConstructorCodeRefactoringProvider()
: CodeRefactoringProvider
{
public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
{
var (document, span, cancellationToken) = context;
var typeDeclaration = await context.TryGetRelevantNodeAsync<TypeDeclarationSyntax>().ConfigureAwait(false);
if (typeDeclaration?.ParameterList is null)
return;
// Converting a record to a non-primary-constructor form is a lot more work (for example, having to synthesize a
// Deconstruct method, and figure out how to specify properties, etc.). We can consider adding support for that
// scenario later if desired.
if (typeDeclaration is RecordDeclarationSyntax)
return;
var triggerSpan = TextSpan.FromBounds(typeDeclaration.SpanStart, typeDeclaration.ParameterList.FullSpan.End);
if (!triggerSpan.Contains(span))
return;
context.RegisterRefactoring(CodeAction.Create(
CSharpFeaturesResources.Convert_to_regular_constructor,
cancellationToken => ConvertAsync(document, typeDeclaration, typeDeclaration.ParameterList, cancellationToken),
nameof(CSharpFeaturesResources.Convert_to_regular_constructor)),
triggerSpan);
}
private static async Task<Solution> ConvertAsync(
Document document,
TypeDeclarationSyntax typeDeclaration,
ParameterListSyntax parameterList,
CancellationToken cancellationToken)
{
var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
var semanticModels = new ConcurrentSet<SemanticModel>();
var semanticModel = await GetSemanticModelAsync(document).ConfigureAwait(false);
var contextInfo = await document.GetCodeGenerationInfoAsync(CodeGenerationContext.Default, cancellationToken).ConfigureAwait(false);
var formattingOptions = await document.GetSyntaxFormattingOptionsAsync(cancellationToken).ConfigureAwait(false);
// The naming rule we need to follow if we synthesize new private fields.
var fieldNameRule = await document.GetApplicableNamingRuleAsync(
new SymbolSpecification.SymbolKindOrTypeKind(SymbolKind.Field),
DeclarationModifiers.None,
Accessibility.Private,
cancellationToken).ConfigureAwait(false);
// Get the named type and all its parameters for use during the rewrite.
var namedType = semanticModel.GetRequiredDeclaredSymbol(typeDeclaration, cancellationToken);
var parameters = parameterList.Parameters.SelectAsArray(p => semanticModel.GetRequiredDeclaredSymbol(p, cancellationToken));
// We may have to update multiple files (in the case of a partial type). Use a solution-editor to make that
// simple. We will insert the regular constructor into the partial part containing the primary constructor.
var solution = document.Project.Solution;
var solutionEditor = new SolutionEditor(solution);
var mainDocumentEditor = await solutionEditor.GetDocumentEditorAsync(document.Id, cancellationToken).ConfigureAwait(false);
var baseType = typeDeclaration.BaseList?.Types is [PrimaryConstructorBaseTypeSyntax type, ..] ? type : null;
var methodTargetingAttributes = typeDeclaration.AttributeLists.Where(list => list.Target?.Identifier.ValueText == "method");
// Find the references to all the parameters. This will help us determine how they're used and what change we
// may need to make.
var parameterReferences = await GetParameterReferencesAsync().ConfigureAwait(false);
// Determine the fields we'll need to synthesize for each parameter.
var parameterToSynthesizedFields = CreateSynthesizedFields();
// Find any field/properties whose initializer references a primary constructor parameter. These initializers
// will have to move inside the constructor we generate.
var initializedFieldsAndProperties = await GetExistingAssignedFieldsOrPropertiesAsync().ConfigureAwait(false);
var constructorAnnotation = new SyntaxAnnotation();
// Now go do the entire transformation.
RemovePrimaryConstructorParameterList();
RemovePrimaryConstructorBaseTypeArgumentList();
RemovePrimaryConstructorTargetingAttributes();
RemoveDirectFieldAndPropertyAssignments();
AddNewFields();
AddConstructorDeclaration();
await RewritePrimaryConstructorParameterReferencesAsync().ConfigureAwait(false);
FixParameterAndBaseArgumentListIndentation();
return solutionEditor.GetChangedSolution();
async ValueTask<SemanticModel> GetSemanticModelAsync(Document document)
{
// Ensure that if we get a semantic model for another document this named type is contained in, that we only
// produce that semantic model once.
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
semanticModels.Add(semanticModel);
return semanticModel;
}
async Task<MultiDictionary<IParameterSymbol, IdentifierNameSyntax>> GetParameterReferencesAsync()
{
var result = new MultiDictionary<IParameterSymbol, IdentifierNameSyntax>();
var documentsToSearch = namedType.DeclaringSyntaxReferences
.Select(r => r.SyntaxTree)
.Distinct()
.Select(solution.GetRequiredDocument)
.ToImmutableHashSet();
foreach (var parameter in parameters)
{
var references = await SymbolFinder.FindReferencesAsync(
parameter, solution, documentsToSearch, cancellationToken).ConfigureAwait(false);
foreach (var reference in references)
{
// We may hit a location multiple times due to how we do FAR for linked symbols, but each linked symbol
// is allowed to report the entire set of references it think it is compatible with. So ensure we're
// hitting each location only once.
//
// Note Use DistinctBy (.Net6) once available.
foreach (var referenceLocation in reference.Locations.Distinct(LinkedFileReferenceLocationEqualityComparer.Instance))
{
if (referenceLocation.IsImplicit)
continue;
if (referenceLocation.Location.FindNode(findInsideTrivia: true, getInnermostNodeForTie: true, cancellationToken) is not IdentifierNameSyntax identifierName)
continue;
// Explicitly ignore references in the base-type-list. These don't need to be rewritten as
// they will still reference the parameter in the new constructor when we make the `:
// base(...)` initializer.
if (identifierName.GetAncestor<PrimaryConstructorBaseTypeSyntax>() != null)
continue;
result.Add(parameter, identifierName);
}
}
}
return result;
}
ImmutableDictionary<IParameterSymbol, IFieldSymbol> CreateSynthesizedFields()
{
using var _1 = PooledDictionary<Location, IFieldSymbol>.GetInstance(out var locationToField);
using var _2 = PooledDictionary<IParameterSymbol, IFieldSymbol>.GetInstance(out var result);
// Compiler already knows which primary constructor parameters ended up becoming fields. So just defer to it. We'll
// create real fields for all these cases.
foreach (var member in namedType.GetMembers())
{
if (member is IFieldSymbol { IsImplicitlyDeclared: true, Locations: [var location, ..] } field)
locationToField[location] = field;
}
foreach (var parameter in parameters)
{
if (parameter.Locations is [var location, ..] &&
locationToField.TryGetValue(location, out var existingField))
{
var baseFieldName = fieldNameRule.NamingStyle.MakeCompliant(parameter.Name).First();
var fieldName = NameGenerator.GenerateUniqueName(baseFieldName, n => namedType.Name != n && !namedType.GetMembers(n).Any());
var isWrittenTo = parameterReferences[parameter].Any(r => r.IsWrittenTo(semanticModel, cancellationToken));
var synthesizedField = CodeGenerationSymbolFactory.CreateFieldSymbol(
existingField,
modifiers: isWrittenTo ? existingField.GetSymbolModifiers() : existingField.GetSymbolModifiers().WithIsReadOnly(true),
name: fieldName);
result.Add(parameter, synthesizedField);
}
}
return result.ToImmutableDictionary();
}
async Task<ImmutableHashSet<(ISymbol fieldOrProperty, EqualsValueClauseSyntax initializer)>> GetExistingAssignedFieldsOrPropertiesAsync()
{
using var _1 = PooledHashSet<EqualsValueClauseSyntax>.GetInstance(out var initializers);
foreach (var (parameter, references) in parameterReferences)
{
foreach (var reference in references)
{
var initializer = reference.AncestorsAndSelf().OfType<EqualsValueClauseSyntax>().LastOrDefault();
if (initializer is null)
continue;
if (initializer.Parent is not PropertyDeclarationSyntax and not VariableDeclaratorSyntax { Parent: VariableDeclarationSyntax { Parent: FieldDeclarationSyntax } })
continue;
initializers.Add(initializer);
}
}
using var _2 = PooledHashSet<(ISymbol fieldOrProperty, EqualsValueClauseSyntax initializer)>.GetInstance(out var result);
foreach (var grouping in initializers.GroupBy(kvp => kvp.Value.SyntaxTree))
{
var syntaxTree = grouping.Key;
var semanticModel = await GetSemanticModelAsync(solution.GetRequiredDocument(syntaxTree)).ConfigureAwait(false);
foreach (var initializer in grouping)
{
var fieldOrProperty = semanticModel.GetRequiredDeclaredSymbol(initializer.GetRequiredParent(), cancellationToken);
result.Add((fieldOrProperty, initializer));
}
}
return [.. result];
}
void RemovePrimaryConstructorParameterList()
{
mainDocumentEditor.RemoveNode(parameterList);
}
void RemovePrimaryConstructorBaseTypeArgumentList()
{
if (baseType != null)
mainDocumentEditor.ReplaceNode(baseType, (current, _) => SimpleBaseType(((PrimaryConstructorBaseTypeSyntax)current).Type).WithTriviaFrom(baseType));
}
void RemovePrimaryConstructorTargetingAttributes()
{
// Remove all the attributes from the type decl that we're moving to the constructor.
foreach (var attributeList in methodTargetingAttributes)
mainDocumentEditor.RemoveNode(attributeList);
}
void RemoveDirectFieldAndPropertyAssignments()
{
// Remove all the initializers from existing fields/props the params are assigned to.
foreach (var (_, initializer) in initializedFieldsAndProperties)
{
if (initializer.Parent is PropertyDeclarationSyntax propertyDeclaration)
{
mainDocumentEditor.ReplaceNode(
propertyDeclaration,
propertyDeclaration
.WithInitializer(null)
.WithSemicolonToken(default)
.WithTrailingTrivia(propertyDeclaration.GetTrailingTrivia()));
}
else if (initializer.Parent is VariableDeclaratorSyntax)
{
mainDocumentEditor.RemoveNode(initializer);
}
else
{
throw ExceptionUtilities.Unreachable();
}
}
}
void AddNewFields()
{
mainDocumentEditor.ReplaceNode(
typeDeclaration,
(current, _) =>
{
var currentTypeDeclaration = (TypeDeclarationSyntax)current;
var fieldsInOrder = parameters
.Select(p => parameterToSynthesizedFields.TryGetValue(p, out var field) ? field : null)
.WhereNotNull();
var codeGenService = document.GetRequiredLanguageService<ICodeGenerationService>();
return codeGenService.AddMembers(
currentTypeDeclaration, fieldsInOrder, contextInfo, cancellationToken);
});
}
void AddConstructorDeclaration()
{
mainDocumentEditor.ReplaceNode(
typeDeclaration,
(current, _) =>
{
// Move any <param> tags from the type decl to the constructor decl.
var currentTypeDeclaration = (TypeDeclarationSyntax)current;
currentTypeDeclaration = RemoveParamXmlElements(currentTypeDeclaration);
var constructorDeclaration = CreateConstructorDeclaration().WithAdditionalAnnotations(constructorAnnotation);
// If there is an existing non-static constructor, place it before that
var firstConstructorIndex = currentTypeDeclaration.Members.IndexOf(m => m is ConstructorDeclarationSyntax c && !c.Modifiers.Any(SyntaxKind.StaticKeyword));
if (firstConstructorIndex >= 0)
{
return currentTypeDeclaration.WithMembers(
currentTypeDeclaration.Members.Insert(firstConstructorIndex, constructorDeclaration));
}
// No constructors. Place after any fields if present, or any properties if there are no fields.
var lastFieldOrProperty = currentTypeDeclaration.Members.LastIndexOf(m => m is FieldDeclarationSyntax);
if (lastFieldOrProperty < 0)
lastFieldOrProperty = currentTypeDeclaration.Members.LastIndexOf(m => m is PropertyDeclarationSyntax);
if (lastFieldOrProperty >= 0)
{
constructorDeclaration = constructorDeclaration
.WithPrependedLeadingTrivia(ElasticCarriageReturnLineFeed);
return currentTypeDeclaration.WithMembers(
currentTypeDeclaration.Members.Insert(lastFieldOrProperty + 1, constructorDeclaration));
}
// Nothing at all. Just place the construct at the top of the type.
return currentTypeDeclaration.WithMembers(
currentTypeDeclaration.Members.Insert(0, constructorDeclaration));
});
}
async Task RewritePrimaryConstructorParameterReferencesAsync()
{
foreach (var (parameter, references) in parameterReferences)
{
// Only have to update references if we're synthesizing a field for this parameter.
if (!parameterToSynthesizedFields.TryGetValue(parameter, out var field))
continue;
var fieldName = field.Name.ToIdentifierName();
foreach (var grouping in references.GroupBy(r => r.SyntaxTree))
{
var syntaxTree = grouping.Key;
var editor = await solutionEditor.GetDocumentEditorAsync(solution.GetDocumentId(syntaxTree), cancellationToken).ConfigureAwait(false);
foreach (var identifierName in grouping)
{
var xmlElement = identifierName.AncestorsAndSelf().OfType<XmlEmptyElementSyntax>().FirstOrDefault();
if (xmlElement is { Name.LocalName.ValueText: "paramref" })
{
var seeTag = xmlElement
.ReplaceToken(xmlElement.Name.LocalName, Identifier("see").WithTriviaFrom(xmlElement.Name.LocalName))
.WithAttributes([XmlCrefAttribute(TypeCref(fieldName))]);
editor.ReplaceNode(xmlElement, seeTag);
}
else
{
editor.ReplaceNode(identifierName, fieldName.WithTriviaFrom(identifierName));
}
}
}
}
}
void FixParameterAndBaseArgumentListIndentation()
{
var currentRoot = mainDocumentEditor.GetChangedRoot();
var indentationOptions = new IndentationOptions(formattingOptions);
var formattedRoot = Formatter.Format(currentRoot, SyntaxAnnotation.ElasticAnnotation, solution.Services, formattingOptions, cancellationToken);
var constructor = (ConstructorDeclarationSyntax)formattedRoot.GetAnnotatedNodes(constructorAnnotation).Single();
var rewrittenParameterList = AddElementIndentation(typeDeclaration, constructor, constructor.ParameterList, static list => list.Parameters);
var initializer = constructor.Initializer;
var rewrittenInitializer = initializer?.WithArgumentList(AddElementIndentation(typeDeclaration, constructor, initializer.ArgumentList, static list => list.Arguments));
var rewrittenConstructor = constructor
.WithParameterList(rewrittenParameterList)
.WithInitializer(rewrittenInitializer);
var rewrittenRoot = formattedRoot.ReplaceNode(constructor, rewrittenConstructor);
mainDocumentEditor.ReplaceNode(mainDocumentEditor.OriginalRoot, rewrittenRoot);
}
static TListSyntax AddElementIndentation<TListSyntax>(
TypeDeclarationSyntax typeDeclaration,
ConstructorDeclarationSyntax constructorDeclaration,
TListSyntax list,
Func<TListSyntax, IEnumerable<SyntaxNode>> getElements)
where TListSyntax : SyntaxNode
{
// Since we're moving parameters from the constructor to the type, attempt to dedent them if appropriate.
var typeLeadingWhitespace = GetLeadingWhitespace(typeDeclaration);
var constructorLeadingWhitespace = GetLeadingWhitespace(constructorDeclaration);
if (constructorLeadingWhitespace.Length > typeLeadingWhitespace.Length &&
constructorLeadingWhitespace.StartsWith(typeLeadingWhitespace))
{
var indentation = constructorLeadingWhitespace[typeLeadingWhitespace.Length..];
return list.ReplaceNodes(
getElements(list),
(p, _) =>
{
var elementLeadingWhitespace = GetLeadingWhitespace(p);
if (elementLeadingWhitespace.Length > 0 && elementLeadingWhitespace.StartsWith(typeLeadingWhitespace))
{
var leadingTrivia = p.GetLeadingTrivia();
return p.WithLeadingTrivia(
leadingTrivia.Concat(Whitespace(indentation)));
}
return p;
});
}
return list;
}
static string GetLeadingWhitespace(SyntaxNode node)
=> node.GetLeadingTrivia() is [.., (kind: SyntaxKind.WhitespaceTrivia) whitespace] ? whitespace.ToString() : "";
ConstructorDeclarationSyntax CreateConstructorDeclaration()
{
using var _1 = ArrayBuilder<StatementSyntax>.GetInstance(out var assignmentStatements);
// First, if we're making a real field for a primary constructor parameter, assign the parameter to it.
foreach (var parameter in parameters)
{
if (!parameterToSynthesizedFields.TryGetValue(parameter, out var field))
continue;
var fieldName = field.Name.ToIdentifierName();
var left = parameter.Name == field.Name
? MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), fieldName)
: (ExpressionSyntax)fieldName;
var assignment = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, left, parameter.Name.ToIdentifierName());
assignmentStatements.Add(ExpressionStatement(assignment));
}
// Next, actually assign to all the fields/properties that were previously referencing any primary
// constructor parameters.
foreach (var (fieldOrProperty, initializer) in initializedFieldsAndProperties.OrderBy(i => i.initializer.SpanStart))
{
var left = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), fieldOrProperty.Name.ToIdentifierName())
.WithAdditionalAnnotations(Simplifier.Annotation);
var assignment = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, left, initializer.EqualsToken, initializer.Value);
assignmentStatements.Add(ExpressionStatement(assignment));
}
var rewrittenParameters = parameterList.ReplaceNodes(
parameterList.Parameters,
(parameter, _) => RewriteNestedReferences(parameter));
var constructorDeclaration = ConstructorDeclaration(
[.. methodTargetingAttributes.Select(a => a.WithTarget(null).WithoutTrivia().WithAdditionalAnnotations(Formatter.Annotation))],
[PublicKeyword.WithAppendedTrailingTrivia(Space)],
typeDeclaration.Identifier.WithoutTrivia(),
rewrittenParameters.WithoutTrivia(),
baseType?.ArgumentList is null ? null : ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, baseType.ArgumentList),
Block(assignmentStatements));
return WithTypeDeclarationParamDocComments(typeDeclaration, constructorDeclaration);
}
TNode RewriteNestedReferences<TNode>(TNode parent) where TNode : SyntaxNode
{
return parent.ReplaceNodes(
parent.DescendantNodes().Where(n => n is MemberAccessExpressionSyntax or QualifiedNameSyntax),
(node, _) =>
{
if (node is MemberAccessExpressionSyntax memberAccessExpression &&
namedType.Equals(semanticModel.GetSymbolInfo(memberAccessExpression.Expression).Symbol))
{
return memberAccessExpression.Name.WithTriviaFrom(node);
}
else if (node is QualifiedNameSyntax qualifiedName &&
namedType.Equals(semanticModel.GetSymbolInfo(qualifiedName.Left).Symbol))
{
return qualifiedName.Right.WithTriviaFrom(node);
}
return node;
});
}
}
}
|