File: Lowering\AsyncRewriter\RuntimeAsyncRewriter.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.Collections;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
 
namespace Microsoft.CodeAnalysis.CSharp;
 
internal sealed class RuntimeAsyncRewriter : BoundTreeRewriterWithStackGuard
{
    public static BoundStatement Rewrite(
        BoundStatement node,
        MethodSymbol method,
        TypeCompilationState compilationState,
        BindingDiagnosticBag diagnostics)
    {
        if (!method.IsAsync)
        {
            return node;
        }
 
        var variablesToHoist = IteratorAndAsyncCaptureWalker.Analyze(compilationState.Compilation, method, node, isRuntimeAsync: true, diagnostics.DiagnosticBag);
        var hoistedLocals = ArrayBuilder<LocalSymbol>.GetInstance();
        var factory = new SyntheticBoundNodeFactory(method, node.Syntax, compilationState, diagnostics);
        var rewriter = new RuntimeAsyncRewriter(factory, variablesToHoist, hoistedLocals);
        var thisStore = hoistThisIfNeeded(rewriter);
        var result = (BoundStatement)rewriter.Visit(node);
 
        if (thisStore is not null)
        {
            result = factory.Block(hoistedLocals.ToImmutableAndFree(),
                factory.HiddenSequencePoint(),
                factory.ExpressionStatement(thisStore),
                result);
        }
        else if (hoistedLocals.Count > 0)
        {
            result = factory.Block(hoistedLocals.ToImmutableAndFree(), result);
        }
        else
        {
            hoistedLocals.Free();
        }
 
        return SpillSequenceSpiller.Rewrite(result, method, compilationState, diagnostics);
 
        static BoundAssignmentOperator? hoistThisIfNeeded(RuntimeAsyncRewriter rewriter)
        {
            Debug.Assert(rewriter._factory.CurrentFunction is not null);
            var thisParameter = rewriter._factory.CurrentFunction.ThisParameter;
            if (thisParameter is { Type.IsValueType: true, RefKind: not RefKind.None })
            {
                // This is a struct or a type parameter. We need to replace it with a hoisted local to preserve behavior from
                // compiler-generated state machines; `this` is a ref, but results are not observable outside of the method.
                // We do this regardless of whether `this` is captured to a ref local, because any usage of `ldarg.0` in these
                // scenarios is illegal after the first await. We could be more precise and only do this if `this` is actually
                // used after the first await, but at the moment we don't feel that is worth the complexity.
                var hoistedThis = rewriter._factory.StoreToTemp(rewriter._factory.This(), out BoundAssignmentOperator store, kind: SynthesizedLocalKind.AwaitByRefSpill);
                rewriter._hoistedLocals.Add(hoistedThis.LocalSymbol);
                rewriter._proxies.Add(thisParameter, new CapturedToExpressionSymbolReplacement<ParameterSymbol>(hoistedThis, hoistedSymbols: [], isReusable: true));
                return store;
            }
 
            return null;
        }
    }
 
    private readonly SyntheticBoundNodeFactory _factory;
    private readonly Dictionary<BoundAwaitableValuePlaceholder, BoundExpression> _placeholderMap;
    private readonly IReadOnlySet<Symbol> _variablesToHoist;
    private readonly RefInitializationHoister<LocalSymbol, BoundLocal> _refInitializationHoister;
    private readonly ArrayBuilder<LocalSymbol> _hoistedLocals;
    private readonly Dictionary<Symbol, CapturedSymbolReplacement> _proxies = [];
 
    private RuntimeAsyncRewriter(SyntheticBoundNodeFactory factory, IReadOnlySet<Symbol> variablesToHoist, ArrayBuilder<LocalSymbol> hoistedLocals)
    {
        Debug.Assert(factory.CurrentFunction != null);
        _factory = factory;
        _placeholderMap = [];
        _variablesToHoist = variablesToHoist;
        _refInitializationHoister = new RefInitializationHoister<LocalSymbol, BoundLocal>(_factory, _factory.CurrentFunction, TypeMap.Empty);
        _hoistedLocals = hoistedLocals;
    }
 
    [return: NotNullIfNotNull(nameof(node))]
    public override BoundNode? Visit(BoundNode? node)
    {
        if (node == null) return node;
        var oldSyntax = _factory.Syntax;
        _factory.Syntax = node.Syntax;
        var result = base.Visit(node);
        _factory.Syntax = oldSyntax;
        return result;
    }
 
    [return: NotNullIfNotNull(nameof(node))]
    public BoundExpression? VisitExpression(BoundExpression? node)
    {
        var result = Visit(node);
        return (BoundExpression?)result;
    }
 
    public override BoundNode? VisitAwaitExpression(BoundAwaitExpression node)
    {
        var nodeType = node.Expression.Type;
        Debug.Assert(nodeType is not null);
 
        var awaitableInfo = node.AwaitableInfo;
 
        if (awaitableInfo.IsDynamic)
        {
            // https://github.com/dotnet/roslyn/issues/79762: await dynamic will need runtime checks, see AsyncMethodToStateMachine.GenerateAwaitOnCompletedDynamic
            Debug.Assert(_factory.CurrentFunction is not null);
            // Method '{0}' uses a feature that is not supported by runtime async currently. Opt the method out of runtime async by attributing it with 'System.Runtime.CompilerServices.RuntimeAsyncMethodGenerationAttribute(false)'.
            _factory.Diagnostics.Add(ErrorCode.ERR_UnsupportedFeatureInRuntimeAsync,
                node.Syntax.Location,
                _factory.CurrentFunction);
            return node;
        }
 
        var runtimeAsyncAwaitCall = awaitableInfo.RuntimeAsyncAwaitCall;
        Debug.Assert(runtimeAsyncAwaitCall is not null);
        Debug.Assert(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder is not null);
        var runtimeAsyncAwaitMethod = runtimeAsyncAwaitCall.Method;
        Debug.Assert(runtimeAsyncAwaitMethod is not null);
        Debug.Assert(ReferenceEquals(
            runtimeAsyncAwaitMethod.ContainingType.OriginalDefinition,
            _factory.Compilation.GetSpecialType(InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers)));
        Debug.Assert(runtimeAsyncAwaitMethod.Name is "Await" or "UnsafeAwaitAwaiter" or "AwaitAwaiter");
 
        if (runtimeAsyncAwaitMethod.Name == "Await")
        {
            // This is the direct await case, with no need for the full pattern.
            // System.Runtime.CompilerServices.RuntimeHelpers.Await(awaitedExpression)
            var expr = VisitExpression(node.Expression);
            _placeholderMap.Add(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder, expr);
            var call = Visit(awaitableInfo.RuntimeAsyncAwaitCall);
            _placeholderMap.Remove(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder);
            return call;
        }
        else
        {
            return RewriteCustomAwaiterAwait(node);
        }
    }
 
    private BoundExpression RewriteCustomAwaiterAwait(BoundAwaitExpression node)
    {
        // await expr
        // becomes
        // var _tmp = expr.GetAwaiter();
        // if (!_tmp.IsCompleted)
        //    UnsafeAwaitAwaiter(_tmp) OR AwaitAwaiter(_tmp);
        // _tmp.GetResult()
 
        var expr = VisitExpression(node.Expression);
 
        var awaitableInfo = node.AwaitableInfo;
        var awaitablePlaceholder = awaitableInfo.AwaitableInstancePlaceholder;
        if (awaitablePlaceholder is not null)
        {
            _placeholderMap.Add(awaitablePlaceholder, expr);
        }
 
        // expr.GetAwaiter()
        var getAwaiter = VisitExpression(awaitableInfo.GetAwaiter);
        Debug.Assert(getAwaiter is not null);
 
        if (awaitablePlaceholder is not null)
        {
            _placeholderMap.Remove(awaitablePlaceholder);
        }
 
        // var _tmp = expr.GetAwaiter();
        var tmp = _factory.StoreToTemp(getAwaiter, out BoundAssignmentOperator store, kind: SynthesizedLocalKind.Awaiter);
 
        // _tmp.IsCompleted
        Debug.Assert(awaitableInfo.IsCompleted is not null);
        var isCompletedMethod = awaitableInfo.IsCompleted.GetMethod;
        Debug.Assert(isCompletedMethod is not null);
        var isCompletedCall = _factory.Call(tmp, isCompletedMethod);
 
        // UnsafeAwaitAwaiter(_tmp) OR AwaitAwaiter(_tmp)
        Debug.Assert(awaitableInfo.RuntimeAsyncAwaitCall is not null);
        Debug.Assert(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder is not null);
        _placeholderMap.Add(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder, tmp);
        var awaitCall = (BoundCall)Visit(awaitableInfo.RuntimeAsyncAwaitCall);
        _placeholderMap.Remove(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder);
 
        // if (!_tmp.IsCompleted) awaitCall
        var ifNotCompleted = _factory.If(_factory.Not(isCompletedCall), _factory.ExpressionStatement(awaitCall));
 
        // _tmp.GetResult()
        var getResultMethod = awaitableInfo.GetResult;
        Debug.Assert(getResultMethod is not null);
        var getResultCall = _factory.Call(tmp, getResultMethod);
 
        // final sequence
        return _factory.SpillSequence(
            locals: [tmp.LocalSymbol],
            sideEffects: [_factory.ExpressionStatement(store), ifNotCompleted],
            result: getResultCall);
    }
 
    public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
    {
        return _placeholderMap[node];
    }
 
    public override BoundNode? VisitAssignmentOperator(BoundAssignmentOperator node)
    {
        if (node.Left is not BoundLocal leftLocal)
        {
            return base.VisitAssignmentOperator(node);
        }
 
        BoundExpression visitedRight;
 
        if (_variablesToHoist.Contains(leftLocal.LocalSymbol) && !_proxies.ContainsKey(leftLocal.LocalSymbol))
        {
            Debug.Assert(leftLocal.LocalSymbol.SynthesizedKind == SynthesizedLocalKind.Spill ||
                         (leftLocal.LocalSymbol.SynthesizedKind == SynthesizedLocalKind.ForEachArray && leftLocal.LocalSymbol.Type.HasInlineArrayAttribute(out _) && leftLocal.LocalSymbol.Type.TryGetInlineArrayElementField() is object));
            Debug.Assert(node.IsRef);
            visitedRight = VisitExpression(node.Right);
            return _refInitializationHoister.HoistRefInitialization(
                leftLocal.LocalSymbol,
                visitedRight,
                _proxies,
                createHoistedLocal,
                createHoistedAccess,
                this,
                isRuntimeAsync: true);
        }
 
        var visitedLeftOrProxy = VisitExpression(leftLocal);
        visitedRight = VisitExpression(node.Right);
 
        if (visitedLeftOrProxy is not BoundLocal visitLeftLocal)
        {
            // Proxy replacement occurred. We need to reassign the proxy into our local as a sequence.
            // ref leftLocal = ref proxy;
            // leftLocal = visitedRight;
            var assignment = _factory.AssignmentExpression(leftLocal, visitedLeftOrProxy, isRef: true);
            return _factory.Sequence([assignment], node.Update(leftLocal, visitedRight, node.IsRef, node.Type));
        }
 
        return node.Update(visitedLeftOrProxy, visitedRight, node.IsRef, node.Type);
 
        static LocalSymbol createHoistedLocal(TypeSymbol type, RuntimeAsyncRewriter @this, LocalSymbol local)
        {
            var hoistedLocal = @this._factory.SynthesizedLocal(type, syntax: local.GetDeclaratorSyntax(), kind: SynthesizedLocalKind.AwaitByRefSpill);
            @this._hoistedLocals.Add(hoistedLocal);
            return hoistedLocal;
        }
 
        static BoundLocal createHoistedAccess(LocalSymbol local, RuntimeAsyncRewriter @this)
            => @this._factory.Local(local);
    }
 
    private bool TryReplaceWithProxy(Symbol localOrParameter, SyntaxNode syntax, [NotNullWhen(true)] out BoundNode? replacement)
    {
        if (_proxies.TryGetValue(localOrParameter, out CapturedSymbolReplacement? proxy))
        {
            replacement = proxy.Replacement(syntax, makeFrame: null, this);
            return true;
        }
 
        replacement = null;
        return false;
    }
 
    public override BoundNode VisitLocal(BoundLocal node)
    {
        if (TryReplaceWithProxy(node.LocalSymbol, node.Syntax, out BoundNode? replacement))
        {
            return replacement;
        }
 
        Debug.Assert(!_variablesToHoist.Contains(node.LocalSymbol));
        return base.VisitLocal(node)!;
    }
 
    public override BoundNode? VisitParameter(BoundParameter node)
    {
        if (TryReplaceWithProxy(node.ParameterSymbol, node.Syntax, out BoundNode? replacement))
        {
            // Currently, the only parameter we expect to be replaced is `this`, which is handled through VisitThisReference.
            // Any other ref to a parameter should have either already been hoisted to a local during local rewriting, or should
            // be an illegal ref to a parameter across an await.
            throw ExceptionUtilities.Unreachable();
        }
 
        Debug.Assert(!_variablesToHoist.Contains(node.ParameterSymbol));
        return base.VisitParameter(node);
    }
 
    public override BoundNode? VisitThisReference(BoundThisReference node)
    {
        Debug.Assert(_factory.CurrentFunction is not null);
        var thisParameter = this._factory.CurrentFunction.ThisParameter;
        if (TryReplaceWithProxy(thisParameter, node.Syntax, out BoundNode? replacement))
        {
            return replacement;
        }
 
        Debug.Assert(thisParameter is not { Type.IsValueType: true, RefKind: RefKind.Ref });
        return base.VisitThisReference(node);
    }
 
    public override BoundNode? VisitExpressionStatement(BoundExpressionStatement node)
    {
        var expr = VisitExpression(node.Expression);
        if (expr is null)
        {
            // Happens when the node is a hoisted expression that has no side effects.
            // The generated proxy will have the original content from this node and we can drop it.
            return _factory.StatementList();
        }
 
        return node.Update(expr);
    }
}