|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Shared.Extensions;
namespace Microsoft.CodeAnalysis.UseNullPropagation;
internal static class UseNullPropagationHelpers
{
public const string IsTrivialNullableValueAccess = nameof(IsTrivialNullableValueAccess);
public const string WhenPartIsNullable = nameof(WhenPartIsNullable);
public static bool IsSystemNullableValueProperty([NotNullWhen(true)] ISymbol? symbol)
=> symbol is
{
Name: nameof(Nullable<>.Value),
ContainingType.OriginalDefinition.SpecialType: SpecialType.System_Nullable_T,
};
}
/// <summary>
/// Looks for code snippets similar to <c>x == null ? null : x.Y()</c> and converts it to <c>x?.Y()</c>. This form is also supported:
/// <code>
/// if (x != null)
/// x.Y();
/// </code>
/// </summary>
internal abstract partial class AbstractUseNullPropagationDiagnosticAnalyzer<
TSyntaxKind,
TExpressionSyntax,
TStatementSyntax,
TConditionalExpressionSyntax,
TBinaryExpressionSyntax,
TInvocationExpressionSyntax,
TConditionalAccessExpressionSyntax,
TElementAccessExpressionSyntax,
TMemberAccessExpressionSyntax,
TIfStatementSyntax,
TExpressionStatementSyntax>() : AbstractBuiltInCodeStyleDiagnosticAnalyzer(
IDEDiagnosticIds.UseNullPropagationDiagnosticId,
EnforceOnBuildValues.UseNullPropagation,
CodeStyleOptions2.PreferNullPropagation,
new LocalizableResourceString(nameof(AnalyzersResources.Use_null_propagation), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)),
new LocalizableResourceString(nameof(AnalyzersResources.Null_check_can_be_simplified), AnalyzersResources.ResourceManager, typeof(AnalyzersResources)))
where TSyntaxKind : struct
where TExpressionSyntax : SyntaxNode
where TStatementSyntax : SyntaxNode
where TConditionalExpressionSyntax : TExpressionSyntax
where TBinaryExpressionSyntax : TExpressionSyntax
where TInvocationExpressionSyntax : TExpressionSyntax
where TConditionalAccessExpressionSyntax : TExpressionSyntax
where TElementAccessExpressionSyntax : TExpressionSyntax
where TMemberAccessExpressionSyntax : TExpressionSyntax
where TIfStatementSyntax : TStatementSyntax
where TExpressionStatementSyntax : TStatementSyntax
{
private static readonly ImmutableDictionary<string, string?> s_whenPartIsNullableProperties =
ImmutableDictionary<string, string?>.Empty.Add(UseNullPropagationHelpers.WhenPartIsNullable, "");
public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
=> DiagnosticAnalyzerCategory.SemanticSpanAnalysis;
protected abstract bool ShouldAnalyze(Compilation compilation);
protected abstract TSyntaxKind IfStatementSyntaxKind { get; }
protected abstract ISemanticFacts SemanticFacts { get; }
protected ISyntaxFacts SyntaxFacts => SemanticFacts.SyntaxFacts;
protected abstract bool TryAnalyzePatternCondition(
ISyntaxFacts syntaxFacts, TExpressionSyntax conditionNode,
[NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck, out bool isEquals);
public (INamedTypeSymbol? expressionType, IMethodSymbol? referenceEqualsMethod) GetAnalysisSymbols(Compilation compilation)
{
var expressionType = compilation.ExpressionOfTType();
var objectType = compilation.GetSpecialType(SpecialType.System_Object);
var referenceEqualsMethod = objectType?.GetMembers(nameof(ReferenceEquals))
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m is { DeclaredAccessibility: Accessibility.Public, Parameters.Length: 2 });
return (expressionType, referenceEqualsMethod);
}
protected override void InitializeWorker(AnalysisContext context)
{
context.RegisterCompilationStartAction(context =>
{
if (!ShouldAnalyze(context.Compilation))
return;
var expressionType = context.Compilation.ExpressionOfTType();
var (objectType, referenceEqualsMethod) = GetAnalysisSymbols(context.Compilation);
var syntaxKinds = this.SyntaxFacts.SyntaxKinds;
context.RegisterSyntaxNodeAction(
context => AnalyzeTernaryConditionalExpressionAndReportDiagnostic(context, expressionType, referenceEqualsMethod),
syntaxKinds.Convert<TSyntaxKind>(syntaxKinds.TernaryConditionalExpression));
context.RegisterSyntaxNodeAction(
context => AnalyzeIfStatementAndReportDiagnostic(context, referenceEqualsMethod),
IfStatementSyntaxKind);
});
}
public (TExpressionSyntax conditionalPart, SyntaxNode whenPart)? GetPartsOfConditionalExpression(
SemanticModel semanticModel,
TConditionalExpressionSyntax conditionalExpression,
CancellationToken cancellationToken)
{
var (objectType, referenceEqualsMethod) = GetAnalysisSymbols(semanticModel.Compilation);
var analysisResult = AnalyzeTernaryConditionalExpression(
semanticModel, objectType, referenceEqualsMethod, conditionalExpression, cancellationToken);
if (analysisResult is null)
return null;
return (analysisResult.Value.ConditionPartToCheck, analysisResult.Value.WhenPartToCheck);
}
private void AnalyzeTernaryConditionalExpressionAndReportDiagnostic(
SyntaxNodeAnalysisContext context,
INamedTypeSymbol? expressionType,
IMethodSymbol? referenceEqualsMethod)
{
var cancellationToken = context.CancellationToken;
var conditionalExpression = (TConditionalExpressionSyntax)context.Node;
var option = context.GetAnalyzerOptions().PreferNullPropagation;
if (!option.Value || ShouldSkipAnalysis(context, option.Notification))
return;
var analysisResult = AnalyzeTernaryConditionalExpression(
context.SemanticModel, expressionType, referenceEqualsMethod, conditionalExpression, cancellationToken);
if (analysisResult is null)
return;
context.ReportDiagnostic(DiagnosticHelper.Create(
Descriptor,
conditionalExpression.GetLocation(),
option.Notification,
context.Options,
additionalLocations: [conditionalExpression.GetLocation()],
analysisResult.Value.Properties));
}
public ConditionalExpressionAnalysisResult? AnalyzeTernaryConditionalExpression(
SemanticModel semanticModel,
INamedTypeSymbol? expressionType,
IMethodSymbol? referenceEqualsMethod,
TConditionalExpressionSyntax conditionalExpression,
CancellationToken cancellationToken)
{
var syntaxFacts = this.SyntaxFacts;
syntaxFacts.GetPartsOfConditionalExpression(
conditionalExpression, out var condition, out var whenTrue, out var whenFalse);
var conditionNode = (TExpressionSyntax)condition;
var whenTrueNode = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(whenTrue);
var whenFalseNode = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(whenFalse);
if (!TryAnalyzeCondition(
semanticModel, referenceEqualsMethod, conditionNode,
out var conditionPartToCheck, out var isEquals, cancellationToken))
{
return null;
}
// Needs to be of the form:
// x == null ? null : ... or
// x != null ? ... : null;
if (isEquals && !syntaxFacts.IsNullLiteralExpression(whenTrueNode))
return null;
if (!isEquals && !syntaxFacts.IsNullLiteralExpression(whenFalseNode))
return null;
var whenPartToCheck = isEquals ? whenFalseNode : whenTrueNode;
var whenPartMatch = GetWhenPartMatch(syntaxFacts, semanticModel, conditionPartToCheck, whenPartToCheck, cancellationToken);
if (whenPartMatch == null)
return null;
// can't use ?. on a pointer
var whenPartType = semanticModel.GetTypeInfo(whenPartMatch, cancellationToken).Type;
if (whenPartType is IPointerTypeSymbol)
return null;
var type = semanticModel.GetTypeInfo(conditionalExpression, cancellationToken).Type;
if (type?.IsValueType == true)
{
if (type is not INamedTypeSymbol namedType || namedType.ConstructedFrom.SpecialType != SpecialType.System_Nullable_T)
{
// User has something like: If(str is nothing, nothing, str.Length)
// In this case, converting to str?.Length changes the type of this from
// int to int?
return null;
}
// But for a nullable type, such as If(c is nothing, nothing, c.nullable)
// converting to c?.nullable doesn't affect the type
}
var isTrivialNullableValueAccess = false;
if (syntaxFacts.IsSimpleMemberAccessExpression(whenPartToCheck))
{
// `x == null ? x : x.M` cannot be converted to `x?.M` when M is a method symbol.
var memberSymbol = semanticModel.GetSymbolInfo(whenPartToCheck, cancellationToken).GetAnySymbol();
if (memberSymbol is IMethodSymbol)
return null;
// we're converting from `x.M` to `x?.M`. This is not legal if 'M' is an unconstrained type parameter as
// the lang/compiler doesn't know what final type to make out of this.
var memberType = semanticModel.GetTypeInfo(whenPartToCheck, cancellationToken).Type;
if (memberType is null or ITypeParameterSymbol
{
IsReferenceType: false,
IsValueType: false,
})
{
return null;
}
// `x == null ? x : x.Value` will be converted to just 'x'.
if (UseNullPropagationHelpers.IsSystemNullableValueProperty(memberSymbol))
isTrivialNullableValueAccess = true;
}
// ?. is not available in expression-trees. Disallow the fix in that case.
if (this.SemanticFacts.IsInExpressionTree(semanticModel, conditionNode, expressionType, cancellationToken))
return null;
var whenPartIsNullable = whenPartType?.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T;
var properties = whenPartIsNullable
? s_whenPartIsNullableProperties
: ImmutableDictionary<string, string?>.Empty;
if (isTrivialNullableValueAccess)
properties = properties.Add(UseNullPropagationHelpers.IsTrivialNullableValueAccess, UseNullPropagationHelpers.IsTrivialNullableValueAccess);
return new(
conditionPartToCheck,
whenPartToCheck,
properties);
}
private bool TryAnalyzeCondition(
SemanticModel semanticModel,
IMethodSymbol? referenceEqualsMethod,
TExpressionSyntax condition,
[NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck,
out bool isEquals,
CancellationToken cancellationToken)
{
var syntaxFacts = this.SyntaxFacts;
condition = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(condition);
var conditionIsNegated = false;
if (syntaxFacts.IsLogicalNotExpression(condition))
{
conditionIsNegated = true;
condition = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(
syntaxFacts.GetOperandOfPrefixUnaryExpression(condition));
}
var result = condition switch
{
TBinaryExpressionSyntax binaryExpression => TryAnalyzeBinaryExpressionCondition(
syntaxFacts, binaryExpression, out conditionPartToCheck, out isEquals),
TInvocationExpressionSyntax invocation => TryAnalyzeInvocationCondition(
semanticModel, syntaxFacts, referenceEqualsMethod, invocation, out conditionPartToCheck, out isEquals, cancellationToken),
_ => TryAnalyzePatternCondition(syntaxFacts, condition, out conditionPartToCheck, out isEquals),
};
if (conditionIsNegated)
isEquals = !isEquals;
return result;
}
private static bool TryAnalyzeBinaryExpressionCondition(
ISyntaxFacts syntaxFacts, TBinaryExpressionSyntax condition,
[NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck, out bool isEquals)
{
var syntaxKinds = syntaxFacts.SyntaxKinds;
isEquals = syntaxKinds.ReferenceEqualsExpression == condition.RawKind;
var isNotEquals = syntaxKinds.ReferenceNotEqualsExpression == condition.RawKind;
if (!isEquals && !isNotEquals)
{
conditionPartToCheck = null;
return false;
}
else
{
syntaxFacts.GetPartsOfBinaryExpression(condition, out var conditionLeft, out var conditionRight);
conditionPartToCheck = GetConditionPartToCheck(syntaxFacts, (TExpressionSyntax)conditionLeft, (TExpressionSyntax)conditionRight);
return conditionPartToCheck != null;
}
}
private static bool TryAnalyzeInvocationCondition(
SemanticModel semanticModel,
ISyntaxFacts syntaxFacts,
IMethodSymbol? referenceEqualsMethod,
TInvocationExpressionSyntax invocation,
[NotNullWhen(true)] out TExpressionSyntax? conditionPartToCheck,
out bool isEquals,
CancellationToken cancellationToken)
{
conditionPartToCheck = null;
isEquals = true;
if (referenceEqualsMethod == null)
return false;
var expression = syntaxFacts.GetExpressionOfInvocationExpression(invocation);
var nameNode = syntaxFacts.IsIdentifierName(expression)
? expression
: syntaxFacts.IsSimpleMemberAccessExpression(expression)
? syntaxFacts.GetNameOfMemberAccessExpression(expression)
: null;
if (!syntaxFacts.IsIdentifierName(nameNode))
{
return false;
}
syntaxFacts.GetNameAndArityOfSimpleName(nameNode, out var name, out _);
if (!syntaxFacts.StringComparer.Equals(name, nameof(ReferenceEquals)))
{
return false;
}
var arguments = syntaxFacts.GetArgumentsOfInvocationExpression(invocation);
if (arguments.Count != 2)
{
return false;
}
var conditionLeft = (TExpressionSyntax)syntaxFacts.GetExpressionOfArgument(arguments[0]);
var conditionRight = (TExpressionSyntax)syntaxFacts.GetExpressionOfArgument(arguments[1]);
if (conditionLeft == null || conditionRight == null)
{
return false;
}
conditionPartToCheck = GetConditionPartToCheck(syntaxFacts, conditionLeft, conditionRight);
if (conditionPartToCheck == null)
{
return false;
}
var symbol = semanticModel.GetSymbolInfo(invocation, cancellationToken).Symbol;
return referenceEqualsMethod.Equals(symbol);
}
private static TExpressionSyntax? GetConditionPartToCheck(
ISyntaxFacts syntaxFacts, TExpressionSyntax conditionLeft, TExpressionSyntax conditionRight)
{
var conditionLeftIsNull = syntaxFacts.IsNullLiteralExpression(conditionLeft);
var conditionRightIsNull = syntaxFacts.IsNullLiteralExpression(conditionRight);
if (conditionRightIsNull && conditionLeftIsNull)
{
// null == null nothing to do here.
return null;
}
if (!conditionRightIsNull && !conditionLeftIsNull)
{
return null;
}
return conditionRightIsNull ? conditionLeft : conditionRight;
}
#pragma warning disable CA1822 // Mark members as static. Helper method that doesn't want to call through generic form.
public TExpressionSyntax? GetWhenPartMatch(
ISyntaxFacts syntaxFacts,
SemanticModel semanticModel,
TExpressionSyntax expressionToMatch,
TExpressionSyntax whenPart,
CancellationToken cancellationToken)
{
expressionToMatch = RemoveObjectCastIfAny(syntaxFacts, semanticModel, expressionToMatch, cancellationToken);
var current = whenPart;
while (true)
{
var unwrapped = Unwrap(syntaxFacts, current);
if (unwrapped == null)
return null;
if (syntaxFacts.IsSimpleMemberAccessExpression(current) || current is TElementAccessExpressionSyntax)
{
if (syntaxFacts.AreEquivalent(unwrapped, expressionToMatch))
return unwrapped;
}
current = unwrapped;
}
}
#pragma warning restore CA1822 // Mark members as static
private static TExpressionSyntax RemoveObjectCastIfAny(
ISyntaxFacts syntaxFacts, SemanticModel semanticModel, TExpressionSyntax node, CancellationToken cancellationToken)
{
if (syntaxFacts.IsCastExpression(node))
{
syntaxFacts.GetPartsOfCastExpression(node, out var type, out var expression);
var typeSymbol = semanticModel.GetTypeInfo(type, cancellationToken).Type;
if (typeSymbol?.SpecialType == SpecialType.System_Object)
return (TExpressionSyntax)expression;
}
return node;
}
private static TExpressionSyntax? Unwrap(ISyntaxFacts syntaxFacts, TExpressionSyntax node)
{
node = (TExpressionSyntax)syntaxFacts.WalkDownParentheses(node);
if (node is TInvocationExpressionSyntax invocation)
return (TExpressionSyntax)syntaxFacts.GetExpressionOfInvocationExpression(invocation);
if (syntaxFacts.IsSimpleMemberAccessExpression(node))
return (TExpressionSyntax?)syntaxFacts.GetExpressionOfMemberAccessExpression(node);
if (node is TConditionalAccessExpressionSyntax conditionalAccess)
return (TExpressionSyntax)syntaxFacts.GetExpressionOfConditionalAccessExpression(conditionalAccess);
if (node is TElementAccessExpressionSyntax elementAccess)
return (TExpressionSyntax?)syntaxFacts.GetExpressionOfElementAccessExpression(elementAccess);
if (syntaxFacts.IsAnyAssignmentStatement(node.Parent) &&
syntaxFacts.SupportsNullConditionalAssignment(node.SyntaxTree.Options))
{
syntaxFacts.GetPartsOfAssignmentExpressionOrStatement(node, out var left, out _, out _);
return (TExpressionSyntax)left;
}
return null;
}
}
|