File: SymbolExtensions.cs
Web Access
Project: ..\..\..\src\Compatibility\Microsoft.DotNet.ApiSymbolExtensions\Microsoft.DotNet.ApiSymbolExtensions.csproj (Microsoft.DotNet.ApiSymbolExtensions)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
 
namespace Microsoft.DotNet.ApiSymbolExtensions
{
    public static class SymbolExtensions
    {
        private static readonly SymbolDisplayFormat s_comparisonFormat;
        public static readonly SymbolDisplayFormat DisplayFormat;
 
        static SymbolExtensions()
        {
            // This is the default format for symbol.ToDisplayString;
            SymbolDisplayFormat format = SymbolDisplayFormat.CSharpErrorMessageFormat;
            format = format.WithMemberOptions(format.MemberOptions | SymbolDisplayMemberOptions.IncludeType);
 
            DisplayFormat = format.WithParameterOptions(format.ParameterOptions | SymbolDisplayParameterOptions.IncludeExtensionThis);
 
            // Remove ? annotations from reference types as we want to map the APIs without nullable annotations
            // and have a special rule to catch those differences.
            // Also don't use keyword names for special types. This makes the comparison more accurate when no
            // references are running or if one side has references and the other doesn't.
            format = format.WithMiscellaneousOptions(format.MiscellaneousOptions &
                ~SymbolDisplayMiscellaneousOptions.UseSpecialTypes &
                ~SymbolDisplayMiscellaneousOptions.UseErrorTypeSymbolName &
                ~SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier);
 
            // Remove ref/out from parameters to compare APIs when building the mappers.
            s_comparisonFormat = format.WithParameterOptions((format.ParameterOptions | SymbolDisplayParameterOptions.IncludeExtensionThis) & ~SymbolDisplayParameterOptions.IncludeParamsRefOut);
        }
 
        public static string ToComparisonDisplayString(this ISymbol symbol) =>
            symbol.ToDisplayString(s_comparisonFormat)
                  .Replace("System.IntPtr", "nint") // Treat IntPtr and nint as the same
                  .Replace("System.UIntPtr", "nuint"); // Treat UIntPtr and nuint as the same
 
        public static IEnumerable<ITypeSymbol> GetAllBaseTypes(this ITypeSymbol type)
        {
            if (type.TypeKind == TypeKind.Interface)
            {
                foreach (ITypeSymbol @interface in type.Interfaces)
                {
                    yield return @interface;
                    foreach (ITypeSymbol baseInterface in @interface.GetAllBaseTypes())
                        yield return baseInterface;
                }
            }
            else if (type.BaseType != null)
            {
                yield return type.BaseType;
                foreach (ITypeSymbol baseType in type.BaseType.GetAllBaseTypes())
                    yield return baseType;
            }
        }
 
        /// <summary>
        /// Determines if a symbol was generated by the compiler by checking for the presence of the CompilerGeneratedAttribute.
        /// </summary>
        /// <param name="symbol">Symbol to check</param>
        /// <returns>True if the attribute was found.</returns>
        public static bool IsCompilerGenerated(this ISymbol symbol)
        {
            return symbol.GetAttributes().Any(attribute =>
                attribute?.AttributeClass?.Name == "CompilerGeneratedAttribute" &&
                attribute.AttributeClass.ContainingNamespace.ToDisplayString().Equals("System.Runtime.CompilerServices", StringComparison.Ordinal));
        }
 
        public static bool IsEffectivelySealed(this ITypeSymbol type, bool includeInternalSymbols) =>
            type.IsSealed || !HasVisibleConstructor(type, includeInternalSymbols);
 
        /// <summary>
        /// Determines where the symbol is the explicit interface implementation method or property.
        /// </summary>
        /// <param name="symbol"><see cref="ISymbol"/>  Represents a symbol (namespace, class, method, parameter, etc.) exposed by the compiler.</param>
        /// <returns>true if the symbol is the explicit interface implementation method</returns>
        public static bool IsExplicitInterfaceImplementation(this ISymbol symbol) =>
            symbol is IMethodSymbol method && method.MethodKind == MethodKind.ExplicitInterfaceImplementation ||
            symbol is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsEmpty;
 
        private static bool HasVisibleConstructor(ITypeSymbol type, bool includeInternalSymbols)
        {
            if (type is INamedTypeSymbol namedType)
            {
                foreach (IMethodSymbol constructor in namedType.Constructors)
                {
                    if (!constructor.IsStatic && constructor.IsVisibleOutsideOfAssembly(includeInternalSymbols, includeEffectivelyPrivateSymbols: true))
                        return true;
                }
            }
 
            return false;
        }
 
        public static IEnumerable<ITypeSymbol> GetAllBaseInterfaces(this ITypeSymbol type)
        {
            foreach (ITypeSymbol @interface in type.Interfaces)
            {
                yield return @interface;
                foreach (ITypeSymbol baseInterface in @interface.GetAllBaseInterfaces())
                    yield return baseInterface;
            }
 
            foreach (ITypeSymbol baseType in type.GetAllBaseTypes())
                foreach (ITypeSymbol baseInterface in baseType.GetAllBaseInterfaces())
                    yield return baseInterface;
        }
 
        public static bool IsVisibleOutsideOfAssembly(this ISymbol symbol,
            bool includeInternalSymbols,
            bool includeEffectivelyPrivateSymbols = false,
            bool includeExplicitInterfaceImplementationSymbols = false) =>
            symbol.DeclaredAccessibility switch
            {
                Accessibility.Public => true,
                Accessibility.Protected => includeEffectivelyPrivateSymbols || symbol.ContainingType == null || !IsEffectivelySealed(symbol.ContainingType, includeInternalSymbols),
                Accessibility.ProtectedOrInternal => includeEffectivelyPrivateSymbols || includeInternalSymbols || symbol.ContainingType == null || !IsEffectivelySealed(symbol.ContainingType, includeInternalSymbols),
                Accessibility.ProtectedAndInternal => includeInternalSymbols && (includeEffectivelyPrivateSymbols || symbol.ContainingType == null || !IsEffectivelySealed(symbol.ContainingType, includeInternalSymbols)),
                Accessibility.Private => includeExplicitInterfaceImplementationSymbols && IsExplicitInterfaceImplementation(symbol),
                _ => includeInternalSymbols,
            };
 
        public static bool IsEventAdderOrRemover(this IMethodSymbol method) =>
            method.MethodKind == MethodKind.EventAdd ||
            method.MethodKind == MethodKind.EventRemove ||
            method.Name.StartsWith("add_", StringComparison.Ordinal) ||
            method.Name.StartsWith("remove_", StringComparison.Ordinal);
 
        /// <summary>
        /// Attempts to locate and return the constructor generated by the compiler for `record Foo(...)` syntax.
        /// The compiler will not generate a constructor in the case where the user defined it themself without using an argument list
        /// in the record declaration, or if the record has no parameters.
        /// </summary>
        /// <param name="type">The type to check for a compiler generated constructor.</param>
        /// <param name="recordConstructor">When this method returns <see langword="true"/>, then the compiler generated constructor for the record is stored in this out parameter, otherwise it becomes <see langword="null" />.</param>
        /// <returns><see langword="true" /> if the type is a record and the compiler generated constructor is found, otherwise <see langword="false"/>.</returns>
        public static bool TryGetRecordConstructor(this INamedTypeSymbol type, [NotNullWhen(true)] out IMethodSymbol? recordConstructor)
        {
            if (!type.IsRecord)
            {
                recordConstructor = null;
                return false;
            }
 
            // Locate the compiler generated Deconstruct method.
            var deconstructMethod = (IMethodSymbol?)type.GetMembers("Deconstruct")
                .FirstOrDefault(m => m is IMethodSymbol && m.IsCompilerGenerated());
 
            // Locate the compiler generated constructor by matching parameters to Deconstruct - since we cannot locate it with an attribute.
            recordConstructor = (IMethodSymbol?)type.GetMembers(".ctor")
                .FirstOrDefault(m => m is IMethodSymbol method &&
                    method.MethodKind == MethodKind.Constructor &&
                    (deconstructMethod == null ?
                        method.Parameters.IsEmpty :
                        method.Parameters.Select(p => p.Type).SequenceEqual(
                            deconstructMethod.Parameters.Select(p => p.Type), SymbolEqualityComparer.Default)));
 
            return recordConstructor != null;
        }
    }
}