File: System\Windows\Forms\SourceGenerators\EnumValidationGenerator.cs
Web Access
Project: src\src\System.Windows.Forms.PrivateSourceGenerators\src\System.Windows.Forms.PrivateSourceGenerators.csproj (System.Windows.Forms.PrivateSourceGenerators)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Immutable;
using System.Text;
 
namespace System.Windows.Forms.PrivateSourceGenerators;
 
[Generator]
public class EnumValidationGenerator : IIncrementalGenerator
{
    private const string EnumValidatorStub = @"
// <auto-generated />
namespace SourceGenerated
{
    internal static partial class EnumValidator
    {
        /// <summary>Validates that the enum value passed in is valid for the enum type. Calling this overload will result in a type-specific version being generated.</summary>
        public static void Validate(System.Enum enumToValidate, string parameterName = ""value"")
        {
            // This will be filled in by the generator once you call EnumValidator.Validate()
        }
    }
}
";
    private const string ReportErrorMethod = @"
        private static void ReportEnumValidationError(string parameterName, int value, System.Type enumType)
        {
            throw new System.ComponentModel.InvalidEnumArgumentException(parameterName, value, enumType);
        }";
 
    public void Initialize(IncrementalGeneratorInitializationContext context)
    {
        context.RegisterPostInitializationOutput(context =>
        {
            // Always generate an Enum overload so people o
            context.AddSource("BaseValidator.cs", EnumValidatorStub);
        });
 
        IncrementalValuesProvider<SyntaxNode> argumentsToValidate = context.SyntaxProvider.CreateSyntaxProvider(
            predicate: (syntaxNode, cancellationToken) =>
            {
#pragma warning disable SA1513 // Closing brace should be followed by blank line
                if (syntaxNode is InvocationExpressionSyntax
                    {
                        // 1 parameter for the enum value, 1 optional for the parameter name
                        ArgumentList.Arguments.Count: <= 2,
                        Expression: MemberAccessExpressionSyntax
                        {
                            Name.Identifier.ValueText: "Validate",
                            Expression: MemberAccessExpressionSyntax  // For: SourceGenerated.EnumValidator.Validate(..)
                            {
                                Name.Identifier.ValueText: "EnumValidator"
                            } or IdentifierNameSyntax                 // For: EnumValidator.Validate(..) with a using statement
                            {
                                Identifier.ValueText: "EnumValidator"
                            }
                        }
                    })
                {
                    return true;
                }
#pragma warning restore SA1513
 
                return false;
            },
            transform: (context, cancellationToken) =>
            {
                var invocationExpression = (InvocationExpressionSyntax)context.Node;
                var arguments = invocationExpression.ArgumentList.Arguments;
                var argumentToValidate = arguments.First().Expression;
                return (SyntaxNode)argumentToValidate;
            });
        IncrementalValuesProvider<EnumValidationInfo> enumsToValidate = context.CompilationProvider.Combine(argumentsToValidate.Collect()).SelectMany(
            (compilationAndArguments, cancellationToken) =>
            {
                var (compilation, argumentsToValidate) = compilationAndArguments;
                return GetEnumValidationInfo(compilation, argumentsToValidate, cancellationToken);
            });
 
        context.RegisterSourceOutput(
            enumsToValidate.Collect(),
            static (context, enumsToValidate) =>
            {
                if (enumsToValidate.Any())
                {
                    StringBuilder sb = new();
                    GenerateValidator(context, sb, enumsToValidate);
                    context.CancellationToken.ThrowIfCancellationRequested();
 
                    context.AddSource("Validation.cs", sb.ToString());
                }
            });
    }
 
    private static void GenerateValidator(SourceProductionContext context, StringBuilder sb, IEnumerable<EnumValidationInfo> infos)
    {
        const string indent = "        ";
 
        sb.AppendLine(
@"// <auto-generated />
namespace SourceGenerated
{
    internal static partial class EnumValidator
    {");
 
        foreach (EnumValidationInfo info in infos)
        {
            context.CancellationToken.ThrowIfCancellationRequested();
 
            sb.AppendLine($"{indent}/// <summary>Validates that the enum value passed in is valid for the enum type.</summary>");
            sb.AppendLine($"{indent}public static void Validate({info.EnumType} enumToValidate, string parameterName = \"value\")");
            sb.AppendLine($"{indent}{{");
 
            GenerateValidateMethodBody(context, sb, info, indent + "    ");
 
            sb.AppendLine($"{indent}}}");
            sb.AppendLine();
        }
 
        sb.AppendLine(ReportErrorMethod);
        sb.AppendLine(@"    }
}");
    }
 
    private static void GenerateValidateMethodBody(SourceProductionContext context, StringBuilder sb, EnumValidationInfo info, string indent)
    {
        sb.AppendLine($"{indent}int intValue = (int)enumToValidate;");
        if (info.IsFlags)
        {
            GenerateFlagsValidationMethodBody(sb, info, indent);
        }
        else
        {
            GenerateSequenceValidationMethodBody(context, sb, info, indent);
        }
 
        sb.AppendLine($"{indent}ReportEnumValidationError(parameterName, intValue, typeof({info.EnumType}));");
    }
 
    private static void GenerateFlagsValidationMethodBody(StringBuilder sb, EnumValidationInfo info, string indent)
    {
        int total = 0;
        foreach (int value in info.Values)
        {
            total |= value;
        }
 
        sb.AppendLine($"{indent}if ((intValue & {total}) == intValue) return;");
    }
 
    private static void GenerateSequenceValidationMethodBody(SourceProductionContext context, StringBuilder sb, EnumValidationInfo info, string indent)
    {
        foreach ((int min, int max) in GetElementSets(context, info.Values))
        {
            if (min == max)
            {
                sb.AppendLine($"{indent}if (intValue == {min}) return;");
            }
            else
            {
                sb.AppendLine($"{indent}if (intValue >= {min} && intValue <= {max}) return;");
            }
        }
    }
 
    private static IEnumerable<(int min, int max)> GetElementSets(SourceProductionContext context, EquatableArray<int> values)
    {
        int min = 0;
        int? max = null;
        foreach (int value in values)
        {
            if (max is null || value != max + 1)
            {
                if (max != null)
                {
                    yield return (min, max.Value);
                }
 
                min = value;
                max = value;
            }
            else
            {
                max = value;
            }
        }
 
        if (max is null)
        {
            context.ReportDiagnostic(Diagnostic.Create("EV1", nameof(EnumValidationGenerator), $"Can't validate an enum that has no elements", DiagnosticSeverity.Error, DiagnosticSeverity.Error, true, 4));
            yield break;
        }
 
        yield return (min, max.Value);
    }
 
    private static IEnumerable<EnumValidationInfo> GetEnumValidationInfo(Compilation compilation, ImmutableArray<SyntaxNode> argumentsToValidate, CancellationToken cancellationToken)
    {
        // The compiler doesn't necessarily cache semantic models for a single syntax tree
        // so we will do that here, ensuring we only do the expensive work once per tree.
        // We can't cache this at a higher level because generator lifetime is not to be relied on.
        var semanticModelCache = new Dictionary<SyntaxTree, SemanticModel>();
 
        INamedTypeSymbol? flagsAttributeType = compilation.GetTypeByMetadataName("System.FlagsAttribute");
 
        HashSet<ITypeSymbol> foundTypes = new(SymbolEqualityComparer.Default);
 
        foreach (SyntaxNode argument in argumentsToValidate)
        {
            if (cancellationToken.IsCancellationRequested)
            {
                yield break;
            }
 
            SemanticModel semanticModel = GetSemanticModel(compilation, argument.SyntaxTree);
 
            ITypeSymbol? enumType = semanticModel.GetTypeInfo(argument, cancellationToken).Type;
            if (enumType is null || foundTypes.Contains(enumType))
            {
                continue;
            }
 
            foundTypes.Add(enumType);
 
            bool isFlags = enumType.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, flagsAttributeType));
 
            var info = EnumValidationInfo.FromEnumType(enumType, isFlags);
 
            yield return info;
        }
 
        SemanticModel GetSemanticModel(Compilation compilation, SyntaxTree syntaxTree)
        {
            if (!semanticModelCache.TryGetValue(syntaxTree, out SemanticModel model))
            {
                model = compilation.GetSemanticModel(syntaxTree);
                semanticModelCache.Add(syntaxTree, model);
            }
 
            return model;
        }
    }
}