File: ConfigurablePInvokePolicy.cs
Web Access
Project: src\src\runtime\src\coreclr\tools\aot\ILCompiler\ILCompiler_inbuild.csproj (ilc)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;

using Internal.IL;
using Internal.TypeSystem;

namespace ILCompiler
{
    internal sealed class ConfigurablePInvokePolicy : PInvokeILEmitterConfiguration
    {
        private readonly TargetDetails _target;
        private readonly Dictionary<string, HashSet<string>> _directPInvokes;

        public ConfigurablePInvokePolicy(TargetDetails target, IReadOnlyList<string> directPInvokes, IReadOnlyList<string> directPInvokeLists)
        {
            _directPInvokes = new Dictionary<string, HashSet<string>>(target.IsWindows ? StringComparer.OrdinalIgnoreCase : StringComparer.Ordinal);

            // * is always a direct call
            _directPInvokes.Add("*", null);

            foreach (var file in directPInvokeLists)
            {
                foreach (var entry in File.ReadLines(file))
                {
                    AddDirectPInvoke(entry);
                }
            }

            foreach (var entry in directPInvokes)
            {
                AddDirectPInvoke(entry);
            }

            _target = target;
        }

        private void AddDirectPInvoke(string entry)
        {
            // Ignore comments
            if (entry.StartsWith('#'))
                return;

            entry = entry.Trim();

            // Ignore empty entries
            if (string.IsNullOrEmpty(entry))
                return;

            int separator = entry.IndexOf('!');

            if (separator != -1)
            {
                string libraryName = entry.Substring(0, separator);
                string entrypointName = entry.Substring(separator + 1);

                if (_directPInvokes.TryGetValue(libraryName, out HashSet<string> entrypointSet))
                {
                    // All entrypoints from the library are direct
                    if (entrypointSet == null)
                        return;
                }
                else
                {
                    _directPInvokes.Add(libraryName, entrypointSet = new HashSet<string>());
                }

                entrypointSet.Add(entrypointName);
            }
            else
            {
                // All entrypoints from the library are direct
                _directPInvokes[entry] = null;
            }
        }

        private IEnumerable<string> ModuleNameVariations(string name)
        {
            yield return name;

            if (_target.IsWindows)
            {
                string suffix = ".dll";

                if (name.EndsWith(suffix, StringComparison.OrdinalIgnoreCase))
                    yield return name.Substring(0, name.Length - suffix.Length);
            }
            else
            {
                string suffix = _target.IsApplePlatform ? ".dylib" : ".so";
                bool hasSharedLibraryExtension = name.EndsWith(suffix, StringComparison.Ordinal);
                const string LibPrefix = "lib";
                bool hasLibPrefix = name.StartsWith(LibPrefix, StringComparison.Ordinal);

                if (hasSharedLibraryExtension)
                    yield return name.Substring(0, name.Length - suffix.Length);
                if (hasLibPrefix)
                    yield return name.Substring(LibPrefix.Length);
                if (hasLibPrefix && hasSharedLibraryExtension)
                    yield return name.Substring(LibPrefix.Length, name.Length - suffix.Length - LibPrefix.Length);
            }
        }

        private IEnumerable<string> EntryPointNameVariations(string name, PInvokeFlags flags)
        {
            if (_target.IsWindows && !flags.ExactSpelling)
            {
                // Mirror CharSet normalization from Marshaller.CreateMarshaller
                bool isAnsi = flags.CharSet switch
                {
                    CharSet.Ansi => true,
                    CharSet.Unicode => false,
                    CharSet.Auto => false,
                    _ => true
                };

                if (isAnsi)
                {
                    // For ANSI, look for the user-provided entry point name first.
                    // If that does not exist, try the charset suffix.
                    yield return name;
                    yield return name + "A";
                }
                else
                {
                    // For Unicode, look for the entry point name with the charset suffix first.
                    // The 'W' API takes precedence over the undecorated one.
                    yield return name + "W";
                    yield return name;
                }
            }
            else
            {
                yield return name;
            }
        }

        public override bool GenerateDirectCall(MethodDesc method, out string externName)
        {
            var pInvokeMetadata = method.GetPInvokeMethodMetadata();

            foreach (var moduleName in ModuleNameVariations(pInvokeMetadata.Module))
            {
                if (_directPInvokes.TryGetValue(moduleName, out HashSet<string> entrypoints))
                {
                    string entryPointMetadataName = pInvokeMetadata.Name ?? method.GetName();

                    if (entrypoints == null)
                    {
                        externName = entryPointMetadataName;
                        return true;
                    }

                    foreach (var entryPointName in EntryPointNameVariations(entryPointMetadataName, pInvokeMetadata.Flags))
                    {
                        if (entrypoints.Contains(entryPointName))
                        {
                            externName = entryPointName;
                            return true;
                        }
                    }
                }
            }

            externName = null;
            return false;
        }
    }
}