|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.CodeDom.Compiler;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.Interop.Analyzers;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
[assembly: System.Resources.NeutralResourcesLanguage("en-US")]
namespace Microsoft.Interop
{
[Generator]
public sealed class VtableIndexStubGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all methods with the [VirtualMethodIndex] attribute.
var attributedMethods = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.VirtualMethodIndexAttribute,
static (node, ct) => node is MethodDeclarationSyntax,
static (context, ct) => context.TargetSymbol is IMethodSymbol methodSymbol
? new { Syntax = (MethodDeclarationSyntax)context.TargetNode, Symbol = methodSymbol }
: null)
.Where(
static modelData => modelData is not null);
// Filter out methods that are invalid for generation (diagnostics for invalid methods are reported by the analyzer).
var methodsToGenerate = attributedMethods.Where(
static data => data is not null && VtableIndexStubDiagnosticsAnalyzer.GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol) is null);
// Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs
// for each method.
IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> generateStubInformation = methodsToGenerate
.Combine(context.CreateStubEnvironmentProvider())
.Select(static (data, ct) => CalculateStubInformation(data.Left.Syntax, data.Left.Symbol, data.Right, ct));
// Filter the list of all stubs to only the stubs that requested managed-to-unmanaged stub generation.
IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> managedToNativeStubContexts =
generateStubInformation
.Where(data => data.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional);
context.RegisterSourceOutput(managedToNativeStubContexts.Collect(), static (context, data) =>
{
if (data.IsEmpty)
return;
using StringWriter sw = new();
using IndentedTextWriter writer = new(sw);
writer.WriteLine("// <auto-generated/>");
// Generate the code for the managed-to-unmanaged stubs.
foreach (SourceAvailableIncrementalMethodStubGenerationContext stub in data)
{
sw.WriteLine();
var (stubSyntax, _) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(stub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
stub.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, stubSyntax, static (writer, stubSyntax) =>
{
writer.WriteLine("internal partial interface Native");
writer.WriteLine('{');
writer.Indent++;
writer.WriteMultilineNode(stubSyntax.NormalizeWhitespace());
writer.Indent--;
writer.WriteLine('}');
});
}
context.AddSource("ManagedToNativeStubs.g.cs", sw.ToString());
});
// Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation.
IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> nativeToManagedStubContexts =
generateStubInformation
.Where(data => data.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional);
context.RegisterSourceOutput(nativeToManagedStubContexts.Collect(), static (context, data) =>
{
if (data.IsEmpty)
return;
using StringWriter sw = new();
using IndentedTextWriter writer = new(sw);
writer.WriteLine("// <auto-generated/>");
// Generate the code for the unmanaged-to-managed stubs.
foreach (SourceAvailableIncrementalMethodStubGenerationContext stub in data)
{
sw.WriteLine();
var (stubSyntax, _) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(stub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
stub.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, stubSyntax, static (writer, stubSyntax) =>
{
writer.WriteLine("internal partial interface Native");
writer.WriteLine('{');
writer.Indent++;
writer.WriteMultilineNode(stubSyntax.NormalizeWhitespace());
writer.Indent--;
writer.WriteLine('}');
});
}
context.AddSource("NativeToManagedStubs.g.cs", sw.ToString());
});
IncrementalValueProvider<ImmutableArray<ContainingSyntaxContext>> syntaxContexts = generateStubInformation
.Select(static (context, ct) => context.ContainingSyntaxContext)
.Collect();
context.RegisterSourceOutput(syntaxContexts, static (context, data) =>
{
if (data.IsEmpty)
return;
using StringWriter sw = new();
using IndentedTextWriter writer = new(sw);
writer.WriteLine("// <auto-generated/>");
// Generate the native interface metadata for each interface that contains a method with the [VirtualMethodIndex] attribute.
foreach (ContainingSyntaxContext syntaxContext in data.Distinct())
{
sw.WriteLine();
syntaxContext.WriteToWithUnsafeModifier(writer, syntaxContext.ContainingSyntax[0].Identifier.Text, static (writer, baseTypeName) =>
{
writer.WriteLine("[global::System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute]");
writer.WriteLine($"internal partial interface Native : {baseTypeName} {{ }}");
});
}
context.AddSource("NativeInterfaces.g.cs", sw.ToString());
});
context.RegisterSourceOutput(nativeToManagedStubContexts.Collect(), static (context, data) =>
{
if (data.IsEmpty)
return;
using StringWriter sw = new();
using IndentedTextWriter writer = new(sw);
writer.WriteLine("// <auto-generated/>");
foreach (var group in data.GroupBy(s => s.ContainingSyntaxContext))
{
sw.WriteLine();
// Generate a method named PopulateUnmanagedVirtualMethodTable on the native interface implementation
// that fills in a span with the addresses of the unmanaged-to-managed stub functions at their correct indices.
group.Key.WriteToWithUnsafeModifier(writer, group, static (writer, data) =>
{
writer.WriteLine("internal unsafe partial interface Native");
writer.WriteLine('{');
writer.Indent++;
writer.WriteLine("internal static unsafe void PopulateUnmanagedVirtualMethodTable(void** vtable)");
writer.WriteLine('{');
writer.Indent++;
foreach (SourceAvailableIncrementalMethodStubGenerationContext method in data)
{
FunctionPointerTypeSyntax functionPointerType = VirtualMethodPointerStubGenerator.GenerateUnmanagedFunctionPointerTypeForMethod(method, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
writer.WriteLine($"vtable[{method.VtableIndexData.Index}] = (void*)({functionPointerType.NormalizeWhitespace()})&ABI_{method.StubMethodSyntaxTemplate.Identifier};");
}
writer.Indent--;
writer.WriteLine('}');
writer.Indent--;
writer.WriteLine('}');
});
}
context.AddSource("PopulateVTable.g.cs", sw.ToString());
});
}
private static VirtualMethodIndexCompilationData? ProcessVirtualMethodIndexAttribute(AttributeData attrData)
{
// Found the attribute, but it has an error so report the error.
// This is most likely an issue with targeting an incorrect TFM.
if (attrData.AttributeClass?.TypeKind is null or TypeKind.Error)
{
return null;
}
var namedArguments = ImmutableDictionary.CreateRange(attrData.NamedArguments);
if (attrData.ConstructorArguments.Length == 0 || attrData.ConstructorArguments[0].Value is not int)
{
return null;
}
MarshalDirection direction = MarshalDirection.Bidirectional;
bool implicitThis = true;
bool exceptionMarshallingDefined = false;
ExceptionMarshalling exceptionMarshalling = ExceptionMarshalling.Custom;
INamedTypeSymbol? exceptionMarshallingCustomType = null;
if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.Direction), out TypedConstant directionValue))
{
// TypedConstant's Value property only contains primitive values.
if (directionValue.Value is not int)
{
return null;
}
// A boxed primitive can be unboxed to an enum with the same underlying type.
direction = (MarshalDirection)directionValue.Value!;
}
if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.ImplicitThisParameter), out TypedConstant implicitThisValue))
{
if (implicitThisValue.Value is not bool)
{
return null;
}
implicitThis = (bool)implicitThisValue.Value!;
}
if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.ExceptionMarshalling), out TypedConstant exceptionMarshallingValue))
{
exceptionMarshallingDefined = true;
// TypedConstant's Value property only contains primitive values.
if (exceptionMarshallingValue.Value is not int)
{
return null;
}
// A boxed primitive can be unboxed to an enum with the same underlying type.
exceptionMarshalling = (ExceptionMarshalling)exceptionMarshallingValue.Value!;
}
if (namedArguments.TryGetValue(nameof(VirtualMethodIndexCompilationData.ExceptionMarshallingCustomType), out TypedConstant exceptionMarshallingCustomTypeValue))
{
if (exceptionMarshallingCustomTypeValue.Value is not INamedTypeSymbol)
{
return null;
}
exceptionMarshallingCustomType = (INamedTypeSymbol)exceptionMarshallingCustomTypeValue.Value;
}
return new VirtualMethodIndexCompilationData((int)attrData.ConstructorArguments[0].Value).WithValuesFromNamedArguments(namedArguments) with
{
Direction = direction,
ImplicitThisParameter = implicitThis,
ExceptionMarshallingDefined = exceptionMarshallingDefined,
ExceptionMarshalling = exceptionMarshalling,
ExceptionMarshallingCustomType = exceptionMarshallingCustomType,
};
}
internal static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute);
INamedTypeSymbol? unmanagedCallConvAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute);
INamedTypeSymbol iUnmanagedInterfaceTypeType = environment.Compilation.GetTypeByMetadataName(TypeNames.IUnmanagedInterfaceType_Metadata)!;
// Get any attributes of interest on the method
AttributeData? virtualMethodIndexAttr = null;
AttributeData? lcidConversionAttr = null;
AttributeData? suppressGCTransitionAttribute = null;
AttributeData? unmanagedCallConvAttribute = null;
foreach (AttributeData attr in symbol.GetAttributes())
{
if (attr.AttributeClass is not null
&& attr.AttributeClass.ToDisplayString() == TypeNames.VirtualMethodIndexAttribute)
{
virtualMethodIndexAttr = attr;
}
else if (lcidConversionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType))
{
lcidConversionAttr = attr;
}
else if (suppressGCTransitionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType))
{
suppressGCTransitionAttribute = attr;
}
else if (unmanagedCallConvAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType))
{
unmanagedCallConvAttribute = attr;
}
}
Debug.Assert(virtualMethodIndexAttr is not null);
var locations = new MethodSignatureDiagnosticLocations(syntax);
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), locations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
// Process the LibraryImport attribute
VirtualMethodIndexCompilationData? virtualMethodIndexData = ProcessVirtualMethodIndexAttribute(virtualMethodIndexAttr!);
if (virtualMethodIndexData is null)
{
virtualMethodIndexData = new VirtualMethodIndexCompilationData(-1);
}
else if (virtualMethodIndexData.Index < 0)
{
// Report missing or invalid index
}
if (virtualMethodIndexData.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling))
{
// User specified StringMarshalling.Custom without specifying StringMarshallingCustomType
if (virtualMethodIndexData.StringMarshalling == StringMarshalling.Custom && virtualMethodIndexData.StringMarshallingCustomType is null)
{
generatorDiagnostics.ReportInvalidStringMarshallingConfiguration(
virtualMethodIndexAttr, symbol.Name, SR.InvalidStringMarshallingConfigurationMissingCustomType);
}
// User specified something other than StringMarshalling.Custom while specifying StringMarshallingCustomType
if (virtualMethodIndexData.StringMarshalling != StringMarshalling.Custom && virtualMethodIndexData.StringMarshallingCustomType is not null)
{
generatorDiagnostics.ReportInvalidStringMarshallingConfiguration(
virtualMethodIndexAttr, symbol.Name, SR.InvalidStringMarshallingConfigurationNotCustom);
}
}
if (!virtualMethodIndexData.ImplicitThisParameter && virtualMethodIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
{
// Report invalid configuration
}
if (lcidConversionAttr is not null)
{
// Using LCIDConversion with source-generated interop is not supported
generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute));
}
// Create the stub.
var signatureContext = SignatureContext.Create(
symbol,
DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, virtualMethodIndexData, virtualMethodIndexAttr),
environment,
new CodeEmitOptions(SkipInit: true),
typeof(VtableIndexStubGenerator).Assembly);
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.PartialKeyword) && !m.IsKind(SyntaxKind.VirtualKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute, defaultCallingConventions: ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax>.Empty);
var interfaceType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
INamedTypeSymbol expectedUnmanagedInterfaceType = iUnmanagedInterfaceTypeType;
bool implementsIUnmanagedInterfaceOfSelf = symbol.ContainingType.AllInterfaces.Any(iface => SymbolEqualityComparer.Default.Equals(iface, expectedUnmanagedInterfaceType));
if (!implementsIUnmanagedInterfaceOfSelf)
{
// TODO: Report invalid configuration
}
var unmanagedObjectUnwrapper = symbol.ContainingType.GetAttributes().FirstOrDefault(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute));
if (unmanagedObjectUnwrapper is null)
{
// TODO: report invalid configuration - or ensure that this will never happen at this point
}
var unwrapperSyntax = ParseTypeName(unmanagedObjectUnwrapper.AttributeClass.TypeArguments[0].ToDisplayString());
MarshallingInfo exceptionMarshallingInfo = CreateExceptionMarshallingInfo(virtualMethodIndexAttr, symbol, environment.Compilation, generatorDiagnostics, virtualMethodIndexData);
return new SourceAvailableIncrementalMethodStubGenerationContext(
signatureContext,
containingSyntaxContext,
methodSyntaxTemplate,
locations,
new SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax>(callConv, SyntaxEquivalentComparer.Instance),
VirtualMethodIndexData.From(virtualMethodIndexData),
exceptionMarshallingInfo,
environment.EnvironmentFlags,
interfaceType,
interfaceType,
new SequenceEqualImmutableArray<DiagnosticInfo>(generatorDiagnostics.Diagnostics.ToImmutableArray()),
new ObjectUnwrapperInfo(unwrapperSyntax));
}
private static MarshallingInfo CreateExceptionMarshallingInfo(AttributeData virtualMethodIndexAttr, ISymbol symbol, Compilation compilation, GeneratorDiagnosticsBag diagnostics, VirtualMethodIndexCompilationData virtualMethodIndexData)
{
if (virtualMethodIndexData.ExceptionMarshallingDefined)
{
// User specified ExceptionMarshalling.Custom without specifying ExceptionMarshallingCustomType
if (virtualMethodIndexData.ExceptionMarshalling == ExceptionMarshalling.Custom && virtualMethodIndexData.ExceptionMarshallingCustomType is null)
{
diagnostics.ReportInvalidExceptionMarshallingConfiguration(
virtualMethodIndexAttr, symbol.Name, SR.InvalidExceptionMarshallingConfigurationMissingCustomType);
return NoMarshallingInfo.Instance;
}
// User specified something other than ExceptionMarshalling.Custom while specifying ExceptionMarshallingCustomType
if (virtualMethodIndexData.ExceptionMarshalling != ExceptionMarshalling.Custom && virtualMethodIndexData.ExceptionMarshallingCustomType is not null)
{
diagnostics.ReportInvalidExceptionMarshallingConfiguration(
virtualMethodIndexAttr, symbol.Name, SR.InvalidExceptionMarshallingConfigurationNotCustom);
}
}
if (virtualMethodIndexData.ExceptionMarshalling == ExceptionMarshalling.Com)
{
return new ComExceptionMarshalling();
}
if (virtualMethodIndexData.ExceptionMarshalling == ExceptionMarshalling.Custom)
{
return virtualMethodIndexData.ExceptionMarshallingCustomType is null
? NoMarshallingInfo.Instance
: CustomMarshallingInfoHelper.CreateNativeMarshallingInfoForNonSignatureElement(
compilation.GetTypeByMetadataName(TypeNames.System_Exception),
virtualMethodIndexData.ExceptionMarshallingCustomType!,
virtualMethodIndexAttr,
compilation,
diagnostics);
}
// This should not be reached in normal usage, but a developer can cast any int to the ExceptionMarshalling enum, so we should handle this case without crashing the generator.
diagnostics.ReportInvalidExceptionMarshallingConfiguration(
virtualMethodIndexAttr, symbol.Name, SR.InvalidExceptionMarshallingValue);
return NoMarshallingInfo.Instance;
}
}
}
|