File: src\RoslynAnalyzers\Microsoft.CodeAnalysis.BannedApiAnalyzers\Core\SymbolIsBannedAnalyzerBase.cs
Web Access
Project: src\src\RoslynAnalyzers\Microsoft.CodeAnalysis.Analyzers\Core\Microsoft.CodeAnalysis.Analyzers.csproj (Microsoft.CodeAnalysis.Analyzers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using Analyzer.Utilities.Extensions;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Text;
 
namespace Microsoft.CodeAnalysis.BannedApiAnalyzers
{
    public abstract class SymbolIsBannedAnalyzerBase<TSyntaxKind> : DiagnosticAnalyzer
        where TSyntaxKind : struct
    {
        protected abstract Dictionary<(string ContainerName, string SymbolName), ImmutableArray<BanFileEntry>>? ReadBannedApis(CompilationStartAnalysisContext compilationContext);
 
        protected abstract DiagnosticDescriptor SymbolIsBannedRule { get; }
 
        protected abstract TSyntaxKind XmlCrefSyntaxKind { get; }
 
        protected abstract SyntaxNode GetReferenceSyntaxNodeFromXmlCref(SyntaxNode syntaxNode);
 
        protected abstract ImmutableArray<TSyntaxKind> BaseTypeSyntaxKinds { get; }
 
        protected abstract IEnumerable<SyntaxNode> GetTypeSyntaxNodesFromBaseType(SyntaxNode syntaxNode);
 
        protected abstract SymbolDisplayFormat SymbolDisplayFormat { get; }
 
        public override void Initialize(AnalysisContext context)
        {
            context.EnableConcurrentExecution();
 
            // Analyzer needs to get callbacks for generated code, and might report diagnostics in generated code.
            context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
 
            context.RegisterCompilationStartAction(OnCompilationStart);
        }
 
        private void OnCompilationStart(CompilationStartAnalysisContext compilationContext)
        {
            var bannedApis = ReadBannedApis(compilationContext);
            if (bannedApis == null || bannedApis.Count == 0)
                return;
 
            if (ShouldAnalyzeAttributes())
            {
                compilationContext.RegisterCompilationEndAction(
                    context =>
                    {
                        VerifyAttributes(context.ReportDiagnostic, compilationContext.Compilation.Assembly.GetAttributes(), context.CancellationToken);
                        VerifyAttributes(context.ReportDiagnostic, compilationContext.Compilation.SourceModule.GetAttributes(), context.CancellationToken);
                    });
 
                compilationContext.RegisterSymbolAction(
                    context => VerifyAttributes(context.ReportDiagnostic, context.Symbol.GetAttributes(), context.CancellationToken),
                    SymbolKind.NamedType,
                    SymbolKind.Method,
                    SymbolKind.Field,
                    SymbolKind.Property,
                    SymbolKind.Event);
            }
 
            compilationContext.RegisterOperationAction(
                context =>
                {
                    context.CancellationToken.ThrowIfCancellationRequested();
                    switch (context.Operation)
                    {
                        case IObjectCreationOperation objectCreation:
                            if (objectCreation.Constructor != null)
                                VerifySymbol(context.ReportDiagnostic, objectCreation.Constructor, context.Operation.Syntax);
                            VerifyType(context.ReportDiagnostic, objectCreation.Type, context.Operation.Syntax);
                            break;
 
                        case IInvocationOperation invocation:
                            VerifySymbol(context.ReportDiagnostic, invocation.TargetMethod, context.Operation.Syntax);
                            VerifyType(context.ReportDiagnostic, invocation.TargetMethod.ContainingType, context.Operation.Syntax);
                            break;
 
                        case IMemberReferenceOperation memberReference:
                            VerifySymbol(context.ReportDiagnostic, memberReference.Member, context.Operation.Syntax);
                            VerifyType(context.ReportDiagnostic, memberReference.Member.ContainingType, context.Operation.Syntax);
                            break;
 
                        case IArrayCreationOperation arrayCreation:
                            VerifyType(context.ReportDiagnostic, arrayCreation.Type, context.Operation.Syntax);
                            break;
 
                        case IAddressOfOperation addressOf:
                            VerifyType(context.ReportDiagnostic, addressOf.Type, context.Operation.Syntax);
                            break;
 
                        case IConversionOperation conversion:
                            if (conversion.OperatorMethod != null)
                            {
                                VerifySymbol(context.ReportDiagnostic, conversion.OperatorMethod, context.Operation.Syntax);
                                VerifyType(context.ReportDiagnostic, conversion.OperatorMethod.ContainingType, context.Operation.Syntax);
                            }
 
                            break;
 
                        case IUnaryOperation unary:
                            if (unary.OperatorMethod != null)
                            {
                                VerifySymbol(context.ReportDiagnostic, unary.OperatorMethod, context.Operation.Syntax);
                                VerifyType(context.ReportDiagnostic, unary.OperatorMethod.ContainingType, context.Operation.Syntax);
                            }
 
                            break;
 
                        case IBinaryOperation binary:
                            if (binary.OperatorMethod != null)
                            {
                                VerifySymbol(context.ReportDiagnostic, binary.OperatorMethod, context.Operation.Syntax);
                                VerifyType(context.ReportDiagnostic, binary.OperatorMethod.ContainingType, context.Operation.Syntax);
                            }
 
                            break;
 
                        case IIncrementOrDecrementOperation incrementOrDecrement:
                            if (incrementOrDecrement.OperatorMethod != null)
                            {
                                VerifySymbol(context.ReportDiagnostic, incrementOrDecrement.OperatorMethod, context.Operation.Syntax);
                                VerifyType(context.ReportDiagnostic, incrementOrDecrement.OperatorMethod.ContainingType, context.Operation.Syntax);
                            }
 
                            break;
                        case ITypeOfOperation typeOfOperation:
                            VerifyType(context.ReportDiagnostic, typeOfOperation.TypeOperand, context.Operation.Syntax);
                            break;
                    }
                },
                OperationKind.ObjectCreation,
                OperationKind.Invocation,
                OperationKind.EventReference,
                OperationKind.FieldReference,
                OperationKind.MethodReference,
                OperationKind.PropertyReference,
                OperationKind.ArrayCreation,
                OperationKind.AddressOf,
                OperationKind.Conversion,
                OperationKind.UnaryOperator,
                OperationKind.BinaryOperator,
                OperationKind.Increment,
                OperationKind.Decrement,
                OperationKind.TypeOf);
 
            compilationContext.RegisterSyntaxNodeAction(
                context => VerifyDocumentationSyntax(context.ReportDiagnostic, GetReferenceSyntaxNodeFromXmlCref(context.Node), context),
                XmlCrefSyntaxKind);
 
            compilationContext.RegisterSyntaxNodeAction(
                context => VerifyBaseTypesSyntax(context.ReportDiagnostic, GetTypeSyntaxNodesFromBaseType(context.Node), context),
                BaseTypeSyntaxKinds);
 
            return;
 
            bool IsBannedSymbol([NotNullWhen(true)] ISymbol? symbol, [NotNullWhen(true)] out BanFileEntry? entry)
            {
                if (symbol is { ContainingSymbol.Name: string parentName } &&
                    bannedApis.TryGetValue((parentName, symbol.Name), out var entries))
                {
                    foreach (var bannedFileEntry in entries)
                    {
                        foreach (var bannedSymbol in bannedFileEntry.Symbols)
                        {
                            if (SymbolEqualityComparer.Default.Equals(symbol, bannedSymbol))
                            {
                                entry = bannedFileEntry;
                                return true;
                            }
                        }
                    }
                }
 
                entry = null;
                return false;
            }
 
            bool ShouldAnalyzeAttributes()
            {
                // We want to avoid realizing symbols here as that can be very expensive.  So we instead use a simple
                // heuristic which works thanks to .net coding conventions.  Specifically, we look to see if the banned
                // api contains a type that ends in 'Attribute'.  In that case, we do the work to try to get the real symbol.
                foreach (var kvp in bannedApis)
                {
                    if (!kvp.Key.SymbolName.EndsWith("Attribute", StringComparison.InvariantCulture) &&
                        !kvp.Key.ContainerName.EndsWith("Attribute", StringComparison.InvariantCulture))
                    {
                        continue;
                    }
 
                    foreach (var entry in kvp.Value)
                    {
                        if (entry.Symbols.Any(ContainsAttributeSymbol))
                        {
                            return true;
                        }
                    }
                }
 
                return false;
            }
 
            bool ContainsAttributeSymbol(ISymbol symbol)
            {
                return symbol switch
                {
                    INamedTypeSymbol namedType => namedType.IsAttribute(),
                    IMethodSymbol method => method.ContainingType.IsAttribute() && method.IsConstructor(),
                    _ => false
                };
            }
 
            void VerifyAttributes(Action<Diagnostic> reportDiagnostic, ImmutableArray<AttributeData> attributes, CancellationToken cancellationToken)
            {
                cancellationToken.ThrowIfCancellationRequested();
                foreach (var attribute in attributes)
                {
                    if (IsBannedSymbol(attribute.AttributeClass, out var entry))
                    {
                        var node = attribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken);
                        if (node != null)
                        {
                            reportDiagnostic(
                                node.CreateDiagnostic(
                                    SymbolIsBannedRule,
                                    attribute.AttributeClass.ToDisplayString(),
                                    string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
                        }
                    }
 
                    if (attribute.AttributeConstructor != null)
                    {
                        var syntaxNode = attribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken);
 
                        if (syntaxNode != null)
                        {
                            VerifySymbol(reportDiagnostic, attribute.AttributeConstructor, syntaxNode);
                        }
                    }
                }
            }
 
            bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode)
            {
                do
                {
                    if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type))
                    {
                        return false;
                    }
 
                    if (type == null)
                    {
                        // Type will be null for arrays and pointers.
                        return true;
                    }
 
                    if (IsBannedSymbol(type, out var entry))
                    {
                        reportDiagnostic(
                            syntaxNode.CreateDiagnostic(
                                SymbolIsBannedRule,
                                type.ToDisplayString(SymbolDisplayFormat),
                                string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
                        return false;
                    }
 
                    foreach (var currentNamespace in GetContainingNamespaces(type))
                    {
                        if (IsBannedSymbol(currentNamespace, out entry))
                        {
                            reportDiagnostic(
                                syntaxNode.CreateDiagnostic(
                                    SymbolIsBannedRule,
                                    currentNamespace.ToDisplayString(),
                                    string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
                            return false;
                        }
                    }
 
                    type = type.ContainingType;
                }
                while (!(type is null));
 
                return true;
 
                static IEnumerable<INamespaceSymbol> GetContainingNamespaces(ISymbol symbol)
                {
                    INamespaceSymbol? currentNamespace = symbol.ContainingNamespace;
 
                    while (currentNamespace is { IsGlobalNamespace: false })
                    {
                        foreach (var constituent in currentNamespace.ConstituentNamespaces)
                            yield return constituent;
 
                        currentNamespace = currentNamespace.ContainingNamespace;
                    }
                }
            }
 
            bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition)
            {
                switch (type)
                {
                    case INamedTypeSymbol namedTypeSymbol:
                        originalDefinition = namedTypeSymbol.ConstructedFrom;
                        foreach (var typeArgument in namedTypeSymbol.TypeArguments)
                        {
                            if (typeArgument.TypeKind != TypeKind.TypeParameter &&
                                typeArgument.TypeKind != TypeKind.Error &&
                                !VerifyType(reportDiagnostic, typeArgument, syntaxNode))
                            {
                                return false;
                            }
                        }
 
                        break;
 
                    case IArrayTypeSymbol arrayTypeSymbol:
                        originalDefinition = null;
                        return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode);
 
                    case IPointerTypeSymbol pointerTypeSymbol:
                        originalDefinition = null;
                        return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode);
 
                    default:
                        originalDefinition = type?.OriginalDefinition;
                        break;
 
                }
 
                return true;
            }
 
            void VerifySymbol(Action<Diagnostic> reportDiagnostic, ISymbol symbol, SyntaxNode syntaxNode)
            {
                foreach (var currentSymbol in GetSymbolAndOverridenSymbols(symbol))
                {
                    if (IsBannedSymbol(currentSymbol, out var entry))
                    {
                        reportDiagnostic(
                            syntaxNode.CreateDiagnostic(
                                SymbolIsBannedRule,
                                currentSymbol.ToDisplayString(SymbolDisplayFormat),
                                string.IsNullOrWhiteSpace(entry.Message) ? "" : ": " + entry.Message));
                        return;
                    }
                }
 
                static IEnumerable<ISymbol> GetSymbolAndOverridenSymbols(ISymbol symbol)
                {
                    ISymbol? currentSymbol = symbol.OriginalDefinition;
 
                    while (currentSymbol != null)
                    {
                        yield return currentSymbol;
 
                        // It's possible to have `IsOverride` true and yet have `GetOverriddeMember` returning null when the code is invalid
                        // (e.g. base symbol is not marked as `virtual` or `abstract` and current symbol has the `overrides` modifier).
                        currentSymbol = currentSymbol.IsOverride
                            ? currentSymbol.GetOverriddenMember()?.OriginalDefinition
                            : null;
                    }
                }
            }
 
            void VerifyDocumentationSyntax(Action<Diagnostic> reportDiagnostic, SyntaxNode syntaxNode, SyntaxNodeAnalysisContext context)
            {
                var symbol = context.SemanticModel.GetSymbolInfo(syntaxNode, context.CancellationToken).Symbol;
 
                if (symbol is ITypeSymbol typeSymbol)
                {
                    VerifyType(reportDiagnostic, typeSymbol, syntaxNode);
                }
                else if (symbol != null)
                {
                    VerifySymbol(reportDiagnostic, symbol, syntaxNode);
                }
            }
 
            void VerifyBaseTypesSyntax(Action<Diagnostic> reportDiagnostic, IEnumerable<SyntaxNode> typeSyntaxNodes, SyntaxNodeAnalysisContext context)
            {
                foreach (var typeSyntaxNode in typeSyntaxNodes)
                {
                    var symbol = context.SemanticModel.GetSymbolInfo(typeSyntaxNode, context.CancellationToken).Symbol;
 
                    if (symbol is ITypeSymbol typeSymbol)
                    {
                        VerifyType(reportDiagnostic, typeSymbol, typeSyntaxNode);
                    }
                }
            }
        }
 
        protected sealed class BanFileEntry
        {
            public TextSpan Span { get; }
            public SourceText SourceText { get; }
            public string Path { get; }
            public string DeclarationId { get; }
            public string Message { get; }
 
            private readonly Lazy<ImmutableArray<ISymbol>> _lazySymbols;
            public ImmutableArray<ISymbol> Symbols => _lazySymbols.Value;
 
            public BanFileEntry(Compilation compilation, string text, TextSpan span, SourceText sourceText, string path)
            {
                // Split the text on semicolon into declaration ID and message
                var index = text.IndexOf(';');
 
                if (index == -1)
                {
                    DeclarationId = text.Trim();
                    Message = "";
                }
                else if (index == text.Length - 1)
                {
                    DeclarationId = text[0..^1].Trim();
                    Message = "";
                }
                else
                {
                    DeclarationId = text[..index].Trim();
                    Message = text[(index + 1)..].Trim();
                }
 
                Span = span;
                SourceText = sourceText;
                Path = path;
 
                _lazySymbols = new Lazy<ImmutableArray<ISymbol>>(
                    () => DocumentationCommentId.GetSymbolsForDeclarationId(DeclarationId, compilation)
                        .SelectMany(ExpandConstituentNamespaces).ToImmutableArray());
 
                static IEnumerable<ISymbol> ExpandConstituentNamespaces(ISymbol symbol)
                {
                    if (symbol is not INamespaceSymbol namespaceSymbol)
                    {
                        yield return symbol;
                        yield break;
                    }
 
                    foreach (var constituent in namespaceSymbol.ConstituentNamespaces)
                        yield return constituent;
                }
            }
 
            public Location Location => Location.Create(Path, Span, SourceText.Lines.GetLinePositionSpan(Span));
        }
    }
}