File: BoundTree\NullabilityRewriter.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp
{
    internal sealed partial class NullabilityRewriter : BoundTreeRewriter
    {
        protected override BoundNode? VisitExpressionOrPatternWithoutStackGuard(BoundNode node)
        {
            return Visit(node);
        }
 
        public override BoundNode? VisitBinaryOperator(BoundBinaryOperator node)
        {
            return VisitBinaryOperatorBase(node);
        }
 
        public override BoundNode? VisitUserDefinedConditionalLogicalOperator(BoundUserDefinedConditionalLogicalOperator node)
        {
            return VisitBinaryOperatorBase(node);
        }
 
        public override BoundNode? VisitIfStatement(BoundIfStatement node)
        {
            var stack = ArrayBuilder<(BoundIfStatement, BoundExpression, BoundStatement)>.GetInstance();
 
            BoundStatement? rewrittenAlternative;
            while (true)
            {
                var rewrittenCondition = (BoundExpression)Visit(node.Condition);
                var rewrittenConsequence = (BoundStatement)Visit(node.Consequence);
                Debug.Assert(rewrittenConsequence is { });
                stack.Push((node, rewrittenCondition, rewrittenConsequence));
 
                var alternative = node.AlternativeOpt;
                if (alternative is null)
                {
                    rewrittenAlternative = null;
                    break;
                }
 
                if (alternative is BoundIfStatement elseIfStatement)
                {
                    node = elseIfStatement;
                }
                else
                {
                    rewrittenAlternative = (BoundStatement)Visit(alternative);
                    break;
                }
            }
 
            BoundStatement result;
            do
            {
                var (ifStatement, rewrittenCondition, rewrittenConsequence) = stack.Pop();
                result = ifStatement.Update(rewrittenCondition, rewrittenConsequence, rewrittenAlternative);
                rewrittenAlternative = result;
            }
            while (stack.Any());
 
            stack.Free();
            return result;
        }
 
        private BoundNode VisitBinaryOperatorBase(BoundBinaryOperatorBase binaryOperator)
        {
            // Use an explicit stack to avoid blowing the managed stack when visiting deeply-recursive
            // binary nodes
            var stack = ArrayBuilder<BoundBinaryOperatorBase>.GetInstance();
            BoundBinaryOperatorBase? currentBinary = binaryOperator;
 
            do
            {
                stack.Push(currentBinary);
                currentBinary = currentBinary.Left as BoundBinaryOperatorBase;
            }
            while (currentBinary is not null);
 
            Debug.Assert(stack.Count > 0);
            var leftChild = (BoundExpression)Visit(stack.Peek().Left);
 
            do
            {
                currentBinary = stack.Pop();
 
                bool foundInfo = _updatedNullabilities.TryGetValue(currentBinary, out (NullabilityInfo Info, TypeSymbol? Type) infoAndType);
                var right = (BoundExpression)Visit(currentBinary.Right);
                var type = foundInfo ? infoAndType.Type : currentBinary.Type;
 
                currentBinary = currentBinary switch
                {
                    BoundBinaryOperator binary => binary.Update(
                        binary.OperatorKind,
                        binary.Data?.WithUpdatedMethod(GetUpdatedSymbol(binary, binary.Method)),
                        binary.ResultKind,
                        leftChild,
                        right,
                        type!),
                    // https://github.com/dotnet/roslyn/issues/35031: We'll need to update logical.LogicalOperator
                    BoundUserDefinedConditionalLogicalOperator logical => logical.Update(logical.OperatorKind, logical.LogicalOperator, logical.TrueOperator, logical.FalseOperator, logical.ConstrainedToTypeOpt, logical.ResultKind, logical.OriginalUserDefinedOperatorsOpt, leftChild, right, type!),
                    _ => throw ExceptionUtilities.UnexpectedValue(currentBinary.Kind),
                };
 
                if (foundInfo)
                {
                    currentBinary.TopLevelNullability = infoAndType.Info;
                }
 
                leftChild = currentBinary;
            }
            while (stack.Count > 0);
 
            Debug.Assert(currentBinary != null);
            return currentBinary!;
        }
 
        private T GetUpdatedSymbol<T>(BoundNode expr, T sym) where T : Symbol?
        {
            if (sym is null) return sym;
 
            Symbol? updatedSymbol = null;
            if (_snapshotManager?.TryGetUpdatedSymbol(expr, sym, out updatedSymbol) != true)
            {
                updatedSymbol = sym;
            }
            RoslynDebug.Assert(updatedSymbol is object);
 
            switch (updatedSymbol)
            {
                case LambdaSymbol lambda:
                    return (T)remapLambda((BoundLambda)expr, lambda);
 
                case SourceLocalSymbol local:
                    return (T)remapLocal(local);
 
                case ParameterSymbol param:
                    if (_remappedSymbols.TryGetValue(param, out var updatedParam))
                    {
                        return (T)updatedParam;
                    }
                    break;
            }
 
            return (T)updatedSymbol;
 
            Symbol remapLambda(BoundLambda boundLambda, LambdaSymbol lambda)
            {
                var updatedDelegateType = _snapshotManager?.GetUpdatedDelegateTypeForLambda(lambda);
 
                if (!_remappedSymbols.TryGetValue(lambda.ContainingSymbol, out Symbol? updatedContaining) && updatedDelegateType is null)
                {
                    return lambda;
                }
 
                LambdaSymbol updatedLambda;
                if (updatedDelegateType is null)
                {
                    Debug.Assert(updatedContaining is object);
                    updatedLambda = boundLambda.CreateLambdaSymbol(updatedContaining, lambda.ReturnTypeWithAnnotations, lambda.ParameterTypesWithAnnotations, lambda.ParameterRefKinds, lambda.RefKind);
                }
                else
                {
                    Debug.Assert(updatedDelegateType is object);
                    updatedLambda = boundLambda.CreateLambdaSymbol(updatedDelegateType, updatedContaining ?? lambda.ContainingSymbol);
                }
 
                _remappedSymbols.Add(lambda, updatedLambda);
 
                Debug.Assert(lambda.ParameterCount == updatedLambda.ParameterCount);
                for (int i = 0; i < lambda.ParameterCount; i++)
                {
                    _remappedSymbols.Add(lambda.Parameters[i], updatedLambda.Parameters[i]);
                }
 
                return updatedLambda;
            }
 
            Symbol remapLocal(SourceLocalSymbol local)
            {
                if (_remappedSymbols.TryGetValue(local, out var updatedLocal))
                {
                    return updatedLocal;
                }
 
                var updatedType = _snapshotManager?.GetUpdatedTypeForLocalSymbol(local);
 
                if (!_remappedSymbols.TryGetValue(local.ContainingSymbol, out Symbol? updatedContaining) && !updatedType.HasValue)
                {
                    // Map the local to itself so we don't have to search again in the future
                    _remappedSymbols.Add(local, local);
                    return local;
                }
 
                updatedLocal = new UpdatedContainingSymbolAndNullableAnnotationLocal(local, updatedContaining ?? local.ContainingSymbol, updatedType ?? local.TypeWithAnnotations);
                _remappedSymbols.Add(local, updatedLocal);
                return updatedLocal;
            }
        }
 
        public override BoundNode? VisitImplicitIndexerAccess(BoundImplicitIndexerAccess node)
        {
            BoundExpression receiver = (BoundExpression)this.Visit(node.Receiver);
            BoundExpression argument = (BoundExpression)this.Visit(node.Argument);
            BoundExpression lengthOrCountAccess = node.LengthOrCountAccess;
            BoundExpression indexerAccess = (BoundExpression)this.Visit(node.IndexerOrSliceAccess);
            BoundImplicitIndexerAccess updatedNode;
 
            if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol? Type) infoAndType))
            {
                updatedNode = node.Update(receiver, argument, lengthOrCountAccess, node.ReceiverPlaceholder, indexerAccess, node.ArgumentPlaceholders, infoAndType.Type!);
                updatedNode.TopLevelNullability = infoAndType.Info;
            }
            else
            {
                updatedNode = node.Update(receiver, argument, lengthOrCountAccess, node.ReceiverPlaceholder, indexerAccess, node.ArgumentPlaceholders, node.Type);
            }
            return updatedNode;
        }
 
        private ImmutableArray<T> GetUpdatedArray<T>(BoundNode expr, ImmutableArray<T> symbols) where T : Symbol?
        {
            if (symbols.IsDefaultOrEmpty)
            {
                return symbols;
            }
 
            var builder = ArrayBuilder<T>.GetInstance(symbols.Length);
            bool foundUpdate = false;
            foreach (var originalSymbol in symbols)
            {
                T updatedSymbol = null!;
                if (originalSymbol is object)
                {
                    updatedSymbol = GetUpdatedSymbol(expr, originalSymbol);
                    Debug.Assert(updatedSymbol is object);
                    if ((object)originalSymbol != updatedSymbol)
                    {
                        foundUpdate = true;
                    }
                }
 
                builder.Add(updatedSymbol);
            }
 
            if (foundUpdate)
            {
                return builder.ToImmutableAndFree();
            }
            else
            {
                builder.Free();
                return symbols;
            }
        }
    }
}