File: Lowering\BoundTreeToDifferentEnclosingContextRewriter.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp
{
    /// <summary>
    /// a bound node rewriter that rewrites types properly (which in some cases the automatically-generated
    /// base class does not).  This is used in the lambda rewriter, the iterator rewriter, and the async rewriter.
    /// </summary>
    internal abstract class BoundTreeToDifferentEnclosingContextRewriter : BoundTreeRewriterWithStackGuardWithoutRecursionOnTheLeftOfBinaryOperator
    {
        // A mapping from every local variable to its replacement local variable.  Local variables are replaced when
        // their types change due to being inside of a generic method.  Otherwise we reuse the original local (even
        // though its containing method is not correct because the code is moved into another method)
        private readonly Dictionary<LocalSymbol, LocalSymbol> localMap = new Dictionary<LocalSymbol, LocalSymbol>();
 
        // A mapping for types in the original method to types in its replacement.  This is mainly necessary
        // when the original method was generic, as type parameters in the original method are mapping into
        // type parameters of the resulting class.
        protected abstract TypeMap TypeMap { get; }
 
        protected abstract MethodSymbol CurrentMethod { get; }
 
        public override BoundNode DefaultVisit(BoundNode node)
        {
            Debug.Fail($"Override the visitor for {node.Kind}");
            return base.DefaultVisit(node);
        }
 
        protected void RewriteLocals(ImmutableArray<LocalSymbol> locals, ArrayBuilder<LocalSymbol> newLocals)
        {
            foreach (var local in locals)
            {
                if (TryRewriteLocal(local, out LocalSymbol? newLocal))
                {
                    newLocals.Add(newLocal);
                }
            }
        }
 
        protected virtual bool TryRewriteLocal(LocalSymbol local, [NotNullWhen(true)] out LocalSymbol? newLocal)
        {
            if (localMap.TryGetValue(local, out newLocal))
            {
                return true;
            }
 
            var newType = VisitType(local.Type);
            if (TypeSymbol.Equals(newType, local.Type, TypeCompareKind.ConsiderEverything2))
            {
                newLocal = local;
            }
            else
            {
                newLocal = new TypeSubstitutedLocalSymbol(local, TypeWithAnnotations.Create(newType), CurrentMethod);
                localMap.Add(local, newLocal);
            }
 
            return true;
        }
 
        protected sealed override ImmutableArray<LocalSymbol> VisitLocals(ImmutableArray<LocalSymbol> locals)
        {
            if (locals.IsEmpty) return locals;
            var newLocals = ArrayBuilder<LocalSymbol>.GetInstance();
            RewriteLocals(locals, newLocals);
            return newLocals.ToImmutableAndFree();
        }
 
        public sealed override LocalSymbol VisitLocalSymbol(LocalSymbol local)
        {
            if (!TryRewriteLocal(local, out var newLocal))
            {
                throw ExceptionUtilities.UnexpectedValue(local);
            }
 
            return newLocal;
        }
 
        protected bool TryGetRewrittenLocal(LocalSymbol local, [NotNullWhen(true)] out LocalSymbol? localToUse)
        {
            return localMap.TryGetValue(local, out localToUse);
        }
 
        public override BoundNode VisitBlock(BoundBlock node)
            => VisitBlock(node, removeInstrumentation: false);
 
        protected BoundBlock VisitBlock(BoundBlock node, bool removeInstrumentation)
        {
            // Note: Instrumentation variable is intentionally not rewritten. It should never be lifted.
 
            var newLocals = this.VisitLocals(node.Locals);
            var newLocalFunctions = this.VisitDeclaredLocalFunctions(node.LocalFunctions);
            var newStatements = VisitList(node.Statements);
            var newInstrumentation = removeInstrumentation ? null : (BoundBlockInstrumentation?)Visit(node.Instrumentation);
            return node.Update(newLocals, newLocalFunctions, node.HasUnsafeModifier, newInstrumentation, newStatements);
        }
 
        [return: NotNullIfNotNull(nameof(type))]
        public sealed override TypeSymbol? VisitType(TypeSymbol? type)
        {
            return TypeMap.SubstituteType(type).Type;
        }
 
        protected override BoundBinaryOperator.UncommonData? VisitBinaryOperatorData(BoundBinaryOperator node)
        {
            // Local rewriter should have already rewritten interpolated strings into their final form of calls and gotos
            Debug.Assert(node.InterpolatedStringHandlerData is null);
 
            return BoundBinaryOperator.UncommonData.CreateIfNeeded(node.ConstantValueOpt, VisitMethodSymbol(node.Method), VisitType(node.ConstrainedToType), node.OriginalUserDefinedOperatorsOpt);
        }
 
        public override BoundNode? VisitConversion(BoundConversion node)
        {
            var conversion = node.Conversion;
 
            if (conversion.Method is not null)
            {
                conversion = conversion.SetConversionMethod(VisitMethodSymbol(conversion.Method));
            }
 
            return node.Update(
                (BoundExpression)Visit(node.Operand),
                conversion,
                node.IsBaseConversion,
                node.Checked,
                node.ExplicitCastInCode,
                node.ConstantValueOpt,
                node.ConversionGroupOpt,
                VisitType(node.Type));
        }
 
        [return: NotNullIfNotNull(nameof(method))]
        public override MethodSymbol? VisitMethodSymbol(MethodSymbol? method)
        {
            if (method is null)
            {
                return null;
            }
 
            if (method.ContainingType.IsAnonymousType)
            {
                //  Method of an anonymous type
                var newType = (NamedTypeSymbol)TypeMap.SubstituteType(method.ContainingType).AsTypeSymbolOnly();
                if (ReferenceEquals(newType, method.ContainingType))
                {
                    //  Anonymous type symbol was not rewritten
                    return method;
                }
 
                //  get a new method by name
                foreach (var member in newType.GetMembers(method.Name))
                {
                    if (member.Kind == SymbolKind.Method)
                    {
                        return (MethodSymbol)member;
                    }
                }
 
                throw ExceptionUtilities.Unreachable();
            }
            else
            {
                //  Method of a regular type
                return ((MethodSymbol)method.OriginalDefinition)
                    .AsMember((NamedTypeSymbol)TypeMap.SubstituteType(method.ContainingType).AsTypeSymbolOnly())
                    .ConstructIfGeneric(TypeMap.SubstituteTypes(method.TypeArgumentsWithAnnotations));
            }
        }
    }
}