|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.CodeDom.Compiler;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
using Microsoft.Interop.Analyzers;
namespace Microsoft.Interop
{
[Generator]
public class ComClassGenerator : IIncrementalGenerator
{
private const string ClassInfoTypeName = "ComClassInformation";
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all types with the [GeneratedComClassAttribute] attribute.
IncrementalValuesProvider<ComClassInfo> attributedClasses = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComClassAttribute,
static (node, ct) => node is ClassDeclarationSyntax,
static (context, _) =>
{
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
Compilation compilation = context.SemanticModel.Compilation;
bool unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute);
// Currently all reported diagnostics are fatal to the generator
if (ComClassGeneratorDiagnosticsAnalyzer.GetDiagnosticsForAnnotatedClass(type, unsafeCodeIsEnabled, generatedComInterfaceAttributeType).Any())
{
return null;
}
return ComClassInfo.From(type, syntax, generatedComInterfaceAttributeType);
})
.Where(static info => info is not null);
context.RegisterSourceOutput(attributedClasses, static (context, data) =>
{
string className = data.ClassName;
SequenceEqualImmutableArray<string> implementedInterfaces = data.ImplementedInterfacesNames;
using StringWriter sw = new();
using IndentedTextWriter writer = new(sw);
writer.WriteLine("// <auto-generated />");
writer.WriteLine($"file sealed unsafe class {ClassInfoTypeName} : global::System.Runtime.InteropServices.Marshalling.IComExposedClass");
writer.WriteLine('{');
writer.Indent++;
writer.WriteLine("private static volatile global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables;");
sw.WriteLine();
writer.WriteLine("public static global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count)");
writer.WriteLine('{');
writer.Indent++;
writer.WriteLine($"count = {implementedInterfaces.Length};");
writer.WriteLine("if (s_vtables == null)");
writer.WriteLine('{');
writer.Indent++;
writer.WriteLine($"global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({ClassInfoTypeName}), sizeof(global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * {implementedInterfaces.Length});");
writer.WriteLine("global::System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details;");
sw.WriteLine();
for (int i = 0; i < implementedInterfaces.Length; i++)
{
string ifaceName = implementedInterfaces[i];
writer.WriteLine($"details = global::System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof({ifaceName}).TypeHandle);");
writer.WriteLine($"vtables[{i}] = new() {{ IID = details.Iid, Vtable = (nint)details.ManagedVirtualMethodTable }};");
sw.WriteLine();
}
writer.WriteLine("s_vtables = vtables;");
writer.Indent--;
writer.WriteLine('}');
sw.WriteLine();
writer.WriteLine("return s_vtables;");
writer.Indent--;
writer.WriteLine('}');
writer.Indent--;
writer.WriteLine('}');
sw.WriteLine();
data.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, data.ClassSyntax, static (writer, classSyntax) =>
{
writer.WriteLine($"[global::System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute<{ClassInfoTypeName}>]");
writer.WriteLine($"{string.Join(" ", classSyntax.Modifiers)} class {classSyntax.Identifier}{classSyntax.TypeParameters} {{ }}");
});
// Replace < and > with { and } to make valid hint names for generic types
string hintName = className.Replace('<', '{').Replace('>', '}');
context.AddSource(hintName, sw.ToString());
});
}
}
}
|