File: CodeFixes\FixAllOccurrences\BatchFixAllProvider.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixesAndRefactorings;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CodeFixes;
 
/// <summary>
/// Helper class for "Fix all occurrences" code fix providers.
/// </summary>
internal sealed class BatchFixAllProvider : FixAllProvider
{
    public static readonly FixAllProvider Instance = new BatchFixAllProvider();
 
    private BatchFixAllProvider()
    {
    }
 
    public override IEnumerable<FixAllScope> GetSupportedFixAllScopes()
        => [FixAllScope.Document, FixAllScope.Project, FixAllScope.Solution, FixAllScope.ContainingMember, FixAllScope.ContainingType];
 
    public override Task<CodeAction?> GetFixAsync(FixAllContext fixAllContext)
        => DefaultFixAllProviderHelpers.GetFixAsync(
            fixAllContext.GetDefaultFixAllTitle(), fixAllContext, FixAllContextsAsync);
 
    private async Task<Solution?> FixAllContextsAsync(
        FixAllContext originalFixAllContext,
        ImmutableArray<FixAllContext> fixAllContexts)
    {
        var cancellationToken = originalFixAllContext.CancellationToken;
        var progressTracker = originalFixAllContext.Progress;
        progressTracker.Report(CodeAnalysisProgress.Description(originalFixAllContext.GetDefaultFixAllTitle()));
 
        // We have 2*P + 1 pieces of work.  Computing diagnostics and fixes/changes per context, and then one pass
        // applying fixes.
        progressTracker.AddItems(fixAllContexts.Length * 2 + 1);
 
        // Mapping from document to the cumulative text changes created for that document.
        var docIdToTextMerger = new Dictionary<DocumentId, TextChangeMerger>();
 
        // Process each context one at a time, allowing us to dump most of the information we computed for each once
        // done with it.  The only information we need to preserve is the data we store in docIdToTextMerger
        foreach (var fixAllContext in fixAllContexts)
        {
            Contract.ThrowIfFalse(fixAllContext.Scope is FixAllScope.Document or
                FixAllScope.Project or FixAllScope.ContainingMember or FixAllScope.ContainingType);
            await FixSingleContextAsync(fixAllContext, progressTracker, docIdToTextMerger).ConfigureAwait(false);
        }
 
        // Finally, merge in all text changes into the solution.  We can't do this per-project as we have to have
        // process *all* diagnostics in the solution to find the changes made to all documents.
        using (progressTracker.ItemCompletedScope())
        {
            if (docIdToTextMerger.Count == 0)
                return null;
 
            var currentSolution = originalFixAllContext.Solution;
            foreach (var group in docIdToTextMerger.GroupBy(kvp => kvp.Key.ProjectId))
                currentSolution = await ApplyChangesAsync(currentSolution, group.SelectAsArray(kvp => (kvp.Key, kvp.Value)), cancellationToken).ConfigureAwait(false);
 
            return currentSolution;
        }
    }
 
    private static async Task FixSingleContextAsync(
        FixAllContext fixAllContext, IProgress<CodeAnalysisProgress> progressTracker, Dictionary<DocumentId, TextChangeMerger> docIdToTextMerger)
    {
        // First, determine the diagnostics to fix for that context.
        var documentToDiagnostics = await DetermineDiagnosticsAsync(fixAllContext, progressTracker).ConfigureAwait(false);
 
        // Second, process all those diagnostics, merging the cumulative set of text changes per document into docIdToTextMerger.
        await AddDocumentChangesAsync(fixAllContext, progressTracker, docIdToTextMerger, documentToDiagnostics).ConfigureAwait(false);
    }
 
    private static async Task<ImmutableDictionary<Document, ImmutableArray<Diagnostic>>> DetermineDiagnosticsAsync(FixAllContext fixAllContext, IProgress<CodeAnalysisProgress> progressTracker)
    {
        using var _ = progressTracker.ItemCompletedScope();
 
        var documentToDiagnostics = await fixAllContext.GetDocumentDiagnosticsToFixAsync().ConfigureAwait(false);
 
        var filtered = documentToDiagnostics.Where(kvp =>
        {
            if (kvp.Key.Project != fixAllContext.Project)
                return false;
 
            if (fixAllContext.Document != null && fixAllContext.Document != kvp.Key)
                return false;
 
            return true;
        });
 
        return filtered.ToImmutableDictionary();
    }
 
    private static async Task AddDocumentChangesAsync(
        FixAllContext fixAllContext,
        IProgress<CodeAnalysisProgress> progressTracker,
        Dictionary<DocumentId, TextChangeMerger> docIdToTextMerger,
        ImmutableDictionary<Document, ImmutableArray<Diagnostic>> documentToDiagnostics)
    {
        using var _ = progressTracker.ItemCompletedScope();
 
        // First, order the diagnostics so we process them in a consistent manner and get the same results given the
        // same input solution.
        var orderedDiagnostics = documentToDiagnostics.SelectMany(kvp => kvp.Value)
                                                      .Where(d => d.Location.IsInSource)
                                                      .OrderBy(d => d.Location.SourceTree!.FilePath)
                                                      .ThenBy(d => d.Location.SourceSpan.Start)
                                                      .ToImmutableArray();
 
        // Now determine all the document changes caused from these diagnostics.
        var allChangedDocumentsInDiagnosticsOrder =
            await GetAllChangedDocumentsInDiagnosticsOrderAsync(fixAllContext, orderedDiagnostics).ConfigureAwait(false);
 
        // Finally, take all the changes made to each document and merge them together into docIdToTextMerger to
        // keep track of the total set of changes to any particular document.
        await MergeTextChangesAsync(fixAllContext, allChangedDocumentsInDiagnosticsOrder, docIdToTextMerger).ConfigureAwait(false);
    }
 
    /// <summary>
    /// Returns all the changed documents produced by fixing the list of provided <paramref
    /// name="orderedDiagnostics"/>.  The documents will be returned such that fixed documents for a later
    /// diagnostic will appear later than those for an earlier diagnostic.
    /// </summary>
    private static async Task<ImmutableArray<Document>> GetAllChangedDocumentsInDiagnosticsOrderAsync(
        FixAllContext fixAllContext, ImmutableArray<Diagnostic> orderedDiagnostics)
    {
        var solution = fixAllContext.Solution;
        var cancellationToken = fixAllContext.CancellationToken;
 
        // Process each diagnostic, determine the code actions to fix it, then figure out the document changes
        // produced by that code action.
        using var _1 = ArrayBuilder<Task<ImmutableArray<Document>>>.GetInstance(out var tasks);
        foreach (var diagnostic in orderedDiagnostics)
        {
            var document = solution.GetRequiredDocument(diagnostic.Location.SourceTree!);
 
            cancellationToken.ThrowIfCancellationRequested();
            tasks.Add(Task.Run(async () =>
            {
                // Create a context that will add the reported code actions into this
                using var _2 = ArrayBuilder<CodeAction>.GetInstance(out var codeActions);
                var action = GetRegisterCodeFixAction(fixAllContext.CodeActionEquivalenceKey, codeActions);
                var context = new CodeFixContext(document, diagnostic.Location.SourceSpan, [diagnostic], action, cancellationToken);
 
                // Wait for the all the code actions to be reported for this diagnostic.
                var registerTask = fixAllContext.CodeFixProvider.RegisterCodeFixesAsync(context) ?? Task.CompletedTask;
                await registerTask.ConfigureAwait(false);
 
                // Now, process each code action and find out all the document changes caused by it.
                using var _3 = ArrayBuilder<Document>.GetInstance(out var changedDocuments);
 
                foreach (var codeAction in codeActions)
                {
                    var changedSolution = await codeAction.GetChangedSolutionInternalAsync(
                        solution, fixAllContext.Progress, cancellationToken: cancellationToken).ConfigureAwait(false);
                    if (changedSolution != null)
                    {
                        var changedDocumentIds = new SolutionChanges(changedSolution, solution).GetProjectChanges().SelectMany(p => p.GetChangedDocuments());
                        changedDocuments.AddRange(changedDocumentIds.Select(id => changedSolution.GetRequiredDocument(id)));
                    }
                }
 
                return changedDocuments.ToImmutableAndClear();
            }, cancellationToken));
        }
 
        // Wait for all that work to finish.
        await Task.WhenAll(tasks).ConfigureAwait(false);
 
        // Flatten the set of changed documents.  These will naturally still be ordered by the diagnostic that
        // caused the change.
        using var _4 = ArrayBuilder<Document>.GetInstance(out var result);
        foreach (var task in tasks)
            result.AddRange(await task.ConfigureAwait(false));
 
        return result.ToImmutableAndClear();
    }
 
    /// <summary>
    /// Take all the changes made to a particular document and determine the text changes caused by each one.  Take
    /// those individual text changes and attempt to merge them together in order into <paramref
    /// name="docIdToTextMerger"/>.
    /// </summary>
    private static async Task MergeTextChangesAsync(
        FixAllContext fixAllContext,
        ImmutableArray<Document> allChangedDocumentsInDiagnosticsOrder,
        Dictionary<DocumentId, TextChangeMerger> docIdToTextMerger)
    {
        var cancellationToken = fixAllContext.CancellationToken;
 
        // Now for each document that is changed, grab all the documents it was changed to (remember, many code
        // actions might have touched that document).  Figure out the actual change, and then add that to the
        // interval tree of changes we're keeping track of for that document.
        using var _ = ArrayBuilder<Task>.GetInstance(out var tasks);
        foreach (var group in allChangedDocumentsInDiagnosticsOrder.GroupBy(d => d.Id))
        {
            var docId = group.Key;
            var allDocChanges = group.ToImmutableArray();
 
            // If we don't have an text merger for this doc yet, create one to keep track of all the changes.
            if (!docIdToTextMerger.TryGetValue(docId, out var textMerger))
            {
                var originalDocument = fixAllContext.Solution.GetRequiredDocument(docId);
                textMerger = new TextChangeMerger(originalDocument);
                docIdToTextMerger.Add(docId, textMerger);
            }
 
            // Process all document groups in parallel.  For each group, merge all the doc changes into an
            // aggregated set of changes in the TextChangeMerger type.
            tasks.Add(Task.Run(
                async () => await textMerger.TryMergeChangesAsync(allDocChanges, cancellationToken).ConfigureAwait(false), cancellationToken));
        }
 
        await Task.WhenAll(tasks).ConfigureAwait(false);
    }
 
    private static Action<CodeAction, ImmutableArray<Diagnostic>> GetRegisterCodeFixAction(
        string? codeActionEquivalenceKey, ArrayBuilder<CodeAction> codeActions)
    {
        return (action, diagnostics) =>
        {
            using var _ = ArrayBuilder<CodeAction>.GetInstance(out var builder);
            builder.Push(action);
            while (builder.TryPop(out var currentAction))
            {
                if (currentAction is { EquivalenceKey: var equivalenceKey }
                    && codeActionEquivalenceKey == equivalenceKey)
                {
                    lock (codeActions)
                        codeActions.Add(currentAction);
                }
 
                foreach (var nestedAction in currentAction.NestedActions)
                    builder.Push(nestedAction);
            }
        };
    }
 
    private static async Task<Solution> ApplyChangesAsync(
        Solution currentSolution,
        ImmutableArray<(DocumentId documentId, TextChangeMerger merger)> docIdsAndMerger,
        CancellationToken cancellationToken)
    {
        var docIdsAndTexts = await docIdsAndMerger.SelectAsArrayAsync(async t => (t.documentId, await t.merger.GetFinalMergedTextAsync(cancellationToken).ConfigureAwait(false))).ConfigureAwait(false);
        return currentSolution.WithDocumentTexts(docIdsAndTexts);
    }
}