|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using Analyzer.Utilities.Extensions;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Text;
namespace Microsoft.CodeAnalysis.BannedApiAnalyzers
{
public abstract class SymbolIsBannedAnalyzerBase<TSyntaxKind> : DiagnosticAnalyzer
where TSyntaxKind : struct
{
protected abstract Dictionary<(string ContainerName, string SymbolName), ImmutableArray<BanFileEntry>>? ReadBannedApis(CompilationStartAnalysisContext compilationContext);
protected abstract DiagnosticDescriptor SymbolIsBannedRule { get; }
protected abstract TSyntaxKind XmlCrefSyntaxKind { get; }
protected abstract SyntaxNode GetReferenceSyntaxNodeFromXmlCref(SyntaxNode syntaxNode);
protected abstract ImmutableArray<TSyntaxKind> BaseTypeSyntaxKinds { get; }
protected abstract IEnumerable<SyntaxNode> GetTypeSyntaxNodesFromBaseType(SyntaxNode syntaxNode);
protected abstract SymbolDisplayFormat SymbolDisplayFormat { get; }
public override void Initialize(AnalysisContext context)
{
context.EnableConcurrentExecution();
// Analyzer needs to get callbacks for generated code, and might report diagnostics in generated code.
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
context.RegisterCompilationStartAction(OnCompilationStart);
}
private void OnCompilationStart(CompilationStartAnalysisContext compilationContext)
{
var bannedApis = ReadBannedApis(compilationContext);
if (bannedApis == null || bannedApis.Count == 0)
return;
if (ShouldAnalyzeAttributes())
{
compilationContext.RegisterCompilationEndAction(
context =>
{
VerifyAttributes(context.ReportDiagnostic, compilationContext.Compilation.Assembly.GetAttributes(), context.CancellationToken);
VerifyAttributes(context.ReportDiagnostic, compilationContext.Compilation.SourceModule.GetAttributes(), context.CancellationToken);
});
compilationContext.RegisterSymbolAction(
context => VerifyAttributes(context.ReportDiagnostic, context.Symbol.GetAttributes(), context.CancellationToken),
SymbolKind.NamedType,
SymbolKind.Method,
SymbolKind.Field,
SymbolKind.Property,
SymbolKind.Event);
}
compilationContext.RegisterOperationAction(
context =>
{
context.CancellationToken.ThrowIfCancellationRequested();
switch (context.Operation)
{
case IObjectCreationOperation objectCreation:
if (objectCreation.Constructor != null)
VerifySymbol(context.ReportDiagnostic, objectCreation.Constructor, context.Operation.Syntax);
VerifyType(context.ReportDiagnostic, objectCreation.Type, context.Operation.Syntax);
break;
case IInvocationOperation invocation:
VerifySymbol(context.ReportDiagnostic, invocation.TargetMethod, context.Operation.Syntax);
VerifyType(context.ReportDiagnostic, invocation.TargetMethod.ContainingType, context.Operation.Syntax);
break;
case IMemberReferenceOperation memberReference:
VerifySymbol(context.ReportDiagnostic, memberReference.Member, context.Operation.Syntax);
VerifyType(context.ReportDiagnostic, memberReference.Member.ContainingType, context.Operation.Syntax);
break;
case IArrayCreationOperation arrayCreation:
VerifyType(context.ReportDiagnostic, arrayCreation.Type, context.Operation.Syntax);
break;
case IAddressOfOperation addressOf:
VerifyType(context.ReportDiagnostic, addressOf.Type, context.Operation.Syntax);
break;
case IConversionOperation conversion:
if (conversion.OperatorMethod != null)
{
VerifySymbol(context.ReportDiagnostic, conversion.OperatorMethod, context.Operation.Syntax);
VerifyType(context.ReportDiagnostic, conversion.OperatorMethod.ContainingType, context.Operation.Syntax);
}
break;
case IUnaryOperation unary:
if (unary.OperatorMethod != null)
{
VerifySymbol(context.ReportDiagnostic, unary.OperatorMethod, context.Operation.Syntax);
VerifyType(context.ReportDiagnostic, unary.OperatorMethod.ContainingType, context.Operation.Syntax);
}
break;
case IBinaryOperation binary:
if (binary.OperatorMethod != null)
{
VerifySymbol(context.ReportDiagnostic, binary.OperatorMethod, context.Operation.Syntax);
VerifyType(context.ReportDiagnostic, binary.OperatorMethod.ContainingType, context.Operation.Syntax);
}
break;
case IIncrementOrDecrementOperation incrementOrDecrement:
if (incrementOrDecrement.OperatorMethod != null)
{
VerifySymbol(context.ReportDiagnostic, incrementOrDecrement.OperatorMethod, context.Operation.Syntax);
VerifyType(context.ReportDiagnostic, incrementOrDecrement.OperatorMethod.ContainingType, context.Operation.Syntax);
}
break;
case ITypeOfOperation typeOfOperation:
VerifyType(context.ReportDiagnostic, typeOfOperation.TypeOperand, context.Operation.Syntax);
break;
}
},
OperationKind.ObjectCreation,
OperationKind.Invocation,
OperationKind.EventReference,
OperationKind.FieldReference,
OperationKind.MethodReference,
OperationKind.PropertyReference,
OperationKind.ArrayCreation,
OperationKind.AddressOf,
OperationKind.Conversion,
OperationKind.UnaryOperator,
OperationKind.BinaryOperator,
OperationKind.Increment,
OperationKind.Decrement,
OperationKind.TypeOf);
compilationContext.RegisterSyntaxNodeAction(
context => VerifyDocumentationSyntax(context.ReportDiagnostic, GetReferenceSyntaxNodeFromXmlCref(context.Node), context),
XmlCrefSyntaxKind);
compilationContext.RegisterSyntaxNodeAction(
context => VerifyBaseTypesSyntax(context.ReportDiagnostic, GetTypeSyntaxNodesFromBaseType(context.Node), context),
BaseTypeSyntaxKinds);
return;
bool IsBannedSymbol([NotNullWhen(true)] ISymbol? symbol, [NotNullWhen(true)] out BanFileEntry? entry)
{
if (symbol is { ContainingSymbol.Name: string parentName } &&
bannedApis.TryGetValue((parentName, symbol.Name), out var entries))
{
foreach (var bannedFileEntry in entries)
{
foreach (var bannedSymbol in bannedFileEntry.Symbols)
{
if (SymbolEqualityComparer.Default.Equals(symbol, bannedSymbol))
{
entry = bannedFileEntry;
return true;
}
}
}
}
entry = null;
return false;
}
bool ShouldAnalyzeAttributes()
{
// We want to avoid realizing symbols here as that can be very expensive. So we instead use a simple
// heuristic which works thanks to .net coding conventions. Specifically, we look to see if the banned
// api contains a type that ends in 'Attribute'. In that case, we do the work to try to get the real symbol.
foreach (var kvp in bannedApis)
{
if (!kvp.Key.SymbolName.EndsWith("Attribute", StringComparison.InvariantCulture) &&
!kvp.Key.ContainerName.EndsWith("Attribute", StringComparison.InvariantCulture))
{
continue;
}
foreach (var entry in kvp.Value)
{
if (entry.Symbols.Any(ContainsAttributeSymbol))
{
return true;
}
}
}
return false;
}
bool ContainsAttributeSymbol(ISymbol symbol)
{
return symbol switch
{
INamedTypeSymbol namedType => namedType.IsAttribute(),
IMethodSymbol method => method.ContainingType.IsAttribute() && method.IsConstructor(),
_ => false
};
}
void VerifyAttributes(Action<Diagnostic> reportDiagnostic, ImmutableArray<AttributeData> attributes, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
foreach (var attribute in attributes)
{
if (IsBannedSymbol(attribute.AttributeClass, out var entry))
{
var node = attribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken);
if (node != null)
{
reportDiagnostic(
node.CreateDiagnostic(
SymbolIsBannedRule,
attribute.AttributeClass.ToDisplayString(),
string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
}
}
if (attribute.AttributeConstructor != null)
{
var syntaxNode = attribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken);
if (syntaxNode != null)
{
VerifySymbol(reportDiagnostic, attribute.AttributeConstructor, syntaxNode);
}
}
}
}
bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode)
{
do
{
if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type))
{
return false;
}
if (type == null)
{
// Type will be null for arrays and pointers.
return true;
}
if (IsBannedSymbol(type, out var entry))
{
reportDiagnostic(
syntaxNode.CreateDiagnostic(
SymbolIsBannedRule,
type.ToDisplayString(SymbolDisplayFormat),
string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
return false;
}
foreach (var currentNamespace in GetContainingNamespaces(type))
{
if (IsBannedSymbol(currentNamespace, out entry))
{
reportDiagnostic(
syntaxNode.CreateDiagnostic(
SymbolIsBannedRule,
currentNamespace.ToDisplayString(),
string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
return false;
}
}
type = type.ContainingType;
}
while (!(type is null));
return true;
static IEnumerable<INamespaceSymbol> GetContainingNamespaces(ISymbol symbol)
{
INamespaceSymbol? currentNamespace = symbol.ContainingNamespace;
while (currentNamespace is { IsGlobalNamespace: false })
{
foreach (var constituent in currentNamespace.ConstituentNamespaces)
yield return constituent;
currentNamespace = currentNamespace.ContainingNamespace;
}
}
}
bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition)
{
switch (type)
{
case INamedTypeSymbol namedTypeSymbol:
originalDefinition = namedTypeSymbol.ConstructedFrom;
foreach (var typeArgument in namedTypeSymbol.TypeArguments)
{
if (typeArgument.TypeKind != TypeKind.TypeParameter &&
typeArgument.TypeKind != TypeKind.Error &&
!VerifyType(reportDiagnostic, typeArgument, syntaxNode))
{
return false;
}
}
break;
case IArrayTypeSymbol arrayTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode);
case IPointerTypeSymbol pointerTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode);
default:
originalDefinition = type?.OriginalDefinition;
break;
}
return true;
}
void VerifySymbol(Action<Diagnostic> reportDiagnostic, ISymbol symbol, SyntaxNode syntaxNode)
{
foreach (var currentSymbol in GetSymbolAndOverridenSymbols(symbol))
{
if (IsBannedSymbol(currentSymbol, out var entry))
{
reportDiagnostic(
syntaxNode.CreateDiagnostic(
SymbolIsBannedRule,
currentSymbol.ToDisplayString(SymbolDisplayFormat),
string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
return;
}
}
static IEnumerable<ISymbol> GetSymbolAndOverridenSymbols(ISymbol symbol)
{
ISymbol? currentSymbol = symbol.OriginalDefinition;
while (currentSymbol != null)
{
yield return currentSymbol;
// It's possible to have `IsOverride` true and yet have `GetOverriddeMember` returning null when the code is invalid
// (e.g. base symbol is not marked as `virtual` or `abstract` and current symbol has the `overrides` modifier).
currentSymbol = currentSymbol.IsOverride
? currentSymbol.GetOverriddenMember()?.OriginalDefinition
: null;
}
}
}
void VerifyDocumentationSyntax(Action<Diagnostic> reportDiagnostic, SyntaxNode syntaxNode, SyntaxNodeAnalysisContext context)
{
var symbol = context.SemanticModel.GetSymbolInfo(syntaxNode, context.CancellationToken).Symbol;
if (symbol is ITypeSymbol typeSymbol)
{
VerifyType(reportDiagnostic, typeSymbol, syntaxNode);
}
else if (symbol != null)
{
VerifySymbol(reportDiagnostic, symbol, syntaxNode);
}
}
void VerifyBaseTypesSyntax(Action<Diagnostic> reportDiagnostic, IEnumerable<SyntaxNode> typeSyntaxNodes, SyntaxNodeAnalysisContext context)
{
foreach (var typeSyntaxNode in typeSyntaxNodes)
{
var symbol = context.SemanticModel.GetSymbolInfo(typeSyntaxNode, context.CancellationToken).Symbol;
if (symbol is ITypeSymbol typeSymbol)
{
VerifyType(reportDiagnostic, typeSymbol, typeSyntaxNode);
}
}
}
}
protected sealed class BanFileEntry
{
public TextSpan Span { get; }
public SourceText SourceText { get; }
public string Path { get; }
public string DeclarationId { get; }
public string Message { get; }
private readonly Lazy<ImmutableArray<ISymbol>> _lazySymbols;
public ImmutableArray<ISymbol> Symbols => _lazySymbols.Value;
public BanFileEntry(Compilation compilation, string text, TextSpan span, SourceText sourceText, string path)
{
// Split the text on semicolon into declaration ID and message
var index = text.IndexOf(';');
if (index == -1)
{
DeclarationId = text.Trim();
Message = "";
}
else if (index == text.Length - 1)
{
DeclarationId = text[0..^1].Trim();
Message = "";
}
else
{
DeclarationId = text[..index].Trim();
Message = text[(index + 1)..].Trim();
}
Span = span;
SourceText = sourceText;
Path = path;
_lazySymbols = new Lazy<ImmutableArray<ISymbol>>(
() => DocumentationCommentId.GetSymbolsForDeclarationId(DeclarationId, compilation)
.SelectMany(ExpandConstituentNamespaces).ToImmutableArray());
static IEnumerable<ISymbol> ExpandConstituentNamespaces(ISymbol symbol)
{
if (symbol is not INamespaceSymbol namespaceSymbol)
{
yield return symbol;
yield break;
}
foreach (var constituent in namespaceSymbol.ConstituentNamespaces)
yield return constituent;
}
}
public Location Location => Location.Create(Path, Span, SourceText.Lines.GetLinePositionSpan(Span));
}
}
}
|