File: SemanticSearch\AbstractSemanticSearchService.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#if NET6_0_OR_GREATER
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Classification;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Emit;
using Microsoft.CodeAnalysis.ErrorReporting;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.FindUsages;
using Microsoft.CodeAnalysis.Host;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.Internal.Log;
using Microsoft.CodeAnalysis.NavigateTo;
using Microsoft.CodeAnalysis.Notification;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Tags;
using Microsoft.CodeAnalysis.Text;
using Microsoft.CodeAnalysis.Threading;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.SemanticSearch;
 
internal abstract partial class AbstractSemanticSearchService : ISemanticSearchService
{
    internal sealed class LoadContext() : AssemblyLoadContext("SemanticSearchLoadContext", isCollectible: true)
    {
        private readonly AssemblyLoadContext _current = GetLoadContext(typeof(LoadContext).Assembly)!;
 
        protected override Assembly? Load(AssemblyName assemblyName)
            => _current.LoadFromAssemblyName(assemblyName);
 
        protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
            => IntPtr.Zero;
    }
 
    private readonly struct CompiledQuery(MemoryStream peStream, MemoryStream pdbStream, SourceText text) : IDisposable
    {
        public MemoryStream PEStream { get; } = peStream;
        public MemoryStream PdbStream { get; } = pdbStream;
        public SourceText Text { get; } = text;
 
        public void Dispose()
        {
            PEStream.Dispose();
            PdbStream.Dispose();
        }
    }
 
    /// <summary>
    /// Mapping from the parameter type of the <c>Find</c> method to the <see cref="QueryKind"/> value.
    /// </summary>
    private static readonly ImmutableDictionary<Type, QueryKind> s_queryKindByParameterType = ImmutableDictionary<Type, QueryKind>.Empty
        .Add(typeof(Compilation), QueryKind.Compilation)
        .Add(typeof(INamespaceSymbol), QueryKind.Namespace)
        .Add(typeof(INamedTypeSymbol), QueryKind.NamedType)
        .Add(typeof(IMethodSymbol), QueryKind.Method)
        .Add(typeof(IFieldSymbol), QueryKind.Field)
        .Add(typeof(IPropertySymbol), QueryKind.Property)
        .Add(typeof(IEventSymbol), QueryKind.Event);
 
    private ImmutableDictionary<CompiledQueryId, CompiledQuery> _compiledQueries = ImmutableDictionary<CompiledQueryId, CompiledQuery>.Empty;
 
    protected abstract Compilation CreateCompilation(SourceText query, IEnumerable<MetadataReference> references, SolutionServices services, out SyntaxTree queryTree, CancellationToken cancellationToken);
 
    public CompileQueryResult CompileQuery(
        SolutionServices services,
        string query,
        string referenceAssembliesDir,
        TraceSource traceSource,
        CancellationToken cancellationToken)
    {
        var metadataService = services.GetRequiredService<IMetadataService>();
        var metadataReferences = SemanticSearchUtilities.GetMetadataReferences(metadataService, referenceAssembliesDir);
        var queryText = SemanticSearchUtilities.CreateSourceText(query);
        var queryCompilation = CreateCompilation(queryText, metadataReferences, services, out var queryTree, cancellationToken);
 
        cancellationToken.ThrowIfCancellationRequested();
 
        var emitOptions = new EmitOptions(
            debugInformationFormat: DebugInformationFormat.PortablePdb,
            instrumentationKinds: [InstrumentationKind.StackOverflowProbing, InstrumentationKind.ModuleCancellation]);
 
        var peStream = new MemoryStream();
        var pdbStream = new MemoryStream();
 
        var emitDifferenceTimer = SharedStopwatch.StartNew();
        var emitResult = queryCompilation.Emit(peStream, pdbStream, options: emitOptions, cancellationToken: cancellationToken);
        var emitTime = emitDifferenceTimer.Elapsed;
 
        CompiledQueryId queryId;
        ImmutableArray<QueryCompilationError> errors;
        if (emitResult.Success)
        {
            queryId = CompiledQueryId.Create(queryCompilation.Language);
            Contract.ThrowIfFalse(ImmutableInterlocked.TryAdd(ref _compiledQueries, queryId, new CompiledQuery(peStream, pdbStream, queryText)));
 
            errors = [];
        }
        else
        {
            queryId = default;
 
            foreach (var diagnostic in emitResult.Diagnostics)
            {
                if (diagnostic.Severity == DiagnosticSeverity.Error)
                {
                    traceSource.TraceInformation($"Semantic search query compilation failed: {diagnostic}");
                }
            }
 
            errors = emitResult.Diagnostics.SelectAsArray(
                d => d.Severity == DiagnosticSeverity.Error,
                d => new QueryCompilationError(d.Id, d.GetMessage(), (d.Location.SourceTree == queryTree) ? d.Location.SourceSpan : default));
        }
 
        return new CompileQueryResult(queryId, errors, emitTime);
    }
 
    public void DiscardQuery(CompiledQueryId queryId)
    {
        Contract.ThrowIfFalse(ImmutableInterlocked.TryRemove(ref _compiledQueries, queryId, out var compiledQuery));
        compiledQuery.Dispose();
    }
 
    public async Task<ExecuteQueryResult> ExecuteQueryAsync(
        Solution solution,
        CompiledQueryId queryId,
        ISemanticSearchResultsObserver observer,
        OptionsProvider<ClassificationOptions> classificationOptions,
        TraceSource traceSource,
        CancellationToken cancellationToken)
    {
        Contract.ThrowIfFalse(ImmutableInterlocked.TryRemove(ref _compiledQueries, queryId, out var query));
 
        try
        {
            var executionTime = TimeSpan.Zero;
 
            var remainingProgressItemCount = solution.ProjectIds.Count;
            await observer.AddItemsAsync(remainingProgressItemCount, cancellationToken).ConfigureAwait(false);
 
            query.PEStream.Position = 0;
            query.PdbStream.Position = 0;
            var loadContext = new LoadContext();
            try
            {
                var queryAssembly = loadContext.LoadFromStream(query.PEStream, query.PdbStream);
                SetModuleCancellationToken(queryAssembly, cancellationToken);
 
                SetToolImplementations(
                    queryAssembly,
                    new ReferencingSyntaxFinder(solution, cancellationToken),
                    new SemanticModelGetter(solution, cancellationToken));
 
                if (!TryGetFindMethod(queryAssembly, out var findMethod, out var queryKind, out var errorMessage, out var errorMessageArgs))
                {
                    traceSource.TraceInformation($"Semantic search failed: {errorMessage}");
                    return CreateResult(errorMessage, errorMessageArgs);
                }
 
                var invocationContext = new QueryExecutionContext(query.Text, findMethod, observer, classificationOptions, traceSource);
                try
                {
                    await invocationContext.InvokeAsync(solution, queryKind, cancellationToken).ConfigureAwait(false);
 
                    if (invocationContext.TerminatedWithException)
                    {
                        return CreateResult(FeaturesResources.Semantic_search_query_terminated_with_exception);
                    }
                }
                finally
                {
                    executionTime = new TimeSpan(invocationContext.ExecutionTime);
                    remainingProgressItemCount -= invocationContext.ProcessedProjectCount;
                }
            }
            finally
            {
                loadContext.Unload();
 
                // complete the remaining items (in case the search gets interrupted)
                if (remainingProgressItemCount > 0)
                {
                    await observer.ItemsCompletedAsync(remainingProgressItemCount, cancellationToken).ConfigureAwait(false);
                }
            }
 
            return CreateResult(errorMessage: null);
 
            ExecuteQueryResult CreateResult(string? errorMessage, params string[]? args)
                => new(errorMessage, args, executionTime);
        }
        catch (Exception e) when (FatalError.ReportAndPropagateUnlessCanceled(e, cancellationToken, ErrorSeverity.Critical))
        {
            throw ExceptionUtilities.Unreachable();
        }
        finally
        {
            query.Dispose();
        }
    }
 
    private static void SetModuleCancellationToken(Assembly queryAssembly, CancellationToken cancellationToken)
    {
        var pidType = queryAssembly.GetType("<PrivateImplementationDetails>", throwOnError: true);
        Contract.ThrowIfNull(pidType);
        var moduleCancellationTokenField = pidType.GetField("ModuleCancellationToken", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static);
        Contract.ThrowIfNull(moduleCancellationTokenField);
        moduleCancellationTokenField.SetValue(null, cancellationToken);
    }
 
    private static void SetToolImplementations(Assembly queryAssembly, ReferencingSyntaxFinder finder, SemanticModelGetter semanticModelGetter)
    {
        var toolsType = queryAssembly.GetType(SemanticSearchUtilities.ToolsTypeName, throwOnError: true);
        Contract.ThrowIfNull(toolsType);
 
        SetFieldValue(SemanticSearchUtilities.FindReferencingSyntaxNodesImplName, new Func<ISymbol, IEnumerable<SyntaxNode>>(finder.Find));
        SetFieldValue(SemanticSearchUtilities.GetSemanticModelImplName, new Func<SyntaxTree, Task<SemanticModel>>(semanticModelGetter.GetSemanticModelAsync));
 
        void SetFieldValue(string fieldName, object value)
        {
            var field = toolsType.GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Static);
            Contract.ThrowIfNull(field);
            field.SetValue(null, value);
        }
    }
 
    private static bool TryGetFindMethod(
        Assembly queryAssembly,
        [NotNullWhen(true)] out MethodInfo? method,
        out QueryKind queryKind,
        [NotNullWhen(false)] out string? error,
        out string[]? errorMessageArgs)
    {
        method = null;
        errorMessageArgs = null;
        queryKind = default;
 
        Type? program;
        try
        {
            program = queryAssembly.GetType(WellKnownMemberNames.TopLevelStatementsEntryPointTypeName, throwOnError: true);
        }
        catch (Exception e)
        {
            error = FeaturesResources.Unable_to_load_type_0_1;
            errorMessageArgs = [WellKnownMemberNames.TopLevelStatementsEntryPointTypeName, e.Message];
            return false;
        }
 
        Contract.ThrowIfNull(program);
 
        using var _ = ArrayBuilder<MethodInfo>.GetInstance(out var candidates);
        try
        {
            foreach (var candidate in program.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
            {
                if (candidate.Name.StartsWith($"<{WellKnownMemberNames.TopLevelStatementsEntryPointMethodName}>g__{SemanticSearchUtilities.FindMethodName}|"))
                {
                    candidates.Add(candidate);
                }
            }
 
            if (candidates is [])
            {
                error = string.Format(FeaturesResources.The_query_does_not_specify_0_top_level_function, SemanticSearchUtilities.FindMethodName);
                return false;
            }
 
            candidates.RemoveAll(candidate => candidate.IsGenericMethod || !candidate.IsStatic);
            if (candidates is [])
            {
                error = string.Format(FeaturesResources.Method_0_must_be_static_and_non_generic, SemanticSearchUtilities.FindMethodName);
                return false;
            }
 
            if (candidates.Count > 1)
            {
                error = string.Format(FeaturesResources.The_query_specifies_multiple_top_level_functions_1, SemanticSearchUtilities.FindMethodName);
                return false;
            }
 
            method = candidates[0];
 
            if (method.GetParameters() is not [var parameter])
            {
                error = string.Format(FeaturesResources.The_query_specifies_multiple_top_level_functions_1, SemanticSearchUtilities.FindMethodName);
                return false;
            }
 
            if (!s_queryKindByParameterType.TryGetValue(parameter.ParameterType, out queryKind))
            {
                error = string.Format(
                    FeaturesResources.Parameter_type_0_is_not_among_supported_types_1,
                    parameter.ParameterType,
                    string.Join(", ", s_queryKindByParameterType.Keys.Select(t => $"'{t.Name}'")));
 
                return false;
            }
 
            if (method.ReturnType != typeof(IEnumerable<ISymbol>) &&
                method.ReturnType != typeof(IAsyncEnumerable<ISymbol>))
            {
                error = string.Format(
                    FeaturesResources.Return_type_0_is_not_among_supported_types_1,
                    method.ReturnType,
                    "'IEnumerable<ISymbol>', 'IAsyncEnumerable<ISymbol>'");
 
                return false;
            }
        }
        catch (Exception e)
        {
            error = e.Message;
            return false;
        }
 
        error = null;
        return true;
    }
}
#endif