File: Microsoft.NetCore.Analyzers\InteropServices\CSharpDisableRuntimeMarshalling.Fixer.cs
Web Access
Project: ..\..\..\src\Microsoft.CodeAnalysis.NetAnalyzers\src\Microsoft.CodeAnalysis.CSharp.NetAnalyzers\Microsoft.CodeAnalysis.CSharp.NetAnalyzers.csproj (Microsoft.CodeAnalysis.CSharp.NetAnalyzers)
// Copyright (c) Microsoft.  All Rights Reserved.  Licensed under the MIT license.  See License.txt in the project root for license information.
 
using System.Collections.Immutable;
using System.Composition;
using System.Threading;
using System.Threading.Tasks;
using Analyzer.Utilities;
using Analyzer.Utilities.Extensions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Operations;
 
namespace Microsoft.NetCore.Analyzers.InteropServices
{
    [ExportCodeFixProvider(LanguageNames.CSharp), Shared]
    public sealed partial class CSharpDisableRuntimeMarshallingFixer : CodeFixProvider
    {
        public override ImmutableArray<string> FixableDiagnosticIds { get; } = ImmutableArray.Create(DisableRuntimeMarshallingAnalyzer.MethodUsesRuntimeMarshallingEvenWhenMarshallingDisabledId);
 
        public sealed override FixAllProvider GetFixAllProvider()
        {
            return CustomFixAllProvider.Instance;
        }
 
        public override async Task RegisterCodeFixesAsync(CodeFixContext context)
        {
            if (context.Document.Project.CompilationOptions is CSharpCompilationOptions { AllowUnsafe: false })
            {
                // We can't code fix if unsafe code isn't allowed.
                return;
            }
 
            SyntaxNode root = await context.Document.GetRequiredSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
 
            SyntaxNode enclosingNode = root.FindNode(context.Span);
 
            foreach (var diagnostic in context.Diagnostics)
            {
                if (diagnostic.Properties[DisableRuntimeMarshallingAnalyzer.CanConvertToDisabledMarshallingEquivalentKey] is not null)
                {
                    context.RegisterCodeFix(
                        CodeAction.Create(
                            MicrosoftNetCoreAnalyzersResources.UseDisabledMarshallingEquivalentCodeFix,
                            async ct => await UseDisabledMarshallingEquivalentAsync(enclosingNode, context.Document, context.CancellationToken).ConfigureAwait(false),
                            equivalenceKey: nameof(MicrosoftNetCoreAnalyzersResources.UseDisabledMarshallingEquivalentCodeFix)),
                        diagnostic);
                }
            }
        }
 
        private static int? FindFirstUnusedIdentifierIndex(SemanticModel model, int docOffset, string baseName)
        {
            if (model.GetSpeculativeSymbolInfo(docOffset, SyntaxFactory.IdentifierName(baseName), SpeculativeBindingOption.BindAsExpression).Symbol is null)
            {
                return 0;
            }
 
            for (int i = 1; i < int.MaxValue; i++)
            {
                if (model.GetSpeculativeSymbolInfo(docOffset, SyntaxFactory.IdentifierName($"{baseName}{i}"), SpeculativeBindingOption.BindAsExpression).Symbol is null)
                {
                    return i;
                }
            }
 
            return 0;
        }
 
        private static async Task<Document> UseDisabledMarshallingEquivalentAsync(SyntaxNode node, Document document, CancellationToken ct)
        {
            var editor = await DocumentEditor.CreateAsync(document, ct).ConfigureAwait(false);
            var identifierGenerator = new IdentifierGenerator(editor.SemanticModel, node.SpanStart);
            var addUnsafeToEnclosingMethod = TryRewriteMethodCall(node, editor, identifierGenerator, addRenameAnnotation: true, ct);
 
            if (addUnsafeToEnclosingMethod)
            {
                AddUnsafeModifierToEnclosingMethod(editor, node);
            }
 
            return editor.GetChangedDocument();
        }
 
        private static bool TryRewriteMethodCall(SyntaxNode node, DocumentEditor editor, IdentifierGenerator pointerIdentifierGenerator, bool addRenameAnnotation, CancellationToken ct)
        {
            var operation = (IInvocationOperation)editor.SemanticModel.GetOperation(node, ct)!;
            InvocationExpressionSyntax syntax = (InvocationExpressionSyntax)operation.Syntax;
 
            if (operation.TargetMethod.Name == "SizeOf")
            {
                if (operation.TargetMethod.IsGenericMethod)
                {
                    if (operation.TargetMethod.TypeArguments[0].IsUnmanagedType)
                    {
                        editor.ReplaceNode(syntax, SyntaxFactory.SizeOfExpression((TypeSyntax)editor.Generator.TypeExpression(operation.TargetMethod.TypeArguments[0])));
                        return true;
                    }
                }
                else if (operation.Arguments[0].Value is ITypeOfOperation { TypeOperand.IsUnmanagedType: true } typeOf)
                {
                    editor.ReplaceNode(syntax, SyntaxFactory.SizeOfExpression(GetTypeOfTypeSyntax((TypeOfExpressionSyntax)typeOf.Syntax)));
                    return true;
                }
            }
 
            if (operation.TargetMethod.Name == "StructureToPtr" && operation.Arguments[0].Value.Type!.IsUnmanagedType)
            {
                editor.ReplaceNode(syntax,
                    editor.Generator.AssignmentStatement(
                        SyntaxFactory.PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression,
                            (ExpressionSyntax)editor.Generator.CastExpression(editor.SemanticModel.Compilation.CreatePointerTypeSymbol(operation.Arguments[0].Value.Type!),
                                operation.Arguments[1].Value.Syntax)),
                        operation.Arguments[0].Value.Syntax));
                return true;
            }
 
            if (operation.TargetMethod.Name == "PtrToStructure")
            {
                ITypeSymbol type;
                if (operation.TargetMethod.IsGenericMethod && operation.Arguments.Length == 1)
                {
                    type = operation.TargetMethod.TypeArguments[0];
                }
                else if (operation.TargetMethod.ReturnType.SpecialType == SpecialType.System_Object
                        && operation.Arguments.Length == 2
                        && operation.Arguments[1].Value is ITypeOfOperation typeOf)
                {
                    type = typeOf.TypeOperand;
                }
                else
                {
                    return false;
                }
 
                if (operation.Arguments.Length > 0)
                {
                    SyntaxNode replacementNode;
                    IOperation pointer = operation.Arguments[0].Value;
                    if (type.IsNullableValueType() && type.GetNullableValueTypeUnderlyingType() is ITypeSymbol { IsUnmanagedType: true } underlyingType)
                    {
                        var nonNullPtrIdentifier = pointerIdentifierGenerator.NextIdentifier();
                        if (nonNullPtrIdentifier is null)
                        {
                            // We couldn't generate an identifier to use, so don't update the call
                            return false;
                        }
 
                        SyntaxAnnotation renameIdentiferAnnotation = RenameAnnotation.Create();
 
                        IdentifierNameSyntax nonNullPtrIdentifierNode = SyntaxFactory.IdentifierName(nonNullPtrIdentifier);
 
                        if (addRenameAnnotation)
                        {
                            nonNullPtrIdentifierNode = nonNullPtrIdentifierNode.WithAdditionalAnnotations(renameIdentiferAnnotation);
                        }
 
                        var pointerCast = editor.Generator.CastExpression(
                                editor.SemanticModel.Compilation.CreatePointerTypeSymbol(underlyingType),
                                pointer.Syntax);
 
                        // Parse from a string since we're limited in SyntaxFactory methods due to the Roslyn version we build against.
                        // Use a dummy identifier for the expression since we want to replace it with `pointerCast` from above anyway
                        // to preserve the annotations that SyntaxGenerator provides.
                        var nullCheckAndDecl = (IsPatternExpressionSyntax)SyntaxFactory.ParseExpression($"x is not null and var {nonNullPtrIdentifier}");
                        nullCheckAndDecl = nullCheckAndDecl.WithExpression((ExpressionSyntax)pointerCast);
                        replacementNode = editor.Generator.ConditionalExpression(
                            nullCheckAndDecl,
                            SyntaxFactory.PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, nonNullPtrIdentifierNode),
                            editor.Generator.CastExpression(operation.TargetMethod.ReturnType, editor.Generator.NullLiteralExpression()));
                    }
                    else if (type is { IsUnmanagedType: true })
                    {
                        replacementNode = editor.Generator.CastExpression(operation.TargetMethod.ReturnType,
                                SyntaxFactory.ParenthesizedExpression(SyntaxFactory.PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression,
                                (ExpressionSyntax)editor.Generator.CastExpression(
                                    editor.SemanticModel.Compilation.CreatePointerTypeSymbol(type),
                                    pointer.Syntax))));
                    }
                    else
                    {
                        return false;
                    }
 
                    editor.ReplaceNode(syntax, replacementNode);
                    return true;
                }
            }
 
            return false;
 
            static TypeSyntax GetTypeOfTypeSyntax(TypeOfExpressionSyntax syntax)
            {
                return syntax.Type;
            }
        }
 
        private static void AddUnsafeModifierToEnclosingMethod(DocumentEditor editor, SyntaxNode syntax)
        {
            var enclosingMethod = FindEnclosingMethod(syntax);
 
            static BaseMethodDeclarationSyntax? FindEnclosingMethod(SyntaxNode syntax)
            {
                while (syntax.Parent is not (null or BaseMethodDeclarationSyntax))
                {
                    syntax = syntax.Parent;
                }
 
                return (BaseMethodDeclarationSyntax?)syntax.Parent;
            }
 
            editor.SetModifiers(enclosingMethod, editor.Generator.GetModifiers(enclosingMethod).WithIsUnsafe(true));
        }
    }
}