File: Windows\Win32\System\Com\ComHelpers.cs
Web Access
Project: src\src\System.Private.Windows.Core\src\System.Private.Windows.Core.csproj (System.Private.Windows.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Runtime.InteropServices;
using Windows.Win32.System.Com;
 
namespace Windows.Win32;
 
internal static unsafe partial class ComHelpers
{
    // Note that ComScope<T> needs to be the return value to facilitate using in a `using`.
    //
    //  using var stream = GetComScope<IStream>(obj, out bool success);
 
    /// <summary>
    ///  Returns <see langword="true"/> if built-in COM interop is supported. When using AOT or trimming this will
    ///  return <see langword="false"/>.
    /// </summary>
    internal static bool BuiltInComSupported { get; }
        // Presume it is supported if we can't get the switch.
        = !AppContext.TryGetSwitch("System.Runtime.InteropServices.BuiltInComInterop.IsSupported", out bool supported)
            || supported;
 
    /// <summary>
    ///  Gets a pointer for the specified <typeparamref name="T"/> for the given <paramref name="object"/>. Throws if
    ///  the desired pointer can not be obtained.
    /// </summary>
    internal static ComScope<T> GetComScope<T>(object? @object) where T : unmanaged, IComIID =>
        new(GetComPointer<T>(@object));
 
    /// <summary>
    ///  Attempts to get a pointer for the specified <typeparamref name="T"/> for the given <paramref name="object"/>.
    /// </summary>
    internal static ComScope<T> TryGetComScope<T>(object? @object) where T : unmanaged, IComIID =>
        TryGetComScope<T>(@object, out _);
 
    /// <summary>
    ///  Attempts to get a pointer for the specified <typeparamref name="T"/> for the given <paramref name="object"/>.
    /// </summary>
    internal static ComScope<T> TryGetComScope<T>(object? @object, out HRESULT hr) where T : unmanaged, IComIID =>
        new(TryGetComPointer<T>(@object, out hr));
 
    /// <summary>
    ///  Gets the specified <typeparamref name="T"/> interface for the given <paramref name="object"/>. Throws if
    ///  the desired pointer can not be obtained.
    /// </summary>
    internal static T* GetComPointer<T>(object? @object) where T : unmanaged, IComIID
    {
        T* result = TryGetComPointer<T>(@object, out HRESULT hr);
        hr.ThrowOnFailure();
        return result;
    }
 
    /// <summary>
    ///  Attempts to get the specified <typeparamref name="T"/> interface for the given <paramref name="object"/>.
    /// </summary>
    /// <returns>The requested pointer or <see langword="null"/> if unsuccessful.</returns>
    internal static T* TryGetComPointer<T>(object? @object) where T : unmanaged, IComIID =>
        TryGetComPointer<T>(@object, out _);
 
    /// <summary>
    ///  Queries for the given interface and releases it.
    ///  Note that this method should only be used for the purposes of checking if the object supports a given interface.
    ///  If that interface is needed, it is best try to get the ComScope directly to avoid querying twice.
    /// </summary>
    internal static bool SupportsInterface<T>(object? @object) where T : unmanaged, IComIID
    {
        using var scope = TryGetComScope<T>(@object, out HRESULT hr);
        return hr.Succeeded;
    }
 
    /// <summary>
    ///  Attempts to get the specified <typeparamref name="T"/> interface for the given <paramref name="object"/>.
    /// </summary>
    /// <param name="result">
    ///  Typically either <see cref="HRESULT.S_OK"/> or <see cref="HRESULT.E_POINTER"/>. Check for success, not
    ///  specific results.
    /// </param>
    /// <returns>The requested pointer or <see langword="null"/> if unsuccessful.</returns>
    internal static T* TryGetComPointer<T>(object? @object, out HRESULT result) where T : unmanaged, IComIID
    {
        if (@object is null)
        {
            result = HRESULT.E_POINTER;
            return null;
        }
 
        IUnknown* ccw = null;
        if (@object is IManagedWrapper)
        {
            // One of our classes that we can generate a CCW for.
            ccw = (IUnknown*)WinFormsComWrappers.Instance.GetOrCreateComInterfaceForObject(@object, CreateComInterfaceFlags.None);
        }
        else if (ComWrappers.TryGetComInstance(@object, out nint unknown))
        {
            // A ComWrappers generated RCW.
            ccw = (IUnknown*)unknown;
        }
        else
        {
            // Fall back to COM interop if possible. Note that this will use the globally registered ComWrappers
            // if that exists (so it won't always fall into legacy COM interop).
            try
            {
                ccw = (IUnknown*)Marshal.GetIUnknownForObject(@object);
            }
            catch (Exception ex)
            {
                Debug.WriteLine($"Did not find IUnknown for {@object.GetType().Name}. {ex.Message}");
            }
        }
 
        if (ccw is null)
        {
            result = HRESULT.E_NOINTERFACE;
            return null;
        }
 
        if (typeof(T) == typeof(IUnknown))
        {
            // No need to query if we wanted IUnknown.
            result = HRESULT.S_OK;
            return (T*)ccw;
        }
 
        // Now query out the requested interface
        result = ccw->QueryInterface(IID.GetRef<T>(), out void* ppvObject);
        ccw->Release();
        return (T*)ppvObject;
    }
 
    /// <summary>
    ///  Attempts to unwrap a ComWrapper CCW as a particular managed object.
    /// </summary>
    private static bool TryUnwrapComWrapperCCW<TWrapper>(
        IUnknown* unknown,
        [NotNullWhen(true)] out TWrapper? @interface) where TWrapper : class
    {
        if (ComWrappers.TryGetObject((nint)unknown, out object? obj))
        {
            if (obj is TWrapper desired)
            {
                @interface = desired;
                return true;
            }
            else
            {
                Debug.WriteLine($"{nameof(TryGetObjectForIUnknown)}: Found a manual CCW, but couldn't unwrap to {typeof(TWrapper).Name}");
            }
        }
 
        @interface = default;
        return false;
    }
 
    /// <inheritdoc cref="TryGetObjectForIUnknown{TObject}(IUnknown*, bool, out TObject)"/>
    internal static bool TryGetObjectForIUnknown<TObject, TInterface>(
        ComScope<TInterface> comScope,
        [NotNullWhen(true)] out TObject? @object)
        where TObject : class
        where TInterface : unmanaged, IComIID => TryGetObjectForIUnknown(comScope.Value, out @object);
 
    /// <inheritdoc cref="TryGetObjectForIUnknown{TObject}(IUnknown*, bool, out TObject)"/>
    internal static bool TryGetObjectForIUnknown<TObject, TInterface>(
        TInterface* comPointer,
        [NotNullWhen(true)] out TObject? @object)
        where TObject : class
        where TInterface : unmanaged, IComIID
    {
        if (comPointer is null)
        {
            @object = null;
            return false;
        }
 
        IUnknown* unknown = (IUnknown*)comPointer;
        if (typeof(TInterface) == typeof(IUnknown))
        {
            return TryGetObjectForIUnknown(unknown, out @object);
        }
 
        HRESULT hr = unknown->QueryInterface(IID.Get<IUnknown>(), (void**)&unknown);
        if (hr.Failed)
        {
            Debug.Fail("How did we fail to query for IUnknown?");
            @object = null;
            return false;
        }
 
        return TryGetObjectForIUnknown(unknown, out @object);
    }
 
    /// <inheritdoc cref="TryGetObjectForIUnknown{TObject}(IUnknown*, bool, out TObject)"/>
    internal static bool TryGetObjectForIUnknown<TObject>(
        IUnknown* unknown,
        [NotNullWhen(true)] out TObject? @object) where TObject : class =>
        TryGetObjectForIUnknown(unknown, takeOwnership: false, out @object);
 
    /// <summary>
    ///  Attempts to get a managed wrapper of the specified type for the given COM interface.
    /// </summary>
    /// <param name="takeOwnership">
    ///  When <see langword="true"/>, releases the original <paramref name="unknown"/> whether successful or not.
    /// </param>
    internal static bool TryGetObjectForIUnknown<TObject>(
        IUnknown* unknown,
        bool takeOwnership,
        [NotNullWhen(true)] out TObject? @object) where TObject : class
    {
        @object = null;
        if (unknown is null)
        {
            return false;
        }
 
        try
        {
            @object = (TObject)GetObjectForIUnknown(unknown);
            return true;
        }
        catch (Exception ex)
        {
            Debug.WriteLine($"{nameof(TryGetObjectForIUnknown)}: Failed to get object for {typeof(TObject).Name}. {ex.Message}");
            return false;
        }
        finally
        {
            if (takeOwnership)
            {
                uint count = unknown->Release();
                Debug.WriteLineIf(count > 0, $"{nameof(TryGetObjectForIUnknown)}: Count for {typeof(TObject).Name} is {count} after release.");
            }
        }
    }
 
    /// <summary>
    ///  Returns <see langword="true"/> if the given <paramref name="object"/>
    ///  is projected as the given <paramref name="comPointer"/>.
    /// </summary>
    internal static bool WrapsManagedObject<T>(object @object, T* comPointer)
        where T : unmanaged, IComIID
    {
        if (comPointer is null)
        {
            return false;
        }
 
        using ComScope<IUnknown> unknown = new(null);
        ((IUnknown*)comPointer)->QueryInterface(IID.Get<IUnknown>(), unknown).ThrowOnFailure();
 
        // If it is a ComWrappers object we need to simply pull out the original object to check.
        if (ComWrappers.TryGetObject((nint)unknown, out object? obj))
        {
            return @object == obj;
        }
 
        using ComScope<IUnknown> ccw = new((IUnknown*)(void*)Marshal.GetIUnknownForObject(@object));
        return ccw.Value == unknown;
    }
 
    /// <inheritdoc cref="GetObjectForIUnknown(IUnknown*)"/>
    internal static object GetObjectForIUnknown<TInterface>(TInterface* comPointer)
        where TInterface : unmanaged, IComIID
    {
        if (comPointer is null)
        {
            throw new ArgumentNullException(nameof(comPointer));
        }
 
        IUnknown* unknown = (IUnknown*)comPointer;
 
        if (typeof(TInterface) == typeof(IUnknown))
        {
            return GetObjectForIUnknown(unknown);
        }
 
        unknown->QueryInterface(IID.Get<IUnknown>(), (void**)&unknown).ThrowOnFailure();
        return GetObjectForIUnknown(unknown);
    }
 
    /// <inheritdoc cref="GetObjectForIUnknown(IUnknown*)"/>
    internal static object GetObjectForIUnknown<TInterface>(ComScope<TInterface> comScope)
        where TInterface : unmanaged, IComIID => GetObjectForIUnknown(comScope.Value);
 
    /// <summary>
    ///  <see cref="ComWrappers"/> capable wrapper for <see cref="Marshal.GetObjectForIUnknown(nint)"/>.
    /// </summary>
    /// <exception cref="ArgumentNullException"><paramref name="unknown"/> is <see langword="null"/>.</exception>
    internal static object GetObjectForIUnknown(IUnknown* unknown)
    {
        if (unknown is null)
        {
            throw new ArgumentNullException(nameof(unknown));
        }
 
        // If it is a ComWrappers object we need to simply pull out the original object.
        if (ComWrappers.TryGetObject((nint)unknown, out object? obj))
        {
            return obj;
        }
 
        if (BuiltInComSupported)
        {
            return Marshal.GetObjectForIUnknown((nint)unknown);
        }
        else
        {
            // Analogous to ComInterfaceMarshaller<object>.ConvertToManaged(unknown), but we need our own strategy.
            return WinFormsComStrategy.Instance.GetOrCreateObjectForComInstance((nint)unknown, CreateObjectFlags.Unwrap);
        }
    }
 
    /// <summary>
    ///  <see cref="IUnknown"/> vtable population hook for CsWin32's generated <see cref="IVTable"/> implementation.
    /// </summary>
    static partial void PopulateIUnknownImpl<TComInterface>(IUnknown.Vtbl* vtable)
        where TComInterface : unmanaged =>
        WinFormsComWrappers.PopulateIUnknownVTable(vtable);
 
    /// <summary>
    ///  Find the given interface's <see cref="ITypeInfo"/> from the specified type library.
    /// </summary>
    public static ComScope<ITypeInfo> GetRegisteredTypeInfo(
        Guid typeLibrary,
        ushort majorVersion,
        ushort minorVersion,
        Guid interfaceId)
    {
        // Load the registered type library and get the relevant ITypeInfo for the specified interface.
        //
        // Note that the ITypeLib and ITypeInfo are free to be used on any thread. ITypeInfo add refs the
        // ITypeLib and keeps a reference to it.
        //
        // While type library loading is cached, that is only while it is still referenced (directly or via
        // an ITypeInfo reference) and there is still a fair amount of overhead to look up the right instance. The
        // caching is by the type library path, so the guid needs looked up again in the registry to figure out the
        // path again.
        using ComScope<ITypeLib> typelib = new(null);
        HRESULT hr = PInvokeCore.LoadRegTypeLib(typeLibrary, majorVersion, minorVersion, 0, typelib);
        hr.ThrowOnFailure();
 
        ComScope<ITypeInfo> typeInfo = new(null);
        typelib.Value->GetTypeInfoOfGuid(interfaceId, typeInfo).ThrowOnFailure();
        return typeInfo;
    }
}