File: Lowering\LocalRewriter\DelegateCacheRewriter.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
using ReferenceEqualityComparer = Roslyn.Utilities.ReferenceEqualityComparer;
 
namespace Microsoft.CodeAnalysis.CSharp;
 
/// <summary>
/// This type helps rewrite the delegate creations that target static method groups to use a cached instance of delegate.
/// </summary>
internal sealed class DelegateCacheRewriter
{
    private readonly SyntheticBoundNodeFactory _factory;
    private readonly int _topLevelMethodOrdinal;
 
    private Dictionary<MethodSymbol, DelegateCacheContainer>? _genericCacheContainers;
 
    internal DelegateCacheRewriter(SyntheticBoundNodeFactory factory, int topLevelMethodOrdinal)
    {
        Debug.Assert(factory.TopLevelMethod is { });
 
        _factory = factory;
        _topLevelMethodOrdinal = topLevelMethodOrdinal;
    }
 
    internal static bool CanRewrite(BoundDelegateCreationExpression boundDelegateCreation)
    {
        var targetMethod = boundDelegateCreation.MethodOpt;
 
        Debug.Assert(targetMethod is { });
 
        return targetMethod.IsStatic && !boundDelegateCreation.IsExtensionMethod;
    }
 
    internal BoundExpression Rewrite(BoundDelegateCreationExpression boundDelegateCreation)
    {
        Debug.Assert(boundDelegateCreation.MethodOpt is { });
 
        var oldSyntax = _factory.Syntax;
        _factory.Syntax = boundDelegateCreation.Syntax;
 
        var cacheContainer = GetOrAddCacheContainer(boundDelegateCreation);
        var cacheField = cacheContainer.GetOrAddCacheField(_factory, boundDelegateCreation);
 
        var boundCacheField = _factory.Field(receiver: null, cacheField);
        var rewrittenNode = _factory.Coalesce(boundCacheField, _factory.AssignmentExpression(boundCacheField, boundDelegateCreation));
 
        _factory.Syntax = oldSyntax;
 
        return rewrittenNode;
    }
 
    private DelegateCacheContainer GetOrAddCacheContainer(BoundDelegateCreationExpression boundDelegateCreation)
    {
        Debug.Assert(_factory.ModuleBuilderOpt is { });
        Debug.Assert(_factory.CurrentFunction is { });
 
        var generation = _factory.ModuleBuilderOpt.CurrentGenerationOrdinal;
 
        DelegateCacheContainer? container;
 
        // We don't need to synthesize a container for each and every function.
        //
        // For example:
        //   void LF1<T>()
        //   {
        //       void LF2<G>()
        //       {
        //           void LF3()
        //           {
        //               Func<T> d = SomeMethod<T>;
        //               static void LF4 () { Func<T> d = SomeMethod<T>; }
        //           }
        //
        //           void LF5()
        //           {
        //               Func<T> d = SomeMethod<T>;
        //           }
        //       }
        //   }
        //
        // In the above case, only one cached delegate is necessary, and it could be assigned to the container 'owned' by LF1.
 
        if (!TryGetOwnerFunction(_factory.CurrentFunction, boundDelegateCreation, out var ownerFunction))
        {
            var typeCompilationState = _factory.CompilationState;
            container = typeCompilationState.ConcreteDelegateCacheContainer;
 
            if (container is { })
            {
                return container;
            }
 
            container = new DelegateCacheContainer(typeCompilationState.Type, generation);
            typeCompilationState.ConcreteDelegateCacheContainer = container;
        }
        else
        {
            var containers = _genericCacheContainers ??= new Dictionary<MethodSymbol, DelegateCacheContainer>(ReferenceEqualityComparer.Instance);
 
            if (containers.TryGetValue(ownerFunction, out container))
            {
                return container;
            }
 
            container = new DelegateCacheContainer(ownerFunction, _topLevelMethodOrdinal, containers.Count, generation);
            containers.Add(ownerFunction, container);
        }
 
        _factory.AddNestedType(container);
 
        return container;
    }
 
    private static bool TryGetOwnerFunction(MethodSymbol currentFunction, BoundDelegateCreationExpression boundDelegateCreation, [NotNullWhen(true)] out MethodSymbol? ownerFunction)
    {
        var targetMethod = boundDelegateCreation.MethodOpt;
        Debug.Assert(targetMethod is { });
 
        if (targetMethod.MethodKind == MethodKind.LocalFunction)
        {
            // Local functions can use type parameters from their enclosing methods!
            //
            // For example:
            //   void Test<T>()
            //   {
            //       var t = Target<int>;
            //       static object Target<V>() => default(T);
            //   }
            //
            // Therefore, without too much analysis, we select the closest generic enclosing function as the cache container owner.
 
            for (Symbol? enclosingSymbol = currentFunction; enclosingSymbol is MethodSymbol enclosingMethod; enclosingSymbol = enclosingSymbol.ContainingSymbol)
            {
                if (enclosingMethod.Arity > 0)
                {
                    ownerFunction = enclosingMethod;
                    return true;
                }
            }
 
            ownerFunction = null;
            return false;
        }
 
        // @AlekseyTs: It is Ok to create delegates for other method kinds as well.
        // @jcouv: We'd likely want to pay attention to this code if this happens.
        // What we really cared above was,
        // - "Are there any type parameters from the target method that we cannot discover simply from it's signature?"
        // As of C# 10, we only observe local functions could potentially answer yes, so we used that.
        // If this is hit, feel free to change but please also add tests.
        Debug.Assert(targetMethod.MethodKind == MethodKind.Ordinary);
 
        var usedTypeParameters = PooledHashSet<TypeParameterSymbol>.GetInstance();
        try
        {
            if ((targetMethod.IsAbstract || targetMethod.IsVirtual) && boundDelegateCreation.Argument is BoundTypeExpression typeExpression)
            {
                FindTypeParameters(typeExpression.Type, usedTypeParameters);
            }
 
            var delegateType = boundDelegateCreation.Type;
 
            FindTypeParameters(delegateType, usedTypeParameters);
            FindTypeParameters(targetMethod, usedTypeParameters);
 
            for (Symbol? enclosingSymbol = currentFunction; enclosingSymbol is MethodSymbol enclosingMethod; enclosingSymbol = enclosingSymbol.ContainingSymbol)
            {
                if (usedTypeParametersContains(usedTypeParameters, enclosingMethod.TypeParameters))
                {
                    ownerFunction = enclosingMethod;
                    return true;
                }
            }
 
            ownerFunction = null;
            return false;
        }
        finally
        {
            usedTypeParameters.Free();
        }
 
        static bool usedTypeParametersContains(HashSet<TypeParameterSymbol> used, ImmutableArray<TypeParameterSymbol> typeParameters)
        {
            foreach (var typeParameter in typeParameters)
            {
                if (used.Contains(typeParameter))
                {
                    return true;
                }
            }
 
            return false;
        }
    }
 
    private static void FindTypeParameters(TypeSymbol type, HashSet<TypeParameterSymbol> result)
        => type.VisitType(s_typeParameterSymbolCollector, result, visitCustomModifiers: true);
 
    private static void FindTypeParameters(MethodSymbol method, HashSet<TypeParameterSymbol> result)
    {
        FindTypeParameters(method.ContainingType, result);
 
        foreach (var typeArgument in method.TypeArgumentsWithAnnotations)
        {
            typeArgument.VisitType(type: null, typeWithAnnotationsPredicate: null, s_typeParameterSymbolCollector, result, visitCustomModifiers: true);
        }
    }
 
    private static readonly Func<TypeSymbol, HashSet<TypeParameterSymbol>, bool, bool> s_typeParameterSymbolCollector = (typeSymbol, result, _) =>
    {
        if (typeSymbol is TypeParameterSymbol typeParameter)
        {
            result.Add(typeParameter);
        }
 
        return false;
    };
}