File: ComClassGenerator.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj (Microsoft.Interop.ComInterfaceGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Microsoft.Interop.SyntaxFactoryExtensions;
 
namespace Microsoft.Interop
{
    [Generator]
    public class ComClassGenerator : IIncrementalGenerator
    {
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled
            // Get all types with the [GeneratedComClassAttribute] attribute.
            var attributedClassesOrDiagnostics = context.SyntaxProvider
                .ForAttributeWithMetadataName(
                    TypeNames.GeneratedComClassAttribute,
                    static (node, ct) => node is ClassDeclarationSyntax,
                    static (context, ct) => context)
                .Combine(unsafeCodeIsEnabled)
                .Select(static (data, ct) =>
                    {
                        var context = data.Left;
                        var unsafeCodeIsEnabled = data.Right;
                        var type = (INamedTypeSymbol)context.TargetSymbol;
                        var syntax = (ClassDeclarationSyntax)context.TargetNode;
                        return ComClassInfo.From(type, syntax, unsafeCodeIsEnabled);
                    });
 
            var attributedClasses = context.FilterAndReportDiagnostics(attributedClassesOrDiagnostics);
 
            var className = attributedClasses.Select(static (info, ct) => info.ClassName);
 
            var classInfoType = attributedClasses
                .Select(static (info, ct) => new { info.ClassName, info.ImplementedInterfacesNames })
                .Select(static (info, ct) => GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace());
 
            var attribute = attributedClasses
                .Select(static (info, ct) => new { info.ContainingSyntaxContext, info.ClassSyntax })
                .Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace());
 
            context.RegisterSourceOutput(className.Zip(classInfoType).Zip(attribute), static (context, classInfo) =>
            {
                var ((className, classInfoType), attribute) = classInfo;
                StringWriter writer = new();
                writer.WriteLine("// <auto-generated />");
                writer.WriteLine(classInfoType.ToFullString());
                writer.WriteLine();
                writer.WriteLine(attribute);
                context.AddSource(className, writer.ToString());
            });
        }
 
        private const string ClassInfoTypeName = "ComClassInformation";
 
        private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
            Attribute(
                GenericName(TypeNames.GlobalAlias + TypeNames.ComExposedClassAttribute)
                    .AddTypeArgumentListArguments(
                        IdentifierName(ClassInfoTypeName)));
        private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) =>
            containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
                TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier)
                    .WithModifiers(classSyntax.Modifiers)
                    .WithTypeParameterList(classSyntax.TypeParameters)
                    .AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
        private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray<string> implementedInterfaces)
        {
            const string vtablesField = "s_vtables";
            const string vtablesLocal = "vtables";
            const string detailsTempLocal = "details";
            const string countIdentifier = "count";
            var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
                .AddModifiers(
                    Token(SyntaxKind.FileKeyword),
                    Token(SyntaxKind.SealedKeyword),
                    Token(SyntaxKind.UnsafeKeyword))
                .AddBaseListTypes(SimpleBaseType(TypeSyntaxes.IComExposedClass))
                .AddMembers(
                    FieldDeclaration(
                        VariableDeclaration(
                            PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
                            SingletonSeparatedList(VariableDeclarator(vtablesField))))
                    .AddModifiers(
                        Token(SyntaxKind.PrivateKeyword),
                        Token(SyntaxKind.StaticKeyword),
                        Token(SyntaxKind.VolatileKeyword)));
            List<StatementSyntax> vtableInitializationBlock = new()
            {
                // ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<ClassInfoTypeName>), sizeof(ComInterfaceEntry) * <numInterfaces>);
                Declare(
                    PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
                    vtablesLocal,
                        CastExpression(
                            PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
                            MethodInvocation(
                                TypeSyntaxes.System_Runtime_CompilerServices_RuntimeHelpers,
                                IdentifierName("AllocateTypeAssociatedMemory"),
                                Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))),
                                Argument(
                                    BinaryExpression(
                                        SyntaxKind.MultiplyExpression,
                                        SizeOfExpression(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
                                        LiteralExpression(
                                            SyntaxKind.NumericLiteralExpression,
                                            Literal(implementedInterfaces.Length))))))),
 
                // IIUnknownDerivedDetails details;
                Declare(TypeSyntaxes.IIUnknownDerivedDetails, detailsTempLocal, initializeToDefault: false)
            };
            for (int i = 0; i < implementedInterfaces.Length; i++)
            {
                string ifaceName = implementedInterfaces[i];
 
                // details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
                vtableInitializationBlock.Add(
                    AssignmentStatement(
                        IdentifierName(detailsTempLocal),
                        MethodInvocation(
                                TypeSyntaxes.StrategyBasedComWrappers
                                    .Dot(IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
                            IdentifierName("GetIUnknownDerivedDetails"),
                            Argument(
                                MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
                                    TypeOfExpression(ParseName(ifaceName)),
                                    IdentifierName("TypeHandle"))))));
                // vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
                vtableInitializationBlock.Add(
                    AssignmentStatement(
                        IndexExpression(
                            IdentifierName(vtablesLocal),
                            Argument(IntLiteral(i))),
                        ImplicitObjectCreationExpression(
                            ArgumentList(),
                            InitializerExpression(SyntaxKind.ObjectInitializerExpression,
                                SeparatedList(
                                    new ExpressionSyntax[]
                                    {
                                        AssignmentExpression(
                                            SyntaxKind.SimpleAssignmentExpression,
                                            IdentifierName("IID"),
                                            IdentifierName(detailsTempLocal)
                                                .Dot(IdentifierName("Iid"))),
                                        AssignmentExpression(
                                            SyntaxKind.SimpleAssignmentExpression,
                                            IdentifierName("Vtable"),
                                            CastExpression(
                                                IdentifierName("nint"),
                                                    IdentifierName(detailsTempLocal)
                                                    .Dot(IdentifierName("ManagedVirtualMethodTable"))))
                                    })))));
            }
 
            // s_vtable = vtable;
            vtableInitializationBlock.Add(
                ExpressionStatement(
                    AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
                        IdentifierName(vtablesField),
                        IdentifierName(vtablesLocal))));
 
            BlockSyntax getComInterfaceEntriesMethodBody = Block(
                // count = <count>;
                ExpressionStatement(
                    AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
                        IdentifierName(countIdentifier),
                        LiteralExpression(SyntaxKind.NumericLiteralExpression,
                            Literal(implementedInterfaces.Length)))),
                // if (s_vtable == null)
                //   { initializer block }
                IfStatement(
                    BinaryExpression(SyntaxKind.EqualsExpression,
                        IdentifierName(vtablesField),
                        LiteralExpression(SyntaxKind.NullLiteralExpression)),
                    Block(vtableInitializationBlock)),
                // return s_vtable;
                ReturnStatement(IdentifierName(vtablesField)));
 
            typeDeclaration = typeDeclaration.AddMembers(
                // public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
                // { body }
                MethodDeclaration(
                    PointerType(
                        TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
                    "GetComInterfaceEntries")
                    .AddParameterListParameters(
                        Parameter(Identifier(countIdentifier))
                            .WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
                            .AddModifiers(Token(SyntaxKind.OutKeyword)))
                    .WithBody(getComInterfaceEntriesMethodBody)
                    .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));
 
            return typeDeclaration;
        }
    }
}