File: GenerateFilteredReferenceAssembliesTask.cs
Web Access
Project: src\src\Tools\SemanticSearch\BuildTask\SemanticSearch.BuildTask.csproj (SemanticSearch.BuildTask)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Reflection.Metadata;
using System.Reflection.Metadata.Ecma335;
using System.Reflection.PortableExecutable;
using System.Security.Cryptography;
using System.Text.RegularExpressions;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;
using Microsoft.CodeAnalysis.CSharp;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.Tools;
 
internal readonly record struct ApiPattern(
    SymbolKindFlags SymbolKinds,
    Regex MetadataNamePattern,
    bool IsIncluded);
 
[Flags]
internal enum SymbolKindFlags
{
    None = 0,
    NamedType = 1,
    Method = 1 << 1,
    Field = 1 << 3,
}
 
/// <summary>
/// The task transforms given assemblies by changing the visibility of members defined in these assemblies
/// based on filter patterns specified in the corresponding <see cref="ApiSets"/>.
/// <see cref="ApiSets"/> are text files whose file names (without extension) match the file names of <see cref="References"/>.
/// Each API set specifies a list of patterns that define which members should be included or excluded from the output assembly.
/// All excluded members are made internal or private.
/// </summary>
public sealed class GenerateFilteredReferenceAssembliesTask : Task
{
    private static readonly Regex s_lineSyntax = new("""
        ^
        \s*
        (?<Inclusion>[+|-]?)
        ((?<Kinds>[A-Za-z]+):)?
        (?<MetadataName>[^#]*)
        ([#].*)?
        $
        """, RegexOptions.Singleline | RegexOptions.IgnorePatternWhitespace);
 
    [Required]
    public ITaskItem[] ApiSets { get; private set; } = null!;
 
    [Required]
    public ITaskItem[] References { get; private set; } = null!;
 
    [Required]
    public string OutputDir { get; private set; } = null!;
 
    public override bool Execute()
    {
        try
        {
            ExecuteImpl();
        }
        catch (Exception e)
        {
            Log.LogError($"GenerateFilteredReferenceAssembliesTask failed with exception:{Environment.NewLine}{e}");
        }
 
        return !Log.HasLoggedErrors;
    }
 
    private void ExecuteImpl()
    {
        ExecuteImpl(ApiSets.Select(item => (item.ItemSpec, (IReadOnlyList<string>)File.ReadAllLines(item.ItemSpec))));
    }
 
    internal void ExecuteImpl(IEnumerable<(string apiSpecPath, IReadOnlyList<string> lines)> apiSets)
    {
        var referencesByName = References.ToDictionary(r => Path.GetFileNameWithoutExtension(r.ItemSpec), r => r.ItemSpec);
 
        foreach (var (specPath, filters) in apiSets)
        {
            var assemblyName = Path.GetFileNameWithoutExtension(specPath);
            if (!referencesByName.TryGetValue(assemblyName, out var originalReferencePath))
            {
                Log.LogWarning($"Assembly '{assemblyName}' not found among project references");
                continue;
            }
 
            var filteredReferencePath = Path.Combine(OutputDir, assemblyName + ".dll");
            var errors = new List<(string message, int line)>();
            var patterns = new List<ApiPattern>();
            ParseApiPatterns(filters, errors, patterns);
 
            foreach (var (message, line) in errors)
            {
                Log.LogWarning($"Invalid API pattern at {specPath} line {line}: {message}");
            }
 
            var peImageBuffer = File.ReadAllBytes(originalReferencePath);
            Rewrite(peImageBuffer, patterns.ToImmutableArray());
 
            try
            {
                File.WriteAllBytes(filteredReferencePath, peImageBuffer);
            }
            catch when (File.Exists(filteredReferencePath))
            {
                // Another instance of the task might already be writing the content. 
                Log.LogMessage($"Output file '{filteredReferencePath}' already exists.");
            }
        }
    }
 
    internal static void ParseApiPatterns(IReadOnlyList<string> lines, List<(string message, int line)> errors, List<ApiPattern> patterns)
    {
        for (var i = 0; i < lines.Count; i++)
        {
            var line = lines[i];
 
            var match = s_lineSyntax.Match(line);
            if (!match.Success)
            {
                errors.Add(("unable to parse", i + 1));
                continue;
            }
 
            var inclusion = match.Groups["Inclusion"].Value;
            var kinds = match.Groups["Kinds"].Value;
            var metadataName = match.Groups["MetadataName"].Value;
 
            var hasSymbolKindError = false;
            var symbolKinds = SymbolKindFlags.None;
            foreach (var kind in kinds)
            {
                symbolKinds |= kind switch
                {
                    'F' => SymbolKindFlags.Field,
                    'M' => SymbolKindFlags.Method,
                    'T' => SymbolKindFlags.NamedType,
                    _ => Unexpected()
                };
 
                SymbolKindFlags Unexpected()
                {
                    hasSymbolKindError = true;
                    errors.Add(($"unexpected symbol kind: '{kind}'", i + 1));
                    return SymbolKindFlags.None;
                }
            }
 
            if (hasSymbolKindError)
            {
                continue;
            }
 
            if (symbolKinds == SymbolKindFlags.None)
            {
                symbolKinds = SymbolKindFlags.NamedType;
            }
 
            if (metadataName is "")
            {
                if (inclusion is not "" || kinds is not "")
                {
                    errors.Add(("expected metadata name", i + 1));
                }
 
                continue;
            }
 
            patterns.Add(new()
            {
                SymbolKinds = symbolKinds,
                MetadataNamePattern = ParseApiPattern(metadataName),
                IsIncluded = inclusion is not ['-']
            });
        }
    }
 
    /// <summary>
    /// Interprets `*` as `.*` and escapes the rest of regex-special characters.
    /// </summary>
    internal static Regex ParseApiPattern(string pattern)
        => new("^" + string.Join(".*", pattern.Trim().Split('*').Select(Regex.Escape)) + "$",
            RegexOptions.Singleline | RegexOptions.IgnorePatternWhitespace | RegexOptions.Compiled);
 
    internal static void GetAllMembers(
        Compilation compilation,
        List<INamedTypeSymbol> types,
        List<IMethodSymbol> methods,
        List<IFieldSymbol> fields)
    {
        Recurse(compilation.GlobalNamespace.GetMembers());
 
        void Recurse(IEnumerable<ISymbol> members)
        {
            foreach (var member in members)
            {
                switch (member)
                {
                    case INamedTypeSymbol type:
                        if (type.MetadataToken != 0)
                        {
                            types.Add(type);
                            Recurse(type.GetMembers());
                        }
                        break;
 
                    case IMethodSymbol method:
                        if (method.MetadataToken != 0)
                        {
                            methods.Add(method);
                        }
                        break;
 
                    case IFieldSymbol field:
                        if (field.MetadataToken != 0)
                        {
                            fields.Add(field);
                        }
                        break;
 
                    case INamespaceSymbol ns:
                        Recurse(ns.GetMembers());
                        break;
                }
            }
        }
    }
 
    private static bool IsIncluded(ISymbol symbol, ImmutableArray<ApiPattern> patterns)
    {
        var id = symbol.GetDocumentationCommentId();
        Debug.Assert(id is [_, ':', ..]);
        id = id[2..];
 
        var kind = GetKindFlags(symbol);
 
        // Type symbols areconsidered excluded by default.
        // Member symbols are included by default since their type limits the effective visibility.
        var isIncluded = symbol is not INamedTypeSymbol;
 
        foreach (var pattern in patterns)
        {
            if ((pattern.SymbolKinds & kind) == kind && pattern.MetadataNamePattern.IsMatch(id))
            {
                isIncluded = pattern.IsIncluded;
            }
        }
 
        return isIncluded;
    }
 
    private static SymbolKindFlags GetKindFlags(ISymbol symbol)
        => symbol.Kind switch
        {
            SymbolKind.Field => SymbolKindFlags.Field,
            SymbolKind.Method => SymbolKindFlags.Method,
            SymbolKind.NamedType => SymbolKindFlags.NamedType,
            _ => throw ExceptionUtilities.UnexpectedValue(symbol.Kind)
        };
 
    internal static unsafe void Rewrite(byte[] peImage, ImmutableArray<ApiPattern> patterns)
    {
        // Include all APIs if no patterns are specified.
        if (patterns.IsEmpty)
        {
            return;
        }
 
        using var readableStream = new MemoryStream(peImage, writable: false);
        var metadataRef = MetadataReference.CreateFromStream(readableStream);
        var compilation = CSharpCompilation.Create("Metadata", references: [metadataRef]);
 
        // Collect all member definitions that have visibility flags:
        var types = new List<INamedTypeSymbol>();
        var methods = new List<IMethodSymbol>();
        var fields = new List<IFieldSymbol>();
        GetAllMembers(compilation, types, methods, fields);
 
        // Update visibility flags:
        using var writableStream = new MemoryStream(peImage, writable: true);
        using var peReader = new PEReader(writableStream);
        using var writer = new BinaryWriter(writableStream);
 
        var headers = peReader.PEHeaders;
        Debug.Assert(headers.PEHeader != null);
 
        var metadataReader = peReader.GetMetadataReader();
        var metadataOffset = peReader.PEHeaders.MetadataStartOffset;
 
        UpdateTypeDefinitions(
            writer,
            metadataReader,
            patterns,
            types.OrderBy(t => t.MetadataToken).ToImmutableArray(),
            metadataOffset);
 
        UpdateMethodDefinitions(
            writer,
            metadataReader,
            patterns,
            methods.OrderBy(t => t.MetadataToken).ToImmutableArray(),
            metadataOffset);
 
        UpdateFieldDefinitions(
            writer,
            metadataReader,
            patterns,
            fields.OrderBy(t => t.MetadataToken).ToImmutableArray(),
            metadataOffset);
 
        // unsign:
        if (headers.PEHeader.CertificateTableDirectory.Size > 0)
        {
            var certificateTableDirectoryOffset = (headers.PEHeader.Magic == PEMagic.PE32Plus) ? 144 : 128;
            writableStream.Position = peReader.PEHeaders.PEHeaderStartOffset + certificateTableDirectoryOffset;
            writer.Write((long)0);
        }
 
        writer.Flush();
 
        // update mvid:
        var moduleDef = metadataReader.GetModuleDefinition();
        var mvidOffset = metadataOffset + metadataReader.GetHeapMetadataOffset(HeapIndex.Guid) + (MetadataTokens.GetHeapOffset(moduleDef.Mvid) - 1) * sizeof(Guid);
#if DEBUG
        writableStream.Position = mvidOffset;
        Debug.Assert(metadataReader.GetGuid(moduleDef.Mvid) == ReadGuid(writableStream));
#endif
        var newMvid = CreateMvid(writableStream);
        writableStream.Position = mvidOffset;
        WriteGuid(writer, newMvid);
 
        writer.Flush();
    }
 
    private static unsafe TSymbol? GetSymbolWithToken<TSymbol>(ImmutableArray<TSymbol> symbols, ref int symbolIndex, EntityHandle handle) where TSymbol : class, ISymbol
        // If the current definition does not have corresponding symbol,
        // we couldn't decode the symbol from metadata. Treat such definition as excluded.
        => (symbolIndex < symbols.Length && symbols[symbolIndex].MetadataToken == MetadataTokens.GetToken(handle)) ? symbols[symbolIndex++] : null;
 
    private static unsafe void UpdateTypeDefinitions(BinaryWriter writer, MetadataReader metadataReader, ImmutableArray<ApiPattern> patterns, ImmutableArray<INamedTypeSymbol> symbols, int metadataOffset)
    {
        var tableOffset = metadataOffset + metadataReader.GetTableMetadataOffset(TableIndex.TypeDef);
        var tableRowSize = metadataReader.GetTableRowSize(TableIndex.TypeDef);
        var symbolIndex = 0;
 
        foreach (var handle in metadataReader.TypeDefinitions)
        {
            var symbol = GetSymbolWithToken(symbols, ref symbolIndex, handle);
            if (symbol == null || !IsIncluded(symbol, patterns))
            {
                var typeDef = metadataReader.GetTypeDefinition(handle);
 
                // reduce visibility so that the type is not visible outside the assembly:
                var oldVisibility = typeDef.Attributes & TypeAttributes.VisibilityMask;
                var newVisibility = oldVisibility switch
                {
                    TypeAttributes.Public => TypeAttributes.NotPublic,
                    TypeAttributes.NestedPublic or TypeAttributes.NestedFamily or TypeAttributes.NestedFamORAssem => TypeAttributes.NestedAssembly,
                    _ => oldVisibility
                };
 
                if (oldVisibility == newVisibility)
                {
                    continue;
                }
 
                // Type attributes are store as the first field of the row and are 4B
                var offset = tableOffset + (MetadataTokens.GetRowNumber(handle) - 1) * tableRowSize + 0;
#if DEBUG
                writer.BaseStream.Position = offset;
                Debug.Assert((TypeAttributes)ReadUInt32(writer.BaseStream) == typeDef.Attributes);
#endif
                writer.BaseStream.Position = offset;
                Debug.Assert(BitConverter.IsLittleEndian);
                writer.Write((uint)(typeDef.Attributes & ~TypeAttributes.VisibilityMask | newVisibility));
            }
        }
    }
 
    private static unsafe void UpdateMethodDefinitions(BinaryWriter writer, MetadataReader metadataReader, ImmutableArray<ApiPattern> patterns, ImmutableArray<IMethodSymbol> symbols, int metadataOffset)
    {
        var tableOffset = metadataOffset + metadataReader.GetTableMetadataOffset(TableIndex.MethodDef);
        var tableRowSize = metadataReader.GetTableRowSize(TableIndex.MethodDef);
        var symbolIndex = 0;
 
        foreach (var handle in metadataReader.MethodDefinitions)
        {
            var symbol = GetSymbolWithToken(symbols, ref symbolIndex, handle);
            if (symbol == null || !IsIncluded(symbol, patterns))
            {
                var def = metadataReader.GetMethodDefinition(handle);
 
                // reduce visibility so that the method is not visible outside the assembly:
                var oldVisibility = def.Attributes & MethodAttributes.MemberAccessMask;
                var newVisibility = MethodAttributes.Private;
                if (oldVisibility == newVisibility)
                {
                    continue;
                }
 
                // Row: RvaOffset (4B), ImplAttributes (2B), Attributes (2B), ...
                var offset = tableOffset + (MetadataTokens.GetRowNumber(handle) - 1) * tableRowSize + sizeof(uint) + sizeof(ushort);
#if DEBUG
                writer.BaseStream.Position = offset;
                Debug.Assert((MethodAttributes)ReadUInt16(writer.BaseStream) == def.Attributes);
#endif
                writer.BaseStream.Position = offset;
                Debug.Assert(BitConverter.IsLittleEndian);
                writer.Write((ushort)(def.Attributes & ~MethodAttributes.MemberAccessMask | newVisibility));
            }
        }
    }
 
    private static unsafe void UpdateFieldDefinitions(BinaryWriter writer, MetadataReader metadataReader, ImmutableArray<ApiPattern> patterns, ImmutableArray<IFieldSymbol> symbols, int metadataOffset)
    {
        var tableOffset = metadataOffset + metadataReader.GetTableMetadataOffset(TableIndex.Field);
        var tableRowSize = metadataReader.GetTableRowSize(TableIndex.Field);
        var symbolIndex = 0;
 
        foreach (var handle in metadataReader.FieldDefinitions)
        {
            var symbol = GetSymbolWithToken(symbols, ref symbolIndex, handle);
            if (symbol == null || !IsIncluded(symbol, patterns))
            {
                var def = metadataReader.GetFieldDefinition(handle);
 
                // reduce visibility so that the field is not visible outside the assembly:
                var oldVisibility = def.Attributes & FieldAttributes.FieldAccessMask;
                var newVisibility = FieldAttributes.Private;
                if (oldVisibility == newVisibility)
                {
                    continue;
                }
 
                // Row: Attributes (2B), ...
                var offset = tableOffset + (MetadataTokens.GetRowNumber(handle) - 1) * tableRowSize + 0;
#if DEBUG
                writer.BaseStream.Position = offset;
                Debug.Assert((FieldAttributes)ReadUInt16(writer.BaseStream) == def.Attributes);
#endif
                writer.BaseStream.Position = offset;
                Debug.Assert(BitConverter.IsLittleEndian);
                writer.Write((ushort)(def.Attributes & ~FieldAttributes.FieldAccessMask | newVisibility));
            }
        }
    }
 
    private static uint ReadUInt32(Stream stream)
        => unchecked((uint)(stream.ReadByte() | stream.ReadByte() << 8 | stream.ReadByte() << 16 | stream.ReadByte() << 24));
 
    private static uint ReadUInt16(Stream stream)
        => unchecked((uint)(stream.ReadByte() | stream.ReadByte() << 8));
 
    private static unsafe Guid ReadGuid(Stream stream)
    {
        var buffer = new byte[sizeof(Guid)];
        Debug.Assert(stream.Read(buffer, 0, buffer.Length) == buffer.Length);
        fixed (byte* ptr = buffer)
        {
            var reader = new BlobReader(ptr, buffer.Length);
            return reader.ReadGuid();
        }
    }
 
    private static unsafe void WriteGuid(BinaryWriter writer, Guid guid)
    {
        var buffer = new byte[sizeof(Guid)];
        var blob = new BlobWriter(buffer);
        blob.WriteGuid(guid);
        writer.Write(buffer, 0, buffer.Length);
    }
 
    private static Guid CreateMvid(Stream stream)
    {
        stream.Position = 0;
        using var sha = SHA256.Create();
        return BlobContentId.FromHash(sha.ComputeHash(stream)).Guid;
    }
}