File: CodeRefactoringRunner.cs
Web Access
Project: src\src\Tools\AnalyzerRunner\AnalyzerRunner.csproj (AnalyzerRunner)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.Text;
using Microsoft.VisualStudio.Composition;
using static AnalyzerRunner.Program;
 
namespace AnalyzerRunner
{
    public sealed class CodeRefactoringRunner
    {
        private readonly Workspace _workspace;
        private readonly Options _options;
        private readonly ImmutableDictionary<string, ImmutableArray<CodeRefactoringProvider>> _refactorings;
        private readonly ImmutableDictionary<string, ImmutableHashSet<int>> _syntaxKinds;
 
        public CodeRefactoringRunner(Workspace workspace, Options options)
        {
            _workspace = workspace;
            _options = options;
 
            var refactorings = GetCodeRefactoringProviders(options.AnalyzerPath);
            _refactorings = FilterRefactorings(refactorings, options);
 
            _syntaxKinds = GetSyntaxKinds(options.RefactoringNodes);
        }
 
        public bool HasRefactorings => _refactorings.Any(pair => pair.Value.Any());
 
        public async Task RunAsync(CancellationToken cancellationToken)
        {
            if (!HasRefactorings)
            {
                return;
            }
 
            var solution = _workspace.CurrentSolution;
            var updatedSolution = solution;
 
            foreach (var project in solution.Projects)
            {
                foreach (var document in project.Documents)
                {
                    var newDocument = await RefactorDocumentAsync(document, cancellationToken).ConfigureAwait(false);
                    if (newDocument is null)
                    {
                        continue;
                    }
 
                    updatedSolution = updatedSolution.WithDocumentSyntaxRoot(document.Id, await newDocument.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false));
                }
            }
 
            if (_options.ApplyChanges)
            {
                _workspace.TryApplyChanges(updatedSolution);
            }
        }
 
        private async Task<Document> RefactorDocumentAsync(Document document, CancellationToken cancellationToken)
        {
            var syntaxKinds = _syntaxKinds[document.Project.Language];
            var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
            foreach (var node in root.DescendantNodesAndTokens(descendIntoTrivia: true))
            {
                if (!syntaxKinds.Contains(node.RawKind))
                {
                    continue;
                }
 
                foreach (var refactoringProvider in _refactorings[document.Project.Language])
                {
                    var codeActions = new List<CodeAction>();
                    var context = new CodeRefactoringContext(document, new TextSpan(node.SpanStart, 0), codeActions.Add, cancellationToken);
                    await refactoringProvider.ComputeRefactoringsAsync(context).ConfigureAwait(false);
 
                    foreach (var codeAction in codeActions)
                    {
                        var operations = await codeAction.GetOperationsAsync(
                            document.Project.Solution, CodeAnalysisProgress.None, cancellationToken).ConfigureAwait(false);
                        foreach (var operation in operations)
                        {
                            if (operation is not ApplyChangesOperation applyChangesOperation)
                            {
                                continue;
                            }
 
                            var changes = applyChangesOperation.ChangedSolution.GetChanges(document.Project.Solution);
                            var projectChanges = changes.GetProjectChanges().ToArray();
                            if (projectChanges.Length != 1 || projectChanges[0].ProjectId != document.Project.Id)
                            {
                                continue;
                            }
 
                            var documentChanges = projectChanges[0].GetChangedDocuments().ToArray();
                            if (documentChanges.Length != 1 || documentChanges[0] != document.Id)
                            {
                                continue;
                            }
 
                            return projectChanges[0].NewProject.GetDocument(document.Id);
                        }
                    }
                }
            }
 
            return null;
        }
 
        private static ImmutableDictionary<string, ImmutableHashSet<int>> GetSyntaxKinds(ImmutableHashSet<string> refactoringNodes)
        {
            var knownLanguages = new[]
            {
                (LanguageNames.CSharp, typeof(Microsoft.CodeAnalysis.CSharp.SyntaxKind)),
                (LanguageNames.VisualBasic, typeof(Microsoft.CodeAnalysis.VisualBasic.SyntaxKind)),
            };
 
            var builder = ImmutableDictionary.CreateBuilder<string, ImmutableHashSet<int>>();
            foreach (var (language, enumType) in knownLanguages)
            {
                var kindBuilder = ImmutableHashSet.CreateBuilder<int>();
                foreach (var name in refactoringNodes)
                {
                    if (!Enum.IsDefined(enumType, name))
                    {
                        continue;
                    }
 
                    kindBuilder.Add(Convert.ToInt32(Enum.Parse(enumType, name)));
                }
 
                builder.Add(language, kindBuilder.ToImmutable());
            }
 
            return builder.ToImmutable();
        }
 
        private static ImmutableDictionary<string, ImmutableArray<CodeRefactoringProvider>> FilterRefactorings(ImmutableDictionary<string, ImmutableArray<Lazy<CodeRefactoringProvider, CodeRefactoringProviderMetadata>>> refactorings, Options options)
        {
            return refactorings.ToImmutableDictionary(
                pair => pair.Key,
                pair => FilterRefactorings(pair.Value, options).ToImmutableArray());
        }
 
        private static IEnumerable<CodeRefactoringProvider> FilterRefactorings(IEnumerable<Lazy<CodeRefactoringProvider, CodeRefactoringProviderMetadata>> refactorings, Options options)
        {
            if (options.IncrementalAnalyzerNames.Any())
            {
                // AnalyzerRunner is running for IIncrementalAnalyzer testing. DiagnosticAnalyzer testing is disabled
                // unless /all or /a was used.
                if (!options.UseAll && options.AnalyzerNames.IsEmpty)
                {
                    yield break;
                }
            }
 
            if (options.RefactoringNodes.IsEmpty)
            {
                // AnalyzerRunner isn't configured to run refactorings on any nodes.
                yield break;
            }
 
            var refactoringTypes = new HashSet<Type>();
 
            foreach (var refactoring in refactorings.Select(refactoring => refactoring.Value))
            {
                if (!refactoringTypes.Add(refactoring.GetType()))
                {
                    // Avoid running the same analyzer multiple times
                    continue;
                }
 
                if (options.AnalyzerNames.Count == 0)
                {
                    yield return refactoring;
                }
                else if (options.AnalyzerNames.Contains(refactoring.GetType().Name))
                {
                    yield return refactoring;
                }
            }
        }
 
        private static ImmutableDictionary<string, ImmutableArray<Lazy<CodeRefactoringProvider, CodeRefactoringProviderMetadata>>> GetCodeRefactoringProviders(string path)
        {
            var assemblies = new List<Assembly>(MefHostServices.DefaultAssemblies);
            if (File.Exists(path))
            {
                assemblies.Add(Assembly.LoadFrom(path));
            }
            else if (Directory.Exists(path))
            {
                foreach (var file in Directory.GetFiles(path, "*.dll", SearchOption.AllDirectories))
                {
                    try
                    {
                        assemblies.Add(Assembly.LoadFrom(file));
                    }
                    catch
                    {
                        WriteLine($"Skipped assembly '{Path.GetFileNameWithoutExtension(file)}' during code refactoring discovery.", ConsoleColor.Yellow);
                    }
                }
            }
 
            var discovery = new AttributedPartDiscovery(Resolver.DefaultInstance, isNonPublicSupported: true);
            var parts = Task.Run(() => discovery.CreatePartsAsync(assemblies)).GetAwaiter().GetResult();
            var catalog = ComposableCatalog.Create(Resolver.DefaultInstance).AddParts(parts);
 
            var configuration = CompositionConfiguration.Create(catalog);
            var runtimeConfiguration = RuntimeComposition.CreateRuntimeComposition(configuration);
            var exportProviderFactory = runtimeConfiguration.CreateExportProviderFactory();
 
            var exportProvider = exportProviderFactory.CreateExportProvider();
            var refactorings = exportProvider.GetExports<CodeRefactoringProvider, CodeRefactoringProviderMetadata>();
            var languages = refactorings.SelectMany(refactoring => refactoring.Metadata.Languages).Distinct();
            return languages.ToImmutableDictionary(
                language => language,
                language => refactorings.Where(refactoring => refactoring.Metadata.Languages.Contains(language)).ToImmutableArray());
        }
 
        private class CodeRefactoringProviderMetadata
        {
            public IEnumerable<string> Languages { get; }
 
            public CodeRefactoringProviderMetadata(IDictionary<string, object> data)
            {
                data.TryGetValue(nameof(Languages), out var languages);
 
                Languages = languages switch
                {
                    IEnumerable<string> values => values,
                    string value => [value],
                    _ => Array.Empty<string>(),
                };
            }
        }
    }
}