File: MetaAnalyzers\CompareSymbolsCorrectlyAnalyzer.cs
Web Access
Project: src\src\RoslynAnalyzers\Microsoft.CodeAnalysis.Analyzers\Core\Microsoft.CodeAnalysis.Analyzers.csproj (Microsoft.CodeAnalysis.Analyzers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Analyzer.Utilities;
using Analyzer.Utilities.Extensions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
 
namespace Microsoft.CodeAnalysis.Analyzers.MetaAnalyzers
{
    using static CodeAnalysisDiagnosticsResources;
 
    /// <summary>
    /// RS1024: <inheritdoc cref="CompareSymbolsCorrectlyTitle"/>
    /// </summary>
    [DiagnosticAnalyzer(LanguageNames.CSharp, LanguageNames.VisualBasic)]
    public class CompareSymbolsCorrectlyAnalyzer : DiagnosticAnalyzer
    {
        private static readonly LocalizableString s_localizableTitle = CreateLocalizableResourceString(nameof(CompareSymbolsCorrectlyTitle));
        private static readonly LocalizableString s_localizableMessage = CreateLocalizableResourceString(nameof(CompareSymbolsCorrectlyMessage));
        private static readonly LocalizableString s_localizableDescription = CreateLocalizableResourceString(nameof(CompareSymbolsCorrectlyDescription));
 
        private static readonly string s_symbolTypeFullName = typeof(ISymbol).FullName;
        private const string s_symbolEqualsName = nameof(ISymbol.Equals);
        private const string s_HashCodeCombineName = "Combine";
 
        public const string SymbolEqualityComparerName = "Microsoft.CodeAnalysis.SymbolEqualityComparer";
        public const string RulePropertyName = "Rule";
 
        public const string EqualityRuleName = "EqualityRule";
        public const string GetHashCodeRuleName = "GetHashCodeRule";
        public const string CollectionRuleName = "CollectionRule";
 
        private static readonly DiagnosticDescriptor s_equalityRule = new(
            DiagnosticIds.CompareSymbolsCorrectlyRuleId,
            s_localizableTitle,
            s_localizableMessage,
            DiagnosticCategory.MicrosoftCodeAnalysisCorrectness,
            DiagnosticSeverity.Warning,
            isEnabledByDefault: true,
            description: s_localizableDescription,
            customTags: WellKnownDiagnosticTagsExtensions.Telemetry);
 
        private static readonly DiagnosticDescriptor s_getHashCodeRule = new(
            DiagnosticIds.CompareSymbolsCorrectlyRuleId,
            s_localizableTitle,
            s_localizableMessage,
            DiagnosticCategory.MicrosoftCodeAnalysisCorrectness,
            DiagnosticSeverity.Warning,
            isEnabledByDefault: true,
            description: CreateLocalizableResourceString(nameof(CompareSymbolsCorrectlyDescriptionGetHashCode)),
            customTags: WellKnownDiagnosticTagsExtensions.Telemetry);
 
        private static readonly DiagnosticDescriptor s_collectionRule = new(
            DiagnosticIds.CompareSymbolsCorrectlyRuleId,
            s_localizableTitle,
            s_localizableMessage,
            DiagnosticCategory.MicrosoftCodeAnalysisCorrectness,
            DiagnosticSeverity.Warning,
            isEnabledByDefault: true,
            description: s_localizableDescription,
            customTags: WellKnownDiagnosticTagsExtensions.Telemetry);
 
        private static readonly ImmutableDictionary<string, string?> s_EqualityRuleProperties =
            ImmutableDictionary.CreateRange([new KeyValuePair<string, string?>(RulePropertyName, EqualityRuleName)]);
 
        private static readonly ImmutableDictionary<string, string?> s_GetHashCodeRuleProperties =
            ImmutableDictionary.CreateRange([new KeyValuePair<string, string?>(RulePropertyName, GetHashCodeRuleName)]);
 
        private static readonly ImmutableDictionary<string, string?> s_CollectionRuleProperties =
            ImmutableDictionary.CreateRange([new KeyValuePair<string, string?>(RulePropertyName, CollectionRuleName)]);
 
        public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(s_equalityRule);
 
        public override void Initialize(AnalysisContext context)
        {
            context.EnableConcurrentExecution();
            context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
 
            context.RegisterCompilationStartAction(context =>
            {
                var compilation = context.Compilation;
                var symbolType = compilation.GetOrCreateTypeByMetadataName(s_symbolTypeFullName);
                if (symbolType is null)
                {
                    return;
                }
 
                // Check that the EqualityComparer exists and can be used, otherwise the Roslyn version
                // being used it too low to need the change for method references
                var hasSymbolEqualityComparer = UseSymbolEqualityComparer(compilation);
 
                context.RegisterOperationAction(
                    context => HandleBinaryOperator(in context, symbolType),
                    OperationKind.BinaryOperator);
 
                var equalityComparerMethods = GetEqualityComparerMethodsToCheck(compilation);
                var systemHashCode = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemHashCode);
                var iEqualityComparer = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsGenericIEqualityComparer1);
 
                context.RegisterOperationAction(
                    context => HandleInvocationOperation(in context, symbolType, hasSymbolEqualityComparer, equalityComparerMethods, systemHashCode, iEqualityComparer),
                    OperationKind.Invocation);
 
                if (hasSymbolEqualityComparer && iEqualityComparer is not null)
                {
                    var collectionTypesBuilder = ImmutableHashSet.CreateBuilder<INamedTypeSymbol>(SymbolEqualityComparer.Default);
                    collectionTypesBuilder.AddIfNotNull(compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsGenericDictionary2));
                    collectionTypesBuilder.AddIfNotNull(compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsGenericHashSet1));
                    collectionTypesBuilder.AddIfNotNull(compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsConcurrentConcurrentDictionary2));
 
                    context.RegisterOperationAction(
                        context => HandleObjectCreation(in context, symbolType, iEqualityComparer, collectionTypesBuilder.ToImmutable()),
                        OperationKind.ObjectCreation);
                }
            });
        }
 
        private static void HandleBinaryOperator(in OperationAnalysisContext context, INamedTypeSymbol symbolType)
        {
            var binary = (IBinaryOperation)context.Operation;
            if (binary.OperatorKind is not BinaryOperatorKind.Equals and not BinaryOperatorKind.NotEquals)
            {
                return;
            }
 
            // Allow user-defined operators
            if (binary.OperatorMethod?.ContainingSymbol is INamedTypeSymbol containingType
                && containingType.SpecialType != SpecialType.System_Object)
            {
                return;
            }
 
            // If either operand is 'null' or 'default', do not analyze
            if (binary.LeftOperand.HasNullConstantValue() || binary.RightOperand.HasNullConstantValue())
            {
                return;
            }
 
            if (!IsSymbolType(binary.LeftOperand, symbolType)
                && !IsSymbolType(binary.RightOperand, symbolType))
            {
                return;
            }
 
            if (binary.Language == LanguageNames.VisualBasic
                && (IsSymbolClassType(binary.LeftOperand) || IsSymbolClassType(binary.RightOperand)))
            {
                return;
            }
 
            if (IsExplicitCastToObject(binary.LeftOperand) || IsExplicitCastToObject(binary.RightOperand))
            {
                return;
            }
 
            context.ReportDiagnostic(binary.Syntax.GetLocation().CreateDiagnostic(s_equalityRule, s_EqualityRuleProperties));
        }
 
        private static void HandleInvocationOperation(
            in OperationAnalysisContext context,
            INamedTypeSymbol symbolType,
            bool hasSymbolEqualityComparer,
            ImmutableDictionary<string, ImmutableHashSet<INamedTypeSymbol>> equalityComparerMethods,
            INamedTypeSymbol? systemHashCodeType,
            INamedTypeSymbol? iEqualityComparer)
        {
            var invocationOperation = (IInvocationOperation)context.Operation;
            var method = invocationOperation.TargetMethod;
 
            switch (method.Name)
            {
                case WellKnownMemberNames.ObjectGetHashCode:
                    // This is a call for an instance of ISymbol.GetHashCode()
                    // without the correct arguments
                    if (IsSymbolType(invocationOperation.Instance, symbolType))
                    {
                        context.ReportDiagnostic(invocationOperation.CreateDiagnostic(s_getHashCodeRule, s_GetHashCodeRuleProperties));
                    }
 
                    break;
 
                case s_symbolEqualsName:
                    if (hasSymbolEqualityComparer && IsNotInstanceInvocationOrNotOnSymbol(invocationOperation, symbolType))
                    {
                        var parameters = invocationOperation.Arguments;
                        if (parameters.All(p => IsSymbolType(p.Value, symbolType)))
                        {
                            context.ReportDiagnostic(invocationOperation.Syntax.GetLocation().CreateDiagnostic(s_equalityRule, s_EqualityRuleProperties));
                        }
                    }
 
                    break;
 
                case s_HashCodeCombineName:
                    // A call System.HashCode.Combine(ISymbol) will do the wrong thing and should be avoided
                    if (systemHashCodeType is not null &&
                        invocationOperation.Instance is null &&
                        systemHashCodeType.Equals(method.ContainingType, SymbolEqualityComparer.Default) &&
                        invocationOperation.Arguments.Any(arg => IsSymbolType(arg.Value, symbolType)))
                    {
                        context.ReportDiagnostic(invocationOperation.CreateDiagnostic(s_getHashCodeRule, s_GetHashCodeRuleProperties));
                    }
 
                    break;
 
                default:
                    if (equalityComparerMethods.TryGetValue(method.Name, out var possibleMethodTypes) &&
                        hasSymbolEqualityComparer &&
                        possibleMethodTypes.Contains(method.ContainingType.OriginalDefinition) &&
                        IsBehavingOnSymbolType(method, symbolType) &&
                        !invocationOperation.Arguments.Any(arg => IsSymbolType(arg.Value, iEqualityComparer)))
                    {
                        context.ReportDiagnostic(invocationOperation.CreateDiagnostic(s_collectionRule, s_CollectionRuleProperties));
                    }
 
                    break;
            }
 
            static bool IsNotInstanceInvocationOrNotOnSymbol(IInvocationOperation invocationOperation, INamedTypeSymbol symbolType)
                => invocationOperation.Instance is null || IsSymbolType(invocationOperation.Instance, symbolType);
 
            static bool IsBehavingOnSymbolType(IMethodSymbol? method, INamedTypeSymbol symbolType)
            {
                if (method is null)
                {
                    return false;
                }
                else if (!method.TypeArguments.IsEmpty)
                {
                    var destinationTypeIndex = method.TypeParameters
                        .Select((type, index) => type.Name == "TKey" ? index : -1)
                        .FirstOrDefault(x => x >= 0);
 
                    Debug.Assert(destinationTypeIndex < method.TypeArguments.Length);
 
                    return IsSymbolType(method.TypeArguments[destinationTypeIndex], symbolType);
                }
                else if (method.ReducedFrom != null && !method.ReducedFrom.TypeArguments.IsEmpty)
                {
                    // We are in the case where the ReducedFrom has TypeArguments but the original method doesn't.
                    // This seems to happen only for VB.NET and the only workaround is to force the construction
                    // of the ReducedFrom.
                    return IsBehavingOnSymbolType(method.GetConstructedReducedFrom(), symbolType);
                }
                else
                {
                    return false;
                }
            }
        }
 
        private static void HandleObjectCreation(in OperationAnalysisContext context, INamedTypeSymbol symbolType,
             INamedTypeSymbol iEqualityComparerType, ImmutableHashSet<INamedTypeSymbol> collectionTypes)
        {
            var objectCreation = (IObjectCreationOperation)context.Operation;
 
            if (objectCreation.Type is INamedTypeSymbol createdType &&
                collectionTypes.Contains(createdType.OriginalDefinition) &&
                !createdType.TypeArguments.IsEmpty &&
                IsSymbolType(createdType.TypeArguments[0], symbolType) &&
                !objectCreation.Arguments.Any(arg => IsSymbolType(arg.Value, iEqualityComparerType)))
            {
                context.ReportDiagnostic(objectCreation.CreateDiagnostic(s_collectionRule, s_CollectionRuleProperties));
            }
        }
 
        private static bool IsSymbolType(IOperation? operation, INamedTypeSymbol? symbolType)
        {
            if (operation?.Type is object && IsSymbolType(operation.Type.OriginalDefinition, symbolType))
            {
                return true;
            }
 
            if (operation is IConversionOperation conversion)
            {
                return IsSymbolType(conversion.Operand, symbolType);
            }
 
            return false;
        }
 
        private static bool IsSymbolType(ITypeSymbol typeSymbol, INamedTypeSymbol? symbolType)
            => typeSymbol != null
                && (SymbolEqualityComparer.Default.Equals(typeSymbol, symbolType)
                    || typeSymbol.AllInterfaces.Any(SymbolEqualityComparer.Default.Equals, symbolType));
 
        private static bool IsSymbolClassType(IOperation operation)
        {
            if (operation.Type is object &&
                operation.Type.TypeKind == TypeKind.Class &&
                operation.Type.SpecialType != SpecialType.System_Object)
            {
                return true;
            }
 
            if (operation is IConversionOperation conversion)
            {
                return IsSymbolClassType(conversion.Operand);
            }
 
            return false;
        }
 
        private static bool IsExplicitCastToObject(IOperation operation)
        {
            if (operation is not IConversionOperation conversion)
            {
                return false;
            }
 
            if (conversion.IsImplicit)
            {
                return false;
            }
 
            return conversion.Type?.SpecialType == SpecialType.System_Object;
        }
 
        private static ImmutableDictionary<string, ImmutableHashSet<INamedTypeSymbol>> GetEqualityComparerMethodsToCheck(Compilation compilation)
        {
            var builder = ImmutableDictionary.CreateBuilder<string, ImmutableHashSet<INamedTypeSymbol>.Builder>();
 
            if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsImmutableImmutableHashSet, out var immutableHashSetType))
            {
                AddOrUpdate(nameof(ImmutableHashSet.CreateBuilder), immutableHashSetType);
                AddOrUpdate(nameof(ImmutableHashSet.Create), immutableHashSetType);
                AddOrUpdate(nameof(ImmutableHashSet.CreateRange), immutableHashSetType);
                AddOrUpdate(nameof(ImmutableHashSet.ToImmutableHashSet), immutableHashSetType);
            }
 
            if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsImmutableImmutableDictionary, out var immutableDictionaryType))
            {
                AddOrUpdate(nameof(ImmutableDictionary.CreateBuilder), immutableDictionaryType);
                AddOrUpdate(nameof(ImmutableDictionary.Create), immutableDictionaryType);
                AddOrUpdate(nameof(ImmutableDictionary.CreateRange), immutableDictionaryType);
                AddOrUpdate(nameof(ImmutableDictionary.ToImmutableDictionary), immutableDictionaryType);
            }
 
            if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemLinqEnumerable, out var enumerableType))
            {
                AddOrUpdate(nameof(Enumerable.Contains), enumerableType);
                AddOrUpdate(nameof(Enumerable.Distinct), enumerableType);
                AddOrUpdate(nameof(Enumerable.GroupBy), enumerableType);
                AddOrUpdate(nameof(Enumerable.GroupJoin), enumerableType);
                AddOrUpdate(nameof(Enumerable.Intersect), enumerableType);
                AddOrUpdate(nameof(Enumerable.Join), enumerableType);
                AddOrUpdate(nameof(Enumerable.SequenceEqual), enumerableType);
                AddOrUpdate(nameof(Enumerable.ToDictionary), enumerableType);
                AddOrUpdate("ToHashSet", enumerableType);
                AddOrUpdate(nameof(Enumerable.ToLookup), enumerableType);
                AddOrUpdate(nameof(Enumerable.Union), enumerableType);
            }
 
            return builder.ToImmutableDictionary(kvp => kvp.Key, kvp => kvp.Value.ToImmutable());
 
            void AddOrUpdate(string methodName, INamedTypeSymbol typeSymbol)
            {
                if (!builder.TryGetValue(methodName, out var methodTypeSymbols))
                {
                    methodTypeSymbols = ImmutableHashSet.CreateBuilder<INamedTypeSymbol>(SymbolEqualityComparer.Default);
                    builder.Add(methodName, methodTypeSymbols);
                }
 
                methodTypeSymbols.Add(typeSymbol);
            }
        }
 
        public static bool UseSymbolEqualityComparer(Compilation compilation)
            => compilation.GetOrCreateTypeByMetadataName(SymbolEqualityComparerName) is object;
    }
}