File: System\Runtime\InteropServices\IDynamicInterfaceCastable.cs
Web Access
Project: src\src\runtime\src\coreclr\nativeaot\System.Private.CoreLib\src\System.Private.CoreLib.csproj (System.Private.CoreLib)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

using Internal.Runtime;
using Internal.Runtime.Augments;
using Internal.TypeSystem;

using Debug = System.Diagnostics.Debug;

namespace System.Runtime.InteropServices
{
    public unsafe partial interface IDynamicInterfaceCastable
    {
        private static readonly object s_thunkPoolHeap = RuntimeAugments.CreateThunksHeap(RuntimeImports.GetInteropCommonStubAddress());

        internal static nint GetDynamicInterfaceImplementation(IDynamicInterfaceCastable instance, MethodTable* interfaceType, ushort slot)
        {
            RuntimeTypeHandle handle = instance.GetInterfaceImplementation(new RuntimeTypeHandle(interfaceType));
            MethodTable* implType = handle.ToMethodTable();
            if (implType == null)
            {
                ThrowInvalidCastException(instance, interfaceType);
            }
            if (!implType->IsInterface)
            {
                ThrowInvalidOperationException(implType);
            }

            MethodTable* genericContext = null;
            nint result = RuntimeImports.RhResolveDynamicInterfaceCastableDispatchOnType(implType, interfaceType, slot, &genericContext);
            if (result == nint.Zero)
            {
                GetInterfaceImplementationFailure(instance, interfaceType, implType);
            }

            if (genericContext != null)
            {
                if (!s_thunkHashtable.TryGetValue(new InstantiatingThunkKey(result, (nint)genericContext), out nint thunk))
                {
                    thunk = RuntimeAugments.AllocateThunk(s_thunkPoolHeap);
                    RuntimeAugments.SetThunkData(s_thunkPoolHeap, thunk, (nint)genericContext, result);
                    nint thunkInHashtable = s_thunkHashtable.AddOrGetExisting(thunk);
                    if (thunkInHashtable != thunk)
                    {
                        RuntimeAugments.FreeThunk(s_thunkPoolHeap, thunk);
                        thunk = thunkInHashtable;
                    }
                }

                result = thunk;
            }
            return result;
        }

        private static void ThrowInvalidCastException(object instance, MethodTable* interfaceType)
        {
            throw new InvalidCastException(SR.Format(SR.InvalidCast_FromTo, instance.GetType(), Type.GetTypeFromMethodTable(interfaceType)));
        }

        private static void ThrowInvalidOperationException(MethodTable* resolvedImplType)
        {
            throw new InvalidOperationException(SR.Format(SR.IDynamicInterfaceCastable_NotInterface, Type.GetTypeFromMethodTable(resolvedImplType)));
        }

        private static void GetInterfaceImplementationFailure(object instance, MethodTable* interfaceType, MethodTable* resolvedImplType)
        {
            if (resolvedImplType->DispatchMap == null)
                throw new InvalidOperationException(SR.Format(SR.IDynamicInterfaceCastable_MissingImplementationAttribute, Type.GetTypeFromMethodTable(resolvedImplType), nameof(DynamicInterfaceCastableImplementationAttribute)));

            bool implementsInterface = false;
            var interfaces = resolvedImplType->InterfaceMap;
            for (int i = 0; i < resolvedImplType->NumInterfaces; i++)
            {
                if (interfaces[i] == interfaceType)
                {
                    implementsInterface = true;
                    break;
                }
            }

            if (!implementsInterface)
                throw new InvalidOperationException(SR.Format(SR.IDynamicInterfaceCastable_DoesNotImplementRequested, Type.GetTypeFromMethodTable(resolvedImplType), Type.GetTypeFromMethodTable(interfaceType)));

            throw new EntryPointNotFoundException();
        }

        private static readonly InstantiatingThunkHashtable s_thunkHashtable = new InstantiatingThunkHashtable();

        private class InstantiatingThunkHashtable : LockFreeReaderHashtableOfPointers<InstantiatingThunkKey, nint>
        {
            protected override bool CompareKeyToValue(InstantiatingThunkKey key, nint value)
            {
                bool result = RuntimeAugments.TryGetThunkData(s_thunkPoolHeap, value, out nint context, out nint target);
                Debug.Assert(result);
                return key.Target == target && key.Context == context;
            }

            protected override bool CompareValueToValue(nint value1, nint value2)
            {
                bool result1 = RuntimeAugments.TryGetThunkData(s_thunkPoolHeap, value1, out nint context1, out nint target1);
                Debug.Assert(result1);

                bool result2 = RuntimeAugments.TryGetThunkData(s_thunkPoolHeap, value2, out nint context2, out nint target2);
                Debug.Assert(result2);
                return context1 == context2 && target1 == target2;
            }

            protected override nint ConvertIntPtrToValue(nint pointer) => pointer;
            protected override nint ConvertValueToIntPtr(nint value) => value;
            protected override nint CreateValueFromKey(InstantiatingThunkKey key) => throw new NotImplementedException();
            protected override int GetKeyHashCode(InstantiatingThunkKey key) => HashCode.Combine(key.Target, key.Context);

            protected override int GetValueHashCode(nint value)
            {
                bool result = RuntimeAugments.TryGetThunkData(s_thunkPoolHeap, value, out nint context, out nint target);
                Debug.Assert(result);
                return HashCode.Combine(target, context);
            }
        }

        private struct InstantiatingThunkKey
        {
            public readonly nint Target;
            public readonly nint Context;
            public InstantiatingThunkKey(nint target, nint context) => (Target, Context) = (target, context);
        }
    }
}