File: Extensions\ITypeSymbolExtensions.cs
Web Access
Project: src\src\Http\Http.Extensions\gen\Microsoft.AspNetCore.Http.ValidationsGenerator\Microsoft.AspNetCore.Http.ValidationsGenerator.csproj (Microsoft.AspNetCore.Http.ValidationsGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
 
namespace Microsoft.AspNetCore.Http.ValidationsGenerator;
 
internal static class ITypeSymbolExtensions
{
    public static bool IsEnumerable(this ITypeSymbol type, INamedTypeSymbol enumerable)
    {
        if (type.SpecialType == SpecialType.System_String)
        {
            return false;
        }
 
        return type.ImplementsInterface(enumerable) || SymbolEqualityComparer.Default.Equals(type, enumerable);
    }
 
    public static bool ImplementsValidationAttribute(this ITypeSymbol typeSymbol, INamedTypeSymbol validationAttributeSymbol)
    {
        var baseType = typeSymbol.BaseType;
        while (baseType != null)
        {
            if (SymbolEqualityComparer.Default.Equals(baseType, validationAttributeSymbol))
            {
                return true;
            }
            baseType = baseType.BaseType;
        }
 
        return false;
    }
 
    public static ITypeSymbol UnwrapType(this ITypeSymbol type, INamedTypeSymbol enumerable)
    {
        if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T &&
            type is INamedTypeSymbol { TypeArguments.Length: 1 })
        {
            // Extract the T from a Nullable<T>
            type = ((INamedTypeSymbol)type).TypeArguments[0];
        }
 
        if (type.NullableAnnotation == NullableAnnotation.Annotated)
        {
            // Extract the underlying type from a reference type
            type = type.OriginalDefinition;
        }
 
        if (type is INamedTypeSymbol namedType && namedType.IsEnumerable(enumerable) && namedType.TypeArguments.Length == 1)
        {
            // Extract the T from an IEnumerable<T> or List<T>
            type = namedType.TypeArguments[0];
        }
 
        return type;
    }
 
    internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol interfaceType)
    {
        foreach (var iface in type.AllInterfaces)
        {
            if (SymbolEqualityComparer.Default.Equals(interfaceType, iface))
            {
                return true;
            }
        }
        return false;
    }
 
    internal static ImmutableArray<INamedTypeSymbol>? GetJsonDerivedTypes(this ITypeSymbol type, INamedTypeSymbol jsonDerivedTypeAttribute)
    {
        var derivedTypes = ImmutableArray.CreateBuilder<INamedTypeSymbol>();
        foreach (var attribute in type.GetAttributes())
        {
            if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, jsonDerivedTypeAttribute))
            {
                var derivedType = (INamedTypeSymbol?)attribute.ConstructorArguments[0].Value;
                if (derivedType is not null && !SymbolEqualityComparer.Default.Equals(derivedType, type))
                {
                    derivedTypes.Add(derivedType);
                }
            }
        }
 
        return derivedTypes.Count == 0 ? null : derivedTypes.ToImmutable();
    }
 
    // Types exempted here have special binding rules in RDF and RDG and are not validatable
    // types themselves so we short-circuit on them.
    internal static bool IsExemptType(this ITypeSymbol type, RequiredSymbols requiredSymbols)
    {
        return SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpContext)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpRequest)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpResponse)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.CancellationToken)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormCollection)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFileCollection)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFile)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.Stream)
               || SymbolEqualityComparer.Default.Equals(type, requiredSymbols.PipeReader);
    }
}