File: ComClassGenerator.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj (Microsoft.Interop.ComInterfaceGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.CodeDom.Compiler;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
using Microsoft.Interop.Analyzers;
 
namespace Microsoft.Interop
{
    [Generator]
    public class ComClassGenerator : IIncrementalGenerator
    {
        private const string ClassInfoTypeName = "ComClassInformation";
 
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            // Get all types with the [GeneratedComClassAttribute] attribute.
            IncrementalValuesProvider<ComClassInfo> attributedClasses = context.SyntaxProvider
                .ForAttributeWithMetadataName(
                    TypeNames.GeneratedComClassAttribute,
                    static (node, ct) => node is ClassDeclarationSyntax,
                    static (context, _) =>
                    {
                        var type = (INamedTypeSymbol)context.TargetSymbol;
                        var syntax = (ClassDeclarationSyntax)context.TargetNode;
                        Compilation compilation = context.SemanticModel.Compilation;
                        bool unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
                        INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute);
 
                        // Currently all reported diagnostics are fatal to the generator
                        if (ComClassGeneratorDiagnosticsAnalyzer.GetDiagnosticsForAnnotatedClass(type, unsafeCodeIsEnabled, generatedComInterfaceAttributeType).Any())
                        {
                            return null;
                        }
 
                        return ComClassInfo.From(type, syntax, generatedComInterfaceAttributeType);
                    })
                .Where(static info => info is not null);
 
            context.RegisterSourceOutput(attributedClasses, static (context, data) =>
            {
                string className = data.ClassName;
                SequenceEqualImmutableArray<string> implementedInterfaces = data.ImplementedInterfacesNames;
 
                using StringWriter sw = new();
                using IndentedTextWriter writer = new(sw);
                writer.WriteLine("// <auto-generated />");
 
                writer.WriteLine($"file sealed unsafe class {ClassInfoTypeName} : global::System.Runtime.InteropServices.Marshalling.IComExposedClass");
                writer.WriteLine('{');
                writer.Indent++;
                writer.WriteLine("private static volatile global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables;");
                sw.WriteLine();
                writer.WriteLine("public static global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count)");
                writer.WriteLine('{');
                writer.Indent++;
                writer.WriteLine($"count = {implementedInterfaces.Length};");
                writer.WriteLine("if (s_vtables == null)");
                writer.WriteLine('{');
                writer.Indent++;
                writer.WriteLine($"global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({ClassInfoTypeName}), sizeof(global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * {implementedInterfaces.Length});");
                writer.WriteLine("global::System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details;");
                sw.WriteLine();
                for (int i = 0; i < implementedInterfaces.Length; i++)
                {
                    string ifaceName = implementedInterfaces[i];
 
                    writer.WriteLine($"details = global::System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof({ifaceName}).TypeHandle);");
                    writer.WriteLine($"vtables[{i}] = new() {{ IID = details.Iid, Vtable = (nint)details.ManagedVirtualMethodTable }};");
                    sw.WriteLine();
                }
 
                writer.WriteLine("s_vtables = vtables;");
                writer.Indent--;
                writer.WriteLine('}');
                sw.WriteLine();
                writer.WriteLine("return s_vtables;");
                writer.Indent--;
                writer.WriteLine('}');
                writer.Indent--;
                writer.WriteLine('}');
 
                sw.WriteLine();
 
                data.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, data.ClassSyntax, static (writer, classSyntax) =>
                {
                    writer.WriteLine($"[global::System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute<{ClassInfoTypeName}>]");
                    writer.WriteLine($"{string.Join(" ", classSyntax.Modifiers)} class {classSyntax.Identifier}{classSyntax.TypeParameters} {{ }}");
                });
 
                // Replace < and > with { and } to make valid hint names for generic types
                string hintName = className.Replace('<', '{').Replace('>', '}');
                context.AddSource(hintName, sw.ToString());
            });
        }
    }
}