File: System\Windows\Forms\SourceGenerators\EnumValidationGenerator.cs
Project: src\src\System.Windows.Forms.PrivateSourceGenerators\src\System.Windows.Forms.PrivateSourceGenerators.csproj (System.Windows.Forms.PrivateSourceGenerators)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Immutable;
using System.Text;
namespace System.Windows.Forms.PrivateSourceGenerators;
public class EnumValidationGenerator : IIncrementalGenerator
    private const string EnumValidatorStub = @"
// <auto-generated />
namespace SourceGenerated
    internal static partial class EnumValidator
        /// <summary>Validates that the enum value passed in is valid for the enum type. Calling this overload will result in a type-specific version being generated.</summary>
        public static void Validate(System.Enum enumToValidate, string parameterName = ""value"")
            // This will be filled in by the generator once you call EnumValidator.Validate()
    private const string ReportErrorMethod = @"
        private static void ReportEnumValidationError(string parameterName, int value, System.Type enumType)
            throw new System.ComponentModel.InvalidEnumArgumentException(parameterName, value, enumType);
    public void Initialize(IncrementalGeneratorInitializationContext context)
        context.RegisterPostInitializationOutput(context =>
            // Always generate an Enum overload so people o
            context.AddSource("BaseValidator.cs", EnumValidatorStub);
        IncrementalValuesProvider<SyntaxNode> argumentsToValidate = context.SyntaxProvider.CreateSyntaxProvider(
            predicate: (syntaxNode, cancellationToken) =>
#pragma warning disable SA1513 // Closing brace should be followed by blank line
                if (syntaxNode is InvocationExpressionSyntax
                        // 1 parameter for the enum value, 1 optional for the parameter name
                        ArgumentList.Arguments.Count: <= 2,
                        Expression: MemberAccessExpressionSyntax
                            Name.Identifier.ValueText: "Validate",
                            Expression: MemberAccessExpressionSyntax  // For: SourceGenerated.EnumValidator.Validate(..)
                                Name.Identifier.ValueText: "EnumValidator"
                            } or IdentifierNameSyntax                 // For: EnumValidator.Validate(..) with a using statement
                                Identifier.ValueText: "EnumValidator"
                    return true;
#pragma warning restore SA1513
                return false;
            transform: (context, cancellationToken) =>
                var invocationExpression = (InvocationExpressionSyntax)context.Node;
                var arguments = invocationExpression.ArgumentList.Arguments;
                var argumentToValidate = arguments.First().Expression;
                return (SyntaxNode)argumentToValidate;
        IncrementalValuesProvider<EnumValidationInfo> enumsToValidate = context.CompilationProvider.Combine(argumentsToValidate.Collect()).SelectMany(
            (compilationAndArguments, cancellationToken) =>
                var (compilation, argumentsToValidate) = compilationAndArguments;
                return GetEnumValidationInfo(compilation, argumentsToValidate, cancellationToken);
            static (context, enumsToValidate) =>
                if (enumsToValidate.Any())
                    StringBuilder sb = new();
                    GenerateValidator(context, sb, enumsToValidate);
                    context.AddSource("Validation.cs", sb.ToString());
    private static void GenerateValidator(SourceProductionContext context, StringBuilder sb, IEnumerable<EnumValidationInfo> infos)
        const string indent = "        ";
@"// <auto-generated />
namespace SourceGenerated
    internal static partial class EnumValidator
        foreach (EnumValidationInfo info in infos)
            sb.AppendLine($"{indent}/// <summary>Validates that the enum value passed in is valid for the enum type.</summary>");
            sb.AppendLine($"{indent}public static void Validate({info.EnumType} enumToValidate, string parameterName = \"value\")");
            GenerateValidateMethodBody(context, sb, info, indent + "    ");
        sb.AppendLine(@"    }
    private static void GenerateValidateMethodBody(SourceProductionContext context, StringBuilder sb, EnumValidationInfo info, string indent)
        sb.AppendLine($"{indent}int intValue = (int)enumToValidate;");
        if (info.IsFlags)
            GenerateFlagsValidationMethodBody(sb, info, indent);
            GenerateSequenceValidationMethodBody(context, sb, info, indent);
        sb.AppendLine($"{indent}ReportEnumValidationError(parameterName, intValue, typeof({info.EnumType}));");
    private static void GenerateFlagsValidationMethodBody(StringBuilder sb, EnumValidationInfo info, string indent)
        int total = 0;
        foreach (int value in info.Values)
            total |= value;
        sb.AppendLine($"{indent}if ((intValue & {total}) == intValue) return;");
    private static void GenerateSequenceValidationMethodBody(SourceProductionContext context, StringBuilder sb, EnumValidationInfo info, string indent)
        foreach ((int min, int max) in GetElementSets(context, info.Values))
            if (min == max)
                sb.AppendLine($"{indent}if (intValue == {min}) return;");
                sb.AppendLine($"{indent}if (intValue >= {min} && intValue <= {max}) return;");
    private static IEnumerable<(int min, int max)> GetElementSets(SourceProductionContext context, EquatableArray<int> values)
        int min = 0;
        int? max = null;
        foreach (int value in values)
            if (max is null || value != max + 1)
                if (max != null)
                    yield return (min, max.Value);
                min = value;
                max = value;
                max = value;
        if (max is null)
            context.ReportDiagnostic(Diagnostic.Create("EV1", nameof(EnumValidationGenerator), $"Can't validate an enum that has no elements", DiagnosticSeverity.Error, DiagnosticSeverity.Error, true, 4));
            yield break;
        yield return (min, max.Value);
    private static IEnumerable<EnumValidationInfo> GetEnumValidationInfo(Compilation compilation, ImmutableArray<SyntaxNode> argumentsToValidate, CancellationToken cancellationToken)
        // The compiler doesn't necessarily cache semantic models for a single syntax tree
        // so we will do that here, ensuring we only do the expensive work once per tree.
        // We can't cache this at a higher level because generator lifetime is not to be relied on.
        var semanticModelCache = new Dictionary<SyntaxTree, SemanticModel>();
        INamedTypeSymbol? flagsAttributeType = compilation.GetTypeByMetadataName("System.FlagsAttribute");
        HashSet<ITypeSymbol> foundTypes = new(SymbolEqualityComparer.Default);
        foreach (SyntaxNode argument in argumentsToValidate)
            if (cancellationToken.IsCancellationRequested)
                yield break;
            SemanticModel semanticModel = GetSemanticModel(compilation, argument.SyntaxTree);
            ITypeSymbol? enumType = semanticModel.GetTypeInfo(argument, cancellationToken).Type;
            if (enumType is null || foundTypes.Contains(enumType))
            bool isFlags = enumType.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, flagsAttributeType));
            var info = EnumValidationInfo.FromEnumType(enumType, isFlags);
            yield return info;
        SemanticModel GetSemanticModel(Compilation compilation, SyntaxTree syntaxTree)
            if (!semanticModelCache.TryGetValue(syntaxTree, out SemanticModel model))
                model = compilation.GetSemanticModel(syntaxTree);
                semanticModelCache.Add(syntaxTree, model);
            return model;