File: src\Analyzers\Core\CodeFixes\SimplifyLinqExpression\SimplifyLinqExpressionCodeFixProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.SimplifyLinqExpression;
 
[ExportCodeFixProvider(LanguageNames.CSharp, LanguageNames.VisualBasic, Name = PredefinedCodeFixProviderNames.SimplifyLinqExpression), Shared]
[method: ImportingConstructor]
[method: SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
internal sealed class SimplifyLinqExpressionCodeFixProvider() : SyntaxEditorBasedCodeFixProvider
{
    public sealed override ImmutableArray<string> FixableDiagnosticIds
       => [IDEDiagnosticIds.SimplifyLinqExpressionDiagnosticId];
 
    public sealed override Task RegisterCodeFixesAsync(CodeFixContext context)
    {
        RegisterCodeFix(context, AnalyzersResources.Simplify_LINQ_expression, nameof(AnalyzersResources.Simplify_LINQ_expression));
        return Task.CompletedTask;
    }
 
    protected override Task FixAllAsync(
        Document document,
        ImmutableArray<Diagnostic> diagnostics,
        SyntaxEditor editor,
        CancellationToken cancellationToken)
    {
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var root = editor.OriginalRoot;
 
        foreach (var diagnostic in diagnostics.OrderByDescending(diagnostics => diagnostics.Location.SourceSpan.Start))
        {
            var invocation = root.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true);
 
            editor.ReplaceNode(invocation, (current, generator) =>
            {
                // 'current' is the original full expression.  like x.Where(...).Count();
 
                // 'x.Where(...).Count' in the above expression
                var memberAccess = syntaxFacts.GetExpressionOfInvocationExpression(current);
 
                // 'Count' in the above expression
                var outerName = syntaxFacts.GetNameOfMemberAccessExpression(memberAccess);
 
                // 'x.Where(...)' in the above expression.
                var innerInvocationExpression = syntaxFacts.GetExpressionOfMemberAccessExpression(memberAccess)!;
 
                // 'x.Where' in the above expression.
                var innerMemberAccessExpression = syntaxFacts.GetExpressionOfInvocationExpression(innerInvocationExpression);
 
                // 'Where' in the above expression.
                var innerName = syntaxFacts.GetNameOfMemberAccessExpression(innerMemberAccessExpression);
 
                // trim down to the 'x.Where(...)', except with 'Where' replaced with 'Count'.
                return innerInvocationExpression.ReplaceNode(innerName, outerName.WithTriviaFrom(innerName)).WithTrailingTrivia(current.GetTrailingTrivia());
            });
        }
 
        return Task.CompletedTask;
    }
}