File: Analyzers\ComInterfaceGeneratorDiagnosticsAnalyzer.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj (Microsoft.Interop.ComInterfaceGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
 
namespace Microsoft.Interop.Analyzers
{
    [DiagnosticAnalyzer(LanguageNames.CSharp)]
    public class ComInterfaceGeneratorDiagnosticsAnalyzer : DiagnosticAnalyzer
    {
        public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } =
            ImmutableArray.Create(
                // Interface-level diagnostics
                GeneratorDiagnostics.RequiresAllowUnsafeBlocks,
                GeneratorDiagnostics.InvalidAttributedInterfaceGenericNotSupported,
                GeneratorDiagnostics.InvalidAttributedInterfaceMissingPartialModifiers,
                GeneratorDiagnostics.InvalidAttributedInterfaceNotAccessible,
                GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute,
                GeneratorDiagnostics.InvalidStringMarshallingMismatchBetweenBaseAndDerived,
                GeneratorDiagnostics.InvalidOptionsOnInterface,
                GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnInterface,
                GeneratorDiagnostics.InvalidExceptionToUnmanagedMarshallerType,
                GeneratorDiagnostics.StringMarshallingCustomTypeNotAccessibleByGeneratedCode,
                GeneratorDiagnostics.ExceptionToUnmanagedMarshallerNotAccessibleByGeneratedCode,
                GeneratorDiagnostics.MultipleComInterfaceBaseTypes,
                GeneratorDiagnostics.BaseInterfaceIsNotGenerated,
                GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly,
                // Method-level diagnostics
                GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface,
                GeneratorDiagnostics.InstancePropertyDeclaredInInterface,
                GeneratorDiagnostics.InstanceEventDeclaredInInterface,
                GeneratorDiagnostics.CannotAnalyzeMethodPattern,
                GeneratorDiagnostics.CannotAnalyzeInterfacePattern,
                // Stub-level diagnostics
                GeneratorDiagnostics.ConfigurationNotSupported,
                GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnMethod,
                GeneratorDiagnostics.ParameterTypeNotSupported,
                GeneratorDiagnostics.ReturnTypeNotSupported,
                GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails,
                GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails,
                GeneratorDiagnostics.ParameterConfigurationNotSupported,
                GeneratorDiagnostics.ReturnConfigurationNotSupported,
                GeneratorDiagnostics.MarshalAsParameterConfigurationNotSupported,
                GeneratorDiagnostics.MarshalAsReturnConfigurationNotSupported,
                GeneratorDiagnostics.ConfigurationValueNotSupported,
                GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported,
                GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo,
                GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo,
                GeneratorDiagnostics.ComMethodManagedReturnWillBeOutVariable,
                GeneratorDiagnostics.HResultTypeWillBeTreatedAsStruct,
                GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallOutParam,
                GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue,
                GeneratorDiagnostics.InvalidExceptionMarshallingConfiguration,
                GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices);
 
        public override void Initialize(AnalysisContext context)
        {
            context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
            context.EnableConcurrentExecution();
            context.RegisterCompilationStartAction(compilationContext =>
            {
                INamedTypeSymbol? generatedComInterfaceAttrType = compilationContext.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute);
                if (generatedComInterfaceAttrType is null)
                    return;
 
                StubEnvironment env = new StubEnvironment(
                    compilationContext.Compilation,
                    compilationContext.Compilation.GetEnvironmentFlags());
 
                // Cache ComInterfaceInfo per symbol for deduplication when multiple interfaces share the same base.
                // This avoids recomputing the same interface info when traversing the ancestor chain of different derived interfaces.
                var interfaceInfoCache = new ConcurrentDictionary<INamedTypeSymbol, DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)>>(SymbolEqualityComparer.Default);
 
                compilationContext.RegisterSymbolAction(symbolContext =>
                {
                    INamedTypeSymbol typeSymbol = (INamedTypeSymbol)symbolContext.Symbol;
                    if (typeSymbol.TypeKind != TypeKind.Interface)
                        return;
 
                    // Find the [GeneratedComInterface] attribute and the syntax node of the declaring partial interface
                    InterfaceDeclarationSyntax? ifaceSyntax = null;
                    foreach (AttributeData attr in typeSymbol.GetAttributes())
                    {
                        if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType))
                        {
                            ifaceSyntax = FindInterfaceSyntaxWithAttribute(typeSymbol, generatedComInterfaceAttrType, symbolContext.CancellationToken);
                            break;
                        }
                    }
 
                    if (ifaceSyntax is null)
                        return;
 
                    AnalyzeInterface(symbolContext, typeSymbol, ifaceSyntax, env, generatedComInterfaceAttrType, interfaceInfoCache);
                }, SymbolKind.NamedType);
            });
        }
 
        private static void AnalyzeInterface(
            SymbolAnalysisContext context,
            INamedTypeSymbol typeSymbol,
            InterfaceDeclarationSyntax ifaceSyntax,
            StubEnvironment env,
            INamedTypeSymbol generatedComInterfaceAttrType,
            ConcurrentDictionary<INamedTypeSymbol, DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)>> interfaceInfoCache)
        {
            CancellationToken ct = context.CancellationToken;
 
            // Get or compute ComInterfaceInfo for this interface (cached to avoid recomputing for shared base interfaces)
            DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)> ciiResult = interfaceInfoCache.GetOrAdd(
                typeSymbol, _ => ComInterfaceInfo.From(typeSymbol, ifaceSyntax, env, ct));
 
            // Report interface-level diagnostics
            if (ciiResult.HasDiagnostic)
            {
                foreach (DiagnosticInfo diag in ciiResult.Diagnostics)
                    context.ReportDiagnostic(diag.ToDiagnostic());
            }
 
            if (!ciiResult.HasValue)
                return;
 
            (ComInterfaceInfo cii, INamedTypeSymbol _) = ciiResult.Value;
 
            // Build the context chain for this interface (ancestors first, then this interface) to detect
            // BaseInterfaceIsNotGenerated. Note: vtable indices don't need to be correct here since we're
            // only reporting diagnostics, not emitting code.
            ImmutableArray<ComInterfaceInfo> contextChain = BuildContextChain(
                typeSymbol, cii, env, generatedComInterfaceAttrType, interfaceInfoCache, ct);
 
            ImmutableArray<DiagnosticOr<ComInterfaceContext>> contextResults = ComInterfaceContext.GetContexts(contextChain, ct);
            // BuildContextChain always appends cii as the last element, so contextResults is always non-empty.
            Debug.Assert(contextResults.Length > 0);
            // The last entry corresponds to this interface
            DiagnosticOr<ComInterfaceContext> thisContextResult = contextResults[contextResults.Length - 1];
            if (thisContextResult.HasDiagnostic)
            {
                foreach (DiagnosticInfo diag in thisContextResult.Diagnostics)
                    context.ReportDiagnostic(diag.ToDiagnostic());
                return;
            }
 
            // Process each method declared on this interface
            foreach (DiagnosticOr<(ComMethodInfo ComMethod, IMethodSymbol Symbol)> methodResult in
                ComMethodInfo.GetMethodsFromInterface((cii, typeSymbol), ct))
            {
                if (methodResult.HasDiagnostic)
                {
                    foreach (DiagnosticInfo diag in methodResult.Diagnostics)
                        context.ReportDiagnostic(diag.ToDiagnostic());
                }
 
                if (!methodResult.HasValue)
                    continue;
 
                (ComMethodInfo comMethod, IMethodSymbol methodSymbol) = methodResult.Value;
 
                if (comMethod.Syntax is null)
                    continue; // externally-defined method; no stub diagnostics to report
 
                // Note: the vtable index passed here (0) doesn't need to be the correct vtable slot since
                // we're only reporting diagnostics, not emitting code.
                IncrementalMethodStubGenerationContext stubContext = ComInterfaceGenerator.CalculateStubInformation(
                    comMethod.Syntax,
                    methodSymbol,
                    0,
                    env,
                    cii,
                    ct);
 
                if (stubContext is not SourceAvailableIncrementalMethodStubGenerationContext srcCtx)
                    continue;
 
                ImmutableArray<DiagnosticInfo> managedToNativeDiags = ImmutableArray<DiagnosticInfo>.Empty;
                ImmutableArray<DiagnosticInfo> nativeToManagedDiags = ImmutableArray<DiagnosticInfo>.Empty;
 
                if (srcCtx.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
                {
                    (_, managedToNativeDiags) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(srcCtx, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
                }
                if (srcCtx.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
                {
                    (_, nativeToManagedDiags) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(srcCtx, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
                }
 
                // Deduplicate diagnostics reported for both directions (matching original generator behavior)
                foreach (DiagnosticInfo diag in managedToNativeDiags.Union(nativeToManagedDiags))
                    context.ReportDiagnostic(diag.ToDiagnostic());
            }
        }
 
        /// <summary>
        /// Builds the ancestor chain for context creation (root-to-parent order, then the current interface last).
        /// Only successfully-computed ancestors are included; if an ancestor fails, the chain stops there and
        /// <see cref="ComInterfaceContext.GetContexts"/> will emit <see cref="GeneratorDiagnostics.BaseInterfaceIsNotGenerated"/>
        /// for the next derived interface.
        /// </summary>
        private static ImmutableArray<ComInterfaceInfo> BuildContextChain(
            INamedTypeSymbol typeSymbol,
            ComInterfaceInfo cii,
            StubEnvironment env,
            INamedTypeSymbol generatedComInterfaceAttrType,
            ConcurrentDictionary<INamedTypeSymbol, DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)>> interfaceInfoCache,
            CancellationToken ct)
        {
            // For external base interfaces, CreateInterfaceInfoForBaseInterfacesInOtherCompilations already
            // provides the full ancestor chain ordered from root to immediate parent.
            ImmutableArray<(ComInterfaceInfo, INamedTypeSymbol)> externalBases =
                ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(typeSymbol);
            if (!externalBases.IsEmpty)
            {
                return [.. externalBases.Select(static e => e.Item1), cii];
            }
 
            // Traverse same-compilation base interfaces, inserting at the front to get root-first order.
            var ancestorChain = new List<ComInterfaceInfo>();
            INamedTypeSymbol current = typeSymbol;
 
            while (true)
            {
                INamedTypeSymbol? baseSymbol = FindBaseComInterfaceSymbol(current, generatedComInterfaceAttrType);
                if (baseSymbol is null)
                    break;
 
                if (!SymbolEqualityComparer.Default.Equals(baseSymbol.ContainingAssembly, typeSymbol.ContainingAssembly))
                {
                    // Switch to external base handling
                    ImmutableArray<(ComInterfaceInfo, INamedTypeSymbol)> externalInfos =
                        ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(current);
                    ancestorChain.InsertRange(0, externalInfos.Select(static e => e.Item1));
                    break;
                }
 
                // Get or compute the base's ComInterfaceInfo (using the cache for deduplication)
                DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)> baseResult = interfaceInfoCache.GetOrAdd(
                    baseSymbol,
                    sym =>
                    {
                        InterfaceDeclarationSyntax? baseSyntax = FindInterfaceSyntaxWithAttribute(sym, generatedComInterfaceAttrType, ct);
                        if (baseSyntax is null)
                            return DiagnosticOr<(ComInterfaceInfo, INamedTypeSymbol)>.From(
                                DiagnosticInfo.Create(GeneratorDiagnostics.CannotAnalyzeInterfacePattern, sym.Locations.FirstOrDefault() ?? Location.None, sym.Name));
                        return ComInterfaceInfo.From(sym, baseSyntax, env, ct);
                    });
 
                if (!baseResult.HasValue)
                    break; // Base failed — GetContexts will report BaseInterfaceIsNotGenerated for this interface
 
                ancestorChain.Insert(0, baseResult.Value.Item1);
                current = baseSymbol;
            }
 
            ancestorChain.Add(cii);
            return ancestorChain.ToImmutableArray();
        }
 
        /// <summary>
        /// Finds the first direct base interface of <paramref name="typeSymbol"/> that has the <see cref="TypeNames.GeneratedComInterfaceAttribute"/>.
        /// </summary>
        private static INamedTypeSymbol? FindBaseComInterfaceSymbol(INamedTypeSymbol typeSymbol, INamedTypeSymbol generatedComInterfaceAttrType)
        {
            foreach (INamedTypeSymbol iface in typeSymbol.Interfaces)
            {
                foreach (AttributeData attr in iface.GetAttributes())
                {
                    if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType))
                        return iface;
                }
            }
            return null;
        }
 
        /// <summary>
        /// Finds the <see cref="InterfaceDeclarationSyntax"/> for <paramref name="symbol"/> that carries the <see cref="TypeNames.GeneratedComInterfaceAttribute"/>.
        /// For partial types, this is the specific partial declaration that has the attribute.
        /// </summary>
        private static InterfaceDeclarationSyntax? FindInterfaceSyntaxWithAttribute(
            INamedTypeSymbol symbol,
            INamedTypeSymbol generatedComInterfaceAttrType,
            CancellationToken ct)
        {
            foreach (AttributeData attr in symbol.GetAttributes())
            {
                if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, generatedComInterfaceAttrType))
                {
                    SyntaxReference? attrSyntaxRef = attr.ApplicationSyntaxReference;
                    if (attrSyntaxRef is not null)
                    {
                        SyntaxNode attrSyntax = attrSyntaxRef.GetSyntax(ct);
                        // Attribute syntax structure: AttributeSyntax -> AttributeListSyntax -> InterfaceDeclarationSyntax
                        if (attrSyntax.Parent?.Parent is InterfaceDeclarationSyntax ifaceSyntax)
                            return ifaceSyntax;
                    }
                    foreach (SyntaxReference syntaxRef in symbol.DeclaringSyntaxReferences)
                    {
                        if (syntaxRef.GetSyntax(ct) is InterfaceDeclarationSyntax ifaceSyntax)
                            return ifaceSyntax;
                    }
                    break;
                }
            }
            return null;
        }
    }
}