|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Security.Claims;
using System.Threading.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.Shared;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Log = Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcherLog;
namespace Microsoft.AspNetCore.SignalR.Internal;
internal sealed partial class DefaultHubDispatcher<[DynamicallyAccessedMembers(Hub.DynamicallyAccessedMembers)] THub> : HubDispatcher<THub> where THub : Hub
{
private static readonly string _fullHubName = typeof(THub).FullName ?? typeof(THub).Name;
private readonly Dictionary<string, HubMethodDescriptor> _methods = new(StringComparer.OrdinalIgnoreCase);
private readonly Utf8HashLookup _cachedMethodNames = new();
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IHubContext<THub> _hubContext;
private readonly ILogger<HubDispatcher<THub>> _logger;
private readonly bool _enableDetailedErrors;
private readonly Func<HubInvocationContext, ValueTask<object?>>? _invokeMiddleware;
private readonly Func<HubLifetimeContext, Task>? _onConnectedMiddleware;
private readonly Func<HubLifetimeContext, Exception?, Task>? _onDisconnectedMiddleware;
private readonly HubLifetimeManager<THub> _hubLifetimeManager;
[FeatureSwitchDefinition("Microsoft.AspNetCore.SignalR.Hub.IsCustomAwaitableSupported")]
[FeatureGuard(typeof(RequiresDynamicCodeAttribute))]
[FeatureGuard(typeof(RequiresUnreferencedCodeAttribute))]
private static bool IsCustomAwaitableSupported { get; } =
AppContext.TryGetSwitch("Microsoft.AspNetCore.SignalR.Hub.IsCustomAwaitableSupported", out bool customAwaitableSupport) ? customAwaitableSupport : true;
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors,
bool disableImplicitFromServiceParameters, ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter>? hubFilters, HubLifetimeManager<THub> lifetimeManager)
{
_serviceScopeFactory = serviceScopeFactory;
_hubContext = hubContext;
_enableDetailedErrors = enableDetailedErrors;
_logger = logger;
_hubLifetimeManager = lifetimeManager;
DiscoverHubMethods(disableImplicitFromServiceParameters);
var count = hubFilters?.Count ?? 0;
if (count != 0)
{
_invokeMiddleware = (invocationContext) =>
{
var arguments = invocationContext.HubMethodArguments as object?[] ?? invocationContext.HubMethodArguments.ToArray();
if (invocationContext.ObjectMethodExecutor != null)
{
return ExecuteMethod(invocationContext.ObjectMethodExecutor, invocationContext.Hub, arguments);
}
return ExecuteMethod(invocationContext.HubMethod.Name, invocationContext.Hub, arguments);
};
_onConnectedMiddleware = (context) => context.Hub.OnConnectedAsync();
_onDisconnectedMiddleware = (context, exception) => context.Hub.OnDisconnectedAsync(exception);
for (var i = count - 1; i > -1; i--)
{
var resolvedFilter = hubFilters![i];
var nextFilter = _invokeMiddleware;
_invokeMiddleware = (context) => resolvedFilter.InvokeMethodAsync(context, nextFilter);
var connectedFilter = _onConnectedMiddleware;
_onConnectedMiddleware = (context) => resolvedFilter.OnConnectedAsync(context, connectedFilter);
var disconnectedFilter = _onDisconnectedMiddleware;
_onDisconnectedMiddleware = (context, exception) => resolvedFilter.OnDisconnectedAsync(context, exception, disconnectedFilter);
}
}
}
public override async Task OnConnectedAsync(HubConnectionContext connection)
{
await using var scope = _serviceScopeFactory.CreateAsyncScope();
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
var hub = hubActivator.Create();
Activity? activity = null;
try
{
// OnConnectedAsync won't work with client results (ISingleClientProxy.InvokeAsync)
InitializeHub(hub, connection, invokeAllowed: false);
activity = StartActivity(SignalRServerActivitySource.OnConnected, ActivityKind.Internal, linkedActivity: null, scope.ServiceProvider, nameof(hub.OnConnectedAsync), headers: null, _logger);
if (_onConnectedMiddleware != null)
{
var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub);
await _onConnectedMiddleware(context);
}
else
{
await hub.OnConnectedAsync();
}
}
catch (Exception ex)
{
SetActivityError(activity, ex);
throw;
}
finally
{
activity?.Stop();
hubActivator.Release(hub);
}
}
public override async Task OnDisconnectedAsync(HubConnectionContext connection, Exception? exception)
{
await using var scope = _serviceScopeFactory.CreateAsyncScope();
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
var hub = hubActivator.Create();
Activity? activity = null;
try
{
InitializeHub(hub, connection);
activity = StartActivity(SignalRServerActivitySource.OnDisconnected, ActivityKind.Internal, linkedActivity: null, scope.ServiceProvider, nameof(hub.OnDisconnectedAsync), headers: null, _logger);
if (_onDisconnectedMiddleware != null)
{
var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub);
await _onDisconnectedMiddleware(context, exception);
}
else
{
await hub.OnDisconnectedAsync(exception);
}
}
catch (Exception ex)
{
SetActivityError(activity, ex);
throw;
}
finally
{
activity?.Stop();
hubActivator.Release(hub);
}
}
public override Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
{
// Messages are dispatched sequentially and will stop other messages from being processed until they complete.
// Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run.
// With parallel invokes enabled, messages run sequentially until they go async and then the next message will be allowed to start running.
if (!connection.ShouldProcessMessage(hubMessage))
{
Log.DroppingMessage(_logger, hubMessage.GetType().Name, (hubMessage as HubInvocationMessage)?.InvocationId ?? "(null)");
return Task.CompletedTask;
}
switch (hubMessage)
{
case InvocationBindingFailureMessage bindingFailureMessage:
return ProcessInvocationBindingFailure(connection, bindingFailureMessage);
case StreamBindingFailureMessage bindingFailureMessage:
return ProcessStreamBindingFailure(connection, bindingFailureMessage);
case InvocationMessage invocationMessage:
Log.ReceivedHubInvocation(_logger, invocationMessage);
return ProcessInvocation(connection, invocationMessage, isStreamResponse: false);
case StreamInvocationMessage streamInvocationMessage:
Log.ReceivedStreamHubInvocation(_logger, streamInvocationMessage);
return ProcessInvocation(connection, streamInvocationMessage, isStreamResponse: true);
case CancelInvocationMessage cancelInvocationMessage:
// Check if there is an associated active stream and cancel it if it exists.
// The cts will be removed when the streaming method completes executing
if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId!, out var cts))
{
Log.CancelStream(_logger, cancelInvocationMessage.InvocationId!);
cts.Cancel();
}
else
{
// Stream can be canceled on the server while client is canceling stream.
Log.UnexpectedCancel(_logger);
}
break;
case PingMessage _:
connection.StartClientTimeout();
break;
case StreamItemMessage streamItem:
return ProcessStreamItem(connection, streamItem);
case CompletionMessage completionMessage:
// closes channels, removes from Lookup dict
// user's method can see the channel is complete and begin wrapping up
if (connection.StreamTracker.TryComplete(completionMessage))
{
Log.CompletingStream(_logger, completionMessage);
}
// InvocationId is always required on CompletionMessage, it's nullable because of the base type
else if (_hubLifetimeManager.TryGetReturnType(completionMessage.InvocationId!, out _))
{
return _hubLifetimeManager.SetConnectionResultAsync(connection.ConnectionId, completionMessage);
}
else
{
Log.UnexpectedCompletion(_logger, completionMessage.InvocationId!);
}
break;
case AckMessage ackMessage:
Log.ReceivedAckMessage(_logger, ackMessage.SequenceId);
return connection.AckAsync(ackMessage);
case SequenceMessage sequenceMessage:
Log.ReceivedSequenceMessage(_logger, sequenceMessage.SequenceId);
break;
case CloseMessage closeMessage:
connection.CloseMessage = closeMessage;
connection.Abort();
break;
// Other kind of message we weren't expecting
default:
Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName!);
throw new NotSupportedException($"Received unsupported message: {hubMessage}");
}
return Task.CompletedTask;
}
private Task ProcessInvocationBindingFailure(HubConnectionContext connection, InvocationBindingFailureMessage bindingFailureMessage)
{
Log.InvalidHubParameters(_logger, bindingFailureMessage.Target, bindingFailureMessage.BindingFailure.SourceException);
var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.",
bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors);
return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage);
}
private Task ProcessStreamBindingFailure(HubConnectionContext connection, StreamBindingFailureMessage bindingFailureMessage)
{
var errorString = ErrorMessageHelper.BuildErrorMessage(
"Failed to bind Stream message.",
bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors);
var message = CompletionMessage.WithError(bindingFailureMessage.Id, errorString);
Log.ClosingStreamWithBindingError(_logger, message);
// ignore failure, it means the client already completed the stream or the stream never existed on the server
connection.StreamTracker.TryComplete(message);
// TODO: Send stream completion message to client when we add it
return Task.CompletedTask;
}
private Task ProcessStreamItem(HubConnectionContext connection, StreamItemMessage message)
{
if (!connection.StreamTracker.TryProcessItem(message, out var processTask))
{
Log.UnexpectedStreamItem(_logger);
return Task.CompletedTask;
}
Log.ReceivedStreamItem(_logger, message);
return processTask;
}
private Task ProcessInvocation(HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse)
{
if (!_methods.TryGetValue(hubMethodInvocationMessage.Target, out var descriptor))
{
Log.UnknownHubMethod(_logger, hubMethodInvocationMessage.Target);
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
// Send an error to the client. Then let the normal completion process occur
return connection.WriteAsync(CompletionMessage.WithError(
hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask();
}
else
{
return Task.CompletedTask;
}
}
else
{
bool isStreamCall = descriptor.StreamingParameters != null;
if (!isStreamCall && !isStreamResponse)
{
return connection.ActiveInvocationLimit.RunAsync(static state =>
{
var (dispatcher, descriptor, connection, invocationMessage) = state;
return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false);
}, (this, descriptor, connection, hubMethodInvocationMessage)).AsTask();
}
else
{
return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall);
}
}
}
private async Task<bool> Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall)
{
var methodExecutor = descriptor.MethodExecutor;
var wasSemaphoreReleased = false;
var disposeScope = true;
var scope = _serviceScopeFactory.CreateAsyncScope();
IHubActivator<THub>? hubActivator = null;
THub? hub = null;
try
{
hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
hub = hubActivator.Create();
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor, hubMethodInvocationMessage.Arguments, hub))
{
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized");
return true;
}
if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection))
{
return true;
}
try
{
var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0;
var serverStreamLength = descriptor.StreamingParameters?.Count ?? 0;
if (clientStreamLength != serverStreamLength)
{
var ex = new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}.");
Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors));
return true;
}
InitializeHub(hub, connection);
Task? invocation = null;
var arguments = hubMethodInvocationMessage.Arguments;
CancellationTokenSource? cts = null;
if (descriptor.HasSyntheticArguments)
{
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, scope, ref arguments, out cts);
}
if (isStreamCall || isStreamResponse)
{
Debug.Assert(hub.Clients is HubCallerClients);
// Streaming invocations aren't involved with the semaphore.
// Setting the semaphore released flag avoids potential client result calls from the streaming hub method
// releasing the semaphore which would cause a SemaphoreFullException.
((HubCallerClients)hub.Clients).TrySetSemaphoreReleased();
}
if (isStreamResponse)
{
_ = StreamAsync(hubMethodInvocationMessage.InvocationId!, connection, arguments, scope, hubActivator, hub, cts, hubMethodInvocationMessage, descriptor);
}
else
{
// Invoke or Send
static async Task ExecuteInvocation(DefaultHubDispatcher<THub> dispatcher,
ObjectMethodExecutor methodExecutor,
THub hub,
object?[] arguments,
AsyncServiceScope scope,
IHubActivator<THub> hubActivator,
HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage,
bool isStreamCall)
{
var logger = dispatcher._logger;
var enableDetailedErrors = dispatcher._enableDetailedErrors;
// Hub invocation gets its parent from a remote source. Clear any current activity and restore it later.
var previousActivity = Activity.Current;
if (previousActivity != null)
{
Activity.Current = null;
}
// Use hubMethodInvocationMessage.Target instead of methodExecutor.MethodInfo.Name
// We want to take HubMethodNameAttribute into account which will be the same as what the invocation target is
var activity = StartActivity(SignalRServerActivitySource.InvocationIn, ActivityKind.Server, connection.OriginalActivity, scope.ServiceProvider, hubMethodInvocationMessage.Target, hubMethodInvocationMessage.Headers, logger);
object? result;
try
{
result = await dispatcher.ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider);
Log.SendingResult(logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
}
catch (Exception ex)
{
SetActivityError(activity, ex);
Log.FailedInvokingHubMethod(logger, hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, enableDetailedErrors));
return;
}
finally
{
activity?.Stop();
if (Activity.Current != previousActivity)
{
Activity.Current = previousActivity;
}
// Stream response handles cleanup in StreamResultsAsync
// And normal invocations handle cleanup below in the finally
if (isStreamCall)
{
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
}
}
// No InvocationId - Send Async, no response expected
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
// Invoke Async, one response expected
await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result));
}
}
invocation = ExecuteInvocation(this, methodExecutor, hub, arguments, scope, hubActivator, connection, hubMethodInvocationMessage, isStreamCall);
}
if (isStreamCall || isStreamResponse)
{
// don't await streaming invocations
// leave them running in the background, allowing dispatcher to process other messages between streaming items
disposeScope = false;
}
else
{
// complete the non-streaming calls now
await invocation!;
}
}
catch (TargetInvocationException ex)
{
Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex.InnerException ?? ex, _enableDetailedErrors));
}
catch (Exception ex)
{
Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors));
}
}
finally
{
if (disposeScope)
{
if (hub?.Clients is HubCallerClients hubCallerClients)
{
wasSemaphoreReleased = !hubCallerClients.TrySetSemaphoreReleased();
}
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
}
}
return !wasSemaphoreReleased;
}
private static ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMessage, IHubActivator<THub>? hubActivator,
THub? hub, AsyncServiceScope scope)
{
if (hubMessage.StreamIds != null)
{
foreach (var stream in hubMessage.StreamIds)
{
connection.StreamTracker.TryComplete(CompletionMessage.Empty(stream));
}
}
if (hub != null)
{
hubActivator?.Release(hub);
}
return scope.DisposeAsync();
}
private async Task StreamAsync(string invocationId, HubConnectionContext connection, object?[] arguments, AsyncServiceScope scope,
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource? streamCts, HubMethodInvocationMessage hubMethodInvocationMessage, HubMethodDescriptor descriptor)
{
string? error = null;
streamCts ??= CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
// Hub invocation gets its parent from a remote source. Clear any current activity and restore it later.
var previousActivity = Activity.Current;
if (previousActivity != null)
{
Activity.Current = null;
}
var activity = StartActivity(SignalRServerActivitySource.InvocationIn, ActivityKind.Server, connection.OriginalActivity, scope.ServiceProvider, hubMethodInvocationMessage.Target, hubMethodInvocationMessage.Headers, _logger);
try
{
if (!connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts))
{
Log.InvocationIdInUse(_logger, invocationId);
error = $"Invocation ID '{invocationId}' is already in use.";
return;
}
object? result;
try
{
result = await ExecuteHubMethod(descriptor.MethodExecutor, hub, arguments, connection, scope.ServiceProvider);
}
catch (Exception ex)
{
SetActivityError(activity, ex);
Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex);
error = ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors);
return;
}
if (result == null)
{
Log.InvalidReturnValueFromStreamingMethod(_logger, descriptor.MethodExecutor.MethodInfo.Name);
error = $"The value returned by the streaming method '{descriptor.MethodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.";
return;
}
await using var enumerator = descriptor.FromReturnedStream(result, streamCts.Token);
Log.StreamingResult(_logger, invocationId, descriptor.MethodExecutor);
var streamItemMessage = new StreamItemMessage(invocationId, null);
while (await enumerator.MoveNextAsync())
{
streamItemMessage.Item = enumerator.Current;
// Send the stream item
await connection.WriteAsync(streamItemMessage);
}
}
catch (ChannelClosedException ex)
{
// If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method
var exception = ex.InnerException ?? ex;
SetActivityError(activity, exception);
Log.FailedStreaming(_logger, invocationId, descriptor.MethodExecutor.MethodInfo.Name, exception);
error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", exception, _enableDetailedErrors);
}
catch (Exception ex)
{
// If the streaming method was canceled we don't want to send a HubException message - this is not an error case
if (!(ex is OperationCanceledException && streamCts.IsCancellationRequested))
{
SetActivityError(activity, ex);
Log.FailedStreaming(_logger, invocationId, descriptor.MethodExecutor.MethodInfo.Name, ex);
error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors);
}
}
finally
{
activity?.Stop();
if (Activity.Current != previousActivity)
{
Activity.Current = previousActivity;
}
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
streamCts.Dispose();
connection.ActiveRequestCancellationSources.TryRemove(invocationId, out _);
await connection.WriteAsync(CompletionMessage.WithError(invocationId, error));
}
}
private ValueTask<object?> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object?[] arguments, HubConnectionContext connection, IServiceProvider serviceProvider)
{
if (_invokeMiddleware != null)
{
var invocationContext = new HubInvocationContext(methodExecutor, connection.HubCallerContext, serviceProvider, hub, arguments);
return _invokeMiddleware(invocationContext);
}
// If no Hub filters are registered
return ExecuteMethod(methodExecutor, hub, arguments);
}
private ValueTask<object?> ExecuteMethod(string hubMethodName, Hub hub, object?[] arguments)
{
if (!_methods.TryGetValue(hubMethodName, out var methodDescriptor))
{
throw new HubException($"Unknown hub method '{hubMethodName}'");
}
var methodExecutor = methodDescriptor.MethodExecutor;
return ExecuteMethod(methodExecutor, hub, arguments);
}
private static async ValueTask<object?> ExecuteMethod(ObjectMethodExecutor methodExecutor, Hub hub, object?[] arguments)
{
if (methodExecutor.IsMethodAsync)
{
if (methodExecutor.MethodReturnType == typeof(Task))
{
await (Task)methodExecutor.Execute(hub, arguments)!;
return null;
}
else
{
return await methodExecutor.ExecuteAsync(hub, arguments);
}
}
else
{
return methodExecutor.Execute(hub, arguments);
}
}
private static async Task SendInvocationError(string? invocationId, HubConnectionContext connection, string errorMessage)
{
if (string.IsNullOrEmpty(invocationId))
{
return;
}
await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage));
}
private void InitializeHub(THub hub, HubConnectionContext connection, bool invokeAllowed = true)
{
hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit) { InvokeAllowed = invokeAllowed };
hub.Context = connection.HubCallerContext;
hub.Groups = _hubContext.Groups;
}
private static Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, HubMethodDescriptor descriptor, object?[] hubMethodArguments, Hub hub)
{
// If there are no policies we don't need to run auth
if (descriptor.Policies.Count == 0)
{
return TaskCache.True;
}
return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, descriptor.Policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, provider, hub, descriptor.MethodExecutor.MethodInfo, hubMethodArguments));
}
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies, HubInvocationContext resource)
{
var authService = provider.GetRequiredService<IAuthorizationService>();
var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies);
// AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above
Debug.Assert(authorizePolicy != null);
var authorizationResult = await authService.AuthorizeAsync(principal, resource, authorizePolicy);
// Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation
return authorizationResult.Succeeded;
}
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
{
if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse)
{
// Non-null/empty InvocationId? Blocking
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
$"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation."));
}
return false;
}
if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse)
{
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!,
$"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation."));
return false;
}
return true;
}
private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall,
HubConnectionContext connection, AsyncServiceScope scope, ref object?[] arguments, out CancellationTokenSource? cts)
{
cts = null;
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
arguments = new object?[descriptor.OriginalParameterTypes!.Count];
var streamPointer = 0;
var hubInvocationArgumentPointer = 0;
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
{
if (hubMethodInvocationMessage.Arguments?.Length > hubInvocationArgumentPointer &&
(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer] == null ||
descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer]?.GetType())))
{
// The types match so it isn't a synthetic argument, just copy it into the arguments array
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
hubInvocationArgumentPointer++;
}
else
{
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
{
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else if (descriptor.IsServiceArgument(parameterPointer))
{
arguments[parameterPointer] = descriptor.GetService(scope.ServiceProvider, parameterPointer, descriptor.OriginalParameterTypes[parameterPointer]);
}
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
{
Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds![streamPointer]);
var itemType = descriptor.StreamingParameters![streamPointer];
arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer],
itemType, descriptor.OriginalParameterTypes[parameterPointer]);
streamPointer++;
}
else
{
// This should never happen
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{descriptor.MethodExecutor.MethodInfo.Name}'.");
}
}
}
}
private void DiscoverHubMethods(bool disableImplicitFromServiceParameters)
{
var hubType = typeof(THub);
var hubTypeInfo = hubType.GetTypeInfo();
var hubName = hubType.Name;
using var scope = _serviceScopeFactory.CreateScope();
IServiceProviderIsService? serviceProviderIsService = null;
if (!disableImplicitFromServiceParameters)
{
serviceProviderIsService = scope.ServiceProvider.GetService<IServiceProviderIsService>();
}
foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType))
{
if (methodInfo.IsGenericMethod)
{
throw new NotSupportedException($"Method '{methodInfo.Name}' is a generic method which is not supported on a Hub.");
}
var methodName =
methodInfo.GetCustomAttribute<HubMethodNameAttribute>()?.Name ??
methodInfo.Name;
if (_methods.ContainsKey(methodName))
{
throw new NotSupportedException($"Duplicate definitions of '{methodName}'. Overloading is not supported.");
}
var executor = IsCustomAwaitableSupported
? CreateObjectMethodExecutor(methodInfo, hubTypeInfo)
: ObjectMethodExecutor.CreateTrimAotCompatible(methodInfo, hubTypeInfo);
var authorizeAttributes = methodInfo.GetCustomAttributes<AuthorizeAttribute>(inherit: true);
_methods[methodName] = new HubMethodDescriptor(executor, serviceProviderIsService, authorizeAttributes);
_cachedMethodNames.Add(methodName);
Log.HubMethodBound(_logger, hubName, methodName);
}
}
[RequiresUnreferencedCode("Using SignalR with 'Microsoft.AspNetCore.SignalR.Hub.IsCustomAwaitableSupported=true' is not trim compatible.")]
[RequiresDynamicCode("Using SignalR with 'Microsoft.AspNetCore.SignalR.Hub.IsCustomAwaitableSupported=true' is not native AOT compatible.")]
private static ObjectMethodExecutor CreateObjectMethodExecutor(MethodInfo methodInfo, TypeInfo targetType)
=> ObjectMethodExecutor.Create(methodInfo, targetType);
public override IReadOnlyList<Type> GetParameterTypes(string methodName)
{
if (!_methods.TryGetValue(methodName, out var descriptor))
{
throw new HubException("Method does not exist.");
}
return descriptor.ParameterTypes;
}
public override string? GetTargetName(ReadOnlySpan<byte> targetUtf8Bytes)
{
if (_cachedMethodNames.TryGetValue(targetUtf8Bytes, out var targetName))
{
return targetName;
}
return null;
}
// Starts an Activity for a Hub method invocation and sets up all the tags and other state.
// Make sure to call Activity.Stop() once the Hub method completes, and consider calling SetActivityError on exception.
private static Activity? StartActivity(string operationName, ActivityKind kind, Activity? linkedActivity, IServiceProvider serviceProvider, string methodName, IDictionary<string, string>? headers, ILogger logger)
{
var activitySource = serviceProvider.GetService<SignalRServerActivitySource>()?.ActivitySource;
if (activitySource is null)
{
return null;
}
var loggingEnabled = logger.IsEnabled(LogLevel.Critical);
if (!activitySource.HasListeners() && !loggingEnabled)
{
return null;
}
IEnumerable<KeyValuePair<string, object?>> tags =
[
new("rpc.method", methodName),
new("rpc.system", "signalr"),
new("rpc.service", _fullHubName),
// See https://github.com/dotnet/aspnetcore/blob/027c60168383421750f01e427e4f749d0684bc02/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelMetrics.cs#L308
// And https://github.com/dotnet/aspnetcore/issues/43786
//new("server.address", ...),
];
IEnumerable<ActivityLink>? links = (linkedActivity is not null) ? [new ActivityLink(linkedActivity.Context)] : null;
Activity? activity;
if (headers != null)
{
var propagator = serviceProvider.GetService<DistributedContextPropagator>() ?? DistributedContextPropagator.Current;
activity = ActivityCreator.CreateFromRemote(
activitySource,
propagator,
headers,
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
{
fieldValues = default;
var headers = (IDictionary<string, string>)carrier!;
headers.TryGetValue(fieldName, out fieldValue);
},
operationName,
kind,
tags,
links,
loggingEnabled);
}
else
{
activity = activitySource.CreateActivity(operationName, kind, parentId: null, tags: tags, links: links);
}
if (activity is not null)
{
activity.DisplayName = $"{_fullHubName}/{methodName}";
activity.Start();
}
return activity;
}
private static void SetActivityError(Activity? activity, Exception ex)
{
activity?.SetTag("error.type", ex.GetType().FullName);
activity?.SetStatus(ActivityStatusCode.Error);
}
}
|