File: RequiresUnsafeCodeFixProvider.cs
Web Access
Project: src\src\tools\illink\src\ILLink.CodeFix\ILLink.CodeFixProvider.csproj (ILLink.CodeFixProvider)
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
 
#if DEBUG
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using ILLink.CodeFixProvider;
using ILLink.RoslynAnalyzer;
using ILLink.Shared;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
 
namespace ILLink.CodeFix
{
    [ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(RequiresUnsafeCodeFixProvider)), Shared]
    public sealed class RequiresUnsafeCodeFixProvider : BaseAttributeCodeFixProvider
    {
        private const string WrapInUnsafeBlockTitle = "Wrap in unsafe block";
 
        public static ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => ImmutableArray.Create(DiagnosticDescriptors.GetDiagnosticDescriptor(DiagnosticId.RequiresUnsafe));
 
        public sealed override ImmutableArray<string> FixableDiagnosticIds => SupportedDiagnostics.Select(dd => dd.Id).ToImmutableArray();
 
        private protected override LocalizableString CodeFixTitle => new LocalizableResourceString(nameof(Resources.RequiresUnsafeCodeFixTitle), Resources.ResourceManager, typeof(Resources));
 
        private protected override string FullyQualifiedAttributeName => RequiresUnsafeAnalyzer.FullyQualifiedRequiresUnsafeAttribute;
 
        private protected override AttributeableParentTargets AttributableParentTargets => AttributeableParentTargets.MethodOrConstructor;
 
        public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
        {
            // Register the base code fix (add RequiresUnsafe attribute)
            await BaseRegisterCodeFixesAsync(context).ConfigureAwait(false);
 
            // Register the "wrap in unsafe block" code fix
            var document = context.Document;
            var diagnostic = context.Diagnostics.First();
 
            if (await document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false) is not { } root)
                return;
 
            SyntaxNode targetNode = root.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true);
 
            // Find the statement containing the unsafe call
            var containingStatement = targetNode.AncestorsAndSelf().OfType<StatementSyntax>().FirstOrDefault();
 
            // Check if this is a local function with expression body - treat it like expression-bodied member
            if (containingStatement is LocalFunctionStatementSyntax localFunc && localFunc.ExpressionBody != null)
            {
                if (!HasDirectiveTrivia(localFunc.ExpressionBody))
                {
                    context.RegisterCodeFix(CodeAction.Create(
                        title: WrapInUnsafeBlockTitle,
                        createChangedDocument: ct => ConvertExpressionBodyToUnsafeBlockAsync(document, localFunc.ExpressionBody, ct),
                        equivalenceKey: WrapInUnsafeBlockTitle), diagnostic);
                }
                return;
            }
 
            if (containingStatement is null || containingStatement is BlockSyntax)
            {
                // Try expression-bodied member
                var arrowExpr = targetNode.AncestorsAndSelf().OfType<ArrowExpressionClauseSyntax>().FirstOrDefault();
                if (arrowExpr != null && !HasDirectiveTrivia(arrowExpr))
                {
                    context.RegisterCodeFix(CodeAction.Create(
                        title: WrapInUnsafeBlockTitle,
                        createChangedDocument: ct => ConvertExpressionBodyToUnsafeBlockAsync(document, arrowExpr, ct),
                        equivalenceKey: WrapInUnsafeBlockTitle), diagnostic);
                }
                return;
            }
 
            // Find the parent block containing this statement
            var parentBlock = containingStatement.Parent as BlockSyntax;
            if (parentBlock != null)
            {
                context.RegisterCodeFix(CodeAction.Create(
                    title: WrapInUnsafeBlockTitle,
                    createChangedDocument: ct => WrapStatementsInUnsafeBlockAsync(document, parentBlock, containingStatement, ct),
                    equivalenceKey: WrapInUnsafeBlockTitle), diagnostic);
                return;
            }
 
            // Handle switch case sections
            var switchSection = containingStatement.Parent as SwitchSectionSyntax;
            if (switchSection != null)
            {
                context.RegisterCodeFix(CodeAction.Create(
                    title: WrapInUnsafeBlockTitle,
                    createChangedDocument: ct => WrapSwitchSectionStatementInUnsafeBlockAsync(document, switchSection, containingStatement, ct),
                    equivalenceKey: WrapInUnsafeBlockTitle), diagnostic);
                return;
            }
 
            // Handle embedded statements (if/else/while/for without braces)
            if (IsEmbeddedStatement(containingStatement))
            {
                context.RegisterCodeFix(CodeAction.Create(
                    title: WrapInUnsafeBlockTitle,
                    createChangedDocument: ct => WrapEmbeddedStatementInUnsafeBlockAsync(document, containingStatement, ct),
                    equivalenceKey: WrapInUnsafeBlockTitle), diagnostic);
                return;
            }
        }
 
        private static bool IsEmbeddedStatement(StatementSyntax statement)
        {
            // An embedded statement is a statement that is the direct child of a control flow statement
            // without being wrapped in a block (e.g., "if (x) return;" instead of "if (x) { return; }")
            return statement.Parent is IfStatementSyntax
                || statement.Parent is ElseClauseSyntax
                || statement.Parent is WhileStatementSyntax
                || statement.Parent is ForStatementSyntax
                || statement.Parent is ForEachStatementSyntax
                || statement.Parent is DoStatementSyntax
                || statement.Parent is UsingStatementSyntax
                || statement.Parent is LockStatementSyntax
                || statement.Parent is FixedStatementSyntax;
        }
 
        private static async Task<Document> WrapStatementsInUnsafeBlockAsync(
            Document document,
            BlockSyntax parentBlock,
            StatementSyntax triggerStatement,
            CancellationToken cancellationToken)
        {
            var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            if (semanticModel == null)
                return document;
 
            // Find the trigger statement in the block
            var statements = parentBlock.Statements;
            int triggerIndex = statements.IndexOf(triggerStatement);
            if (triggerIndex < 0)
                return document;
 
            // Check if any statement has directive trivia that would be lost or mangled
            var leadingTrivia = triggerStatement.GetLeadingTrivia();
            if (leadingTrivia.Any(t => t.IsDirective))
            {
                // Skip the fix - directives in statement trivia would be lost or malformed
                return document;
            }
 
            var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
 
            // Create the TODO comment
            var todoComment = SyntaxFactory.Comment("// TODO(unsafe): Baselining unsafe usage");
            var newLine = SyntaxFactory.ElasticCarriageReturnLineFeed;
 
            // Check if we can use forward declaration strategy
            // This applies when the trigger is a local declaration with a single variable
            // Ref locals (including ref readonly and scoped ref) cannot be forward-declared
            LocalDeclarationStatementSyntax? forwardDecl = null;
            StatementSyntax statementToWrap = triggerStatement;
            int endIndex = triggerIndex;  // For block expansion when forward decl not possible
 
            bool isRefOrScopedLocal = triggerStatement is LocalDeclarationStatementSyntax localDeclCheck &&
                (localDeclCheck.Declaration.Type is RefTypeSyntax || localDeclCheck.Declaration.Type is ScopedTypeSyntax);
 
            if (triggerStatement is LocalDeclarationStatementSyntax localDecl &&
                !localDecl.IsConst &&
                !isRefOrScopedLocal &&
                localDecl.Declaration.Variables.Count == 1 &&
                localDecl.Declaration.Variables[0].Initializer != null)
            {
                var variable = localDecl.Declaration.Variables[0];
                TypeSyntax? typeSyntax = localDecl.Declaration.Type;
 
                // If using 'var', resolve to explicit type
                if (typeSyntax.IsVar)
                {
                    var typeInfo = semanticModel.GetTypeInfo(typeSyntax, cancellationToken);
                    if (typeInfo.Type is not null and not IErrorTypeSymbol)
                    {
                        typeSyntax = SyntaxFactory.ParseTypeName(typeInfo.Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat))
                            .WithTrailingTrivia(SyntaxFactory.Space);
                    }
                    else
                    {
                        // Can't resolve type, fall back to wrapping the whole declaration
                        typeSyntax = null;
                    }
                }
 
                if (typeSyntax != null)
                {
                    // Create forward declaration: Type varName;
                    var forwardDeclVariable = SyntaxFactory.VariableDeclarator(variable.Identifier);
                    var forwardDeclDeclaration = SyntaxFactory.VariableDeclaration(typeSyntax)
                        .AddVariables(forwardDeclVariable);
                    forwardDecl = SyntaxFactory.LocalDeclarationStatement(forwardDeclDeclaration)
                        .WithLeadingTrivia(triggerStatement.GetLeadingTrivia())
                        .WithTrailingTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticCarriageReturnLineFeed));
 
                    // Create assignment: varName = initializer;
                    var assignment = SyntaxFactory.AssignmentExpression(
                        SyntaxKind.SimpleAssignmentExpression,
                        SyntaxFactory.IdentifierName(variable.Identifier),
                        variable.Initializer!.Value);
                    statementToWrap = SyntaxFactory.ExpressionStatement(assignment)
                        .WithTrailingTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticCarriageReturnLineFeed));
                }
            }
            else if (isRefOrScopedLocal)
            {
                // For ref/scoped locals, we must expand the block until no variables
                // declared inside are used outside. This handles chains of dependencies
                // where ref locals lead to regular locals that are also used later.
                var localDeclStmt = (LocalDeclarationStatementSyntax)triggerStatement;
                if (localDeclStmt.Declaration.Variables.Count == 1)
                {
                    // Iteratively expand until no variables declared inside escape outside
                    bool expanded = true;
                    while (expanded && endIndex < statements.Count - 1)
                    {
                        expanded = false;
 
                        // Collect all variables declared in the current range (using symbols for correct scoping)
                        var declaredVariableSymbols = new HashSet<ISymbol>(SymbolEqualityComparer.Default);
                        for (int i = triggerIndex; i <= endIndex; i++)
                        {
                            if (statements[i] is LocalDeclarationStatementSyntax rangeLocalDecl)
                            {
                                foreach (var variable in rangeLocalDecl.Declaration.Variables)
                                {
                                    var symbol = semanticModel.GetDeclaredSymbol(variable, cancellationToken);
                                    if (symbol is not null)
                                        declaredVariableSymbols.Add(symbol);
                                }
                            }
                        }
 
                        // Check ALL remaining statements - does any use a declared variable?
                        for (int nextIndex = endIndex + 1; nextIndex < statements.Count; nextIndex++)
                        {
                            var stmt = statements[nextIndex];
 
                            bool usesAnyDeclaredVariable = stmt.DescendantNodes()
                                .OfType<IdentifierNameSyntax>()
                                .Any(id =>
                                {
                                    var symbolInfo = semanticModel.GetSymbolInfo(id, cancellationToken);
                                    return symbolInfo.Symbol is not null && declaredVariableSymbols.Contains(symbolInfo.Symbol);
                                });
 
                            if (usesAnyDeclaredVariable)
                            {
                                // Check for directive trivia that would be mangled
                                if (stmt.GetLeadingTrivia().Any(t => t.IsDirective))
                                {
                                    // Stop expansion here - don't include statements with directives
                                    break;
                                }
 
                                // Expand to include this statement
                                endIndex = nextIndex;
                                expanded = true;
                                break; // Restart the outer loop to recollect variables
                            }
                        }
                    }
                }
            }
 
            // Build the statements to wrap
            List<StatementSyntax> statementsToWrap;
            if (forwardDecl != null)
            {
                statementsToWrap = new List<StatementSyntax> { statementToWrap };
            }
            else
            {
                statementsToWrap = statements.Skip(triggerIndex).Take(endIndex - triggerIndex + 1).ToList();
            }
 
            // Create the unsafe block
            var wrappedStatements = statementsToWrap.Select(s =>
                s.WithoutTrivia()
                 .WithTrailingTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticCarriageReturnLineFeed)));
            var unsafeBlock = SyntaxFactory.UnsafeStatement(
                SyntaxFactory.Block(wrappedStatements))
                .WithLeadingTrivia(forwardDecl != null
                    ? SyntaxFactory.TriviaList(todoComment, newLine)
                    : triggerStatement.GetLeadingTrivia().InsertRange(0, new[] { todoComment, newLine }))
                .WithTrailingTrivia(forwardDecl != null
                    ? triggerStatement.GetTrailingTrivia()
                    : statementsToWrap.Last().GetTrailingTrivia());
 
            // Build the new list of statements
            var newStatements = new List<StatementSyntax>();
            for (int i = 0; i < statements.Count; i++)
            {
                if (i == triggerIndex)
                {
                    if (forwardDecl != null)
                    {
                        newStatements.Add(forwardDecl);
                    }
                    newStatements.Add(unsafeBlock);
                }
                else if (i > triggerIndex && i <= endIndex)
                {
                    // Skip - already included in unsafe block
                }
                else
                {
                    newStatements.Add(statements[i]);
                }
            }
 
            var newBlock = parentBlock.WithStatements(SyntaxFactory.List(newStatements));
            editor.ReplaceNode(parentBlock, newBlock);
 
            return editor.GetChangedDocument();
        }
 
        private static async Task<Document> WrapSwitchSectionStatementInUnsafeBlockAsync(
            Document document,
            SwitchSectionSyntax _,
            StatementSyntax triggerStatement,
            CancellationToken cancellationToken)
        {
            var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
 
            // Create the TODO comment
            var todoComment = SyntaxFactory.Comment("// TODO(unsafe): Baselining unsafe usage");
            var newLine = SyntaxFactory.CarriageReturnLineFeed;
 
            // Create the unsafe block wrapping just the trigger statement
            var unsafeBlock = SyntaxFactory.UnsafeStatement(
                SyntaxFactory.Block(triggerStatement.WithoutTrivia()))
                .WithLeadingTrivia(triggerStatement.GetLeadingTrivia().InsertRange(0, new[] { todoComment, newLine }))
                .WithTrailingTrivia(triggerStatement.GetTrailingTrivia());
 
            // Replace the trigger statement with the unsafe block
            editor.ReplaceNode(triggerStatement, unsafeBlock);
 
            return editor.GetChangedDocument();
        }
 
        private static async Task<Document> WrapEmbeddedStatementInUnsafeBlockAsync(
            Document document,
            StatementSyntax embeddedStatement,
            CancellationToken cancellationToken)
        {
            var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
 
            // Create the TODO comment
            var todoComment = SyntaxFactory.Comment("// TODO(unsafe): Baselining unsafe usage");
            var newLine = SyntaxFactory.CarriageReturnLineFeed;
 
            // Create the unsafe block wrapping the statement
            var unsafeBlock = SyntaxFactory.UnsafeStatement(
                SyntaxFactory.Block(embeddedStatement.WithoutTrivia()))
                .WithLeadingTrivia(todoComment, newLine);
 
            // Wrap in a block to replace the embedded statement
            var wrappingBlock = SyntaxFactory.Block(unsafeBlock)
                .WithLeadingTrivia(embeddedStatement.GetLeadingTrivia())
                .WithTrailingTrivia(embeddedStatement.GetTrailingTrivia());
 
            editor.ReplaceNode(embeddedStatement, wrappingBlock);
 
            return editor.GetChangedDocument();
        }
 
        private static async Task<Document> ConvertExpressionBodyToUnsafeBlockAsync(
            Document document,
            ArrowExpressionClauseSyntax arrowExpr,
            CancellationToken cancellationToken)
        {
            var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
 
            // Create the TODO comment
            var todoComment = SyntaxFactory.Comment("// TODO(unsafe): Baselining unsafe usage");
            var newLine = SyntaxFactory.CarriageReturnLineFeed;
 
            // Get the parent member to determine the return type
            var parent = arrowExpr.Parent;
            bool isVoid = false;
 
            if (parent is MethodDeclarationSyntax method)
            {
                isVoid = method.ReturnType is PredefinedTypeSyntax pts && pts.Keyword.IsKind(SyntaxKind.VoidKeyword);
            }
            else if (parent is LocalFunctionStatementSyntax localFunc)
            {
                isVoid = localFunc.ReturnType is PredefinedTypeSyntax pts && pts.Keyword.IsKind(SyntaxKind.VoidKeyword);
            }
            else if (parent is DestructorDeclarationSyntax)
            {
                isVoid = true;
            }
            else if (parent is AccessorDeclarationSyntax or PropertyDeclarationSyntax or IndexerDeclarationSyntax)
            {
                isVoid = false;
            }
 
            // Create the statement that goes inside the unsafe block
            StatementSyntax innerStatement = isVoid
                ? SyntaxFactory.ExpressionStatement(arrowExpr.Expression.WithoutTrivia())
                : SyntaxFactory.ReturnStatement(arrowExpr.Expression.WithoutTrivia());
 
            // Create the unsafe block
            var unsafeBlock = SyntaxFactory.UnsafeStatement(
                SyntaxFactory.Block(innerStatement))
                .WithLeadingTrivia(todoComment, newLine);
 
            // Create the block body
            var blockBody = SyntaxFactory.Block(unsafeBlock);
 
            // Replace based on parent type
            switch (parent)
            {
                case MethodDeclarationSyntax methodDecl:
                    var newMethod = methodDecl
                        .WithExpressionBody(null)
                        .WithSemicolonToken(default)
                        .WithBody(blockBody);
                    editor.ReplaceNode(methodDecl, newMethod);
                    break;
 
                case LocalFunctionStatementSyntax localFunc:
                    var newLocalFunc = localFunc
                        .WithExpressionBody(null)
                        .WithSemicolonToken(default)
                        .WithBody(blockBody);
                    editor.ReplaceNode(localFunc, newLocalFunc);
                    break;
 
                case DestructorDeclarationSyntax destructorDecl:
                    var newDestructor = destructorDecl
                        .WithExpressionBody(null)
                        .WithSemicolonToken(default)
                        .WithBody(blockBody);
                    editor.ReplaceNode(destructorDecl, newDestructor);
                    break;
 
                case PropertyDeclarationSyntax propDecl:
                    var getter = SyntaxFactory.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
                        .WithBody(SyntaxFactory.Block(unsafeBlock));
                    var accessorList = SyntaxFactory.AccessorList(SyntaxFactory.SingletonList(getter));
                    var newProp = propDecl
                        .WithExpressionBody(null)
                        .WithSemicolonToken(default)
                        .WithAccessorList(accessorList);
                    editor.ReplaceNode(propDecl, newProp);
                    break;
 
                case AccessorDeclarationSyntax accessor:
                    var newAccessor = accessor
                        .WithExpressionBody(null)
                        .WithSemicolonToken(default)
                        .WithBody(SyntaxFactory.Block(unsafeBlock));
                    editor.ReplaceNode(accessor, newAccessor);
                    break;
 
                case IndexerDeclarationSyntax indexerDecl:
                    var indexerGetter = SyntaxFactory.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
                        .WithBody(SyntaxFactory.Block(unsafeBlock));
                    var indexerAccessorList = SyntaxFactory.AccessorList(SyntaxFactory.SingletonList(indexerGetter));
                    var newIndexer = indexerDecl
                        .WithExpressionBody(null)
                        .WithSemicolonToken(default)
                        .WithAccessorList(indexerAccessorList);
                    editor.ReplaceNode(indexerDecl, newIndexer);
                    break;
            }
 
            return editor.GetChangedDocument();
        }
 
        protected override SyntaxNode[] GetAttributeArguments(ISymbol? attributableSymbol, ISymbol targetSymbol, SyntaxGenerator syntaxGenerator, Diagnostic diagnostic) =>
            RequiresHelpers.GetAttributeArgumentsForRequires(targetSymbol, syntaxGenerator, HasPublicAccessibility(attributableSymbol));
 
        /// <summary>
        /// Checks if the arrow expression clause or its expression has preprocessor directive trivia.
        /// Converting expression bodies with directives to block bodies is error-prone, so we skip the fix.
        /// </summary>
        private static bool HasDirectiveTrivia(ArrowExpressionClauseSyntax arrowExpr)
        {
            // Check the arrow expression clause's leading trivia (e.g., #if or #else before =>)
            if (arrowExpr.GetLeadingTrivia().Any(t => t.IsDirective))
                return true;
 
            // Check the arrow token's trailing trivia (e.g., #if right after =>)
            if (arrowExpr.ArrowToken.TrailingTrivia.Any(t => t.IsDirective))
                return true;
 
            // Check the expression's leading trivia (e.g., #if before the expression)
            if (arrowExpr.Expression.GetLeadingTrivia().Any(t => t.IsDirective))
                return true;
 
            // Check the expression's trailing trivia (e.g., #endif after the expression)
            if (arrowExpr.Expression.GetTrailingTrivia().Any(t => t.IsDirective))
                return true;
 
            // Check the arrow expression clause's trailing trivia
            if (arrowExpr.GetTrailingTrivia().Any(t => t.IsDirective))
                return true;
 
            // Check the parent member for directives between the signature and the arrow
            // e.g., method() #if FOO => expr1; #else => expr2; #endif
            if (arrowExpr.Parent is MethodDeclarationSyntax method)
            {
                // Check trivia after the parameter list (where #if might appear)
                if (method.ParameterList.GetTrailingTrivia().Any(t => t.IsDirective))
                    return true;
                // Check constraint clauses if present
                foreach (var constraint in method.ConstraintClauses)
                {
                    if (constraint.GetTrailingTrivia().Any(t => t.IsDirective))
                        return true;
                }
            }
            else if (arrowExpr.Parent is LocalFunctionStatementSyntax localFunc)
            {
                if (localFunc.ParameterList.GetTrailingTrivia().Any(t => t.IsDirective))
                    return true;
            }
            else if (arrowExpr.Parent is PropertyDeclarationSyntax prop)
            {
                if (prop.Identifier.TrailingTrivia.Any(t => t.IsDirective))
                    return true;
            }
 
            return false;
        }
    }
}
#endif