File: UnmanagedToManagedStubGenerator.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj (Microsoft.Interop.ComInterfaceGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
 
namespace Microsoft.Interop
{
    internal sealed class UnmanagedToManagedStubGenerator
    {
        private const string ReturnIdentifier = "__retVal";
 
        private readonly BoundGenerators _marshallers;
 
        private readonly NativeToManagedStubCodeContext _context;
 
        public UnmanagedToManagedStubGenerator(
            ImmutableArray<TypePositionInfo> argTypes,
            GeneratorDiagnosticsBag diagnosticsBag,
            IMarshallingGeneratorResolver generatorResolver)
        {
            _context = new NativeToManagedStubCodeContext(ReturnIdentifier, ReturnIdentifier);
            _marshallers = BoundGenerators.Create(argTypes, generatorResolver, _context, new Forwarder(), out var bindingDiagnostics);
 
            diagnosticsBag.ReportGeneratorDiagnostics(bindingDiagnostics);
 
            if (_marshallers.NativeReturnMarshaller.Generator.UsesNativeIdentifier(_marshallers.NativeReturnMarshaller.TypeInfo, _context))
            {
                // If we need a different native return identifier, then recreate the context with the correct identifier before we generate any code.
                _context = new NativeToManagedStubCodeContext(ReturnIdentifier, $"{ReturnIdentifier}{StubCodeContext.GeneratedNativeIdentifierSuffix}");
            }
        }
 
        /// <summary>
        /// Generate the method body of the unmanaged-to-managed ComWrappers-based method stub.
        /// </summary>
        /// <param name="methodToInvoke">Name of the method on the managed type to invoke</param>
        /// <returns>Method body of the stub</returns>
        /// <remarks>
        /// The generated code assumes it will be in an unsafe context.
        /// </remarks>
        public BlockSyntax GenerateStubBody(ExpressionSyntax methodToInvoke)
        {
            List<StatementSyntax> setupStatements = new();
            GeneratedStatements statements = GeneratedStatements.Create(
                _marshallers,
                _context,
                methodToInvoke);
            Debug.Assert(statements.CleanupCalleeAllocated.IsEmpty);
 
            bool shouldInitializeVariables =
                !statements.GuaranteedUnmarshal.IsEmpty
                || !statements.CleanupCallerAllocated.IsEmpty
                || !statements.ManagedExceptionCatchClauses.IsEmpty;
            VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForUnmanagedToManaged(_marshallers, _context, shouldInitializeVariables);
 
            setupStatements.AddRange(declarations.Initializations);
            setupStatements.AddRange(declarations.Variables);
            setupStatements.AddRange(statements.Setup);
 
            List<StatementSyntax> tryStatements = new();
            tryStatements.AddRange(statements.GuaranteedUnmarshal);
            tryStatements.AddRange(statements.Unmarshal);
 
            tryStatements.Add(statements.InvokeStatement);
 
            tryStatements.AddRange(statements.NotifyForSuccessfulInvoke);
            tryStatements.AddRange(statements.Marshal);
            tryStatements.AddRange(statements.PinnedMarshal);
 
            List<StatementSyntax> allStatements = setupStatements;
            List<StatementSyntax> finallyStatements = new();
 
            SyntaxList<CatchClauseSyntax> catchClauses = List(statements.ManagedExceptionCatchClauses);
 
            finallyStatements.AddRange(statements.CleanupCallerAllocated);
            if (finallyStatements.Count > 0)
            {
                allStatements.Add(
                    TryStatement(Block(tryStatements), catchClauses, FinallyClause(Block(finallyStatements))));
            }
            else if (catchClauses.Count > 0)
            {
                allStatements.Add(
                    TryStatement(Block(tryStatements), catchClauses, @finally: null));
            }
            else
            {
                allStatements.AddRange(tryStatements);
            }
 
            // Return
            if (!_marshallers.IsUnmanagedVoidReturn)
                allStatements.Add(ReturnStatement(IdentifierName(_context.GetIdentifiers(_marshallers.NativeReturnMarshaller.TypeInfo).native)));
 
            return Block(allStatements);
        }
 
        public (ParameterListSyntax ParameterList, TypeSyntax ReturnType, AttributeListSyntax? ReturnTypeAttributes) GenerateAbiMethodSignatureData()
        {
            return _marshallers.GenerateTargetMethodSignatureData(_context);
        }
    }
}