|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.AddParameter;
internal abstract class AbstractAddParameterCodeFixProvider<
TArgumentSyntax,
TAttributeArgumentSyntax,
TArgumentListSyntax,
TAttributeArgumentListSyntax,
TExpressionSyntax,
TInvocationExpressionSyntax,
TObjectCreationExpressionSyntax> : CodeFixProvider
where TArgumentSyntax : SyntaxNode
where TArgumentListSyntax : SyntaxNode
where TAttributeArgumentListSyntax : SyntaxNode
where TExpressionSyntax : SyntaxNode
where TInvocationExpressionSyntax : TExpressionSyntax
where TObjectCreationExpressionSyntax : TExpressionSyntax
{
private static readonly SymbolDisplayFormat SimpleFormat = new(
typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameOnly,
genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters,
parameterOptions: SymbolDisplayParameterOptions.IncludeParamsRefOut | SymbolDisplayParameterOptions.IncludeType,
miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes);
protected abstract ImmutableArray<string> TooManyArgumentsDiagnosticIds { get; }
protected abstract ImmutableArray<string> CannotConvertDiagnosticIds { get; }
protected abstract ITypeSymbol GetArgumentType(SyntaxNode argumentNode, SemanticModel semanticModel, CancellationToken cancellationToken);
protected abstract Argument<TExpressionSyntax> GetArgument(TArgumentSyntax argument);
public override FixAllProvider? GetFixAllProvider()
{
// Fix All is not supported for this code fix.
return null;
}
protected virtual RegisterFixData<TArgumentSyntax>? TryGetLanguageSpecificFixInfo(
SemanticModel semanticModel,
SyntaxNode node,
CancellationToken cancellationToken)
=> null;
public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var cancellationToken = context.CancellationToken;
var diagnostic = context.Diagnostics.First();
var document = context.Document;
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var initialNode = root.FindNode(diagnostic.Location.SourceSpan);
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
for (var node = initialNode; node != null; node = node.Parent)
{
var fixData =
TryGetInvocationExpressionFixInfo(semanticModel, syntaxFacts, node, cancellationToken) ??
TryGetObjectCreationFixInfo(semanticModel, syntaxFacts, node, cancellationToken) ??
TryGetLanguageSpecificFixInfo(semanticModel, node, cancellationToken);
if (fixData != null)
{
var candidates = fixData.MethodCandidates;
if (fixData.IsConstructorInitializer)
{
// The invocation is a :this() or :base() call. In the 'this' case we need to exclude the
// method with the diagnostic because otherwise we might introduce a call to itself (which is forbidden).
if (semanticModel.GetEnclosingSymbol(node.SpanStart, cancellationToken) is IMethodSymbol methodWithDiagnostic)
{
candidates = candidates.Remove(methodWithDiagnostic);
}
}
var argumentOpt = TryGetRelevantArgument(initialNode, node, diagnostic);
var argumentInsertPositionInMethodCandidates = GetArgumentInsertPositionForMethodCandidates(
argumentOpt, semanticModel, syntaxFacts, fixData.Arguments, candidates);
RegisterFixForMethodOverloads(context, fixData.Arguments, argumentInsertPositionInMethodCandidates);
return;
}
}
}
/// <summary>
/// If the diagnostic is on a argument, the argument is considered to be the argument to fix.
/// There are some exceptions to this rule. Returning null indicates that the fixer needs
/// to find the relevant argument by itself.
/// </summary>
private TArgumentSyntax? TryGetRelevantArgument(
SyntaxNode initialNode, SyntaxNode node, Diagnostic diagnostic)
{
if (TooManyArgumentsDiagnosticIds.Contains(diagnostic.Id))
{
return null;
}
if (CannotConvertDiagnosticIds.Contains(diagnostic.Id))
{
return null;
}
return initialNode.GetAncestorsOrThis<TArgumentSyntax>()
.LastOrDefault(a => a.AncestorsAndSelf().Contains(node));
}
private static RegisterFixData<TArgumentSyntax>? TryGetInvocationExpressionFixInfo(
SemanticModel semanticModel,
ISyntaxFactsService syntaxFacts,
SyntaxNode node,
CancellationToken cancellationToken)
{
if (node is TInvocationExpressionSyntax invocationExpression)
{
var expression = syntaxFacts.GetExpressionOfInvocationExpression(invocationExpression);
var candidates = semanticModel.GetMemberGroup(expression, cancellationToken).OfType<IMethodSymbol>().ToImmutableArray();
var arguments = (SeparatedSyntaxList<TArgumentSyntax>)syntaxFacts.GetArgumentsOfInvocationExpression(invocationExpression);
// In VB a constructor calls other constructor overloads via a Me.New(..) invocation.
// If the candidates are MethodKind.Constructor than these are the equivalent the a C# ConstructorInitializer.
var isConstructorInitializer = candidates.All(m => m.MethodKind == MethodKind.Constructor);
return new RegisterFixData<TArgumentSyntax>(arguments, candidates, isConstructorInitializer);
}
return null;
}
private static RegisterFixData<TArgumentSyntax>? TryGetObjectCreationFixInfo(
SemanticModel semanticModel,
ISyntaxFactsService syntaxFacts,
SyntaxNode node,
CancellationToken cancellationToken)
{
if (node is TObjectCreationExpressionSyntax objectCreation)
{
// Not supported if this is "new { ... }" (as there are no parameters at all.
var typeNode = syntaxFacts.IsImplicitObjectCreationExpression(node)
? node
: syntaxFacts.GetTypeOfObjectCreationExpression(objectCreation);
if (typeNode == null)
{
return new RegisterFixData<TArgumentSyntax>();
}
var symbol = semanticModel.GetSymbolInfo(typeNode, cancellationToken).GetAnySymbol();
var type = symbol switch
{
IMethodSymbol methodSymbol => methodSymbol.ContainingType, // Implicit object creation expressions
INamedTypeSymbol namedTypeSymbol => namedTypeSymbol, // Standard object creation expressions
_ => null,
};
// If we can't figure out the type being created, or the type isn't in source,
// then there's nothing we can do.
if (type == null)
{
return new RegisterFixData<TArgumentSyntax>();
}
if (!type.IsNonImplicitAndFromSource())
{
return new RegisterFixData<TArgumentSyntax>();
}
var arguments = (SeparatedSyntaxList<TArgumentSyntax>)syntaxFacts.GetArgumentsOfObjectCreationExpression(objectCreation);
var methodCandidates = type.InstanceConstructors;
return new RegisterFixData<TArgumentSyntax>(arguments, methodCandidates, isConstructorInitializer: false);
}
return null;
}
private static ImmutableArray<ArgumentInsertPositionData<TArgumentSyntax>> GetArgumentInsertPositionForMethodCandidates(
TArgumentSyntax? argumentOpt,
SemanticModel semanticModel,
ISyntaxFactsService syntaxFacts,
SeparatedSyntaxList<TArgumentSyntax> arguments,
ImmutableArray<IMethodSymbol> methodCandidates)
{
var comparer = syntaxFacts.StringComparer;
using var _ = ArrayBuilder<ArgumentInsertPositionData<TArgumentSyntax>>.GetInstance(out var methodsAndArgumentToAdd);
foreach (var method in methodCandidates.OrderBy(m => m.Parameters.Length))
{
if (method.IsNonImplicitAndFromSource())
{
var isNamedArgument = !string.IsNullOrWhiteSpace(syntaxFacts.GetNameForArgument(argumentOpt));
if (isNamedArgument || NonParamsParameterCount(method) < arguments.Count)
{
var argumentToAdd = DetermineFirstArgumentToAdd(
semanticModel, syntaxFacts, comparer, method,
arguments);
if (argumentToAdd != null)
{
if (argumentOpt != null && argumentToAdd != argumentOpt)
{
// We were trying to fix a specific argument, but the argument we want
// to fix is something different. That means there was an error earlier
// than this argument. Which means we're looking at a non-viable
// constructor or method. Skip this one.
continue;
}
methodsAndArgumentToAdd.Add(new ArgumentInsertPositionData<TArgumentSyntax>(
method, argumentToAdd, arguments.IndexOf(argumentToAdd)));
}
}
}
}
return methodsAndArgumentToAdd.ToImmutableAndClear();
}
private static int NonParamsParameterCount(IMethodSymbol method)
=> method.IsParams() ? method.Parameters.Length - 1 : method.Parameters.Length;
private void RegisterFixForMethodOverloads(
CodeFixContext context,
SeparatedSyntaxList<TArgumentSyntax> arguments,
ImmutableArray<ArgumentInsertPositionData<TArgumentSyntax>> methodsAndArgumentsToAdd)
{
var codeFixData = PrepareCreationOfCodeActions(context.Document, arguments, methodsAndArgumentsToAdd);
// To keep the list of offered fixes short we create one menu entry per overload only
// as long as there are two or less overloads present. If there are more overloads we
// create two menu entries. One entry for non-cascading fixes and one with cascading fixes.
var fixes = codeFixData.Length <= 2
? NestByOverload()
: NestByCascading();
context.RegisterFixes(fixes, context.Diagnostics);
return;
ImmutableArray<CodeAction> NestByOverload()
{
var builder = new FixedSizeArrayBuilder<CodeAction>(codeFixData.Length);
foreach (var data in codeFixData)
{
// We create the mandatory data.CreateChangedSolutionNonCascading fix first.
var title = GetCodeFixTitle(CodeFixesResources.Add_parameter_to_0, data.Method, includeParameters: true);
var codeAction = CodeAction.Create(
title,
data.CreateChangedSolutionNonCascading,
equivalenceKey: title);
if (data.CreateChangedSolutionCascading != null)
{
// We have two fixes to offer. We nest the two fixes in an inlinable CodeAction
// so the IDE is free to either show both at once or to create a sub-menu.
var titleForNesting = GetCodeFixTitle(CodeFixesResources.Add_parameter_to_0, data.Method, includeParameters: true);
var titleCascading = GetCodeFixTitle(CodeFixesResources.Add_parameter_to_0_and_overrides_implementations, data.Method,
includeParameters: true);
codeAction = CodeAction.Create(
title: titleForNesting,
[
codeAction,
CodeAction.Create(
titleCascading,
data.CreateChangedSolutionCascading,
equivalenceKey: titleCascading),
],
isInlinable: true);
}
// codeAction is now either a single fix or two fixes wrapped in a CodeActionWithNestedActions
builder.Add(codeAction);
}
return builder.MoveToImmutable();
}
ImmutableArray<CodeAction> NestByCascading()
{
using var builder = TemporaryArray<CodeAction>.Empty;
var nonCascadingActions = codeFixData.SelectAsArray(data =>
{
var title = GetCodeFixTitle(CodeFixesResources.Add_to_0, data.Method, includeParameters: true);
return CodeAction.Create(title, data.CreateChangedSolutionNonCascading, equivalenceKey: title);
});
var cascadingActions = codeFixData.SelectAsArray(
data => data.CreateChangedSolutionCascading != null,
data =>
{
var title = GetCodeFixTitle(CodeFixesResources.Add_to_0, data.Method, includeParameters: true);
return CodeAction.Create(title, data.CreateChangedSolutionCascading!, equivalenceKey: title);
});
var aMethod = codeFixData.First().Method; // We need to term the MethodGroup and need an arbitrary IMethodSymbol to do so.
var nestedNonCascadingTitle = GetCodeFixTitle(CodeFixesResources.Add_parameter_to_0, aMethod, includeParameters: false);
// Create a sub-menu entry with all the non-cascading CodeActions.
// We make sure the IDE does not inline. Otherwise the context menu gets flooded with our fixes.
builder.Add(CodeAction.Create(nestedNonCascadingTitle, nonCascadingActions, isInlinable: false));
if (cascadingActions.Length > 0)
{
// if there are cascading CodeActions create a second sub-menu.
var nestedCascadingTitle = GetCodeFixTitle(CodeFixesResources.Add_parameter_to_0_and_overrides_implementations,
aMethod, includeParameters: false);
builder.Add(CodeAction.Create(nestedCascadingTitle, cascadingActions, isInlinable: false));
}
return builder.ToImmutableAndClear();
}
}
private ImmutableArray<CodeFixData> PrepareCreationOfCodeActions(
Document document,
SeparatedSyntaxList<TArgumentSyntax> arguments,
ImmutableArray<ArgumentInsertPositionData<TArgumentSyntax>> methodsAndArgumentsToAdd)
{
var builder = new FixedSizeArrayBuilder<CodeFixData>(methodsAndArgumentsToAdd.Length);
// Order by the furthest argument index to the nearest argument index. The ones with
// larger argument indexes mean that we matched more earlier arguments (and thus are
// likely to be the correct match).
foreach (var argumentInsertPositionData in methodsAndArgumentsToAdd.OrderByDescending(t => t.ArgumentInsertionIndex))
{
var methodToUpdate = argumentInsertPositionData.MethodToUpdate;
var argumentToInsert = argumentInsertPositionData.ArgumentToInsert;
var cascadingFix = AddParameterService.HasCascadingDeclarations(methodToUpdate)
? new Func<CancellationToken, Task<Solution>>(cancellationToken => FixAsync(document, methodToUpdate, argumentToInsert, arguments, fixAllReferences: true, cancellationToken))
: null;
builder.Add(new CodeFixData(
methodToUpdate,
cancellationToken => FixAsync(document, methodToUpdate, argumentToInsert, arguments, fixAllReferences: false, cancellationToken),
cascadingFix));
}
return builder.MoveToImmutable();
}
private static string GetCodeFixTitle(string resourceString, IMethodSymbol methodToUpdate, bool includeParameters)
{
var methodDisplay = methodToUpdate.ToDisplayString(new SymbolDisplayFormat(
typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameOnly,
extensionMethodStyle: SymbolDisplayExtensionMethodStyle.StaticMethod,
parameterOptions: SymbolDisplayParameterOptions.None,
memberOptions: SymbolDisplayMemberOptions.None));
var parameters = methodToUpdate.Parameters.Select(p => p.ToDisplayString(SimpleFormat));
var signature = includeParameters
? $"{methodDisplay}({string.Join(", ", parameters)})"
: methodDisplay;
var title = string.Format(resourceString, signature);
return title;
}
private async Task<Solution> FixAsync(
Document invocationDocument,
IMethodSymbol method,
TArgumentSyntax argument,
SeparatedSyntaxList<TArgumentSyntax> argumentList,
bool fixAllReferences,
CancellationToken cancellationToken)
{
var (argumentType, refKind) = await GetArgumentTypeAndRefKindAsync(invocationDocument, argument, cancellationToken).ConfigureAwait(false);
// The argumentNameSuggestion is the base for the parameter name.
// For each method declaration the name is made unique to avoid name collisions.
var (argumentNameSuggestion, isNamedArgument) = await GetNameSuggestionForArgumentAsync(
invocationDocument, argument, method.ContainingType, cancellationToken).ConfigureAwait(false);
var newParameterIndex = isNamedArgument ? (int?)null : argumentList.IndexOf(argument);
return await AddParameterService.AddParameterAsync<TExpressionSyntax>(
invocationDocument,
method,
argumentType,
refKind,
new ParameterName(argumentNameSuggestion, isNamedArgument, tryMakeCamelCase: !method.ContainingType.IsRecord),
GetArgument(argument),
newParameterIndex,
fixAllReferences,
cancellationToken).ConfigureAwait(false);
}
private async Task<(ITypeSymbol, RefKind)> GetArgumentTypeAndRefKindAsync(Document invocationDocument, TArgumentSyntax argument, CancellationToken cancellationToken)
{
var syntaxFacts = invocationDocument.GetRequiredLanguageService<ISyntaxFactsService>();
var semanticModel = await invocationDocument.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var argumentType = GetArgumentType(argument, semanticModel, cancellationToken);
var refKind = syntaxFacts.GetRefKindOfArgument(argument);
return (argumentType, refKind);
}
private static async Task<(string argumentNameSuggestion, bool isNamed)> GetNameSuggestionForArgumentAsync(
Document invocationDocument, TArgumentSyntax argument, INamedTypeSymbol containingType, CancellationToken cancellationToken)
{
var syntaxFacts = invocationDocument.GetRequiredLanguageService<ISyntaxFactsService>();
var argumentName = syntaxFacts.GetNameForArgument(argument);
if (!string.IsNullOrWhiteSpace(argumentName))
{
return (argumentNameSuggestion: argumentName, isNamed: true);
}
else
{
var semanticModel = await invocationDocument.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var expression = syntaxFacts.GetExpressionOfArgument(argument);
var semanticFacts = invocationDocument.GetRequiredLanguageService<ISemanticFactsService>();
argumentName = semanticFacts.GenerateNameForExpression(
semanticModel, expression, capitalize: containingType.IsRecord, cancellationToken: cancellationToken);
return (argumentNameSuggestion: argumentName, isNamed: false);
}
}
private static TArgumentSyntax? DetermineFirstArgumentToAdd(
SemanticModel semanticModel,
ISyntaxFactsService syntaxFacts,
StringComparer comparer,
IMethodSymbol method,
SeparatedSyntaxList<TArgumentSyntax> arguments)
{
var compilation = semanticModel.Compilation;
var methodParameterNames = new HashSet<string>(comparer);
methodParameterNames.AddRange(method.Parameters.Select(p => p.Name));
for (int i = 0, n = arguments.Count; i < n; i++)
{
var argument = arguments[i];
var argumentName = syntaxFacts.GetNameForArgument(argument);
if (!string.IsNullOrWhiteSpace(argumentName))
{
// If the user provided an argument-name and we don't have any parameters that
// match, then this is the argument we want to add a parameter for.
if (!methodParameterNames.Contains(argumentName))
{
return argument;
}
}
else
{
// Positional argument. If the position is beyond what the method supports,
// then this definitely is an argument we could add.
if (i >= method.Parameters.Length)
{
if (method.Parameters.LastOrDefault()?.IsParams == true)
{
// Last parameter is a params. We can't place any parameters past it.
return null;
}
return argument;
}
// Now check the type of the argument versus the type of the parameter. If they
// don't match, then this is the argument we should make the parameter for.
var expressionOfArgument = syntaxFacts.GetExpressionOfArgument(argument);
if (expressionOfArgument is null)
{
return null;
}
var argumentTypeInfo = semanticModel.GetTypeInfo(expressionOfArgument);
var isNullLiteral = syntaxFacts.IsNullLiteralExpression(expressionOfArgument);
var isDefaultLiteral = syntaxFacts.IsDefaultLiteralExpression(expressionOfArgument);
if (argumentTypeInfo.Type == null && argumentTypeInfo.ConvertedType == null)
{
// Didn't know the type of the argument. We shouldn't assume it doesn't
// match a parameter. However, if the user wrote 'null' and it didn't
// match anything, then this is the problem argument.
if (!isNullLiteral && !isDefaultLiteral)
{
continue;
}
}
var parameter = method.Parameters[i];
if (!TypeInfoMatchesType(
compilation, argumentTypeInfo, parameter.Type,
isNullLiteral, isDefaultLiteral))
{
if (TypeInfoMatchesWithParamsExpansion(
compilation, argumentTypeInfo, parameter,
isNullLiteral, isDefaultLiteral))
{
// The argument matched if we expanded out the params-parameter.
// As the params-parameter has to be last, there's nothing else to
// do here.
return null;
}
return argument;
}
}
}
return null;
}
private static bool TypeInfoMatchesWithParamsExpansion(
Compilation compilation, TypeInfo argumentTypeInfo, IParameterSymbol parameter,
bool isNullLiteral, bool isDefaultLiteral)
{
if (parameter.IsParams && parameter.Type is IArrayTypeSymbol arrayType)
{
if (TypeInfoMatchesType(
compilation, argumentTypeInfo, arrayType.ElementType,
isNullLiteral, isDefaultLiteral))
{
return true;
}
}
return false;
}
private static bool TypeInfoMatchesType(
Compilation compilation, TypeInfo argumentTypeInfo, ITypeSymbol parameterType,
bool isNullLiteral, bool isDefaultLiteral)
{
if (parameterType.Equals(argumentTypeInfo.Type) || parameterType.Equals(argumentTypeInfo.ConvertedType))
return true;
if (isDefaultLiteral)
return true;
if (isNullLiteral)
return parameterType.IsReferenceType || parameterType.IsNullable();
// Overload resolution couldn't resolve the actual type of the type parameter. We assume
// that the type parameter can be the argument's type (ignoring any type parameter constraints).
if (parameterType.Kind == SymbolKind.TypeParameter)
return true;
// If there's an implicit conversion from the arg type to the param type then
// count this as a match. This happens commonly with cases like:
//
// `Goo(derivedType)`
// `void Goo(BaseType baseType)`.
//
// We want this simple case to match.
if (argumentTypeInfo.Type != null)
{
var conversion = compilation.ClassifyCommonConversion(argumentTypeInfo.Type, parameterType);
if (conversion.IsImplicit)
return true;
}
return false;
}
}
|