File: Infrastructure\ControllerActionInvoker.cs
Web Access
Project: src\src\Mvc\Mvc.Core\src\Microsoft.AspNetCore.Mvc.Core.csproj (Microsoft.AspNetCore.Mvc.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#nullable enable
 
using System.Diagnostics;
using System.Globalization;
using System.Runtime.ExceptionServices;
using Microsoft.AspNetCore.Mvc.Abstractions;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Resources = Microsoft.AspNetCore.Mvc.Core.Resources;
 
namespace Microsoft.AspNetCore.Mvc.Infrastructure;
 
#pragma warning disable CA1852 // Seal internal types
internal partial class ControllerActionInvoker : ResourceInvoker, IActionInvoker
#pragma warning restore CA1852 // Seal internal types
{
    private readonly ControllerActionInvokerCacheEntry _cacheEntry;
    private readonly ControllerContext _controllerContext;
 
    private Dictionary<string, object?>? _arguments;
 
    private ActionExecutingContextSealed? _actionExecutingContext;
    private ActionExecutedContextSealed? _actionExecutedContext;
 
    internal ControllerActionInvoker(
        ILogger logger,
        DiagnosticListener diagnosticListener,
        IActionContextAccessor actionContextAccessor,
        IActionResultTypeMapper mapper,
        ControllerContext controllerContext,
        ControllerActionInvokerCacheEntry cacheEntry,
        IFilterMetadata[] filters)
        : base(diagnosticListener, logger, actionContextAccessor, mapper, controllerContext, filters, controllerContext.ValueProviderFactories)
    {
        ArgumentNullException.ThrowIfNull(cacheEntry);
 
        _cacheEntry = cacheEntry;
        _controllerContext = controllerContext;
    }
 
    // Internal for testing
    internal ControllerContext ControllerContext => _controllerContext;
 
    protected override ValueTask ReleaseResources()
    {
        if (_instance != null && _cacheEntry.ControllerReleaser != null)
        {
            return _cacheEntry.ControllerReleaser(_controllerContext, _instance);
        }
 
        return default;
    }
 
    private Task Next(ref State next, ref Scope scope, ref object? state, ref bool isCompleted)
    {
        switch (next)
        {
            case State.ActionBegin:
                {
                    var controllerContext = _controllerContext;
 
                    _cursor.Reset();
                    Log.ExecutingControllerFactory(_logger, controllerContext);
 
                    _instance = _cacheEntry.ControllerFactory(controllerContext);
                    Log.ExecutedControllerFactory(_logger, controllerContext);
 
                    _arguments = new Dictionary<string, object?>(StringComparer.OrdinalIgnoreCase);
 
                    var task = BindArgumentsAsync();
                    if (task.Status != TaskStatus.RanToCompletion)
                    {
                        next = State.ActionNext;
                        return task;
                    }
 
                    goto case State.ActionNext;
                }
 
            case State.ActionNext:
                {
                    var current = _cursor.GetNextFilter<IActionFilter, IAsyncActionFilter>();
                    if (current.FilterAsync != null)
                    {
                        if (_actionExecutingContext == null)
                        {
                            _actionExecutingContext = new ActionExecutingContextSealed(_controllerContext, _filters, _arguments!, _instance!);
                        }
 
                        state = current.FilterAsync;
                        goto case State.ActionAsyncBegin;
                    }
                    else if (current.Filter != null)
                    {
                        if (_actionExecutingContext == null)
                        {
                            _actionExecutingContext = new ActionExecutingContextSealed(_controllerContext, _filters, _arguments!, _instance!);
                        }
 
                        state = current.Filter;
                        goto case State.ActionSyncBegin;
                    }
                    else
                    {
                        goto case State.ActionInside;
                    }
                }
 
            case State.ActionAsyncBegin:
                {
                    Debug.Assert(state != null);
                    Debug.Assert(_actionExecutingContext != null);
 
                    var filter = (IAsyncActionFilter)state;
                    var actionExecutingContext = _actionExecutingContext;
 
                    _diagnosticListener.BeforeOnActionExecution(actionExecutingContext, filter);
                    _logger.BeforeExecutingMethodOnFilter(
                        MvcCoreLoggerExtensions.ActionFilter,
                        nameof(IAsyncActionFilter.OnActionExecutionAsync),
                        filter);
 
                    var task = filter.OnActionExecutionAsync(actionExecutingContext, InvokeNextActionFilterAwaitedAsync);
                    if (task.Status != TaskStatus.RanToCompletion)
                    {
                        next = State.ActionAsyncEnd;
                        return task;
                    }
 
                    goto case State.ActionAsyncEnd;
                }
 
            case State.ActionAsyncEnd:
                {
                    Debug.Assert(state != null);
                    Debug.Assert(_actionExecutingContext != null);
 
                    var filter = (IAsyncActionFilter)state;
 
                    if (_actionExecutedContext == null)
                    {
                        // If we get here then the filter didn't call 'next' indicating a short circuit.
                        _logger.ActionFilterShortCircuited(filter);
 
                        _actionExecutedContext = new ActionExecutedContextSealed(
                            _controllerContext,
                            _filters,
                            _instance!)
                        {
                            Canceled = true,
                            Result = _actionExecutingContext.Result,
                        };
                    }
 
                    _diagnosticListener.AfterOnActionExecution(_actionExecutedContext, filter);
                    _logger.AfterExecutingMethodOnFilter(
                        MvcCoreLoggerExtensions.ActionFilter,
                        nameof(IAsyncActionFilter.OnActionExecutionAsync),
                        filter);
 
                    goto case State.ActionEnd;
                }
 
            case State.ActionSyncBegin:
                {
                    Debug.Assert(state != null);
                    Debug.Assert(_actionExecutingContext != null);
 
                    var filter = (IActionFilter)state;
                    var actionExecutingContext = _actionExecutingContext;
 
                    _diagnosticListener.BeforeOnActionExecuting(actionExecutingContext, filter);
                    _logger.BeforeExecutingMethodOnFilter(
                        MvcCoreLoggerExtensions.ActionFilter,
                        nameof(IActionFilter.OnActionExecuting),
                        filter);
 
                    filter.OnActionExecuting(actionExecutingContext);
 
                    _diagnosticListener.AfterOnActionExecuting(actionExecutingContext, filter);
                    _logger.AfterExecutingMethodOnFilter(
                        MvcCoreLoggerExtensions.ActionFilter,
                        nameof(IActionFilter.OnActionExecuting),
                        filter);
 
                    if (actionExecutingContext.Result != null)
                    {
                        // Short-circuited by setting a result.
                        _logger.ActionFilterShortCircuited(filter);
 
                        _actionExecutedContext = new ActionExecutedContextSealed(
                            _actionExecutingContext,
                            _filters,
                            _instance!)
                        {
                            Canceled = true,
                            Result = _actionExecutingContext.Result,
                        };
 
                        goto case State.ActionEnd;
                    }
 
                    var task = InvokeNextActionFilterAsync();
                    if (task.Status != TaskStatus.RanToCompletion)
                    {
                        next = State.ActionSyncEnd;
                        return task;
                    }
 
                    goto case State.ActionSyncEnd;
                }
 
            case State.ActionSyncEnd:
                {
                    Debug.Assert(state != null);
                    Debug.Assert(_actionExecutingContext != null);
                    Debug.Assert(_actionExecutedContext != null);
 
                    var filter = (IActionFilter)state;
                    var actionExecutedContext = _actionExecutedContext;
 
                    _diagnosticListener.BeforeOnActionExecuted(actionExecutedContext, filter);
                    _logger.BeforeExecutingMethodOnFilter(
                        MvcCoreLoggerExtensions.ActionFilter,
                        nameof(IActionFilter.OnActionExecuted),
                        filter);
 
                    filter.OnActionExecuted(actionExecutedContext);
 
                    _diagnosticListener.AfterOnActionExecuted(actionExecutedContext, filter);
                    _logger.AfterExecutingMethodOnFilter(
                        MvcCoreLoggerExtensions.ActionFilter,
                        nameof(IActionFilter.OnActionExecuted),
                        filter);
 
                    goto case State.ActionEnd;
                }
 
            case State.ActionInside:
                {
                    var task = InvokeActionMethodAsync();
                    if (task.Status != TaskStatus.RanToCompletion)
                    {
                        next = State.ActionEnd;
                        return task;
                    }
 
                    goto case State.ActionEnd;
                }
 
            case State.ActionEnd:
                {
                    if (scope == Scope.Action)
                    {
                        if (_actionExecutedContext == null)
                        {
                            _actionExecutedContext = new ActionExecutedContextSealed(_controllerContext, _filters, _instance!)
                            {
                                Result = _result,
                            };
                        }
 
                        isCompleted = true;
                        return Task.CompletedTask;
                    }
 
                    var actionExecutedContext = _actionExecutedContext;
                    Rethrow(actionExecutedContext);
 
                    if (actionExecutedContext != null)
                    {
                        _result = actionExecutedContext.Result;
                    }
 
                    isCompleted = true;
                    return Task.CompletedTask;
                }
 
            default:
                throw new InvalidOperationException();
        }
    }
 
    private Task InvokeNextActionFilterAsync()
    {
        try
        {
            var next = State.ActionNext;
            var state = (object?)null;
            var scope = Scope.Action;
            var isCompleted = false;
            while (!isCompleted)
            {
                var lastTask = Next(ref next, ref scope, ref state, ref isCompleted);
                if (!lastTask.IsCompletedSuccessfully)
                {
                    return Awaited(this, lastTask, next, scope, state, isCompleted);
                }
            }
        }
        catch (Exception exception)
        {
            _actionExecutedContext = new ActionExecutedContextSealed(_controllerContext, _filters, _instance!)
            {
                ExceptionDispatchInfo = ExceptionDispatchInfo.Capture(exception),
            };
        }
 
        Debug.Assert(_actionExecutedContext != null);
        return Task.CompletedTask;
 
        static async Task Awaited(ControllerActionInvoker invoker, Task lastTask, State next, Scope scope, object? state, bool isCompleted)
        {
            try
            {
                await lastTask;
 
                while (!isCompleted)
                {
                    await invoker.Next(ref next, ref scope, ref state, ref isCompleted);
                }
            }
            catch (Exception exception)
            {
                invoker._actionExecutedContext = new ActionExecutedContextSealed(invoker._controllerContext, invoker._filters, invoker._instance!)
                {
                    ExceptionDispatchInfo = ExceptionDispatchInfo.Capture(exception),
                };
            }
 
            Debug.Assert(invoker._actionExecutedContext != null);
        }
    }
 
    private Task<ActionExecutedContext> InvokeNextActionFilterAwaitedAsync()
    {
        Debug.Assert(_actionExecutingContext != null);
        if (_actionExecutingContext.Result != null)
        {
            // If we get here, it means that an async filter set a result AND called next(). This is forbidden.
            return Throw();
        }
 
        var task = InvokeNextActionFilterAsync();
        if (!task.IsCompletedSuccessfully)
        {
            return Awaited(this, task);
        }
 
        Debug.Assert(_actionExecutedContext != null);
        return Task.FromResult<ActionExecutedContext>(_actionExecutedContext);
 
        static async Task<ActionExecutedContext> Awaited(ControllerActionInvoker invoker, Task task)
        {
            await task;
 
            Debug.Assert(invoker._actionExecutedContext != null);
            return invoker._actionExecutedContext;
        }
#pragma warning disable CS1998
        static async Task<ActionExecutedContext> Throw()
        {
            var message = Resources.FormatAsyncActionFilter_InvalidShortCircuit(
                typeof(IAsyncActionFilter).Name,
                nameof(ActionExecutingContext.Result),
                typeof(ActionExecutingContext).Name,
                typeof(ActionExecutionDelegate).Name);
 
            throw new InvalidOperationException(message);
        }
#pragma warning restore CS1998
    }
 
    private Task InvokeActionMethodAsync()
    {
        if (_diagnosticListener.IsEnabled() || _logger.IsEnabled(LogLevel.Trace))
        {
            return Logged(this);
        }
 
        var objectMethodExecutor = _cacheEntry.ObjectMethodExecutor;
        var actionMethodExecutor = _cacheEntry.ActionMethodExecutor;
        var orderedArguments = PrepareArguments(_arguments, objectMethodExecutor);
 
        var actionResultValueTask = actionMethodExecutor.Execute(ControllerContext, _mapper, objectMethodExecutor, _instance!, orderedArguments);
        if (actionResultValueTask.IsCompletedSuccessfully)
        {
            _result = actionResultValueTask.Result;
        }
        else
        {
            return Awaited(this, actionResultValueTask);
        }
 
        return Task.CompletedTask;
 
        static async Task Awaited(ControllerActionInvoker invoker, ValueTask<IActionResult> actionResultValueTask)
        {
            invoker._result = await actionResultValueTask;
        }
 
        static async Task Logged(ControllerActionInvoker invoker)
        {
            var controllerContext = invoker._controllerContext;
            var objectMethodExecutor = invoker._cacheEntry.ObjectMethodExecutor;
            var controller = invoker._instance;
            var arguments = invoker._arguments;
            var actionMethodExecutor = invoker._cacheEntry.ActionMethodExecutor;
            var orderedArguments = PrepareArguments(arguments, objectMethodExecutor);
 
            var diagnosticListener = invoker._diagnosticListener;
            var logger = invoker._logger;
 
            IActionResult? result = null;
            try
            {
                diagnosticListener.BeforeControllerActionMethod(
                    controllerContext,
                    arguments,
                    controller);
                Log.ActionMethodExecuting(logger, controllerContext, orderedArguments);
                var stopwatch = ValueStopwatch.StartNew();
                var actionResultValueTask = actionMethodExecutor.Execute(controllerContext, invoker._mapper, objectMethodExecutor, controller!, orderedArguments);
                if (actionResultValueTask.IsCompletedSuccessfully)
                {
                    result = actionResultValueTask.Result;
                }
                else
                {
                    result = await actionResultValueTask;
                }
 
                invoker._result = result;
                Log.ActionMethodExecuted(logger, controllerContext, result, stopwatch.GetElapsedTime());
            }
            finally
            {
                diagnosticListener.AfterControllerActionMethod(
                    controllerContext,
                    arguments,
                    controllerContext,
                    result);
            }
        }
    }
 
    /// <remarks><see cref="ResourceInvoker.InvokeFilterPipelineAsync"/> for details on what the
    /// variables in this method represent.</remarks>
    protected override Task InvokeInnerFilterAsync()
    {
        try
        {
            var next = State.ActionBegin;
            var scope = Scope.Invoker;
            var state = (object?)null;
            var isCompleted = false;
 
            while (!isCompleted)
            {
                var lastTask = Next(ref next, ref scope, ref state, ref isCompleted);
                if (!lastTask.IsCompletedSuccessfully)
                {
                    return Awaited(this, lastTask, next, scope, state, isCompleted);
                }
            }
 
            return Task.CompletedTask;
        }
        catch (Exception ex)
        {
            // Wrap non task-wrapped exceptions in a Task,
            // as this isn't done automatically since the method is not async.
            return Task.FromException(ex);
        }
 
        static async Task Awaited(ControllerActionInvoker invoker, Task lastTask, State next, Scope scope, object? state, bool isCompleted)
        {
            await lastTask;
 
            while (!isCompleted)
            {
                await invoker.Next(ref next, ref scope, ref state, ref isCompleted);
            }
        }
    }
 
    private static void Rethrow(ActionExecutedContextSealed? context)
    {
        if (context == null)
        {
            return;
        }
 
        if (context.ExceptionHandled)
        {
            return;
        }
 
        context.ExceptionDispatchInfo?.Throw();
 
        if (context.Exception != null)
        {
            throw context.Exception;
        }
    }
 
    private Task BindArgumentsAsync()
    {
        // Perf: Avoid allocating async state machines where possible. We only need the state
        // machine if you need to bind properties or arguments.
        var actionDescriptor = _controllerContext.ActionDescriptor;
        if (actionDescriptor.BoundProperties.Count == 0 &&
            actionDescriptor.Parameters.Count == 0)
        {
            return Task.CompletedTask;
        }
 
        Debug.Assert(_cacheEntry.ControllerBinderDelegate != null);
        return _cacheEntry.ControllerBinderDelegate(_controllerContext, _instance!, _arguments!);
    }
 
    private static object?[]? PrepareArguments(
        IDictionary<string, object?>? actionParameters,
        ObjectMethodExecutor actionMethodExecutor)
    {
        var declaredParameterInfos = actionMethodExecutor.MethodParameters;
        var count = declaredParameterInfos.Length;
        if (count == 0)
        {
            return null;
        }
 
        Debug.Assert(actionParameters != null, "Expect arguments to be initialized.");
 
        var arguments = new object?[count];
        for (var index = 0; index < count; index++)
        {
            var parameterInfo = declaredParameterInfos[index];
 
            if (!actionParameters.TryGetValue(parameterInfo.Name!, out var value) || value is null)
            {
                value = actionMethodExecutor.GetDefaultValueForParameter(index);
            }
 
            arguments[index] = value;
        }
 
        return arguments;
    }
 
    private enum Scope
    {
        Invoker,
        Action,
    }
 
    private enum State
    {
        ActionBegin,
        ActionNext,
        ActionAsyncBegin,
        ActionAsyncEnd,
        ActionSyncBegin,
        ActionSyncEnd,
        ActionInside,
        ActionEnd,
    }
 
    private sealed class ActionExecutingContextSealed : ActionExecutingContext
    {
        public ActionExecutingContextSealed(ActionContext actionContext, IList<IFilterMetadata> filters, IDictionary<string, object?> actionArguments, object controller) : base(actionContext, filters, actionArguments, controller) { }
    }
 
    private sealed class ActionExecutedContextSealed : ActionExecutedContext
    {
        public ActionExecutedContextSealed(ActionContext actionContext, IList<IFilterMetadata> filters, object controller) : base(actionContext, filters, controller) { }
    }
 
    // Internal for unit testing
    internal static new partial class Log
    {
        public static void ExecutingControllerFactory(ILogger logger, ControllerContext context)
        {
            if (!logger.IsEnabled(LogLevel.Debug))
            {
                return;
            }
 
            var controllerType = context.ActionDescriptor.ControllerTypeInfo.AsType();
            var controllerName = TypeNameHelper.GetTypeDisplayName(controllerType);
            ExecutingControllerFactory(logger, controllerName, controllerType.Assembly.GetName().Name);
        }
 
        [LoggerMessage(1, LogLevel.Debug, "Executing controller factory for controller {Controller} ({AssemblyName})", EventName = "ControllerFactoryExecuting", SkipEnabledCheck = true)]
        private static partial void ExecutingControllerFactory(ILogger logger, string controller, string? assemblyName);
 
        public static void ExecutedControllerFactory(ILogger logger, ControllerContext context)
        {
            if (!logger.IsEnabled(LogLevel.Debug))
            {
                return;
            }
 
            var controllerType = context.ActionDescriptor.ControllerTypeInfo.AsType();
            var controllerName = TypeNameHelper.GetTypeDisplayName(controllerType);
            ExecutedControllerFactory(logger, controllerName, controllerType.Assembly.GetName().Name);
        }
 
        [LoggerMessage(2, LogLevel.Debug, "Executed controller factory for controller {Controller} ({AssemblyName})", EventName = "ControllerFactoryExecuted", SkipEnabledCheck = true)]
        private static partial void ExecutedControllerFactory(ILogger logger, string controller, string? assemblyName);
 
        public static void ActionMethodExecuting(ILogger logger, ControllerContext context, object?[]? arguments)
        {
            if (logger.IsEnabled(LogLevel.Information))
            {
                var actionName = context.ActionDescriptor.DisplayName;
 
                var validationState = context.ModelState.ValidationState;
                ActionMethodExecuting(logger, actionName, validationState);
 
                if (arguments != null && logger.IsEnabled(LogLevel.Trace))
                {
                    var convertedArguments = new string?[arguments.Length];
                    for (var i = 0; i < arguments.Length; i++)
                    {
                        convertedArguments[i] = Convert.ToString(arguments[i], CultureInfo.InvariantCulture);
                    }
 
                    ActionMethodExecutingWithArguments(logger, actionName, convertedArguments);
                }
            }
        }
 
        [LoggerMessage(101, LogLevel.Information, "Executing action method {ActionName} - Validation state: {ValidationState}", EventName = "ActionMethodExecuting", SkipEnabledCheck = true)]
        private static partial void ActionMethodExecuting(ILogger logger, string? actionName, ModelValidationState validationState);
 
        [LoggerMessage(102, LogLevel.Trace, "Executing action method {ActionName} with arguments ({Arguments})", EventName = "ActionMethodExecutingWithArguments", SkipEnabledCheck = true)]
        private static partial void ActionMethodExecutingWithArguments(ILogger logger, string? actionName, string?[] arguments);
 
        public static void ActionMethodExecuted(ILogger logger, ControllerContext context, IActionResult result, TimeSpan timeSpan)
        {
            if (logger.IsEnabled(LogLevel.Information))
            {
                var actionName = context.ActionDescriptor.DisplayName;
                ActionMethodExecuted(logger, actionName, Convert.ToString(result, CultureInfo.InvariantCulture), timeSpan.TotalMilliseconds);
            }
        }
 
        [LoggerMessage(103, LogLevel.Information, "Executed action method {ActionName}, returned result {ActionResult} in {ElapsedMilliseconds}ms.", EventName = "ActionMethodExecuted", SkipEnabledCheck = true)]
        private static partial void ActionMethodExecuted(ILogger logger, string? actionName, string? actionResult, double elapsedMilliseconds);
    }
}