File: DefaultExtensionAssemblyLoader.cs
Web Access
Project: ..\..\..\src\RazorSdk\Tool\Microsoft.NET.Sdk.Razor.Tool.csproj (rzc)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#nullable disable
 
using System.Reflection;
using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using System.Runtime.Loader;
using Microsoft.CodeAnalysis;
 
namespace Microsoft.NET.Sdk.Razor.Tool
{
    internal class DefaultExtensionAssemblyLoader : ExtensionAssemblyLoader
    {
        private readonly string _baseDirectory;
 
        private readonly object _lock = new();
        private readonly Dictionary<string, (Assembly assembly, AssemblyIdentity identity)> _loadedByPath;
        private readonly Dictionary<AssemblyIdentity, Assembly> _loadedByIdentity;
        private readonly Dictionary<string, AssemblyIdentity> _identityCache;
        private readonly Dictionary<string, List<string>> _wellKnownAssemblies;
 
        private ShadowCopyManager _shadowCopyManager;
 
        public DefaultExtensionAssemblyLoader(string baseDirectory)
        {
            _baseDirectory = baseDirectory;
 
            _loadedByPath = new Dictionary<string, (Assembly assembly, AssemblyIdentity identity)>(StringComparer.OrdinalIgnoreCase);
            _loadedByIdentity = new Dictionary<AssemblyIdentity, Assembly>();
            _identityCache = new Dictionary<string, AssemblyIdentity>(StringComparer.OrdinalIgnoreCase);
            _wellKnownAssemblies = new Dictionary<string, List<string>>(StringComparer.OrdinalIgnoreCase);
 
            LoadContext = new ExtensionAssemblyLoadContext(AssemblyLoadContext.GetLoadContext(typeof(ExtensionAssemblyLoader).Assembly), this);
        }
 
        protected AssemblyLoadContext LoadContext { get; }
 
        public override void AddAssemblyLocation(string filePath)
        {
            if (filePath == null)
            {
                throw new ArgumentNullException(nameof(filePath));
            }
 
            if (!Path.IsPathRooted(filePath))
            {
                throw new ArgumentException(nameof(filePath));
            }
 
            var assemblyName = Path.GetFileNameWithoutExtension(filePath);
            lock (_lock)
            {
                if (!_wellKnownAssemblies.TryGetValue(assemblyName, out var paths))
                {
                    paths = new List<string>();
                    _wellKnownAssemblies.Add(assemblyName, paths);
                }
 
                if (!paths.Contains(filePath))
                {
                    paths.Add(filePath);
                }
            }
        }
 
        public override Assembly Load(string assemblyName)
        {
            if (!AssemblyIdentity.TryParseDisplayName(assemblyName, out var identity))
            {
                return null;
            }
 
            lock (_lock)
            {
                // First, check if this loader already loaded the requested assembly:
                if (_loadedByIdentity.TryGetValue(identity, out var assembly))
                {
                    return assembly;
                }
 
                // Second, check if an assembly file of the same simple name was registered with the loader:
                if (_wellKnownAssemblies.TryGetValue(identity.Name, out var paths))
                {
                    // Multiple assemblies of the same simple name but different identities might have been registered.
                    // Load the one that matches the requested identity (if any).
                    foreach (var path in paths)
                    {
                        var candidateIdentity = GetIdentity(path);
 
                        if (identity.Equals(candidateIdentity))
                        {
                            return LoadFromPathUnsafe(path, candidateIdentity);
                        }
                    }
                }
 
                // We only support loading by name from 'well-known' paths. If you need to load something by
                // name and you get here, then that means we don't know where to look.
                return null;
            }
        }
 
        public override Assembly LoadFromPath(string filePath)
        {
            if (filePath == null)
            {
                throw new ArgumentNullException(nameof(filePath));
            }
 
            if (!Path.IsPathRooted(filePath))
            {
                throw new ArgumentException(nameof(filePath));
            }
 
            lock (_lock)
            {
                return LoadFromPathUnsafe(filePath, identity: null);
            }
        }
 
        private Assembly LoadFromPathUnsafe(string filePath, AssemblyIdentity identity)
        {
            // If we've already loaded the assembly by path there should be nothing else to do,
            // all of our data is up to date.
            if (_loadedByPath.TryGetValue(filePath, out var entry))
            {
                return entry.assembly;
            }
 
            // If we've already loaded the assembly by identity, then we might has some updating
            // to do.
            identity = identity ?? GetIdentity(filePath);
            if (identity != null && _loadedByIdentity.TryGetValue(identity, out var assembly))
            {
                // An assembly file might be replaced by another file with a different identity.
                // Last one wins.
                _loadedByPath[filePath] = (assembly, identity);
                return assembly;
            }
 
            // Ok we don't have this cached. Let's actually try to load the assembly.
            assembly = LoadFromPathUnsafeCore(CopyAssembly(filePath));
 
            identity = identity ?? AssemblyIdentity.FromAssemblyDefinition(assembly);
 
            // It's possible an assembly was loaded by two different paths. Just use the original then.
            if (_loadedByIdentity.TryGetValue(identity, out var duplicate))
            {
                assembly = duplicate;
            }
            else
            {
                _loadedByIdentity.Add(identity, assembly);
            }
 
            _loadedByPath[filePath] = (assembly, identity);
            return assembly;
        }
 
        private AssemblyIdentity GetIdentity(string filePath)
        {
            if (!_identityCache.TryGetValue(filePath, out var identity))
            {
                identity = ReadAssemblyIdentity(filePath);
                _identityCache.Add(filePath, identity);
            }
 
            return identity;
        }
 
        protected virtual string CopyAssembly(string filePath)
        {
            if (_baseDirectory == null)
            {
                // Don't shadow-copy when base directory is null. This means we're running as a CLI not
                // a server.
                return filePath;
            }
 
            if (_shadowCopyManager == null)
            {
                _shadowCopyManager = new ShadowCopyManager(_baseDirectory);
            }
 
            return _shadowCopyManager.AddAssembly(filePath);
        }
 
        protected virtual Assembly LoadFromPathUnsafeCore(string filePath)
        {
            return LoadContext.LoadFromAssemblyPath(filePath);
        }
 
        private static AssemblyIdentity ReadAssemblyIdentity(string filePath)
        {
            try
            {
                using (var stream = new FileStream(filePath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite | FileShare.Delete))
                using (var reader = new PEReader(stream))
                {
                    var metadataReader = reader.GetMetadataReader();
                    return metadataReader.GetAssemblyIdentity();
                }
            }
            catch
            {
            }
 
            return null;
        }
 
        private class ExtensionAssemblyLoadContext : AssemblyLoadContext
        {
            private readonly AssemblyLoadContext _parent;
            private readonly DefaultExtensionAssemblyLoader _loader;
 
            public ExtensionAssemblyLoadContext(AssemblyLoadContext parent, DefaultExtensionAssemblyLoader loader)
            {
                _parent = parent;
                _loader = loader;
            }
 
            protected override Assembly Load(AssemblyName assemblyName)
            {
                // Try to load from well-known paths. This will be called when loading a dependency of an extension.
                var assembly = _loader.Load(assemblyName.ToString());
                if (assembly != null)
                {
                    return assembly;
                }
 
                // If we don't have an entry, then fall back to the default load context. This allows extensions
                // to resolve assemblies that are provided by the host.
                return _parent.LoadFromAssemblyName(assemblyName);
            }
        }
    }
}