File: CodeFixes\FixAllOccurrences\BatchFixAllProvider.cs
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixesAndRefactorings;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CodeFixes;
/// <summary>
/// Helper class for "Fix all occurrences" code fix providers.
/// </summary>
internal sealed class BatchFixAllProvider : FixAllProvider
    public static readonly FixAllProvider Instance = new BatchFixAllProvider();
    private BatchFixAllProvider()
    public override IEnumerable<FixAllScope> GetSupportedFixAllScopes()
        => [FixAllScope.Document, FixAllScope.Project, FixAllScope.Solution, FixAllScope.ContainingMember, FixAllScope.ContainingType];
    public override Task<CodeAction?> GetFixAsync(FixAllContext fixAllContext)
        => DefaultFixAllProviderHelpers.GetFixAsync(
            fixAllContext.GetDefaultFixAllTitle(), fixAllContext, FixAllContextsAsync);
    private async Task<Solution?> FixAllContextsAsync(
        FixAllContext originalFixAllContext,
        ImmutableArray<FixAllContext> fixAllContexts)
        var cancellationToken = originalFixAllContext.CancellationToken;
        var progressTracker = originalFixAllContext.Progress;
        // We have 2*P + 1 pieces of work.  Computing diagnostics and fixes/changes per context, and then one pass
        // applying fixes.
        progressTracker.AddItems(fixAllContexts.Length * 2 + 1);
        // Mapping from document to the cumulative text changes created for that document.
        var docIdToTextMerger = new Dictionary<DocumentId, TextChangeMerger>();
        // Process each context one at a time, allowing us to dump most of the information we computed for each once
        // done with it.  The only information we need to preserve is the data we store in docIdToTextMerger
        foreach (var fixAllContext in fixAllContexts)
            Contract.ThrowIfFalse(fixAllContext.Scope is FixAllScope.Document or
                FixAllScope.Project or FixAllScope.ContainingMember or FixAllScope.ContainingType);
            await FixSingleContextAsync(fixAllContext, progressTracker, docIdToTextMerger).ConfigureAwait(false);
        // Finally, merge in all text changes into the solution.  We can't do this per-project as we have to have
        // process *all* diagnostics in the solution to find the changes made to all documents.
        using (progressTracker.ItemCompletedScope())
            if (docIdToTextMerger.Count == 0)
                return null;
            var currentSolution = originalFixAllContext.Solution;
            foreach (var group in docIdToTextMerger.GroupBy(kvp => kvp.Key.ProjectId))
                currentSolution = await ApplyChangesAsync(currentSolution, group.SelectAsArray(kvp => (kvp.Key, kvp.Value)), cancellationToken).ConfigureAwait(false);
            return currentSolution;
    private static async Task FixSingleContextAsync(
        FixAllContext fixAllContext, IProgress<CodeAnalysisProgress> progressTracker, Dictionary<DocumentId, TextChangeMerger> docIdToTextMerger)
        // First, determine the diagnostics to fix for that context.
        var documentToDiagnostics = await DetermineDiagnosticsAsync(fixAllContext, progressTracker).ConfigureAwait(false);
        // Second, process all those diagnostics, merging the cumulative set of text changes per document into docIdToTextMerger.
        await AddDocumentChangesAsync(fixAllContext, progressTracker, docIdToTextMerger, documentToDiagnostics).ConfigureAwait(false);
    private static async Task<ImmutableDictionary<Document, ImmutableArray<Diagnostic>>> DetermineDiagnosticsAsync(FixAllContext fixAllContext, IProgress<CodeAnalysisProgress> progressTracker)
        using var _ = progressTracker.ItemCompletedScope();
        var documentToDiagnostics = await fixAllContext.GetDocumentDiagnosticsToFixAsync().ConfigureAwait(false);
        var filtered = documentToDiagnostics.Where(kvp =>
            if (kvp.Key.Project != fixAllContext.Project)
                return false;
            if (fixAllContext.Document != null && fixAllContext.Document != kvp.Key)
                return false;
            return true;
        return filtered.ToImmutableDictionary();
    private static async Task AddDocumentChangesAsync(
        FixAllContext fixAllContext,
        IProgress<CodeAnalysisProgress> progressTracker,
        Dictionary<DocumentId, TextChangeMerger> docIdToTextMerger,
        ImmutableDictionary<Document, ImmutableArray<Diagnostic>> documentToDiagnostics)
        using var _ = progressTracker.ItemCompletedScope();
        // First, order the diagnostics so we process them in a consistent manner and get the same results given the
        // same input solution.
        var orderedDiagnostics = documentToDiagnostics.SelectMany(kvp => kvp.Value)
                                                      .Where(d => d.Location.IsInSource)
                                                      .OrderBy(d => d.Location.SourceTree!.FilePath)
                                                      .ThenBy(d => d.Location.SourceSpan.Start)
        // Now determine all the document changes caused from these diagnostics.
        var allChangedDocumentsInDiagnosticsOrder =
            await GetAllChangedDocumentsInDiagnosticsOrderAsync(fixAllContext, orderedDiagnostics).ConfigureAwait(false);
        // Finally, take all the changes made to each document and merge them together into docIdToTextMerger to
        // keep track of the total set of changes to any particular document.
        await MergeTextChangesAsync(fixAllContext, allChangedDocumentsInDiagnosticsOrder, docIdToTextMerger).ConfigureAwait(false);
    /// <summary>
    /// Returns all the changed documents produced by fixing the list of provided <paramref
    /// name="orderedDiagnostics"/>.  The documents will be returned such that fixed documents for a later
    /// diagnostic will appear later than those for an earlier diagnostic.
    /// </summary>
    private static async Task<ImmutableArray<Document>> GetAllChangedDocumentsInDiagnosticsOrderAsync(
        FixAllContext fixAllContext, ImmutableArray<Diagnostic> orderedDiagnostics)
        var solution = fixAllContext.Solution;
        var cancellationToken = fixAllContext.CancellationToken;
        // Process each diagnostic, determine the code actions to fix it, then figure out the document changes
        // produced by that code action.
        using var _1 = ArrayBuilder<Task<ImmutableArray<Document>>>.GetInstance(out var tasks);
        foreach (var diagnostic in orderedDiagnostics)
            var document = solution.GetRequiredDocument(diagnostic.Location.SourceTree!);
            tasks.Add(Task.Run(async () =>
                // Create a context that will add the reported code actions into this
                using var _2 = ArrayBuilder<CodeAction>.GetInstance(out var codeActions);
                var action = GetRegisterCodeFixAction(fixAllContext.CodeActionEquivalenceKey, codeActions);
                var context = new CodeFixContext(document, diagnostic.Location.SourceSpan, [diagnostic], action, cancellationToken);
                // Wait for the all the code actions to be reported for this diagnostic.
                var registerTask = fixAllContext.CodeFixProvider.RegisterCodeFixesAsync(context) ?? Task.CompletedTask;
                await registerTask.ConfigureAwait(false);
                // Now, process each code action and find out all the document changes caused by it.
                using var _3 = ArrayBuilder<Document>.GetInstance(out var changedDocuments);
                foreach (var codeAction in codeActions)
                    var changedSolution = await codeAction.GetChangedSolutionInternalAsync(
                        solution, fixAllContext.Progress, cancellationToken: cancellationToken).ConfigureAwait(false);
                    if (changedSolution != null)
                        var changedDocumentIds = new SolutionChanges(changedSolution, solution).GetProjectChanges().SelectMany(p => p.GetChangedDocuments());
                        changedDocuments.AddRange(changedDocumentIds.Select(id => changedSolution.GetRequiredDocument(id)));
                return changedDocuments.ToImmutableAndClear();
            }, cancellationToken));
        // Wait for all that work to finish.
        await Task.WhenAll(tasks).ConfigureAwait(false);
        // Flatten the set of changed documents.  These will naturally still be ordered by the diagnostic that
        // caused the change.
        using var _4 = ArrayBuilder<Document>.GetInstance(out var result);
        foreach (var task in tasks)
            result.AddRange(await task.ConfigureAwait(false));
        return result.ToImmutableAndClear();
    /// <summary>
    /// Take all the changes made to a particular document and determine the text changes caused by each one.  Take
    /// those individual text changes and attempt to merge them together in order into <paramref
    /// name="docIdToTextMerger"/>.
    /// </summary>
    private static async Task MergeTextChangesAsync(
        FixAllContext fixAllContext,
        ImmutableArray<Document> allChangedDocumentsInDiagnosticsOrder,
        Dictionary<DocumentId, TextChangeMerger> docIdToTextMerger)
        var cancellationToken = fixAllContext.CancellationToken;
        // Now for each document that is changed, grab all the documents it was changed to (remember, many code
        // actions might have touched that document).  Figure out the actual change, and then add that to the
        // interval tree of changes we're keeping track of for that document.
        using var _ = ArrayBuilder<Task>.GetInstance(out var tasks);
        foreach (var group in allChangedDocumentsInDiagnosticsOrder.GroupBy(d => d.Id))
            var docId = group.Key;
            var allDocChanges = group.ToImmutableArray();
            // If we don't have an text merger for this doc yet, create one to keep track of all the changes.
            if (!docIdToTextMerger.TryGetValue(docId, out var textMerger))
                var originalDocument = fixAllContext.Solution.GetRequiredDocument(docId);
                textMerger = new TextChangeMerger(originalDocument);
                docIdToTextMerger.Add(docId, textMerger);
            // Process all document groups in parallel.  For each group, merge all the doc changes into an
            // aggregated set of changes in the TextChangeMerger type.
                async () => await textMerger.TryMergeChangesAsync(allDocChanges, cancellationToken).ConfigureAwait(false), cancellationToken));
        await Task.WhenAll(tasks).ConfigureAwait(false);
    private static Action<CodeAction, ImmutableArray<Diagnostic>> GetRegisterCodeFixAction(
        string? codeActionEquivalenceKey, ArrayBuilder<CodeAction> codeActions)
        return (action, diagnostics) =>
            using var _ = ArrayBuilder<CodeAction>.GetInstance(out var builder);
            while (builder.TryPop(out var currentAction))
                if (currentAction is { EquivalenceKey: var equivalenceKey }
                    && codeActionEquivalenceKey == equivalenceKey)
                    lock (codeActions)
                foreach (var nestedAction in currentAction.NestedActions)
    private static async Task<Solution> ApplyChangesAsync(
        Solution currentSolution,
        ImmutableArray<(DocumentId documentId, TextChangeMerger merger)> docIdsAndMerger,
        CancellationToken cancellationToken)
        var docIdsAndTexts = await docIdsAndMerger.SelectAsArrayAsync(async t => (t.documentId, await t.merger.GetFinalMergedTextAsync(cancellationToken).ConfigureAwait(false))).ConfigureAwait(false);
        return currentSolution.WithDocumentTexts(docIdsAndTexts);