File: TestFilters.cs
Web Access
Project: src\src\SignalR\server\SignalR\test\Microsoft.AspNetCore.SignalR.Tests\Microsoft.AspNetCore.SignalR.Tests.csproj (Microsoft.AspNetCore.SignalR.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Diagnostics;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Internal;
 
namespace Microsoft.AspNetCore.SignalR.Tests;
 
public class VerifyMethodFilter : IHubFilter
{
    private readonly TcsService _service;
    public VerifyMethodFilter(TcsService tcsService)
    {
        _service = tcsService;
    }
 
    public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
    {
        _service.StartedMethod.TrySetResult(null);
        await next(context);
        _service.EndMethod.TrySetResult(null);
    }
 
    public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        _service.StartedMethod.TrySetResult(null);
        var result = await next(invocationContext);
        _service.EndMethod.TrySetResult(null);
 
        return result;
    }
 
    public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
    {
        _service.StartedMethod.TrySetResult(null);
        await next(context, exception);
        _service.EndMethod.TrySetResult(null);
    }
}
 
public class SyncPointFilter : IHubFilter
{
    private readonly SyncPoint[] _syncPoint;
    public SyncPointFilter(SyncPoint[] syncPoints)
    {
        Debug.Assert(syncPoints.Length == 3);
        _syncPoint = syncPoints;
    }
 
    public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
    {
        await _syncPoint[0].WaitToContinue();
        await next(context);
    }
 
    public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        await _syncPoint[1].WaitToContinue();
        var result = await next(invocationContext);
 
        return result;
    }
 
    public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
    {
        await _syncPoint[2].WaitToContinue();
        await next(context, exception);
    }
}
 
public class FilterCounter
{
    public int OnConnectedAsyncCount;
    public int InvokeMethodAsyncCount;
    public int OnDisconnectedAsyncCount;
}
 
public class CounterFilter : IHubFilter
{
    private readonly FilterCounter _counter;
    public CounterFilter(FilterCounter counter)
    {
        _counter = counter;
        _counter.OnConnectedAsyncCount = 0;
        _counter.InvokeMethodAsyncCount = 0;
        _counter.OnDisconnectedAsyncCount = 0;
    }
 
    public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
    {
        _counter.OnConnectedAsyncCount++;
        return next(context);
    }
 
    public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
    {
        _counter.OnDisconnectedAsyncCount++;
        return next(context, exception);
    }
 
    public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        _counter.InvokeMethodAsyncCount++;
        return next(invocationContext);
    }
}
 
public class NoExceptionFilter : IHubFilter
{
    public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
    {
        try
        {
            await next(context);
        }
        catch { }
    }
 
    public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
    {
        try
        {
            await next(context, exception);
        }
        catch { }
    }
 
    public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        try
        {
            return await next(invocationContext);
        }
        catch { }
 
        return null;
    }
}
 
public class SkipNextFilter : IHubFilter
{
    private readonly bool _skipOnConnected;
    private readonly bool _skipInvoke;
    private readonly bool _skipOnDisconnected;
 
    public SkipNextFilter(bool skipOnConnected = false, bool skipInvoke = false, bool skipOnDisconnected = false)
    {
        _skipOnConnected = skipOnConnected;
        _skipInvoke = skipInvoke;
        _skipOnDisconnected = skipOnDisconnected;
    }
 
    public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
    {
        if (_skipOnConnected)
        {
            return Task.CompletedTask;
        }
 
        return next(context);
    }
 
    public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
    {
        if (_skipOnDisconnected)
        {
            return Task.CompletedTask;
        }
 
        return next(context, exception);
    }
 
    public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        if (_skipInvoke)
        {
            return new ValueTask<object>();
        }
 
        return next(invocationContext);
    }
}
 
public class DisposableFilter : IHubFilter, IDisposable
{
    private readonly TcsService _tcsService;
 
    public DisposableFilter(TcsService tcsService)
    {
        _tcsService = tcsService;
    }
 
    public void Dispose()
    {
        _tcsService.StartedMethod.SetResult(null);
    }
 
    public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        return next(invocationContext);
    }
}
 
public class AsyncDisposableFilter : IHubFilter, IAsyncDisposable
{
    private readonly TcsService _tcsService;
 
    public AsyncDisposableFilter(TcsService tcsService)
    {
        _tcsService = tcsService;
    }
 
    public ValueTask DisposeAsync()
    {
        _tcsService.StartedMethod.SetResult(null);
        return default;
    }
 
    public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        return next(invocationContext);
    }
}
 
public class ChangeMethodFilter : IHubFilter
{
    public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
    {
        var methodInfo = typeof(BaseHub).GetMethod(nameof(BaseHub.BaseMethod));
        var context = new HubInvocationContext(invocationContext.Context, invocationContext.ServiceProvider, invocationContext.Hub, methodInfo, invocationContext.HubMethodArguments);
        return next(context);
    }
}
 
public class EmptyFilter : IHubFilter
{
    // Purposefully not implementing any methods
}