File: test\Generators\Shared\RoslynTestUtils.cs
Web Access
Project: src\test\Generators\Microsoft.Gen.ComplianceReports\Unit\Microsoft.Gen.ComplianceReports.Unit.Tests.csproj (Microsoft.Gen.ComplianceReports.Unit.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Text;
using Microsoft.Shared.Collections;
using Xunit;
 
#pragma warning disable CA1716
namespace Microsoft.Gen.Shared;
#pragma warning restore CA1716
 
internal static class RoslynTestUtils
{
    internal const string RoslynVersion = "4.0";
 
#if DEBUG
    internal const string BuildType = "Debug";
#else
    internal const string BuildType = "Release";
#endif
 
    /// <summary>
    /// Creates a canonical Roslyn project for testing.
    /// </summary>
    /// <param name="references">Assembly references to include in the project.</param>
    /// <param name="includeBaseReferences">Whether to include references to the BCL assemblies.</param>
    public static Project CreateTestProject(IEnumerable<Assembly>? references, bool includeBaseReferences = true)
    {
        return CreateTestProject(references, Empty.Enumerable<string>(), includeBaseReferences);
    }
 
    /// <summary>
    /// Creates a canonical Roslyn project for testing.
    /// </summary>
    /// <param name="references">Assembly references to include in the project.</param>
    /// <param name="preprocessorSymbols">Preprocessor symbols to run compilation with.</param>
    /// <param name="includeBaseReferences">Whether to include references to the BCL assemblies.</param>
    public static Project CreateTestProject(
        IEnumerable<Assembly>? references,
        IEnumerable<string> preprocessorSymbols,
        bool includeBaseReferences = true,
        LanguageVersion langVersion = LanguageVersion.Preview)
    {
        var corelib = Assembly.GetAssembly(typeof(object))!.Location;
        var runtimeDir = Path.GetDirectoryName(corelib)!;
 
        var refs = new List<MetadataReference>();
        if (includeBaseReferences)
        {
            refs.Add(MetadataReference.CreateFromFile(corelib));
            refs.Add(MetadataReference.CreateFromFile(Path.Combine(runtimeDir, "netstandard.dll")));
            refs.Add(MetadataReference.CreateFromFile(Path.Combine(runtimeDir, "System.Runtime.dll")));
        }
 
        if (references != null)
        {
            foreach (var r in references)
            {
                refs.Add(MetadataReference.CreateFromFile(r.Location));
            }
        }
 
#pragma warning disable CA2000 // Dispose objects before losing scope
        return new AdhocWorkspace()
                    .AddSolution(SolutionInfo.Create(SolutionId.CreateNewId(), VersionStamp.Create()))
                    .AddProject("Test", "test.dll", "C#")
                        .WithMetadataReferences(refs)
                        .WithParseOptions(new CSharpParseOptions(langVersion).WithPreprocessorSymbols(preprocessorSymbols))
                        .WithCompilationOptions(
                            new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary).WithNullableContextOptions(NullableContextOptions.Enable));
#pragma warning restore CA2000 // Dispose objects before losing scope
    }
 
    public static void CommitChanges(this Project proj)
    {
        Assert.True(proj.Solution.Workspace.TryApplyChanges(proj.Solution));
    }
 
    public static Project WithDocument(this Project proj, string name, string text)
    {
        return proj.AddDocument(name, text).Project;
    }
 
    public static Document FindDocument(this Project proj, string name)
    {
        foreach (var doc in proj.Documents)
        {
            if (doc.Name == name)
            {
                return doc;
            }
        }
 
        throw new FileNotFoundException(name);
    }
 
    /// <summary>
    /// Looks for /*N+*/ and /*-N*/ markers in a string and creates a TextSpan containing the enclosed text.
    /// </summary>
    public static TextSpan? MakeTextSpan(this string text, int spanNum)
    {
        var seq = $"/*{spanNum}+*/";
        int start = text.IndexOf(seq, StringComparison.Ordinal);
        if (start < 0)
        {
            return null;
        }
 
        start += seq.Length;
 
        int end = text.IndexOf($"/*-{spanNum}*/", StringComparison.Ordinal);
        if (end < 0)
        {
            return null;
        }
 
        return new TextSpan(start, end - start);
    }
 
    public static void AssertDiagnostic(this string text, int spanNum, DiagnosticDescriptor expected, Diagnostic actual)
    {
        var expectedSpan = text.MakeTextSpan(spanNum);
        if (expectedSpan != null)
        {
            Assert.True(expected.Id == actual.Id,
                $"Span {spanNum} doesn't match: expected {expected.Id} but got {actual}");
 
            Assert.True(expectedSpan.Equals(actual.Location.SourceSpan),
                $"Span {spanNum} doesn't match: expected {expectedSpan} but got {actual.Location.SourceSpan}");
        }
        else
        {
            Assert.Fail($"Unexpected diagnostics {actual}");
        }
    }
 
    public static void AssertDiagnostics(this string text, DiagnosticDescriptor expected, IEnumerable<Diagnostic> actual)
    {
        int spanNum = 0;
        foreach (var d in actual)
        {
            TextSpan? expectedSpan = Location.None.SourceSpan;
            if (d.Location != Location.None)
            {
                expectedSpan = text.MakeTextSpan(spanNum);
                if (expectedSpan == null)
                {
                    Assert.Fail($"No span detected for diagnostic #{spanNum}, {d}");
                }
            }
 
            Assert.True(expected.Id == d.Id,
                $"Span {spanNum} doesn't match: expected {expected.Id} but got {d}");
 
            Assert.True(expectedSpan.Equals(d.Location.SourceSpan),
                $"Span {spanNum} doesn't match: expected {expectedSpan} but got {d.Location.SourceSpan}");
 
            if (expectedSpan != Location.None.SourceSpan)
            {
                spanNum++;
            }
        }
 
        if (text.MakeTextSpan(spanNum) != null)
        {
            Assert.Fail($"Diagnostic {spanNum} was not detected");
        }
    }
 
    public static IReadOnlyList<Diagnostic> FilterDiagnostics(this IEnumerable<Diagnostic> diagnostics, params DiagnosticDescriptor[] filter)
    {
        var filtered = new List<Diagnostic>();
        foreach (Diagnostic diagnostic in diagnostics)
        {
            foreach (var f in filter)
            {
                if (diagnostic.Id.Equals(f.Id, StringComparison.Ordinal))
                {
                    filtered.Add(diagnostic);
                    break;
                }
            }
        }
 
        return filtered;
    }
 
    public static IReadOnlyList<Diagnostic> FilterOutDiagnostics(this IEnumerable<Diagnostic> diagnostics, params DiagnosticDescriptor[] filter)
    {
        var filtered = new List<Diagnostic>();
        foreach (Diagnostic diagnostic in diagnostics)
        {
            bool keep = true;
            foreach (var f in filter)
            {
                if (diagnostic.Id.Equals(f.Id, StringComparison.Ordinal))
                {
                    keep = false;
                    break;
                }
            }
 
            if (keep)
            {
                filtered.Add(diagnostic);
            }
        }
 
        return filtered;
    }
 
    /// <summary>
    /// Runs a Roslyn generator given a Compilation.
    /// </summary>
    public static (IReadOnlyList<Diagnostic> diagnostics, ImmutableArray<GeneratedSourceResult> generatedSources) RunGenerator(
        Compilation compilation,
        IIncrementalGenerator generator,
        CancellationToken cancellationToken = default)
    {
        CSharpGeneratorDriver cgd = CSharpGeneratorDriver.Create(new[] { generator });
        GeneratorDriver gd = cgd.RunGenerators(compilation, cancellationToken);
 
        GeneratorDriverRunResult r = gd.GetRunResult();
        return (r.Results[0].Diagnostics, r.Results[0].GeneratedSources);
    }
 
    /// <summary>
    /// Runs a Roslyn generator over a set of source files.
    /// </summary>
    public static Task<(IReadOnlyList<Diagnostic> diagnostics, ImmutableArray<GeneratedSourceResult> generatedSources)> RunGenerator(
        ISourceGenerator generator,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources,
        AnalyzerConfigOptionsProvider? optionsProvider = null,
        bool includeBaseReferences = true,
        CancellationToken cancellationToken = default)
    {
        return RunGenerator(generator, references, sources, Empty.Enumerable<string>(), optionsProvider, includeBaseReferences, cancellationToken);
    }
 
    /// <summary>
    /// Runs a Roslyn generator over a set of source files.
    /// </summary>
    public static async Task<(IReadOnlyList<Diagnostic> diagnostics, ImmutableArray<GeneratedSourceResult> generatedSources)> RunGenerator(
        ISourceGenerator generator,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources,
        IEnumerable<string> preprocessorSymbols,
        AnalyzerConfigOptionsProvider? optionsProvider = null,
        bool includeBaseReferences = true,
        CancellationToken cancellationToken = default)
    {
        var proj = CreateTestProject(references, preprocessorSymbols, includeBaseReferences);
 
        var count = 0;
        foreach (var s in sources)
        {
            proj = proj.WithDocument($"src-{count++}.cs", s);
        }
 
        proj.CommitChanges();
        var comp = await proj!.GetCompilationAsync(CancellationToken.None).ConfigureAwait(false);
 
        var cgd = CSharpGeneratorDriver.Create(new[] { generator }, optionsProvider: optionsProvider);
        var gd = cgd.RunGenerators(comp!, cancellationToken);
 
        var r = gd.GetRunResult();
        return (Sort(r.Results[0].Diagnostics), r.Results[0].GeneratedSources);
    }
 
    /// <summary>
    /// Runs a Roslyn generator over a set of source files.
    /// </summary>
    public static Task<(IReadOnlyList<Diagnostic> diagnostics, ImmutableArray<GeneratedSourceResult> generatedSources)> RunGenerator(
        IIncrementalGenerator generator,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources,
        bool includeBaseReferences = true,
        CancellationToken cancellationToken = default)
    {
        return RunGenerator(generator, references, sources, preprocessorSymbols: Empty.Enumerable<string>(), includeBaseReferences: includeBaseReferences, cancellationToken: cancellationToken);
    }
 
    /// <summary>
    /// Runs a Roslyn generator over a set of source files.
    /// </summary>
    public static async Task<(IReadOnlyList<Diagnostic> diagnostics, ImmutableArray<GeneratedSourceResult> generatedSources)> RunGenerator(
        IIncrementalGenerator generator,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources,
        IEnumerable<string> preprocessorSymbols,
        bool includeBaseReferences = true,
        LanguageVersion langVersion = LanguageVersion.Preview,
        CancellationToken cancellationToken = default)
    {
        var proj = CreateTestProject(references, preprocessorSymbols, includeBaseReferences, langVersion);
 
        var count = 0;
        foreach (var s in sources)
        {
            proj = proj.WithDocument($"src-{count++}.cs", s);
        }
 
        proj.CommitChanges();
 
        Compilation? comp = await proj!.GetCompilationAsync(CancellationToken.None).ConfigureAwait(false);
 
        CSharpGeneratorDriver cgd = CSharpGeneratorDriver.Create(generator);
 
        GeneratorDriver gd = cgd.RunGenerators(comp!, cancellationToken);
 
        GeneratorDriverRunResult r = gd.GetRunResult();
 
        return (Sort(r.Results[0].Diagnostics), r.Results[0].GeneratedSources);
    }
 
    [Generator]
#pragma warning disable RS1036 // Specify analyzer banned API enforcement setting. Testing code.
    private sealed class Generator : ISourceGenerator
#pragma warning restore RS1036 // Specify analyzer banned API enforcement setting
    {
        private readonly ISyntaxReceiver _receiver;
 
        public Generator(ISyntaxReceiver receiver)
        {
            _receiver = receiver;
        }
 
        public void Execute(GeneratorExecutionContext context)
        {
            // Method intentionally left empty.
        }
 
        public void Initialize(GeneratorInitializationContext context) =>
            context.RegisterForSyntaxNotifications(() => _receiver);
    }
 
    /// <summary>
    /// Runs a Roslyn generator over a set of source files.
    /// </summary>
    public static async Task<Compilation> RunSyntaxContextReceiver(
        ISyntaxReceiver receiver,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources,
        bool includeBaseReferences = true)
    {
        var proj = CreateTestProject(references, includeBaseReferences);
        var count = 0;
        foreach (var s in sources)
        {
            proj = proj.WithDocument($"src-{count++}.cs", s);
        }
 
        proj.CommitChanges();
        var comp = await proj.GetCompilationAsync().ConfigureAwait(false);
        _ = CSharpGeneratorDriver.Create(new Generator(receiver)).RunGenerators(comp!);
        return comp!;
    }
 
    /// <summary>
    /// Runs a Roslyn generator over a set of source files.
    /// </summary>
    public static async Task<TParserOutput?> RunParser<TReceiver, TParserOutput>(
        TReceiver receiver,
        Func<TReceiver, Compilation, TParserOutput> parser,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources,
        bool includeBaseReferences = true)
        where TReceiver : ISyntaxReceiver
    {
        var comp = await RunSyntaxContextReceiver(receiver, references, sources, includeBaseReferences);
        return parser(receiver, comp);
    }
 
    /// <summary>
    /// Runs a Roslyn analyzer over a set of source files.
    /// </summary>
    public static async Task<IReadOnlyList<Diagnostic>> RunAnalyzer(
        DiagnosticAnalyzer analyzer,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources)
    {
        var proj = CreateTestProject(references);
 
        var count = 0;
        foreach (var s in sources)
        {
            proj = proj.WithDocument($"src-{count++}.cs", s);
        }
 
        proj.CommitChanges();
 
        var analyzers = ImmutableArray.Create(analyzer);
 
        var comp = await proj!.GetCompilationAsync().ConfigureAwait(false);
        var diags = await comp!.WithAnalyzers(analyzers).GetAllDiagnosticsAsync().ConfigureAwait(false);
        return Sort(diags);
    }
 
    private static IReadOnlyList<Diagnostic> Sort(ImmutableArray<Diagnostic> diags)
    {
        return diags.Sort((x, y) =>
        {
            if (x.Location.SourceSpan.Start < y.Location.SourceSpan.Start)
            {
                return -1;
            }
            else if (x.Location.SourceSpan.Start > y.Location.SourceSpan.Start)
            {
                return 1;
            }
 
            return 0;
        });
    }
 
    /// <summary>
    /// Runs a Roslyn analyzer and fixer.
    /// </summary>
    public static async Task<IReadOnlyList<string>> RunAnalyzerAndFixer(
        DiagnosticAnalyzer analyzer,
        CodeFixProvider fixer,
        IEnumerable<Assembly>? references,
        IEnumerable<string> sources,
        IEnumerable<string>? sourceNames = null,
        string? defaultNamespace = null,
        string? extraFile = null)
    {
        var proj = CreateTestProject(references);
 
        var count = 0;
        if (sourceNames != null)
        {
            var l = sourceNames.ToList();
            foreach (var s in sources)
            {
                proj = proj.WithDocument(l[count++], s);
            }
        }
        else
        {
            foreach (var s in sources)
            {
                proj = proj.WithDocument($"src-{count++}.cs", s);
            }
        }
 
        if (defaultNamespace != null)
        {
            proj = proj.WithDefaultNamespace(defaultNamespace);
        }
 
        proj.CommitChanges();
 
        var analyzers = ImmutableArray.Create(analyzer);
 
        while (true)
        {
            var comp = await proj!.GetCompilationAsync().ConfigureAwait(false);
            var diags = await comp!.WithAnalyzers(analyzers).GetAllDiagnosticsAsync().ConfigureAwait(false);
            if (diags.IsEmpty)
            {
                // no more diagnostics reported by the analyzers
                break;
            }
 
            var actions = new List<CodeAction>();
            foreach (var d in diags)
            {
                var doc = proj.GetDocument(d.Location.SourceTree);
 
                var context = new CodeFixContext(doc!, d, (action, _) => actions.Add(action), CancellationToken.None);
                await fixer.RegisterCodeFixesAsync(context).ConfigureAwait(false);
            }
 
            if (actions.Count == 0)
            {
                // nothing to fix
                break;
            }
 
            var operations = await actions[0].GetOperationsAsync(CancellationToken.None).ConfigureAwait(false);
            var solution = operations.OfType<ApplyChangesOperation>().Single().ChangedSolution;
            var changedProj = solution.GetProject(proj.Id);
            if (changedProj != proj)
            {
                proj = await RecreateProjectDocumentsAsync(changedProj!).ConfigureAwait(false);
            }
        }
 
        var results = new List<string>();
 
        if (sourceNames != null)
        {
            var l = sourceNames.ToList();
            for (int i = 0; i < count; i++)
            {
                var s = await proj.FindDocument(l[i]).GetTextAsync().ConfigureAwait(false);
                results.Add(Regex.Replace(s.ToString(), "\r\n", "\n", RegexOptions.IgnoreCase));
            }
        }
        else
        {
            for (int i = 0; i < count; i++)
            {
                var s = await proj.FindDocument($"src-{i}.cs").GetTextAsync().ConfigureAwait(false);
                results.Add(Regex.Replace(s.ToString(), "\r\n", "\n", RegexOptions.IgnoreCase));
            }
        }
 
        if (extraFile != null)
        {
            var s = await proj.FindDocument(extraFile).GetTextAsync().ConfigureAwait(false);
            results.Add(Regex.Replace(s.ToString(), "\r\n", "\n", RegexOptions.IgnoreCase));
        }
 
        return results;
    }
 
    private static async Task<Project> RecreateProjectDocumentsAsync(Project project)
    {
        foreach (var documentId in project.DocumentIds)
        {
            var document = project.GetDocument(documentId);
            document = await RecreateDocumentAsync(document!).ConfigureAwait(false);
            project = document.Project;
        }
 
        return project;
    }
 
    private static async Task<Document> RecreateDocumentAsync(Document document)
    {
        var newText = await document.GetTextAsync().ConfigureAwait(false);
        return document.WithText(SourceText.From(newText.ToString(), newText.Encoding, newText.ChecksumAlgorithm));
    }
}