File: Platform\CoreClr\TestExecutionLoadContext.cs
Web Access
Project: src\src\Compilers\Test\Core\Microsoft.CodeAnalysis.Test.Utilities.csproj (Microsoft.CodeAnalysis.Test.Utilities)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#if NET
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Numerics;
using System.Reflection;
using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using System.Runtime.Loader;
using System.Text;
using Roslyn.Test.Utilities;
using Roslyn.Utilities;
 
namespace Roslyn.Test.Utilities.CoreClr
{
    internal sealed class TestExecutionLoadContext : AssemblyLoadContext
    {
        private readonly ImmutableDictionary<string, ModuleData> _dependencies;
 
        public TestExecutionLoadContext(ImmutableArray<ModuleData> dependencies)
            : base(isCollectible: true)
        {
            _dependencies = dependencies.ToImmutableDictionary(d => d.FullName, StringComparer.Ordinal);
        }
 
        protected override Assembly? Load(AssemblyName assemblyName)
        {
            Debug.Assert(assemblyName.Name is not null);
 
            var comparer = StringComparer.OrdinalIgnoreCase;
            var comparison = StringComparison.OrdinalIgnoreCase;
            if (assemblyName.Name.StartsWith("System.", comparison) ||
                assemblyName.Name.StartsWith("Microsoft.", comparison) ||
                comparer.Equals(assemblyName.Name, "mscorlib") ||
                comparer.Equals(assemblyName.Name, "System") ||
                comparer.Equals(assemblyName.Name, "netstandard"))
            {
                return null;
            }
 
            if (_dependencies.TryGetValue(assemblyName.FullName, out var moduleData))
            {
                return LoadImageAsAssembly(moduleData.Image);
            }
 
            return null;
        }
 
        private Assembly LoadImageAsAssembly(ImmutableArray<byte> mainImage)
        {
            using var assemblyStream = new MemoryStream(mainImage.ToArray());
            return LoadFromStream(assemblyStream);
        }
 
        internal (int ExitCode, string Output, string ErrorOutput) Execute(ModuleData mainModuleData, string[] mainArgs, int? maxOutputLength)
        {
            var mainAssembly = LoadImageAsAssembly(mainModuleData.Image);
            var entryPoint = mainAssembly.EntryPoint;
            Debug.Assert(entryPoint is not null);
 
            int exitCode = 0;
            SharedConsole.CaptureOutput(() =>
            {
                var count = entryPoint.GetParameters().Length;
                object[] args;
                if (count == 0)
                {
                    args = Array.Empty<object>();
                }
                else if (count == 1)
                {
                    args = [mainArgs ?? []];
                }
                else
                {
                    throw new Exception("Unrecognized entry point");
                }
 
                exitCode = entryPoint.Invoke(null, args) is int exit ? exit : 0;
            }, maxOutputLength, out var stdOut, out var stdErr);
 
            return (exitCode, stdOut, stdErr);
        }
 
        public SortedSet<string> GetMemberSignaturesFromMetadata(string fullyQualifiedTypeName, string memberName, IEnumerable<ModuleDataId> searchModules)
        {
            try
            {
                var signatures = new SortedSet<string>();
                foreach (var id in searchModules)
                {
                    var name = new AssemblyName(id.FullName);
                    var assembly = LoadFromAssemblyName(name);
                    foreach (var signature in MetadataSignatureHelper.GetMemberSignatures(assembly, fullyQualifiedTypeName, memberName))
                    {
                        signatures.Add(signature);
                    }
                }
                return signatures;
            }
            catch (Exception ex)
            {
                throw new Exception($"Error getting signatures {fullyQualifiedTypeName}.{memberName}", ex);
            }
        }
    }
}
#endif