File: Marshallers\IidParameterIndexMarshallerResolver.cs
Web Access
Project: src\src\runtime\src\libraries\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj (Microsoft.Interop.ComInterfaceGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Microsoft.Interop.SyntaxFactoryExtensions;

namespace Microsoft.Interop
{
    internal sealed class IidParameterIndexMarshallerResolver : IMarshallingGeneratorResolver
    {
        public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
        {
            if (info.MarshallingAttributeInfo is not IidParameterIndexNativeMarshallingInfo iidInfo
                || context.Direction != MarshalDirection.UnmanagedToManaged)
            {
                return ResolvedGenerator.UnresolvedGenerator;
            }

            return ResolvedGenerator.Resolved(new Marshaller(iidInfo.IidParameterIndexInfo).Bind(info, context));
        }

        private sealed class Marshaller(TypePositionInfo iidParameterIndexInfo) : IUnboundMarshallingGenerator
        {
            public ManagedTypeInfo AsNativeType(TypePositionInfo info) => new PointerTypeInfo("void*", "void*", false);

            public SignatureBehavior GetNativeSignatureBehavior(TypePositionInfo info) => info.IsByRef ? SignatureBehavior.PointerToNativeType : SignatureBehavior.NativeType;

            public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, StubCodeContext context)
                => info.IsByRef ? ValueBoundaryBehavior.AddressOfNativeIdentifier : ValueBoundaryBehavior.NativeIdentifier;

            public ByValueMarshalKindSupport SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, TypePositionInfo info, out GeneratorDiagnostic? diagnostic)
                => ByValueMarshalKindSupportDescriptor.Default.GetSupport(marshalKind, info, out diagnostic);

            public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;

            public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext codeContext, StubIdentifierContext context)
            {
                if (context.CurrentStage != StubIdentifierContext.Stage.Marshal)
                {
                    yield break;
                }

                (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
                string unknownIdentifier = context.GetAdditionalIdentifier(info, "unknown");
                string queryInterfaceHResultIdentifier = context.GetAdditionalIdentifier(info, "queryInterfaceHResult");
                string queriedInterfaceIdentifier = context.GetAdditionalIdentifier(info, "queriedInterface");

                ExpressionSyntax iidExpression = MarshallerHelpers.GetIndexedManagedElementExpression(iidParameterIndexInfo, codeContext, context);
                yield return LocalDeclarationStatement(
                    VariableDeclaration(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))))
                        .AddVariables(
                            VariableDeclarator(Identifier(unknownIdentifier))
                                .WithInitializer(
                                    EqualsValueClause(
                                        CastExpression(
                                            PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))),
                                            InvocationExpression(
                                                    MemberAccessExpression(
                                                        SyntaxKind.SimpleMemberAccessExpression,
                                                        ParseTypeName("global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller<object>"),
                                                        IdentifierName("ConvertToUnmanaged")))
                                                .AddArgumentListArguments(
                                                    Argument(IdentifierName(managedIdentifier))))))));

                yield return IfStatement(
                    BinaryExpression(
                        SyntaxKind.NotEqualsExpression,
                        IdentifierName(unknownIdentifier),
                        LiteralExpression(SyntaxKind.NullLiteralExpression)),
                    Block(
                        LocalDeclarationStatement(
                            VariableDeclaration(TypeSyntaxes.System_IntPtr)
                                .AddVariables(
                                    VariableDeclarator(Identifier(queriedInterfaceIdentifier))
                                        .WithInitializer(
                                            EqualsValueClause(
                                                LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))))),
                        LocalDeclarationStatement(
                            VariableDeclaration(PredefinedType(Token(SyntaxKind.IntKeyword)))
                                .AddVariables(
                                    VariableDeclarator(Identifier(queryInterfaceHResultIdentifier))
                                        .WithInitializer(
                                            EqualsValueClause(
                                                InvocationExpression(
                                                        MemberAccessExpression(
                                                            SyntaxKind.SimpleMemberAccessExpression,
                                                            TypeSyntaxes.System_Runtime_InteropServices_Marshal,
                                                            IdentifierName("QueryInterface")))
                                                    .AddArgumentListArguments(
                                                        Argument(CastExpression(TypeSyntaxes.System_IntPtr, IdentifierName(unknownIdentifier))),
                                                        Argument(iidExpression).WithRefKindKeyword(Token(SyntaxKind.InKeyword)),
                                                        Argument(IdentifierName(queriedInterfaceIdentifier)).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))))),
                        ExpressionStatement(
                            InvocationExpression(
                                    MemberAccessExpression(
                                        SyntaxKind.SimpleMemberAccessExpression,
                                        ParseTypeName("global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller<object>"),
                                        IdentifierName("Free")))
                                .AddArgumentListArguments(
                                    Argument(IdentifierName(unknownIdentifier)))),
                        IfStatement(
                            BinaryExpression(
                                SyntaxKind.LessThanExpression,
                                IdentifierName(queryInterfaceHResultIdentifier),
                                LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
                            Block(
                                IfStatement(
                                    BinaryExpression(
                                        SyntaxKind.NotEqualsExpression,
                                        IdentifierName(queriedInterfaceIdentifier),
                                        LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
                                    ExpressionStatement(
                                        InvocationExpression(
                                                MemberAccessExpression(
                                                    SyntaxKind.SimpleMemberAccessExpression,
                                                    TypeSyntaxes.System_Runtime_InteropServices_Marshal,
                                                    IdentifierName("Release")))
                                            .AddArgumentListArguments(
                                                Argument(IdentifierName(queriedInterfaceIdentifier))))),
                                ExpressionStatement(
                                    AssignmentExpression(
                                        SyntaxKind.SimpleAssignmentExpression,
                                        IdentifierName(nativeIdentifier),
                                        LiteralExpression(SyntaxKind.NullLiteralExpression))),
                                // Throw a managed exception derived from the failing HRESULT. The stub's
                                // existing exception-to-HRESULT infrastructure (see
                                // ManagedHResultExceptionGeneratorResolver) catches this on the way out
                                // and returns the HRESULT to the unmanaged caller while also running
                                // the normal cleanup stages.
                                MethodInvocationStatement(
                                    TypeSyntaxes.System_Runtime_InteropServices_Marshal,
                                    IdentifierName("ThrowExceptionForHR"),
                                    Argument(IdentifierName(queryInterfaceHResultIdentifier))))),
                        ExpressionStatement(
                            AssignmentExpression(
                                SyntaxKind.SimpleAssignmentExpression,
                                IdentifierName(nativeIdentifier),
                                CastExpression(
                                    PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))),
                                    IdentifierName(queriedInterfaceIdentifier))))),
                    ElseClause(
                        Block(
                            ExpressionStatement(
                                AssignmentExpression(
                                    SyntaxKind.SimpleAssignmentExpression,
                                    IdentifierName(nativeIdentifier),
                                    LiteralExpression(SyntaxKind.NullLiteralExpression))))));
            }
        }
    }
}