File: SafeHandleMarshallingInfoProvider.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\Microsoft.Interop.SourceGeneration\Microsoft.Interop.SourceGeneration.csproj (Microsoft.Interop.SourceGeneration)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
 
namespace Microsoft.Interop
{
    /// <summary>
    /// This class supports generating marshalling info for SafeHandle-derived types.
    /// </summary>
    public sealed class SafeHandleMarshallingInfoProvider(Compilation compilation) : ITypeBasedMarshallingInfoProvider
    {
        private readonly INamedTypeSymbol? _safeHandleType = compilation.GetBestTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle);
        private readonly INamedTypeSymbol? _safeHandleMarshallerType = compilation.GetBestTypeByMetadataName(TypeNames.System_Runtime_InteropServices_Marshalling_SafeHandleMarshaller_Metadata);
 
        public bool CanProvideMarshallingInfoForType(ITypeSymbol type)
        {
            // Check if type derives from SafHandle
            // The SafeHandle type might not be defined if we're using one of the test CoreLib implementations used for NativeAOT.
            if (_safeHandleType is null)
            {
                return false;
            }
 
            for (ITypeSymbol? currentType = type; currentType is not null; currentType = currentType.BaseType)
            {
                if (currentType.Equals(_safeHandleType, SymbolEqualityComparer.Default))
                {
                    return true;
                }
            }
 
            return false;
        }
 
        public MarshallingInfo GetMarshallingInfo(ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback)
        {
            bool hasDefaultConstructor = false;
            if (type is INamedTypeSymbol named && !named.IsAbstract && named.InstanceConstructors.Length > 0)
            {
                foreach (IMethodSymbol ctor in named.InstanceConstructors)
                {
                    if (ctor.Parameters.Length == 0)
                    {
                        hasDefaultConstructor = ctor.DeclaredAccessibility == Accessibility.Public;
                        break;
                    }
                }
            }
 
            // If we don't have the SafeHandleMarshaller<T> type, then we'll return NoMarshallingInfo,
            // indicating that we don't support marshalling SafeHandles with source-generated marshalling.
            if (_safeHandleMarshallerType is null)
            {
                return NoMarshallingInfo.Instance;
            }
 
            INamedTypeSymbol entryPointType = _safeHandleMarshallerType.Construct(type);
            if (!ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(
                entryPointType,
                type,
                compilation,
                out CustomTypeMarshallers? marshallers))
            {
                return NoMarshallingInfo.Instance;
            }
 
            // If the SafeHandle-derived type doesn't have a default constructor or is abstract,
            // we only support managed-to-unmanaged marshalling
            if (!hasDefaultConstructor || type.IsAbstract)
            {
                marshallers = marshallers.Value with
                {
                    Modes = ImmutableDictionary<MarshalMode, CustomTypeMarshallerData>.Empty
                        .Add(
                            MarshalMode.ManagedToUnmanagedIn,
                            marshallers.Value.GetModeOrDefault(MarshalMode.ManagedToUnmanagedIn))
                };
            }
 
            return new NativeMarshallingAttributeInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType), marshallers.Value);
        }
    }
}