File: Lowering\ExtensionMethodReferenceRewriter.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
 
namespace Microsoft.CodeAnalysis.CSharp
{
    /// <summary>
    /// Replaces references to extension methods with references to their implementation methods
    /// </summary>
    internal sealed class ExtensionMethodReferenceRewriter : BoundTreeRewriterWithStackGuardWithoutRecursionOnTheLeftOfBinaryOperator
    {
        private ExtensionMethodReferenceRewriter()
        {
        }
 
        public static BoundStatement Rewrite(BoundStatement statement)
        {
            var rewriter = new ExtensionMethodReferenceRewriter();
            return (BoundStatement)rewriter.Visit(statement);
        }
 
        public override BoundNode VisitCall(BoundCall node)
        {
            return VisitCall(this, node);
        }
 
        public static BoundNode VisitCall(BoundTreeRewriter rewriter, BoundCall node)
        {
            Debug.Assert(node != null);
 
            BoundExpression rewrittenCall;
 
            if (LocalRewriter.TryGetReceiver(node, out BoundCall? receiver1))
            {
                // Handle long call chain of both instance and extension method invocations.
                var calls = ArrayBuilder<BoundCall>.GetInstance();
 
                calls.Push(node);
                node = receiver1;
 
                while (LocalRewriter.TryGetReceiver(node, out BoundCall? receiver2))
                {
                    calls.Push(node);
                    node = receiver2;
                }
 
                // Rewrite the receiver
                BoundExpression? rewrittenReceiver = (BoundExpression?)rewriter.Visit(node.ReceiverOpt);
 
                do
                {
                    rewrittenCall = visitArgumentsAndFinishRewrite(rewriter, node, rewrittenReceiver);
                    rewrittenReceiver = rewrittenCall;
                }
                while (calls.TryPop(out node!));
 
                calls.Free();
            }
            else
            {
                // Rewrite the receiver
                BoundExpression? rewrittenReceiver = (BoundExpression?)rewriter.Visit(node.ReceiverOpt);
                rewrittenCall = visitArgumentsAndFinishRewrite(rewriter, node, rewrittenReceiver);
            }
 
            return rewrittenCall;
 
            static BoundExpression visitArgumentsAndFinishRewrite(BoundTreeRewriter rewriter, BoundCall node, BoundExpression? rewrittenReceiver)
            {
                return updateCall(
                    node,
                    VisitMethodSymbolWithExtensionRewrite(rewriter, node.Method),
                    rewriter.VisitSymbols(node.OriginalMethodsOpt),
                    rewrittenReceiver,
                    rewriter.VisitList(node.Arguments),
                    node.ArgumentRefKindsOpt,
                    node.InvokedAsExtensionMethod,
                    rewriter.VisitType(node.Type));
            }
 
            static BoundExpression updateCall(
                BoundCall boundCall,
                MethodSymbol method,
                ImmutableArray<MethodSymbol> originalMethodsOpt,
                BoundExpression? receiverOpt,
                ImmutableArray<BoundExpression> arguments,
                ImmutableArray<RefKind> argumentRefKinds,
                bool invokedAsExtensionMethod,
                TypeSymbol type)
            {
                if (receiverOpt is not null && arguments.Length == method.ParameterCount - 1)
                {
                    Debug.Assert(boundCall.Method.OriginalDefinition.TryGetCorrespondingExtensionImplementationMethod() == (object)method.OriginalDefinition);
                    Debug.Assert(!boundCall.Method.IsStatic);
 
                    var receiverRefKind = method.Parameters[0].RefKind;
 
                    if (argumentRefKinds.IsDefault)
                    {
                        if (receiverRefKind != RefKind.None)
                        {
                            var builder = ArrayBuilder<RefKind>.GetInstance(method.ParameterCount, RefKind.None);
                            builder[0] = argumentRefKindFromReceiverRefKind(receiverRefKind);
                            argumentRefKinds = builder.ToImmutableAndFree();
                        }
                    }
                    else
                    {
                        argumentRefKinds = argumentRefKinds.Insert(0, argumentRefKindFromReceiverRefKind(receiverRefKind)); // Tracked by https://github.com/dotnet/roslyn/issues/76130 : Test this code path
                    }
 
                    invokedAsExtensionMethod = true;
 
                    Debug.Assert(receiverOpt.Type!.Equals(method.Parameters[0].Type, TypeCompareKind.ConsiderEverything));
 
                    arguments = arguments.Insert(0, receiverOpt);
                    receiverOpt = null;
                }
 
                return boundCall.Update(
                    receiverOpt,
                    boundCall.InitialBindingReceiverIsSubjectToCloning,
                    method,
                    arguments,
                    default,
                    argumentRefKinds,
                    boundCall.IsDelegateCall,
                    boundCall.Expanded,
                    invokedAsExtensionMethod,
                    default,
                    default,
                    boundCall.ResultKind,
                    originalMethodsOpt,
                    type);
 
                static RefKind argumentRefKindFromReceiverRefKind(RefKind receiverRefKind)
                {
                    return SyntheticBoundNodeFactory.ArgumentRefKindFromParameterRefKind(receiverRefKind, useStrictArgumentRefKinds: false);
                }
            }
        }
 
        [return: NotNullIfNotNull(nameof(method))]
        private static MethodSymbol? VisitMethodSymbolWithExtensionRewrite(BoundTreeRewriter rewriter, MethodSymbol? method)
        {
            if (method?.GetIsNewExtensionMember() == true &&
                method.OriginalDefinition.TryGetCorrespondingExtensionImplementationMethod() is MethodSymbol implementationMethod)
            {
                method = implementationMethod.AsMember(method.ContainingSymbol.ContainingType).
                    ConstructIfGeneric(method.ContainingType.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics.Concat(method.TypeArgumentsWithAnnotations));
            }
 
            return rewriter.VisitMethodSymbol(method);
        }
 
        [return: NotNullIfNotNull(nameof(method))]
        public override MethodSymbol? VisitMethodSymbol(MethodSymbol? method)
        {
            Debug.Assert(method?.GetIsNewExtensionMember() != true ||
                         method.OriginalDefinition.TryGetCorrespondingExtensionImplementationMethod() is null);
            // All possibly interesting methods should go through VisitMethodSymbolWithExtensionRewrite first
            Debug.Assert(method is null ||
                         method.ContainingSymbol is not NamedTypeSymbol ||
                         method.MethodKind is (MethodKind.Constructor or MethodKind.StaticConstructor) ||
                         method.OriginalDefinition is ErrorMethodSymbol ||
                         new StackTrace(fNeedFileInfo: false).GetFrame(1)?.GetMethod() switch
                         {
                             { Name: nameof(VisitTypeOfOperator) } => method is { Name: "GetTypeFromHandle", IsExtensionMethod: false }, // GetTypeFromHandle cannot be an extension method
                             { Name: nameof(VisitRefTypeOperator) } => method is { Name: "GetTypeFromHandle", IsExtensionMethod: false }, // GetTypeFromHandle cannot be an extension method
                             { Name: nameof(VisitReadOnlySpanFromArray) } => method is { Name: "op_Implicit", IsExtensionMethod: false }, // Conversion operator from array to span cannot be an extension method
                             { Name: nameof(VisitLoweredConditionalAccess) } => // Nullable.HasValue cannot be an extension method
                                            method.ContainingAssembly.GetSpecialTypeMember(SpecialMember.System_Nullable_T_get_HasValue) == (object)method.OriginalDefinition,
                             { Name: nameof(VisitUnaryOperator) } => !method.IsExtensionMethod, // Expression tree context. At the moment an operator cannot be an extension method
                             { Name: nameof(VisitUserDefinedConditionalLogicalOperator) } => !method.IsExtensionMethod, // Expression tree context. At the moment an operator cannot be an extension method
                             { Name: nameof(VisitCollectionElementInitializer) } => !method.IsExtensionMethod, // Expression tree context. At the moment an extension method cannot be used in expression tree here.
                             { Name: nameof(VisitAwaitableInfo) } => method is { Name: "GetResult", IsExtensionMethod: false }, // Cannot be an extension method
                             { Name: nameof(VisitMethodSymbolWithExtensionRewrite), DeclaringType: { } declaringType } => declaringType == typeof(ExtensionMethodReferenceRewriter),
                             _ => false
                         });
 
            return base.VisitMethodSymbol(method);
        }
 
        public override BoundNode? VisitMethodDefIndex(BoundMethodDefIndex node)
        {
            MethodSymbol method = node.Method;
            Debug.Assert(method.IsDefinition); // Tracked by https://github.com/dotnet/roslyn/issues/76130 : From the code coverage and other instrumentations perspective, should we remap the index to the implementation symbol? 
            TypeSymbol? type = this.VisitType(node.Type);
            return node.Update(method, type);
        }
 
        public override BoundNode? VisitDelegateCreationExpression(BoundDelegateCreationExpression node)
        {
            return VisitDelegateCreationExpression(this, node);
        }
 
        public static BoundNode VisitDelegateCreationExpression(BoundTreeRewriter rewriter, BoundDelegateCreationExpression node)
        {
            var methodOpt = VisitMethodSymbolWithExtensionRewrite(rewriter, node.MethodOpt);
            var argument = (BoundExpression)rewriter.Visit(node.Argument);
            var type = rewriter.VisitType(node.Type);
            bool isExtensionMethod = node.IsExtensionMethod;
 
            if (!isExtensionMethod && argument is not BoundTypeExpression && methodOpt?.IsStatic == true)
            {
                Debug.Assert(node.MethodOpt!.OriginalDefinition.TryGetCorrespondingExtensionImplementationMethod() == (object)methodOpt.OriginalDefinition);
                isExtensionMethod = true;
            }
 
            return node.Update(argument, methodOpt, isExtensionMethod, node.WasTargetTyped, type);
        }
 
        public override BoundNode VisitFunctionPointerLoad(BoundFunctionPointerLoad node)
        {
            return VisitFunctionPointerLoad(this, node);
        }
 
        public static BoundNode VisitFunctionPointerLoad(BoundTreeRewriter rewriter, BoundFunctionPointerLoad node)
        {
            MethodSymbol targetMethod = VisitMethodSymbolWithExtensionRewrite(rewriter, node.TargetMethod);
            TypeSymbol? constrainedToTypeOpt = rewriter.VisitType(node.ConstrainedToTypeOpt);
            TypeSymbol? type = rewriter.VisitType(node.Type);
            return node.Update(targetMethod, constrainedToTypeOpt, type);
        }
 
        protected override BoundBinaryOperator.UncommonData? VisitBinaryOperatorData(BoundBinaryOperator node)
        {
            Debug.Assert(node.Method is null ||
                         (!node.Method.IsExtensionMethod && !node.Method.GetIsNewExtensionMember())); // Expression tree context. At the moment an operator cannot be an extension method
 
            return base.VisitBinaryOperatorData(node);
        }
 
        [return: NotNullIfNotNull(nameof(symbol))]
        public override PropertySymbol? VisitPropertySymbol(PropertySymbol? symbol)
        {
            Debug.Assert(symbol?.GetIsNewExtensionMember() != true);
            return base.VisitPropertySymbol(symbol);
        }
    }
}