File: GenerateEqualsAndGetHashCodeFromMembers\AbstractGenerateEqualsAndGetHashCodeService.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.GenerateEqualsAndGetHashCodeFromMembers;
 
internal abstract partial class AbstractGenerateEqualsAndGetHashCodeService : IGenerateEqualsAndGetHashCodeService
{
    private const string GetHashCodeName = nameof(object.GetHashCode);
    private static readonly SyntaxAnnotation s_specializedFormattingAnnotation = new();
 
    protected abstract bool TryWrapWithUnchecked(
        ImmutableArray<SyntaxNode> statements, out ImmutableArray<SyntaxNode> wrappedStatements);
 
    public async Task<Document> FormatDocumentAsync(Document document, SyntaxFormattingOptions options, CancellationToken cancellationToken)
    {
        var formatBinaryRule = new FormatLargeBinaryExpressionRule(document.GetRequiredLanguageService<ISyntaxFactsService>());
        var formattedDocument = await Formatter.FormatAsync(
            document, s_specializedFormattingAnnotation,
            options,
            [formatBinaryRule, .. Formatter.GetDefaultFormattingRules(document)],
            cancellationToken).ConfigureAwait(false);
        return formattedDocument;
    }
 
    public async Task<IMethodSymbol> GenerateEqualsMethodAsync(
        Document document, INamedTypeSymbol namedType, ImmutableArray<ISymbol> members,
        string? localNameOpt, CancellationToken cancellationToken)
    {
        var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
        var tree = await document.GetRequiredSyntaxTreeAsync(cancellationToken).ConfigureAwait(false);
        var generator = document.GetLanguageService<SyntaxGenerator>();
        var generatorInternal = document.GetLanguageService<SyntaxGeneratorInternal>();
        return generator.CreateEqualsMethod(
            generatorInternal, compilation, tree.Options, namedType, members, localNameOpt, s_specializedFormattingAnnotation);
    }
 
    public async Task<IMethodSymbol> GenerateIEquatableEqualsMethodAsync(
        Document document, INamedTypeSymbol namedType,
        ImmutableArray<ISymbol> members, INamedTypeSymbol constructedEquatableType, CancellationToken cancellationToken)
    {
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var generator = document.GetLanguageService<SyntaxGenerator>();
        var generatorInternal = document.GetLanguageService<SyntaxGeneratorInternal>();
        return generator.CreateIEquatableEqualsMethod(
            generatorInternal, semanticModel, namedType, members, constructedEquatableType, s_specializedFormattingAnnotation);
    }
 
    public async Task<IMethodSymbol> GenerateEqualsMethodThroughIEquatableEqualsAsync(
        Document document, INamedTypeSymbol containingType, CancellationToken cancellationToken)
    {
        var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
        var tree = await document.GetRequiredSyntaxTreeAsync(cancellationToken).ConfigureAwait(false);
        var generator = document.GetRequiredLanguageService<SyntaxGenerator>();
 
        using var _ = ArrayBuilder<SyntaxNode>.GetInstance(out var expressions);
        var objName = generator.IdentifierName("obj");
        if (containingType.IsValueType)
        {
            if (generator.SyntaxGeneratorInternal.SupportsPatterns(tree.Options))
            {
                // return obj is T t && this.Equals(t);
                var localName = containingType.GetLocalName();
 
                expressions.Add(
                    generator.SyntaxGeneratorInternal.IsPatternExpression(objName,
                        generator.SyntaxGeneratorInternal.DeclarationPattern(containingType, localName)));
                expressions.Add(
                    generator.InvocationExpression(
                        generator.MemberAccessExpression(
                            generator.ThisExpression(),
                            generator.IdentifierName(nameof(Equals))),
                        generator.IdentifierName(localName)));
            }
            else
            {
                // return obj is T && this.Equals((T)obj);
                expressions.Add(generator.IsTypeExpression(objName, containingType));
                expressions.Add(
                    generator.InvocationExpression(
                        generator.MemberAccessExpression(
                            generator.ThisExpression(),
                            generator.IdentifierName(nameof(Equals))),
                        generator.CastExpression(containingType, objName)));
            }
        }
        else
        {
            // return this.Equals(obj as T);
            expressions.Add(
                generator.InvocationExpression(
                    generator.MemberAccessExpression(
                        generator.ThisExpression(),
                        generator.IdentifierName(nameof(Equals))),
                    generator.TryCastExpression(objName, containingType)));
        }
 
        var statement = generator.ReturnStatement(
            expressions.Aggregate(generator.LogicalAndExpression));
 
        return compilation.CreateEqualsMethod(
            [statement]);
    }
 
    public async Task<IMethodSymbol> GenerateGetHashCodeMethodAsync(
        Document document, INamedTypeSymbol namedType,
        ImmutableArray<ISymbol> members, CancellationToken cancellationToken)
    {
        var compilation = await document.Project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
        var factory = document.GetRequiredLanguageService<SyntaxGenerator>();
        var generatorInternal = document.GetRequiredLanguageService<SyntaxGeneratorInternal>();
        return CreateGetHashCodeMethod(factory, generatorInternal, compilation, namedType, members);
    }
 
    private IMethodSymbol CreateGetHashCodeMethod(
        SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, Compilation compilation,
        INamedTypeSymbol namedType, ImmutableArray<ISymbol> members)
    {
        var statements = CreateGetHashCodeStatements(
            factory, generatorInternal, compilation, namedType, members);
 
        return CodeGenerationSymbolFactory.CreateMethodSymbol(
            attributes: default,
            accessibility: Accessibility.Public,
            modifiers: new DeclarationModifiers(isOverride: true),
            returnType: compilation.GetSpecialType(SpecialType.System_Int32),
            refKind: RefKind.None,
            explicitInterfaceImplementations: default,
            name: GetHashCodeName,
            typeParameters: default,
            parameters: default,
            statements: statements);
    }
 
    private ImmutableArray<SyntaxNode> CreateGetHashCodeStatements(
        SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, Compilation compilation,
        INamedTypeSymbol namedType, ImmutableArray<ISymbol> members)
    {
        // See if there's an accessible System.HashCode we can call into to do all the work.
        var hashCodeType = compilation.GetTypeByMetadataName("System.HashCode");
        if (hashCodeType != null && !hashCodeType.IsAccessibleWithin(namedType))
            hashCodeType = null;
 
        var components = factory.GetGetHashCodeComponents(
            generatorInternal, compilation, namedType, members, justMemberReference: true);
 
        if (components.Length > 0 && hashCodeType != null)
        {
            return factory.CreateGetHashCodeStatementsUsingSystemHashCode(
                factory.SyntaxGeneratorInternal, hashCodeType, components);
        }
 
        // Otherwise, try to just spit out a reasonable hash code for these members.
        var statements = factory.CreateGetHashCodeMethodStatements(
            factory.SyntaxGeneratorInternal, compilation, namedType, members, useInt64: false);
 
        // Unfortunately, our 'reasonable' hash code may overflow in checked contexts.
        // C# can handle this by adding 'checked{}' around the code, VB has to jump
        // through more hoops.
        if (!compilation.Options.CheckOverflow)
        {
            return statements;
        }
 
        if (TryWrapWithUnchecked(statements, out var wrappedStatements))
        {
            return wrappedStatements;
        }
 
        // If tuples are available, use (a, b, c).GetHashCode to simply generate the tuple.
        var valueTupleType = compilation.GetTypeByMetadataName(typeof(ValueTuple).FullName!);
        if (components.Length >= 2 && valueTupleType != null)
        {
            return [factory.ReturnStatement(
                factory.InvocationExpression(
                    factory.MemberAccessExpression(
                        factory.TupleExpression(components),
                        GetHashCodeName)))];
        }
 
        // Otherwise, use 64bit math to compute the hash.  Importantly, if we always clamp
        // the hash to be 32 bits or less, then the following cannot ever overflow in 64
        // bits: hashCode = hashCode * -1521134295 + j.GetHashCode()
        //
        // So we'll generate lines like: hashCode = (hashCode * -1521134295 + j.GetHashCode()) & 0x7FFFFFFF
        //
        // This does mean all hashcodes will be positive.  But it will avoid the overflow problem.
        return factory.CreateGetHashCodeMethodStatements(
            factory.SyntaxGeneratorInternal, compilation, namedType, members, useInt64: true);
    }
}