File: Lowering\ExtensionMethodBodyRewriter.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Roslyn.Utilities;
using ReferenceEqualityComparer = Roslyn.Utilities.ReferenceEqualityComparer;
 
namespace Microsoft.CodeAnalysis.CSharp
{
    /// <summary>
    /// Rewrites method body lowered in context of an extension method to a body
    /// that corresponds to its implementation form.
    /// </summary>
    internal sealed class ExtensionMethodBodyRewriter : BoundTreeToDifferentEnclosingContextRewriter
    {
        private readonly SourceExtensionImplementationMethodSymbol _implementationMethod;
 
        /// <summary>
        /// Maps parameters and local functions from original enclosing context to corresponding rewritten symbols for rewritten context.
        /// </summary>
        private ImmutableDictionary<Symbol, Symbol> _symbolMap;
 
        private RewrittenMethodSymbol _rewrittenContainingMethod;
 
        public ExtensionMethodBodyRewriter(MethodSymbol sourceMethod, SourceExtensionImplementationMethodSymbol implementationMethod)
        {
            Debug.Assert(sourceMethod is not null);
            Debug.Assert(implementationMethod is not null);
            Debug.Assert(sourceMethod == (object)implementationMethod.UnderlyingMethod);
 
            _implementationMethod = implementationMethod;
            _symbolMap = ImmutableDictionary<Symbol, Symbol>.Empty.WithComparers(ReferenceEqualityComparer.Instance, ReferenceEqualityComparer.Instance);
 
            bool haveExtraParameter = sourceMethod.ParameterCount != implementationMethod.ParameterCount;
            if (haveExtraParameter)
            {
                Debug.Assert(implementationMethod.ParameterCount - 1 == sourceMethod.ParameterCount);
                var receiverParameter = (WrappedParameterSymbol)implementationMethod.Parameters[0];
                _symbolMap = _symbolMap.Add(receiverParameter.UnderlyingParameter, receiverParameter);
            }
            EnterMethod(sourceMethod, implementationMethod, implementationMethod.Parameters.AsSpan()[(haveExtraParameter ? 1 : 0)..]);
            Debug.Assert(_rewrittenContainingMethod is not null);
        }
 
        private (RewrittenMethodSymbol, ImmutableDictionary<Symbol, Symbol>) EnterMethod(MethodSymbol symbol, RewrittenMethodSymbol rewritten, ReadOnlySpan<ParameterSymbol> rewrittenParameters)
        {
            ImmutableDictionary<Symbol, Symbol> saveSymbolMap = _symbolMap;
            RewrittenMethodSymbol savedContainer = _rewrittenContainingMethod;
 
            Debug.Assert(symbol.Parameters.Length == rewrittenParameters.Length);
 
            if (!rewrittenParameters.IsEmpty)
            {
                var builder = _symbolMap.ToBuilder();
                foreach (var parameter in symbol.Parameters)
                {
                    builder.Add(parameter, rewrittenParameters[parameter.Ordinal]);
                }
                _symbolMap = builder.ToImmutable();
            }
 
            _rewrittenContainingMethod = rewritten;
 
            return (savedContainer, saveSymbolMap);
        }
 
        private (RewrittenMethodSymbol, ImmutableDictionary<Symbol, Symbol>) EnterMethod(MethodSymbol symbol, RewrittenLambdaOrLocalFunctionSymbol rewritten)
        {
            return EnterMethod(symbol, rewritten, rewritten.Parameters.AsSpan());
        }
 
        protected override MethodSymbol CurrentMethod => _rewrittenContainingMethod;
 
        protected override TypeMap TypeMap => _rewrittenContainingMethod.TypeMap;
 
        public override BoundNode? VisitThisReference(BoundThisReference node)
        {
            throw ExceptionUtilities.Unreachable();
        }
 
        public override ParameterSymbol VisitParameterSymbol(ParameterSymbol symbol)
        {
            return (ParameterSymbol)_symbolMap[symbol];
        }
 
        public override BoundNode? VisitLambda(BoundLambda node)
        {
            var rewritten = new RewrittenLambdaOrLocalFunctionSymbol(node.Symbol, _rewrittenContainingMethod);
 
            var savedState = EnterMethod(node.Symbol, rewritten);
            BoundBlock body = (BoundBlock)this.Visit(node.Body);
            (_rewrittenContainingMethod, _symbolMap) = savedState;
 
            TypeSymbol? type = this.VisitType(node.Type);
            return node.Update(node.UnboundLambda, rewritten, body, node.Diagnostics, node.Binder, type);
        }
 
        public override BoundNode? VisitLocalFunctionStatement(BoundLocalFunctionStatement node)
        {
            MethodSymbol symbol = this.VisitMethodSymbol(node.Symbol);
            var savedState = EnterMethod(node.Symbol, (RewrittenLambdaOrLocalFunctionSymbol)symbol);
 
            BoundBlock? blockBody = (BoundBlock?)this.Visit(node.BlockBody);
            BoundBlock? expressionBody = (BoundBlock?)this.Visit(node.ExpressionBody);
 
            (_rewrittenContainingMethod, _symbolMap) = savedState;
 
            return node.Update(symbol, blockBody, expressionBody);
        }
 
        public override BoundNode VisitBlock(BoundBlock node)
        {
            ImmutableDictionary<Symbol, Symbol> saveSymbolMap = _symbolMap;
 
            if (!node.LocalFunctions.IsEmpty)
            {
                var builder = _symbolMap.ToBuilder();
 
                foreach (var localFunction in node.LocalFunctions)
                {
                    builder.Add(localFunction, new RewrittenLambdaOrLocalFunctionSymbol(localFunction, _rewrittenContainingMethod));
                }
 
                _symbolMap = builder.ToImmutable();
            }
 
            var result = base.VisitBlock(node);
 
            _symbolMap = saveSymbolMap;
            return result;
        }
 
        protected override ImmutableArray<MethodSymbol> VisitDeclaredLocalFunctions(ImmutableArray<MethodSymbol> localFunctions)
        {
            return localFunctions.SelectAsArray(static (l, map) => (MethodSymbol)map[l], _symbolMap);
        }
 
        [return: NotNullIfNotNull(nameof(symbol))]
        public override MethodSymbol? VisitMethodSymbol(MethodSymbol? symbol)
        {
            switch (symbol?.MethodKind)
            {
                case MethodKind.LambdaMethod:
                    throw ExceptionUtilities.Unreachable();
 
                case MethodKind.LocalFunction:
                    if (symbol.IsDefinition)
                    {
                        return (MethodSymbol)_symbolMap[symbol];
                    }
 
                    return ((MethodSymbol)_symbolMap[symbol.OriginalDefinition]).ConstructIfGeneric(TypeMap.SubstituteTypes(symbol.TypeArgumentsWithAnnotations));
 
                default:
                    return base.VisitMethodSymbol(symbol);
            }
        }
 
        [return: NotNullIfNotNull(nameof(symbol))]
        public override FieldSymbol? VisitFieldSymbol(FieldSymbol? symbol)
        {
            if (symbol is null)
            {
                return null;
            }
 
            return symbol.OriginalDefinition
                .AsMember((NamedTypeSymbol)TypeMap.SubstituteType(symbol.ContainingType).AsTypeSymbolOnly());
        }
 
        public override BoundNode? VisitCall(BoundCall node)
        {
            return ExtensionMethodReferenceRewriter.VisitCall(this, node);
        }
 
        public override BoundNode? VisitDelegateCreationExpression(BoundDelegateCreationExpression node)
        {
            return ExtensionMethodReferenceRewriter.VisitDelegateCreationExpression(this, node);
        }
 
        public override BoundNode VisitFunctionPointerLoad(BoundFunctionPointerLoad node)
        {
            return ExtensionMethodReferenceRewriter.VisitFunctionPointerLoad(this, node);
        }
 
        [return: NotNullIfNotNull(nameof(symbol))]
        public override PropertySymbol? VisitPropertySymbol(PropertySymbol? symbol)
        {
            Debug.Assert(symbol?.GetIsNewExtensionMember() != true);
            return base.VisitPropertySymbol(symbol);
        }
    }
}