File: ServiceLookup\CallSiteRuntimeResolver.cs
Web Access
Project: src\src\libraries\Microsoft.Extensions.DependencyInjection\src\Microsoft.Extensions.DependencyInjection.csproj (Microsoft.Extensions.DependencyInjection)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.ExceptionServices;
using System.Threading;
 
namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
    internal sealed class CallSiteRuntimeResolver : CallSiteVisitor<RuntimeResolverContext, object?>
    {
        public static CallSiteRuntimeResolver Instance { get; } = new();
 
        private CallSiteRuntimeResolver()
        {
        }
 
        public object? Resolve(ServiceCallSite callSite, ServiceProviderEngineScope scope)
        {
            // Fast path to avoid virtual calls if we already have the cached value in the root scope
            if (scope.IsRootScope && callSite.Value is object cached)
            {
                return cached;
            }
 
            return VisitCallSite(callSite, new RuntimeResolverContext
            {
                Scope = scope
            });
        }
 
        protected override object? VisitDisposeCache(ServiceCallSite transientCallSite, RuntimeResolverContext context)
        {
            return context.Scope.CaptureDisposable(VisitCallSiteMain(transientCallSite, context));
        }
 
        protected override object VisitConstructor(ConstructorCallSite constructorCallSite, RuntimeResolverContext context)
        {
            object?[] parameterValues;
            if (constructorCallSite.ParameterCallSites.Length == 0)
            {
                parameterValues = Array.Empty<object>();
            }
            else
            {
                parameterValues = new object?[constructorCallSite.ParameterCallSites.Length];
                for (int index = 0; index < parameterValues.Length; index++)
                {
                    parameterValues[index] = VisitCallSite(constructorCallSite.ParameterCallSites[index], context);
                }
            }
 
#if NETFRAMEWORK || NETSTANDARD2_0
            try
            {
                return constructorCallSite.ConstructorInfo.Invoke(parameterValues);
            }
            catch (Exception ex) when (ex.InnerException != null)
            {
                ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
                // The above line will always throw, but the compiler requires we throw explicitly.
                throw;
            }
#else
            return constructorCallSite.ConstructorInfo.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: parameterValues, culture: null);
#endif
        }
 
        protected override object? VisitRootCache(ServiceCallSite callSite, RuntimeResolverContext context)
        {
            if (callSite.Value is object value)
            {
                // Value already calculated, return it directly
                return value;
            }
 
            var lockType = RuntimeResolverLock.Root;
            ServiceProviderEngineScope serviceProviderEngine = context.Scope.RootProvider.Root;
 
            lock (callSite)
            {
                // Lock the callsite and check if another thread already cached the value
                if (callSite.Value is object callSiteValue)
                {
                    return callSiteValue;
                }
 
                object? resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext
                {
                    Scope = serviceProviderEngine,
                    AcquiredLocks = context.AcquiredLocks | lockType
                });
                serviceProviderEngine.CaptureDisposable(resolved);
                callSite.Value = resolved;
                return resolved;
            }
        }
 
        protected override object? VisitScopeCache(ServiceCallSite callSite, RuntimeResolverContext context)
        {
            // Check if we are in the situation where scoped service was promoted to singleton
            // and we need to lock the root
            return context.Scope.IsRootScope ?
                VisitRootCache(callSite, context) :
                VisitCache(callSite, context, context.Scope, RuntimeResolverLock.Scope);
        }
 
        private object? VisitCache(ServiceCallSite callSite, RuntimeResolverContext context, ServiceProviderEngineScope serviceProviderEngine, RuntimeResolverLock lockType)
        {
            bool lockTaken = false;
            object sync = serviceProviderEngine.Sync;
            Dictionary<ServiceCacheKey, object?> resolvedServices = serviceProviderEngine.ResolvedServices;
            // Taking locks only once allows us to fork resolution process
            // on another thread without causing the deadlock because we
            // always know that we are going to wait the other thread to finish before
            // releasing the lock
            if ((context.AcquiredLocks & lockType) == 0)
            {
                Monitor.Enter(sync, ref lockTaken);
            }
 
            try
            {
                // Note: This method has already taken lock by the caller for resolution and access synchronization.
                // For scoped: takes a dictionary as both a resolution lock and a dictionary access lock.
                if (resolvedServices.TryGetValue(callSite.Cache.Key, out object? resolved))
                {
                    return resolved;
                }
 
                resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext
                {
                    Scope = serviceProviderEngine,
                    AcquiredLocks = context.AcquiredLocks | lockType
                });
                serviceProviderEngine.CaptureDisposable(resolved);
                resolvedServices.Add(callSite.Cache.Key, resolved);
                return resolved;
            }
            finally
            {
                if (lockTaken)
                {
                    Monitor.Exit(sync);
                }
            }
        }
 
        protected override object? VisitConstant(ConstantCallSite constantCallSite, RuntimeResolverContext context)
        {
            return constantCallSite.DefaultValue;
        }
 
        protected override object VisitServiceProvider(ServiceProviderCallSite serviceProviderCallSite, RuntimeResolverContext context)
        {
            return context.Scope;
        }
 
        protected override object VisitIEnumerable(IEnumerableCallSite enumerableCallSite, RuntimeResolverContext context)
        {
            Array array = CreateArray(
                enumerableCallSite.ItemType,
                enumerableCallSite.ServiceCallSites.Length);
 
            for (int index = 0; index < enumerableCallSite.ServiceCallSites.Length; index++)
            {
                object? value = VisitCallSite(enumerableCallSite.ServiceCallSites[index], context);
                array.SetValue(value, index);
            }
            return array;
 
            [UnconditionalSuppressMessage("AotAnalysis", "IL3050:RequiresDynamicCode",
                Justification = "VerifyAotCompatibility ensures elementType is not a ValueType")]
            static Array CreateArray(Type elementType, int length)
            {
                Debug.Assert(!ServiceProvider.VerifyAotCompatibility || !elementType.IsValueType, "VerifyAotCompatibility=true will throw during building the IEnumerableCallSite if elementType is a ValueType.");
 
                return Array.CreateInstance(elementType, length);
            }
        }
 
        protected override object VisitFactory(FactoryCallSite factoryCallSite, RuntimeResolverContext context)
        {
            return factoryCallSite.Factory(context.Scope);
        }
    }
 
    internal struct RuntimeResolverContext
    {
        public ServiceProviderEngineScope Scope { get; set; }
 
        public RuntimeResolverLock AcquiredLocks { get; set; }
    }
 
    [Flags]
    internal enum RuntimeResolverLock
    {
        Scope = 1,
        Root = 2
    }
}