File: src\Analyzers\CSharp\CodeFixes\UsePatternMatching\CSharpAsAndNullCheckCodeFixProvider.cs
Web Access
Project: src\src\CodeStyle\CSharp\CodeFixes\Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes.csproj (Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.CSharp.UsePatternMatching;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
[ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.UsePatternMatchingAsAndNullCheck), Shared]
[method: ImportingConstructor]
[method: SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
internal sealed partial class CSharpAsAndNullCheckCodeFixProvider() : SyntaxEditorBasedCodeFixProvider
{
    public override ImmutableArray<string> FixableDiagnosticIds
        => [IDEDiagnosticIds.InlineAsTypeCheckId];
 
    public override Task RegisterCodeFixesAsync(CodeFixContext context)
    {
        RegisterCodeFix(context, CSharpAnalyzersResources.Use_pattern_matching, nameof(CSharpAnalyzersResources.Use_pattern_matching));
        return Task.CompletedTask;
    }
 
    protected override async Task FixAllAsync(
        Document document, ImmutableArray<Diagnostic> diagnostics,
        SyntaxEditor editor, CancellationToken cancellationToken)
    {
        using var _1 = PooledHashSet<Location>.GetInstance(out var declaratorLocations);
        using var _2 = PooledHashSet<SyntaxNode>.GetInstance(out var statementParentScopes);
 
        var tree = await document.GetRequiredSyntaxTreeAsync(cancellationToken).ConfigureAwait(false);
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        var languageVersion = tree.Options.LanguageVersion();
 
        foreach (var diagnostic in diagnostics)
        {
            cancellationToken.ThrowIfCancellationRequested();
 
            if (declaratorLocations.Add(diagnostic.AdditionalLocations[0]))
                AddEdits(editor, semanticModel, diagnostic, languageVersion, RemoveStatement, cancellationToken);
        }
 
        foreach (var parentScope in statementParentScopes)
        {
            editor.ReplaceNode(parentScope, (newParentScope, syntaxGenerator) =>
            {
                var firstStatement = newParentScope is BlockSyntax
                    ? ((BlockSyntax)newParentScope).Statements.First()
                    : ((SwitchSectionSyntax)newParentScope).Statements.First();
                return syntaxGenerator.ReplaceNode(newParentScope, firstStatement, firstStatement.WithoutLeadingBlankLinesInTrivia());
            });
        }
 
        return;
 
        void RemoveStatement(StatementSyntax statement)
        {
            editor.RemoveNode(statement, SyntaxRemoveOptions.KeepNoTrivia);
            if (statement.Parent is BlockSyntax or SwitchSectionSyntax)
            {
                statementParentScopes.Add(statement.Parent);
            }
        }
    }
 
    private static void AddEdits(
        SyntaxEditor editor,
        SemanticModel semanticModel,
        Diagnostic diagnostic,
        LanguageVersion languageVersion,
        Action<StatementSyntax> removeStatement,
        CancellationToken cancellationToken)
    {
        var declaratorLocation = diagnostic.AdditionalLocations[0];
        var comparisonLocation = diagnostic.AdditionalLocations[1];
        var asExpressionLocation = diagnostic.AdditionalLocations[2];
 
        var declarator = (VariableDeclaratorSyntax)declaratorLocation.FindNode(cancellationToken);
        var comparison = (ExpressionSyntax)comparisonLocation.FindNode(cancellationToken);
        var asExpression = (BinaryExpressionSyntax)asExpressionLocation.FindNode(cancellationToken);
 
        var rightSideOfComparison = comparison is BinaryExpressionSyntax binaryExpression
            ? (SyntaxNode)binaryExpression.Right
            : ((IsPatternExpressionSyntax)comparison).Pattern;
        var newIdentifier = declarator.Identifier
            .WithoutTrivia().WithTrailingTrivia(rightSideOfComparison.GetTrailingTrivia());
 
        var declarationPattern = DeclarationPattern(
            GetPatternType().WithoutTrivia().WithTrailingTrivia(ElasticMarker),
            SingleVariableDesignation(newIdentifier));
 
        var condition = GetCondition(languageVersion, comparison, asExpression, declarationPattern);
 
        if (declarator.Parent is VariableDeclarationSyntax declaration &&
            declaration.Parent is LocalDeclarationStatementSyntax localDeclaration &&
            declaration.Variables.Count == 1)
        {
            // Trivia on the local declaration will move to the next statement.
            // use the callback form as the next statement may be the place where we're
            // inlining the declaration, and thus need to see the effects of that change.
            editor.ReplaceNode(
                localDeclaration.GetNextStatement()!,
                (s, g) => s.WithPrependedNonIndentationTriviaFrom(localDeclaration)
                           .WithAdditionalAnnotations(Formatter.Annotation));
 
            removeStatement(localDeclaration);
        }
        else
        {
            editor.RemoveNode(declarator, SyntaxRemoveOptions.KeepUnbalancedDirectives);
        }
 
        editor.ReplaceNode(comparison, condition.WithTriviaFrom(comparison));
 
        return;
 
        TypeSyntax GetPatternType()
        {
            // Complex case: object?[]? arr = obj as object[];
            //
            // Because of array variance, the above is legal.  We want the `object?[]` from the LHS here.
            if (semanticModel.GetDeclaredSymbol(declarator, cancellationToken) is ILocalSymbol local)
            {
                var asExpressionTypeInfo = semanticModel.GetTypeInfo(asExpression, cancellationToken);
                if (asExpressionTypeInfo.Type != null)
                {
                    // Strip off the outer ? if present.  But the inner ? will still be there.
                    var localType = local.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
                    var asType = asExpressionTypeInfo.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
 
                    // If they're the same types, except for the inner ?, then use the local's type here.
                    if (SymbolEqualityComparer.Default.Equals(localType, asType) &&
                        !SymbolEqualityComparer.IncludeNullability.Equals(localType, asType))
                    {
                        return localType.GenerateTypeSyntax(allowVar: false);
                    }
                }
            }
 
            return (TypeSyntax)asExpression.Right;
        }
    }
 
    private static ExpressionSyntax GetCondition(
        LanguageVersion languageVersion,
        ExpressionSyntax comparison,
        BinaryExpressionSyntax asExpression,
        DeclarationPatternSyntax declarationPattern)
    {
        var isPatternExpression = IsPatternExpression(asExpression.Left, declarationPattern);
 
        // We should negate the is-expression if we have something like "x == null" or "x is null"
        if (comparison.Kind() is not (SyntaxKind.EqualsExpression or SyntaxKind.IsPatternExpression))
            return isPatternExpression;
 
        if (languageVersion >= LanguageVersion.CSharp9)
        {
            // In C# 9 and higher, convert to `x is not string s`.
            return isPatternExpression.WithPattern(
                UnaryPattern(NotKeyword, isPatternExpression.Pattern));
        }
 
        // In C# 8 and lower, convert to `!(x is string s)`
        return PrefixUnaryExpression(SyntaxKind.LogicalNotExpression, isPatternExpression.Parenthesize());
    }
}