File: BlindAggregatorFactory.cs
Web Access
Project: src\src\EditorFeatures\TestUtilities\Microsoft.CodeAnalysis.EditorFeatures.Test.Utilities.csproj (Microsoft.CodeAnalysis.EditorFeatures.Test.Utilities)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Runtime.InteropServices;
using System.Threading;
 
namespace Microsoft.CodeAnalysis.Test.Utilities
{
    /// <summary>
    /// This factory creates COM "blind aggregator" instances in managed code.
    /// </summary>
    public static class BlindAggregatorFactory
    {
        public static unsafe IntPtr CreateWrapper()
            => (IntPtr)BlindAggregator.CreateInstance();
 
        public static unsafe void SetInnerObject(IntPtr wrapperUnknown, IntPtr innerUnknown, IntPtr managedObjectGCHandlePtr)
        {
            var pWrapper = (BlindAggregator*)wrapperUnknown;
            pWrapper->SetInnerObject(innerUnknown, managedObjectGCHandlePtr);
        }
 
        /// <summary>
        /// A blind aggregator instance. It is allocated in native memory.
        /// </summary>
        [StructLayout(LayoutKind.Sequential)]
        private struct BlindAggregator
        {
            private IntPtr _vfPtr;           // Pointer to the virtual function table
            private int _refCount;           // COM reference count
            private IntPtr _innerUnknown;    // CCW for the managed object supporting aggregation
            private IntPtr _gcHandle;        // The GC Handle to the managed object (the non aggregated object)
 
            public static unsafe BlindAggregator* CreateInstance()
            {
                var pResult = (BlindAggregator*)Marshal.AllocCoTaskMem(sizeof(BlindAggregator));
                if (pResult != null)
                {
                    pResult->Construct();
                }
 
                return pResult;
            }
 
            private void Construct()
            {
                _vfPtr = VTable.AddressOfVTable;
                _refCount = 1;
                _innerUnknown = IntPtr.Zero;
                _gcHandle = IntPtr.Zero;
            }
 
            public void SetInnerObject(IntPtr innerUnknown, IntPtr gcHandle)
            {
                _innerUnknown = innerUnknown;
                Marshal.AddRef(_innerUnknown);
                _gcHandle = gcHandle;
            }
 
            private void FinalRelease()
            {
                Marshal.Release(_innerUnknown);
 
                if (_gcHandle != IntPtr.Zero)
                {
                    GCHandle.FromIntPtr(_gcHandle).Free();
                    _gcHandle = IntPtr.Zero;
                }
            }
 
            private unsafe delegate int QueryInterfaceDelegateType(BlindAggregator* pThis, [In] ref Guid riid, out IntPtr pvObject);
            private unsafe delegate uint AddRefDelegateType(BlindAggregator* pThis);
            private unsafe delegate uint ReleaseDelegateType(BlindAggregator* pThis);
            private unsafe delegate int GetGCHandlePtrDelegateType(BlindAggregator* pThis, out IntPtr pResult);
 
            [StructLayout(LayoutKind.Sequential)]
            private struct VTable
            {
                // Need these to keep the delegates alive
                private static readonly unsafe QueryInterfaceDelegateType s_queryInterface = BlindAggregator.QueryInterface;
                private static readonly unsafe AddRefDelegateType s_addRef = BlindAggregator.AddRef;
                private static readonly unsafe ReleaseDelegateType s_release = BlindAggregator.Release;
                private static readonly unsafe GetGCHandlePtrDelegateType s_get_GCHandlePtr = BlindAggregator.GetGCHandlePtr;
 
                private IntPtr _queryInterfacePtr;
                private IntPtr _addRefPtr;
                private IntPtr _releasePtr;
                private IntPtr _getGCHandlePtr;
 
                private void Construct()
                {
                    _queryInterfacePtr = Marshal.GetFunctionPointerForDelegate(VTable.s_queryInterface);
                    _addRefPtr = Marshal.GetFunctionPointerForDelegate(VTable.s_addRef);
                    _releasePtr = Marshal.GetFunctionPointerForDelegate(VTable.s_release);
                    _getGCHandlePtr = Marshal.GetFunctionPointerForDelegate(VTable.s_get_GCHandlePtr);
                }
 
                /// <summary>
                /// A 'holder' for a native memory allocation. The allocation is freed in the finalizer.
                /// </summary>
                private class CoTaskMemPtr
                {
                    public readonly IntPtr VTablePtr;
 
                    public unsafe CoTaskMemPtr()
                    {
                        var ptr = Marshal.AllocCoTaskMem(sizeof(VTable));
                        this.VTablePtr = ptr;
                        ((VTable*)ptr)->Construct();
                    }
 
                    ~CoTaskMemPtr()
                        => Marshal.FreeCoTaskMem(this.VTablePtr);
                }
 
                // Singleton instance of the VTable allocated in native memory. Since it's static, the
                // underlying native memory will be freed when finalizers run at shutdown.
                private static readonly CoTaskMemPtr s_instance = new CoTaskMemPtr();
 
                public static IntPtr AddressOfVTable { get { return s_instance.VTablePtr; } }
            }
 
            private const int S_OK = 0;
            private const int E_NOINTERFACE = unchecked((int)0x80004002);
 
            // 00000000-0000-0000-C000-000000000046
            private static readonly Guid s_IUnknownInterfaceGuid = new Guid(0x00000000, 0x0000, 0x0000, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46);
 
            // 00000003-0000-0000-C000-000000000046
            private static readonly Guid s_IMarshalInterfaceGuid = new Guid(0x00000003, 0x0000, 0x0000, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46);
 
            // CBD71F2C-6BC5-4932-B851-B93EB3151386
            private static readonly Guid s_IComWrapperGuid = new Guid("CBD71F2C-6BC5-4932-B851-B93EB3151386");
 
            private static unsafe int QueryInterface(BlindAggregator* pThis, [In] ref Guid riid, out IntPtr pvObject)
            {
                if (riid == s_IUnknownInterfaceGuid || riid == s_IComWrapperGuid)
                {
                    AddRef(pThis);
                    pvObject = (IntPtr)pThis;
                    return S_OK;
                }
                else if (riid == s_IMarshalInterfaceGuid)
                {
                    pvObject = IntPtr.Zero;
                    return E_NOINTERFACE;
                }
                else
                {
                    // We don't know what the interface is, so aggregate blindly from here
                    return Marshal.QueryInterface(pThis->_innerUnknown, ref riid, out pvObject);
                }
            }
 
            private static unsafe uint AddRef(BlindAggregator* pThis)
                => unchecked((uint)Interlocked.Increment(ref pThis->_refCount));
 
            private static unsafe uint Release(BlindAggregator* pThis)
            {
                var result = unchecked((uint)Interlocked.Decrement(ref pThis->_refCount));
                if (result == 0u)
                {
                    pThis->FinalRelease();
                    Marshal.FreeCoTaskMem((IntPtr)pThis);
                }
 
                return result;
            }
 
            private static unsafe int GetGCHandlePtr(BlindAggregator* pThis, out IntPtr pResult)
            {
                pResult = pThis->_gcHandle;
                return S_OK;
            }
        }
    }
}