File: StateHasChangedAnalyzer.cs
Web Access
Project: src\aspnetcore\src\Components\Analyzers\src\Microsoft.AspNetCore.Components.Analyzers.csproj (Microsoft.AspNetCore.Components.Analyzers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
 
#nullable enable
 
namespace Microsoft.AspNetCore.Components.Analyzers;
 
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class StateHasChangedAnalyzer : DiagnosticAnalyzer
{
    private const string EventCallbackFactoryTypeName = "Microsoft.AspNetCore.Components.EventCallbackFactory";
 
    public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics =>
        ImmutableArray.Create(DiagnosticDescriptors.UnnecessaryStateHasChangedCall);
 
    public override void Initialize(AnalysisContext context)
    {
        context.EnableConcurrentExecution();
        context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
 
        context.RegisterSymbolStartAction(context =>
        {
            if (!ComponentSymbols.TryCreate(context.Compilation, out var symbols))
            {
                // Types we need are not defined.
                return;
            }
 
            if (symbols.ComponentBaseType is null)
            {
                // ComponentBase availability guard.
                return;
            }
 
            var type = (INamedTypeSymbol)context.Symbol;
            if (!ComponentFacts.IsComponentBase(symbols, type))
            {
                // only applies to ComponentBase derived types.
                return;
            }
 
            var eventCallbackFactoryType = context.Compilation.GetTypeByMetadataName(EventCallbackFactoryTypeName);
            var eventHandlerMethods = new ConcurrentDictionary<IMethodSymbol, byte>(SymbolEqualityComparer.Default);
            var redundantCallLocationsByMethod = new ConcurrentDictionary<IMethodSymbol, ImmutableArray<Location>>(SymbolEqualityComparer.Default);
 
            // collect event handler methods
            context.RegisterOperationAction(operationContext =>
            {
                if (eventCallbackFactoryType is null)
                {
                    return;
                }
 
                var invocation = (IInvocationOperation)operationContext.Operation;
                if (!SymbolEqualityComparer.Default.Equals(invocation.TargetMethod.ContainingType, eventCallbackFactoryType))
                {
                    return;
                }
 
                foreach (var argument in invocation.Arguments)
                {
                    var method = TryGetMethodFromOperation(argument.Value);
                    if (method is not null)
                    {
                        eventHandlerMethods.TryAdd(method, 0);
                    }
                }
            }, OperationKind.Invocation);
 
            // collect unnecessary StateHasChanged calls
            context.RegisterSyntaxNodeAction(syntaxContext =>
            {
                var methodDeclaration = (MethodDeclarationSyntax)syntaxContext.Node;
 
                if (syntaxContext.SemanticModel.GetDeclaredSymbol(methodDeclaration) is not IMethodSymbol methodSymbol)
                {
                    return;
                }
 
                var body = methodDeclaration.Body;
                if (body is null)
                {
                    // Handle expression-bodied methods like: void OnInitialized() => StateHasChanged();
                    var expressionBody = methodDeclaration.ExpressionBody;
                    if (expressionBody is not null &&
                        expressionBody.Expression is InvocationExpressionSyntax expressionBodyInvocation &&
                        IsStateHasChangedCall(syntaxContext.SemanticModel, expressionBodyInvocation))
                    {
                        var expressionBodyCallLocations = new Dictionary<int, Location>();
                        AddCallLocation(expressionBodyCallLocations, expressionBodyInvocation);
                        redundantCallLocationsByMethod.TryAdd(methodSymbol, expressionBodyCallLocations.Values.ToImmutableArray());
                    }
 
                    return;
                }
 
                var awaitExpressions = body.DescendantNodes(static node => !IsNestedFunctionLike(node)).OfType<AwaitExpressionSyntax>().OrderBy(n => n.SpanStart).ToList();
                var stateCalls = body.DescendantNodes(static node => !IsNestedFunctionLike(node)).OfType<InvocationExpressionSyntax>()
                    .Where(invocation => IsStateHasChangedCall(syntaxContext.SemanticModel, invocation))
                    .OrderBy(invocation => invocation.SpanStart)
                    .ToList();
                if (stateCalls.Count == 0)
                {
                    // no call, no problems.
                    return;
                }
 
                var callLocations = new Dictionary<int, Location>();
                if (awaitExpressions.Count == 0)
                {
                    // no await expressions, all calls are potentially redundant
                    foreach (var stateCall in stateCalls)
                    {
                        AddCallLocation(callLocations, stateCall);
                    }
                }
                else
                {
                    var firstAwaitStart = awaitExpressions[0].SpanStart;
                    var lastAwaitStart = awaitExpressions[awaitExpressions.Count - 1].SpanStart;
                    foreach (var stateCall in stateCalls)
                    {
                        if (stateCall.SpanStart < firstAwaitStart || stateCall.SpanStart > lastAwaitStart)
                        {
                            // any calls before the first await or after the last one are redundant, because ComponentBase calls StateHasChanged afterwards.
                            AddCallLocation(callLocations, stateCall);
                        }
                    }
                }
 
                if (callLocations.Count == 0)
                {
                    return;
                }
 
                redundantCallLocationsByMethod.TryAdd(methodSymbol, callLocations.Values.ToImmutableArray());
            }, Microsoft.CodeAnalysis.CSharp.SyntaxKind.MethodDeclaration);
 
            context.RegisterSymbolEndAction(endContext =>
            {
                foreach (var methodAndLocations in redundantCallLocationsByMethod)
                {
                    var method = methodAndLocations.Key;
                    var locations = methodAndLocations.Value;
 
                    if (!IsTargetMethod(method, eventHandlerMethods))
                    {
                        continue;
                    }
 
                    foreach (var location in locations)
                    {
                        endContext.ReportDiagnostic(Diagnostic.Create(
                            DiagnosticDescriptors.UnnecessaryStateHasChangedCall,
                            location,
                            method.Name));
                    }
                }
            });
        }, SymbolKind.NamedType);
    }
 
    // Targets of this analyzer are lifecycle methods (OnInitialized, OnParametersSet) and event handlers
    private static bool IsTargetMethod(IMethodSymbol method, ConcurrentDictionary<IMethodSymbol, byte> eventHandlerMethods)
    {
        if (method.MethodKind != MethodKind.Ordinary)
        {
            return false;
        }
 
        if (method.OverriddenMethod is { } overridden && IsTargetLifecycleMethod(overridden))
        {
            return true;
        }
 
        return eventHandlerMethods.ContainsKey(method);
    }
 
    private static bool IsTargetLifecycleMethod(IMethodSymbol method)
    {
        return method.MethodKind == MethodKind.Ordinary &&
            method.Parameters.Length == 0 &&
            method.Name is "OnInitialized" or "OnInitializedAsync" or "OnParametersSet" or "OnParametersSetAsync";
    }
 
    private static bool IsStateHasChangedCall(SemanticModel semanticModel, InvocationExpressionSyntax invocation)
    {
        return semanticModel.GetSymbolInfo(invocation).Symbol is IMethodSymbol method &&
            method.MethodKind == MethodKind.Ordinary &&
            method.Parameters.Length == 0 &&
            method.Name == "StateHasChanged" &&
            method.ContainingType.ToDisplayString() == ComponentsApi.ComponentBase.FullTypeName;
    }
 
    private static void AddCallLocation(Dictionary<int, Location> callLocations, InvocationExpressionSyntax stateCall)
    {
        if (!callLocations.ContainsKey(stateCall.SpanStart))
        {
            callLocations[stateCall.SpanStart] = stateCall.GetLocation();
        }
    }
 
    // Exclude local functions/lambdas because those have their own flow and execution timing.
    private static bool IsNestedFunctionLike(SyntaxNode node)
    {
        return node is LocalFunctionStatementSyntax or AnonymousFunctionExpressionSyntax;
    }
 
    private static IMethodSymbol? TryGetMethodFromOperation(IOperation operation)
    {
        switch (operation)
        {
            case IMethodReferenceOperation methodReference:
                return methodReference.Method;
            case IDelegateCreationOperation delegateCreation when delegateCreation.Target is not null:
                return TryGetMethodFromOperation(delegateCreation.Target);
            case IConversionOperation conversion:
                return TryGetMethodFromOperation(conversion.Operand);
            default:
                return null;
        }
    }
}