File: VtableIndexStubGenerator.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj (Microsoft.Interop.ComInterfaceGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.CodeDom.Compiler;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.Interop.Analyzers;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
 
[assembly: System.Resources.NeutralResourcesLanguage("en-US")]
 
namespace Microsoft.Interop
{
    [Generator]
    public sealed class VtableIndexStubGenerator : IIncrementalGenerator
    {
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            // Get all methods with the [VirtualMethodIndex] attribute.
            var attributedMethods = context.SyntaxProvider
                .ForAttributeWithMetadataName(
                    TypeNames.VirtualMethodIndexAttribute,
                    static (node, ct) => node is MethodDeclarationSyntax,
                    static (context, ct) => context.TargetSymbol is IMethodSymbol methodSymbol
                        ? new { Syntax = (MethodDeclarationSyntax)context.TargetNode, Symbol = methodSymbol }
                        : null)
                .Where(
                    static modelData => modelData is not null);
 
            // Filter out methods that are invalid for generation (diagnostics for invalid methods are reported by the analyzer).
            var methodsToGenerate = attributedMethods.Where(
                static data => data is not null && VtableIndexStubDiagnosticsAnalyzer.GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol) is null);
 
            // Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs
            // for each method.
            IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> generateStubInformation = methodsToGenerate
                .Combine(context.CreateStubEnvironmentProvider())
                .Select(static (data, ct) => CalculateStubInformation(data.Left.Syntax, data.Left.Symbol, data.Right, ct));
 
            // Filter the list of all stubs to only the stubs that requested managed-to-unmanaged stub generation.
            IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> managedToNativeStubContexts =
                generateStubInformation
                .Where(data => data.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional);
 
            context.RegisterSourceOutput(managedToNativeStubContexts.Collect(), static (context, data) =>
            {
                if (data.IsEmpty)
                    return;
 
                using StringWriter sw = new();
                using IndentedTextWriter writer = new(sw);
 
                writer.WriteLine("// <auto-generated/>");
 
                // Generate the code for the managed-to-unmanaged stubs.
                foreach (SourceAvailableIncrementalMethodStubGenerationContext stub in data)
                {
                    sw.WriteLine();
 
                    var (stubSyntax, _) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(stub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
 
                    stub.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, stubSyntax, static (writer, stubSyntax) =>
                    {
                        writer.WriteLine("internal partial interface Native");
                        writer.WriteLine('{');
                        writer.Indent++;
                        writer.WriteMultilineNode(stubSyntax.NormalizeWhitespace());
                        writer.Indent--;
                        writer.WriteLine('}');
                    });
                }
 
                context.AddSource("ManagedToNativeStubs.g.cs", sw.ToString());
            });
 
            // Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation.
            IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> nativeToManagedStubContexts =
                generateStubInformation
                .Where(data => data.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional);
 
            context.RegisterSourceOutput(nativeToManagedStubContexts.Collect(), static (context, data) =>
            {
                if (data.IsEmpty)
                    return;
 
                using StringWriter sw = new();
                using IndentedTextWriter writer = new(sw);
 
                writer.WriteLine("// <auto-generated/>");
 
                // Generate the code for the unmanaged-to-managed stubs.
                foreach (SourceAvailableIncrementalMethodStubGenerationContext stub in data)
                {
                    sw.WriteLine();
 
                    var (stubSyntax, _) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(stub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
 
                    stub.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, stubSyntax, static (writer, stubSyntax) =>
                    {
                        writer.WriteLine("internal partial interface Native");
                        writer.WriteLine('{');
                        writer.Indent++;
                        writer.WriteMultilineNode(stubSyntax.NormalizeWhitespace());
                        writer.Indent--;
                        writer.WriteLine('}');
                    });
                }
 
                context.AddSource("NativeToManagedStubs.g.cs", sw.ToString());
            });
 
            IncrementalValueProvider<ImmutableArray<ContainingSyntaxContext>> syntaxContexts = generateStubInformation
                .Select(static (context, ct) => context.ContainingSyntaxContext)
                .Collect();
 
            context.RegisterSourceOutput(syntaxContexts, static (context, data) =>
            {
                if (data.IsEmpty)
                    return;
 
                using StringWriter sw = new();
                using IndentedTextWriter writer = new(sw);
 
                writer.WriteLine("// <auto-generated/>");
 
                // Generate the native interface metadata for each interface that contains a method with the [VirtualMethodIndex] attribute.
                foreach (ContainingSyntaxContext syntaxContext in data.Distinct())
                {
                    sw.WriteLine();
 
                    syntaxContext.WriteToWithUnsafeModifier(writer, syntaxContext.ContainingSyntax[0].Identifier.Text, static (writer, baseTypeName) =>
                    {
                        writer.WriteLine("[global::System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute]");
                        writer.WriteLine($"internal partial interface Native : {baseTypeName} {{ }}");
                    });
                }
 
                context.AddSource("NativeInterfaces.g.cs", sw.ToString());
            });
 
            context.RegisterSourceOutput(nativeToManagedStubContexts.Collect(), static (context, data) =>
            {
                if (data.IsEmpty)
                    return;
 
                using StringWriter sw = new();
                using IndentedTextWriter writer = new(sw);
 
                writer.WriteLine("// <auto-generated/>");
 
                foreach (var group in data.GroupBy(s => s.ContainingSyntaxContext))
                {
                    sw.WriteLine();
 
                    // Generate a method named PopulateUnmanagedVirtualMethodTable on the native interface implementation
                    // that fills in a span with the addresses of the unmanaged-to-managed stub functions at their correct indices.
                    group.Key.WriteToWithUnsafeModifier(writer, group, static (writer, data) =>
                    {
                        writer.WriteLine("internal unsafe partial interface Native");
                        writer.WriteLine('{');
                        writer.Indent++;
                        writer.WriteLine("internal static unsafe void PopulateUnmanagedVirtualMethodTable(void** vtable)");
                        writer.WriteLine('{');
                        writer.Indent++;
 
                        foreach (SourceAvailableIncrementalMethodStubGenerationContext method in data)
                        {
                            FunctionPointerTypeSyntax functionPointerType = VirtualMethodPointerStubGenerator.GenerateUnmanagedFunctionPointerTypeForMethod(method, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
                            writer.WriteLine($"vtable[{method.VtableIndexData.Index}] = (void*)({functionPointerType.NormalizeWhitespace()})&ABI_{method.StubMethodSyntaxTemplate.Identifier};");
                        }
 
                        writer.Indent--;
                        writer.WriteLine('}');
                        writer.Indent--;
                        writer.WriteLine('}');
                    });
                }
 
                context.AddSource("PopulateVTable.g.cs", sw.ToString());
            });
        }
 
        private static VirtualMethodIndexCompilationData? ProcessVirtualMethodIndexAttribute(AttributeData attrData)
        {
            // Found the attribute, but it has an error so report the error.
            // This is most likely an issue with targeting an incorrect TFM.
            if (attrData.AttributeClass?.TypeKind is null or TypeKind.Error)
            {
                return null;
            }
 
            var namedArguments = ImmutableDictionary.CreateRange(attrData.NamedArguments);
 
            if (attrData.ConstructorArguments.Length == 0 || attrData.ConstructorArguments[0].Value is not int)
            {
                return null;
            }
 
            MarshalDirection direction = MarshalDirection.Bidirectional;
            bool implicitThis = true;
            bool exceptionMarshallingDefined = false;
            ExceptionMarshalling exceptionMarshalling = ExceptionMarshalling.Custom;
            INamedTypeSymbol? exceptionMarshallingCustomType = null;
            if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.Direction), out TypedConstant directionValue))
            {
                // TypedConstant's Value property only contains primitive values.
                if (directionValue.Value is not int)
                {
                    return null;
                }
                // A boxed primitive can be unboxed to an enum with the same underlying type.
                direction = (MarshalDirection)directionValue.Value!;
            }
            if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.ImplicitThisParameter), out TypedConstant implicitThisValue))
            {
                if (implicitThisValue.Value is not bool)
                {
                    return null;
                }
                implicitThis = (bool)implicitThisValue.Value!;
            }
            if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.ExceptionMarshalling), out TypedConstant exceptionMarshallingValue))
            {
                exceptionMarshallingDefined = true;
                // TypedConstant's Value property only contains primitive values.
                if (exceptionMarshallingValue.Value is not int)
                {
                    return null;
                }
                // A boxed primitive can be unboxed to an enum with the same underlying type.
                exceptionMarshalling = (ExceptionMarshalling)exceptionMarshallingValue.Value!;
            }
            if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.ExceptionMarshallingCustomType), out TypedConstant exceptionMarshallingCustomTypeValue))
            {
                if (exceptionMarshallingCustomTypeValue.Value is not INamedTypeSymbol)
                {
                    return null;
                }
                exceptionMarshallingCustomType = (INamedTypeSymbol)exceptionMarshallingCustomTypeValue.Value;
            }
 
            return new VirtualMethodIndexCompilationData((int)attrData.ConstructorArguments[0].Value).WithValuesFromNamedArguments(namedArguments) with
            {
                Direction = direction,
                ImplicitThisParameter = implicitThis,
                ExceptionMarshallingDefined = exceptionMarshallingDefined,
                ExceptionMarshalling = exceptionMarshalling,
                ExceptionMarshallingCustomType = exceptionMarshallingCustomType,
            };
        }
 
        internal static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
        {
            ct.ThrowIfCancellationRequested();
            INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
            INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute);
            INamedTypeSymbol? unmanagedCallConvAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute);
            INamedTypeSymbol iUnmanagedInterfaceTypeType = environment.Compilation.GetTypeByMetadataName(TypeNames.IUnmanagedInterfaceType_Metadata)!;
            // Get any attributes of interest on the method
            AttributeData? virtualMethodIndexAttr = null;
            AttributeData? lcidConversionAttr = null;
            AttributeData? suppressGCTransitionAttribute = null;
            AttributeData? unmanagedCallConvAttribute = null;
            foreach (AttributeData attr in symbol.GetAttributes())
            {
                if (attr.AttributeClass is not null
                    && attr.AttributeClass.ToDisplayString() == TypeNames.VirtualMethodIndexAttribute)
                {
                    virtualMethodIndexAttr = attr;
                }
                else if (lcidConversionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType))
                {
                    lcidConversionAttr = attr;
                }
                else if (suppressGCTransitionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType))
                {
                    suppressGCTransitionAttribute = attr;
                }
                else if (unmanagedCallConvAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType))
                {
                    unmanagedCallConvAttribute = attr;
                }
            }
 
            Debug.Assert(virtualMethodIndexAttr is not null);
 
            var locations = new MethodSignatureDiagnosticLocations(syntax);
            var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), locations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
 
            // Process the LibraryImport attribute
            VirtualMethodIndexCompilationData? virtualMethodIndexData = ProcessVirtualMethodIndexAttribute(virtualMethodIndexAttr!);
 
            if (virtualMethodIndexData is null)
            {
                virtualMethodIndexData = new VirtualMethodIndexCompilationData(-1);
            }
            else if (virtualMethodIndexData.Index < 0)
            {
                // Report missing or invalid index
            }
 
            if (virtualMethodIndexData.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling))
            {
                // User specified StringMarshalling.Custom without specifying StringMarshallingCustomType
                if (virtualMethodIndexData.StringMarshalling == StringMarshalling.Custom && virtualMethodIndexData.StringMarshallingCustomType is null)
                {
                    generatorDiagnostics.ReportInvalidStringMarshallingConfiguration(
                        virtualMethodIndexAttr, symbol.Name, SR.InvalidStringMarshallingConfigurationMissingCustomType);
                }
 
                // User specified something other than StringMarshalling.Custom while specifying StringMarshallingCustomType
                if (virtualMethodIndexData.StringMarshalling != StringMarshalling.Custom && virtualMethodIndexData.StringMarshallingCustomType is not null)
                {
                    generatorDiagnostics.ReportInvalidStringMarshallingConfiguration(
                        virtualMethodIndexAttr, symbol.Name, SR.InvalidStringMarshallingConfigurationNotCustom);
                }
            }
 
            if (!virtualMethodIndexData.ImplicitThisParameter && virtualMethodIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
            {
                // Report invalid configuration
            }
 
            if (lcidConversionAttr is not null)
            {
                // Using LCIDConversion with source-generated interop is not supported
                generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute));
            }
 
            // Create the stub.
            var signatureContext = SignatureContext.Create(
                symbol,
                DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, virtualMethodIndexData, virtualMethodIndexAttr),
                environment,
                new CodeEmitOptions(SkipInit: true),
                typeof(VtableIndexStubGenerator).Assembly);
 
            var containingSyntaxContext = new ContainingSyntaxContext(syntax);
 
            var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.PartialKeyword) && !m.IsKind(SyntaxKind.VirtualKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
 
            ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute, defaultCallingConventions: ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax>.Empty);
 
            var interfaceType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
 
            INamedTypeSymbol expectedUnmanagedInterfaceType = iUnmanagedInterfaceTypeType;
 
            bool implementsIUnmanagedInterfaceOfSelf = symbol.ContainingType.AllInterfaces.Any(iface => SymbolEqualityComparer.Default.Equals(iface, expectedUnmanagedInterfaceType));
            if (!implementsIUnmanagedInterfaceOfSelf)
            {
                // TODO: Report invalid configuration
            }
 
            var unmanagedObjectUnwrapper = symbol.ContainingType.GetAttributes().FirstOrDefault(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute));
            if (unmanagedObjectUnwrapper is null)
            {
                // TODO: report invalid configuration - or ensure that this will never happen at this point
            }
            var unwrapperSyntax = ParseTypeName(unmanagedObjectUnwrapper.AttributeClass.TypeArguments[0].ToDisplayString());
 
            MarshallingInfo exceptionMarshallingInfo = CreateExceptionMarshallingInfo(virtualMethodIndexAttr, symbol, environment.Compilation, generatorDiagnostics, virtualMethodIndexData);
 
            return new SourceAvailableIncrementalMethodStubGenerationContext(
                signatureContext,
                containingSyntaxContext,
                methodSyntaxTemplate,
                locations,
                new SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax>(callConv, SyntaxEquivalentComparer.Instance),
                VirtualMethodIndexData.From(virtualMethodIndexData),
                exceptionMarshallingInfo,
                environment.EnvironmentFlags,
                interfaceType,
                interfaceType,
                new SequenceEqualImmutableArray<DiagnosticInfo>(generatorDiagnostics.Diagnostics.ToImmutableArray()),
                new ObjectUnwrapperInfo(unwrapperSyntax));
        }
 
        private static MarshallingInfo CreateExceptionMarshallingInfo(AttributeData virtualMethodIndexAttr, ISymbol symbol, Compilation compilation, GeneratorDiagnosticsBag diagnostics, VirtualMethodIndexCompilationData virtualMethodIndexData)
        {
            if (virtualMethodIndexData.ExceptionMarshallingDefined)
            {
                // User specified ExceptionMarshalling.Custom without specifying ExceptionMarshallingCustomType
                if (virtualMethodIndexData.ExceptionMarshalling == ExceptionMarshalling.Custom && virtualMethodIndexData.ExceptionMarshallingCustomType is null)
                {
                    diagnostics.ReportInvalidExceptionMarshallingConfiguration(
                        virtualMethodIndexAttr, symbol.Name, SR.InvalidExceptionMarshallingConfigurationMissingCustomType);
                    return NoMarshallingInfo.Instance;
                }
 
                // User specified something other than ExceptionMarshalling.Custom while specifying ExceptionMarshallingCustomType
                if (virtualMethodIndexData.ExceptionMarshalling != ExceptionMarshalling.Custom && virtualMethodIndexData.ExceptionMarshallingCustomType is not null)
                {
                    diagnostics.ReportInvalidExceptionMarshallingConfiguration(
                        virtualMethodIndexAttr, symbol.Name, SR.InvalidExceptionMarshallingConfigurationNotCustom);
                }
            }
 
            if (virtualMethodIndexData.ExceptionMarshalling == ExceptionMarshalling.Com)
            {
                return new ComExceptionMarshalling();
            }
            if (virtualMethodIndexData.ExceptionMarshalling == ExceptionMarshalling.Custom)
            {
                return virtualMethodIndexData.ExceptionMarshallingCustomType is null
                    ? NoMarshallingInfo.Instance
                    : CustomMarshallingInfoHelper.CreateNativeMarshallingInfoForNonSignatureElement(
                        compilation.GetTypeByMetadataName(TypeNames.System_Exception),
                        virtualMethodIndexData.ExceptionMarshallingCustomType!,
                        virtualMethodIndexAttr,
                        compilation,
                        diagnostics);
            }
            // This should not be reached in normal usage, but a developer can cast any int to the ExceptionMarshalling enum, so we should handle this case without crashing the generator.
            diagnostics.ReportInvalidExceptionMarshallingConfiguration(
                virtualMethodIndexAttr, symbol.Name, SR.InvalidExceptionMarshallingValue);
            return NoMarshallingInfo.Instance;
        }
    }
}