File: Lowering\AsyncRewriter\RuntimeAsyncRewriter.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Symbols;
 
namespace Microsoft.CodeAnalysis.CSharp;
 
internal sealed class RuntimeAsyncRewriter : BoundTreeRewriterWithStackGuard
{
    public static BoundStatement Rewrite(
        BoundStatement node,
        MethodSymbol method,
        TypeCompilationState compilationState,
        BindingDiagnosticBag diagnostics)
    {
        if (!method.IsAsync)
        {
            return node;
        }
 
        // https://github.com/dotnet/roslyn/issues/79763: struct lifting
        var variablesToHoist = IteratorAndAsyncCaptureWalker.Analyze(compilationState.Compilation, method, node, isRuntimeAsync: true, diagnostics.DiagnosticBag);
 
        if (variablesToHoist.Count > 0)
        {
            foreach (var variable in variablesToHoist)
            {
                // Method '{0}' uses a feature that is not supported by runtime async currently. Opt the method out of runtime async by attributing it with 'System.Runtime.CompilerServices.RuntimeAsyncMethodGenerationAttribute(false)'.
                diagnostics.Add(ErrorCode.ERR_UnsupportedFeatureInRuntimeAsync, variable.GetFirstLocation(), method);
            }
        }
 
        var rewriter = new RuntimeAsyncRewriter(new SyntheticBoundNodeFactory(method, node.Syntax, compilationState, diagnostics));
        var result = (BoundStatement)rewriter.Visit(node);
        return SpillSequenceSpiller.Rewrite(result, method, compilationState, diagnostics);
    }
 
    private readonly SyntheticBoundNodeFactory _factory;
    private readonly Dictionary<BoundAwaitableValuePlaceholder, BoundExpression> _placeholderMap;
 
    private RuntimeAsyncRewriter(SyntheticBoundNodeFactory factory)
    {
        _factory = factory;
        _placeholderMap = [];
    }
 
    [return: NotNullIfNotNull(nameof(node))]
    public BoundExpression? VisitExpression(BoundExpression? node)
    {
        var result = Visit(node);
        return (BoundExpression?)result;
    }
 
    public override BoundNode? VisitAwaitExpression(BoundAwaitExpression node)
    {
        var nodeType = node.Expression.Type;
        Debug.Assert(nodeType is not null);
 
        var awaitableInfo = node.AwaitableInfo;
 
        if (awaitableInfo.IsDynamic)
        {
            // https://github.com/dotnet/roslyn/issues/79762: await dynamic will need runtime checks, see AsyncMethodToStateMachine.GenerateAwaitOnCompletedDynamic
            Debug.Assert(_factory.CurrentFunction is not null);
            // Method '{0}' uses a feature that is not supported by runtime async currently. Opt the method out of runtime async by attributing it with 'System.Runtime.CompilerServices.RuntimeAsyncMethodGenerationAttribute(false)'.
            _factory.Diagnostics.Add(ErrorCode.ERR_UnsupportedFeatureInRuntimeAsync,
                node.Syntax.Location,
                _factory.CurrentFunction);
            return node;
        }
 
        var runtimeAsyncAwaitCall = awaitableInfo.RuntimeAsyncAwaitCall;
        Debug.Assert(runtimeAsyncAwaitCall is not null);
        Debug.Assert(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder is not null);
        var runtimeAsyncAwaitMethod = runtimeAsyncAwaitCall.Method;
        Debug.Assert(runtimeAsyncAwaitMethod is not null);
        Debug.Assert(ReferenceEquals(
            runtimeAsyncAwaitMethod.ContainingType.OriginalDefinition,
            _factory.Compilation.GetSpecialType(InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers)));
        Debug.Assert(runtimeAsyncAwaitMethod.Name is "Await" or "UnsafeAwaitAwaiter" or "AwaitAwaiter");
 
        if (runtimeAsyncAwaitMethod.Name == "Await")
        {
            // This is the direct await case, with no need for the full pattern.
            // System.Runtime.CompilerServices.RuntimeHelpers.Await(awaitedExpression)
            var expr = VisitExpression(node.Expression);
            _placeholderMap.Add(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder, expr);
            var call = Visit(awaitableInfo.RuntimeAsyncAwaitCall);
            _placeholderMap.Remove(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder);
            return call;
        }
        else
        {
            return RewriteCustomAwaiterAwait(node);
        }
    }
 
    private BoundExpression RewriteCustomAwaiterAwait(BoundAwaitExpression node)
    {
        // await expr
        // becomes
        // var _tmp = expr.GetAwaiter();
        // if (!_tmp.IsCompleted)
        //    UnsafeAwaitAwaiter(_tmp) OR AwaitAwaiter(_tmp);
        // _tmp.GetResult()
 
        var expr = VisitExpression(node.Expression);
 
        var awaitableInfo = node.AwaitableInfo;
        var awaitablePlaceholder = awaitableInfo.AwaitableInstancePlaceholder;
        if (awaitablePlaceholder is not null)
        {
            _placeholderMap.Add(awaitablePlaceholder, expr);
        }
 
        // expr.GetAwaiter()
        var getAwaiter = VisitExpression(awaitableInfo.GetAwaiter);
        Debug.Assert(getAwaiter is not null);
 
        if (awaitablePlaceholder is not null)
        {
            _placeholderMap.Remove(awaitablePlaceholder);
        }
 
        // var _tmp = expr.GetAwaiter();
        var tmp = _factory.StoreToTemp(getAwaiter, out BoundAssignmentOperator store, kind: SynthesizedLocalKind.Awaiter);
 
        // _tmp.IsCompleted
        Debug.Assert(awaitableInfo.IsCompleted is not null);
        var isCompletedMethod = awaitableInfo.IsCompleted.GetMethod;
        Debug.Assert(isCompletedMethod is not null);
        var isCompletedCall = _factory.Call(tmp, isCompletedMethod);
 
        // UnsafeAwaitAwaiter(_tmp) OR AwaitAwaiter(_tmp)
        Debug.Assert(awaitableInfo.RuntimeAsyncAwaitCall is not null);
        Debug.Assert(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder is not null);
        _placeholderMap.Add(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder, tmp);
        var awaitCall = (BoundCall)Visit(awaitableInfo.RuntimeAsyncAwaitCall);
        _placeholderMap.Remove(awaitableInfo.RuntimeAsyncAwaitCallPlaceholder);
 
        // if (!_tmp.IsCompleted) awaitCall
        var ifNotCompleted = _factory.If(_factory.Not(isCompletedCall), _factory.ExpressionStatement(awaitCall));
 
        // _tmp.GetResult()
        var getResultMethod = awaitableInfo.GetResult;
        Debug.Assert(getResultMethod is not null);
        var getResultCall = _factory.Call(tmp, getResultMethod);
 
        // final sequence
        return _factory.SpillSequence(
            locals: [tmp.LocalSymbol],
            sideEffects: [_factory.ExpressionStatement(store), ifNotCompleted],
            result: getResultCall);
    }
 
    public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
    {
        return _placeholderMap[node];
    }
}