File: Microsoft.NetCore.Analyzers\Performance\CSharpPreferDictionaryTryMethodsOverContainsKeyGuardFixer.cs
Web Access
Project: ..\..\..\src\Microsoft.CodeAnalysis.NetAnalyzers\src\Microsoft.CodeAnalysis.CSharp.NetAnalyzers\Microsoft.CodeAnalysis.CSharp.NetAnalyzers.csproj (Microsoft.CodeAnalysis.CSharp.NetAnalyzers)
// Copyright (c) Microsoft.  All Rights Reserved.  Licensed under the MIT license.  See License.txt in the project root for license information.
 
using System.Collections.Generic;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Analyzer.Utilities;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Simplification;
using Microsoft.NetCore.Analyzers.Performance;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
 
namespace Microsoft.NetCore.CSharp.Analyzers.Performance
{
    [ExportCodeFixProvider(LanguageNames.CSharp), Shared]
    public sealed class CSharpPreferDictionaryTryMethodsOverContainsKeyGuardFixer : PreferDictionaryTryMethodsOverContainsKeyGuardFixer
    {
        private const string Var = "var";
 
        public override async Task RegisterCodeFixesAsync(CodeFixContext context)
        {
            var diagnostic = context.Diagnostics.FirstOrDefault();
            if (diagnostic is not { AdditionalLocations.Count: > 0 })
            {
                return;
            }
 
            Document document = context.Document;
            SyntaxNode root = await document.GetRequiredSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
 
            if (root.FindNode(context.Span) is not InvocationExpressionSyntax
                {
                    Expression: MemberAccessExpressionSyntax containsKeyAccess
                } containsKeyInvocation)
            {
                return;
            }
 
            CodeAction? action = diagnostic.Id == PreferDictionaryTryMethodsOverContainsKeyGuardAnalyzer.PreferTryGetValueRuleId
                ? await GetTryGetValueActionAsync(diagnostic, root, document, containsKeyAccess, containsKeyInvocation, context.CancellationToken).ConfigureAwait(false)
                : GetTryAddAction(diagnostic, root, document, containsKeyInvocation, containsKeyAccess);
            if (action is null)
            {
                return;
            }
 
            context.RegisterCodeFix(action, context.Diagnostics);
        }
 
        private static async Task<CodeAction?> GetTryGetValueActionAsync(Diagnostic diagnostic, SyntaxNode root, Document document, MemberAccessExpressionSyntax containsKeyAccess, InvocationExpressionSyntax containsKeyInvocation, CancellationToken cancellationToken)
        {
            var dictionaryAccessors = new List<SyntaxNode>();
            ExpressionStatementSyntax? addStatementNode = null;
            SyntaxNode? changedValueNode = null;
            string? variableName = null;
            LocalDeclarationStatementSyntax? localDeclarationStatement = null;
            VariableDeclaratorSyntax? variableDeclarator = null;
            var additionalNodes = 0;
            SyntaxNode? typeNode = null;
            foreach (var location in diagnostic.AdditionalLocations)
            {
                var node = root.FindNode(location.SourceSpan, getInnermostNodeForTie: true);
                switch (node)
                {
                    case ElementAccessExpressionSyntax:
                        dictionaryAccessors.Add(node);
                        typeNode ??= node;
                        break;
                    case ExpressionStatementSyntax exp:
                        if (addStatementNode != null)
                            return null;
 
                        addStatementNode = exp;
                        additionalNodes++;
                        switch (addStatementNode.Expression)
                        {
                            case AssignmentExpressionSyntax assign:
                                changedValueNode = assign.Right;
                                break;
                            case InvocationExpressionSyntax invocation:
                                changedValueNode = invocation.ArgumentList.Arguments[1].Expression;
                                break;
                            default:
                                return null;
                        }
 
                        break;
                    case LocalDeclarationStatementSyntax local:
                        localDeclarationStatement = local;
                        variableName = local.Declaration.Variables[0].Identifier.ValueText;
                        additionalNodes++;
                        typeNode ??= local.Declaration.Type;
                        break;
                    case VariableDeclaratorSyntax
                    {
                        Parent: VariableDeclarationSyntax
                        {
                            Parent: LocalDeclarationStatementSyntax local
                        }
                    } declarator:
                        variableDeclarator = declarator;
                        localDeclarationStatement = local;
                        variableName = declarator.Identifier.ValueText;
                        additionalNodes++;
                        typeNode ??= local.Declaration.Type;
                        break;
                }
            }
 
            if (diagnostic.AdditionalLocations.Count != dictionaryAccessors.Count + additionalNodes)
                return null;
 
            var model = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var type = model.GetTypeInfo(typeNode!, cancellationToken).Type;
 
            return CodeAction.Create(PreferDictionaryTryGetValueCodeFixTitle, async ct =>
            {
                var editor = await DocumentEditor.CreateAsync(document, ct).ConfigureAwait(false);
                var generator = editor.Generator;
 
                // Roslyn has reducers that are run after a code action is applied, one of which will
                // simplify a TypeSyntax to `var` if the user prefers that. So we generate TypeSyntax, add
                // simplifier annotation, and then let Roslyn decide whether to keep TypeSyntax or convert it to var.
                // If the type is unknown (null) (likely in error scenario), then fallback to using var.
                TypeSyntax typeSyntax;
                if (type is not null)
                {
                    typeSyntax = (TypeSyntax)generator.TypeExpression(type);
                    if (type.IsReferenceType)
                        typeSyntax = (TypeSyntax)generator.NullableTypeExpression(typeSyntax);
 
                    typeSyntax = typeSyntax.WithAdditionalAnnotations(Simplifier.Annotation);
                }
                else
                {
                    typeSyntax = IdentifierName(Var);
                }
 
                var identifierName = (IdentifierNameSyntax)(variableName is not null
                    ? generator.IdentifierName(variableName)
                    : generator.FirstUnusedIdentifierName(model, containsKeyInvocation.SpanStart, Value));
                var outArgument = (ArgumentSyntax)generator.Argument(RefKind.Out,
                    DeclarationExpression(
                        typeSyntax,
                        SingleVariableDesignation(identifierName.Identifier)
                    )
                );
 
                var tryGetValueInvocation = containsKeyInvocation
                    .ReplaceNode(containsKeyAccess.Name, IdentifierName(TryGetValue).WithTriviaFrom(containsKeyAccess.Name))
                    .AddArgumentListArguments(outArgument);
                editor.ReplaceNode(containsKeyInvocation, tryGetValueInvocation);
 
                if (addStatementNode != null)
                {
                    editor.InsertBefore(addStatementNode,
                        generator.ExpressionStatement(generator.AssignmentStatement(identifierName, changedValueNode)));
                    editor.ReplaceNode(changedValueNode!, identifierName);
                }
 
                foreach (var dictionaryAccess in dictionaryAccessors)
                {
                    switch (dictionaryAccess.Parent)
                    {
                        case PostfixUnaryExpressionSyntax { RawKind: (int)SyntaxKind.PostDecrementExpression } post:
                            editor.ReplaceNode(post, generator.AssignmentStatement(dictionaryAccess,
                                PrefixUnaryExpression(SyntaxKind.PreDecrementExpression, identifierName)).
                                WithTriviaFrom(post));
                            break;
                        case PostfixUnaryExpressionSyntax { RawKind: (int)SyntaxKind.PostIncrementExpression } post:
                            editor.ReplaceNode(post, generator.AssignmentStatement(dictionaryAccess,
                                PrefixUnaryExpression(SyntaxKind.PreIncrementExpression, identifierName)).
                                WithTriviaFrom(post));
                            break;
                        case PrefixUnaryExpressionSyntax pre:
                            editor.ReplaceNode(pre, generator.AssignmentStatement(dictionaryAccess,
                                pre.WithOperand(identifierName)).WithTriviaFrom(pre));
                            break;
                        default:
                            editor.ReplaceNode(dictionaryAccess, identifierName);
                            break;
                    }
                }
 
                if (localDeclarationStatement is not null)
                {
                    if (variableDeclarator is null)
                    {
                        editor.RemoveNode(localDeclarationStatement);
                    }
                    else
                    {
                        editor.RemoveNode(variableDeclarator);
                    }
                }
 
                return editor.GetChangedDocument();
            }, PreferDictionaryTryGetValueCodeFixTitle);
        }
 
        private static CodeAction? GetTryAddAction(Diagnostic diagnostic, SyntaxNode root, Document document, InvocationExpressionSyntax containsKeyInvocation, MemberAccessExpressionSyntax containsKeyAccess)
        {
            var dictionaryAdd = root.FindNode(diagnostic.AdditionalLocations[0].SourceSpan, getInnermostNodeForTie: true);
            if (dictionaryAdd is not InvocationExpressionSyntax dictionaryAddInvocation)
            {
                return null;
            }
 
            return CodeAction.Create(PreferDictionaryTryAddValueCodeFixTitle, async ct =>
            {
                var editor = await DocumentEditor.CreateAsync(document, ct).ConfigureAwait(false);
                var generator = editor.Generator;
 
                var tryAddValueAccess = generator.MemberAccessExpression(containsKeyAccess.Expression, TryAdd);
                var dictionaryAddArguments = dictionaryAddInvocation.ArgumentList.Arguments;
                var tryAddInvocation = generator.InvocationExpression(tryAddValueAccess, dictionaryAddArguments[0], dictionaryAddArguments[1]);
 
                var ifStatement = containsKeyInvocation.FirstAncestorOrSelf<IfStatementSyntax>();
                if (ifStatement is null)
                {
                    return editor.OriginalDocument;
                }
 
                if (ifStatement.Condition is PrefixUnaryExpressionSyntax unary && unary.IsKind(SyntaxKind.LogicalNotExpression))
                {
                    if (ifStatement.Statement is BlockSyntax { Statements.Count: 1 } or ExpressionStatementSyntax)
                    {
                        if (ifStatement.Else is null)
                        {
                            // d.Add() is the only statement in the if and is guarded with a !d.ContainsKey().
                            // Since there is no else-branch, we can replace the entire if-statement with a d.TryAdd() call.
                            var invocationWithTrivia = tryAddInvocation.WithTriviaFrom(ifStatement);
                            editor.ReplaceNode(ifStatement, generator.ExpressionStatement(invocationWithTrivia));
                        }
                        else
                        {
                            // d.Add() is the only statement in the if and is guarded with a !d.ContainsKey().
                            // In this case, we switch out the !d.ContainsKey() call with a !d.TryAdd() call and move the else-branch into the if.
                            editor.ReplaceNode(containsKeyInvocation, tryAddInvocation);
                            editor.ReplaceNode(ifStatement.Statement, ifStatement.Else.Statement);
                            editor.RemoveNode(ifStatement.Else, SyntaxRemoveOptions.KeepNoTrivia);
                        }
                    }
                    else
                    {
                        // d.Add() is one of many statements in the if and is guarded with a !d.ContainsKey().
                        // In this case, we switch out the !d.ContainsKey() call for a d.TryAdd() call.
                        editor.RemoveNode(dictionaryAddInvocation.Parent!, SyntaxRemoveOptions.KeepNoTrivia);
                        editor.ReplaceNode(unary, tryAddInvocation);
                    }
                }
                else if (ifStatement.Condition.IsKind(SyntaxKind.InvocationExpression) && ifStatement.Else is not null)
                {
                    var negatedTryAddInvocation = generator.LogicalNotExpression(tryAddInvocation);
                    editor.ReplaceNode(containsKeyInvocation, negatedTryAddInvocation);
                    if (ifStatement.Else.Statement is BlockSyntax { Statements.Count: 1 } or ExpressionStatementSyntax)
                    {
                        // d.Add() is the only statement the else-branch and guarded by a d.ContainsKey() call in the if.
                        // In this case we replace the d.ContainsKey() call with a !d.TryAdd() call and remove the entire else-branch.
                        editor.RemoveNode(ifStatement.Else);
                    }
                    else
                    {
                        // d.Add() is one of many statements in the else-branch and guarded by a d.ContainsKey() call in the if.
                        // In this case we replace the d.ContainsKey() call with a !d.TryAdd() call and remove the d.Add() call in the else-branch.
                        editor.RemoveNode(dictionaryAddInvocation.Parent!, SyntaxRemoveOptions.KeepNoTrivia);
                    }
                }
 
                return editor.GetChangedDocument();
            }, PreferDictionaryTryAddValueCodeFixTitle);
        }
    }
}