File: System\ComponentModel\Composition\Hosting\FilteredCatalog.DependenciesTraversal.cs
Web Access
Project: src\src\libraries\System.ComponentModel.Composition\src\System.ComponentModel.Composition.csproj (System.ComponentModel.Composition)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.ComponentModel.Composition.Primitives;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
 
namespace System.ComponentModel.Composition.Hosting
{
    public partial class FilteredCatalog
    {
        internal sealed class DependenciesTraversal : IComposablePartCatalogTraversal
        {
            private readonly IEnumerable<ComposablePartDefinition> _parts;
            private readonly Func<ImportDefinition, bool> _importFilter;
            private Dictionary<string, List<ComposablePartDefinition>>? _exportersIndex;
 
            public DependenciesTraversal(FilteredCatalog catalog, Func<ImportDefinition, bool> importFilter)
            {
                ArgumentNullException.ThrowIfNull(catalog);
                ArgumentNullException.ThrowIfNull(importFilter);
 
                _parts = catalog._innerCatalog;
                _importFilter = importFilter;
            }
 
            public void Initialize()
            {
                BuildExportersIndex();
            }
 
            private void BuildExportersIndex()
            {
                _exportersIndex = new Dictionary<string, List<ComposablePartDefinition>>();
                foreach (ComposablePartDefinition part in _parts)
                {
                    foreach (var export in part.ExportDefinitions)
                    {
                        AddToExportersIndex(export.ContractName, part);
                    }
                }
            }
 
            private void AddToExportersIndex(string contractName, ComposablePartDefinition part)
            {
                if (!_exportersIndex!.TryGetValue(contractName, out List<ComposablePartDefinition>? parts))
                {
                    parts = new List<ComposablePartDefinition>();
                    _exportersIndex.Add(contractName, parts);
                }
                parts.Add(part);
            }
 
            public bool TryTraverse(ComposablePartDefinition part, [NotNullWhen(true)] out IEnumerable<ComposablePartDefinition>? reachableParts)
            {
                reachableParts = null;
                List<ComposablePartDefinition>? reachablePartList = null;
 
                // Go through all part imports
                foreach (ImportDefinition import in part.ImportDefinitions.Where(_importFilter))
                {
                    // Find all parts that we know will import each export
                    List<ComposablePartDefinition>? candidateReachableParts = null;
                    Debug.Assert(_exportersIndex != null);
                    foreach (var contractName in import.GetCandidateContractNames(part))
                    {
                        if (_exportersIndex.TryGetValue(contractName, out candidateReachableParts))
                        {
                            // find if they actually match
                            foreach (var candidateReachablePart in candidateReachableParts)
                            {
                                foreach (ExportDefinition export in candidateReachablePart.ExportDefinitions)
                                {
                                    if (import.IsImportDependentOnPart(candidateReachablePart, export, part.IsGeneric() != candidateReachablePart.IsGeneric()))
                                    {
                                        reachablePartList ??= new List<ComposablePartDefinition>();
                                        reachablePartList.Add(candidateReachablePart);
                                    }
                                }
                            }
                        }
                    }
                }
 
                reachableParts = reachablePartList;
                return (reachableParts != null);
            }
        }
    }
}