File: Analyzers\RuntimeComApiUsageWithSourceGeneratedComAnalyzer.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj (Microsoft.Interop.ComInterfaceGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
using Microsoft.CodeAnalysis.Operations;
using static Microsoft.Interop.Analyzers.AnalyzerDiagnostics;
 
namespace Microsoft.Interop.Analyzers
{
    [DiagnosticAnalyzer(LanguageNames.CSharp)]
    public class RuntimeComApiUsageWithSourceGeneratedComAnalyzer : DiagnosticAnalyzer
    {
        public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => ImmutableArray.Create(RuntimeComApisDoNotSupportSourceGeneratedCom, CastsBetweenRuntimeComAndSourceGeneratedComNotSupported);
 
        public override void Initialize(AnalysisContext context)
        {
            context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
            context.EnableConcurrentExecution();
            context.RegisterCompilationStartAction(context =>
            {
                INamedTypeSymbol? marshalType = context.Compilation.GetBestTypeByMetadataName(TypeNames.System_Runtime_InteropServices_Marshal);
                INamedTypeSymbol? generatedComClassAttribute = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComClassAttribute);
                INamedTypeSymbol? generatedComInterfaceAttribute = context.Compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute);
                INamedTypeSymbol? comObjectType = context.Compilation.GetBestTypeByMetadataName(TypeNames.System_Runtime_InteropServices_Marshalling_ComObject);
 
                List<Func<ITypeSymbol, bool>> sourceGeneratedComRecognizers = new();
                if (generatedComClassAttribute is not null)
                {
                    sourceGeneratedComRecognizers.Add(type => type.GetAttributes().Any(attr => generatedComClassAttribute.Equals(attr.AttributeClass, SymbolEqualityComparer.Default)));
                }
                if (generatedComInterfaceAttribute is not null)
                {
                    sourceGeneratedComRecognizers.Add(type => type.GetAttributes().Any(attr => generatedComInterfaceAttribute.Equals(attr.AttributeClass, SymbolEqualityComparer.Default)));
                }
                if (comObjectType is not null)
                {
                    sourceGeneratedComRecognizers.Add(type => type.Equals(comObjectType, SymbolEqualityComparer.Default));
                }
 
                if (marshalType is null || sourceGeneratedComRecognizers.Count == 0)
                {
                    return;
                }
 
                var methodsOfInterest = new Dictionary<ISymbol, ImmutableArray<Func<IInvocationOperation, (ITypeSymbol, Location)?>>>(SymbolEqualityComparer.Default);
 
                var firstArgumentTypeLookup = CreateArgumentTypeLookup(0);
 
                var firstArgumentTypeLookupOnly = ImmutableArray.Create(firstArgumentTypeLookup);
 
                methodsOfInterest.Add(marshalType.GetMembers("SetComObjectData")[0], firstArgumentTypeLookupOnly);
                methodsOfInterest.Add(marshalType.GetMembers("GetComObjectData")[0], firstArgumentTypeLookupOnly);
                methodsOfInterest.Add(marshalType.GetMembers("ReleaseComObject")[0], firstArgumentTypeLookupOnly);
                methodsOfInterest.Add(marshalType.GetMembers("FinalReleaseComObject")[0], firstArgumentTypeLookupOnly);
 
                foreach (var createAggregatedObject in marshalType.GetMembers("CreateAggregatedObject"))
                {
                    if (createAggregatedObject is IMethodSymbol { IsGenericMethod: true })
                    {
                        methodsOfInterest.Add(createAggregatedObject, ImmutableArray.Create(CreateTypeArgumentTypeLookup(0), CreateArgumentTypeLookup(1)));
                    }
                    else
                    {
                        methodsOfInterest.Add(createAggregatedObject, ImmutableArray.Create(CreateArgumentTypeLookup(1)));
                    }
                }
 
                foreach (var createWrapperOfType in marshalType.GetMembers("CreateWrapperOfType"))
                {
                    if (createWrapperOfType is IMethodSymbol { IsGenericMethod: true })
                    {
                        methodsOfInterest.Add(createWrapperOfType, ImmutableArray.Create(CreateTypeArgumentTypeLookup(0), CreateTypeArgumentTypeLookup(1), firstArgumentTypeLookup));
                    }
                    else
                    {
                        methodsOfInterest.Add(createWrapperOfType, ImmutableArray.Create(firstArgumentTypeLookup, CreateTypeOfArgumentTypeLookup(1)));
                    }
                }
 
                methodsOfInterest.Add(marshalType.GetMembers("GetTypedObjectForIUnknown")[0], ImmutableArray.Create(CreateTypeOfArgumentTypeLookup(1)));
                methodsOfInterest.Add(marshalType.GetMembers("GetIUnknownForObject")[0], firstArgumentTypeLookupOnly);
                methodsOfInterest.Add(marshalType.GetMembers("GetIDispatchForObject")[0], firstArgumentTypeLookupOnly);
 
                foreach (var getComInterfaceForObject in marshalType.GetMembers("GetComInterfaceForObject"))
                {
                    if (getComInterfaceForObject is IMethodSymbol { IsGenericMethod: true })
                    {
                        methodsOfInterest.Add(getComInterfaceForObject, ImmutableArray.Create(CreateTypeArgumentTypeLookup(0), CreateTypeArgumentTypeLookup(1), firstArgumentTypeLookup));
                    }
                    else
                    {
                        methodsOfInterest.Add(getComInterfaceForObject, ImmutableArray.Create(CreateArgumentTypeLookup(0), CreateTypeOfArgumentTypeLookup(1)));
                    }
                }
 
                context.RegisterOperationAction(context =>
                {
                    var operation = (IInvocationOperation)context.Operation;
 
                    if (methodsOfInterest.TryGetValue(operation.TargetMethod.OriginalDefinition, out ImmutableArray<Func<IInvocationOperation, (ITypeSymbol, Location)?>> discoverers))
                    {
                        foreach (Func<IInvocationOperation, (ITypeSymbol, Location)?> discoverer in discoverers)
                        {
                            var typeInfo = discoverer(operation);
                            if (typeInfo is (ITypeSymbol targetType, Location diagnosticLocation))
                            {
                                foreach (var recognizer in sourceGeneratedComRecognizers)
                                {
                                    if (recognizer(targetType))
                                    {
                                        context.ReportDiagnostic(
                                            Diagnostic.Create(
                                                RuntimeComApisDoNotSupportSourceGeneratedCom,
                                                diagnosticLocation,
                                                operation.TargetMethod.ToMinimalDisplayString(operation.SemanticModel, operation.Syntax.SpanStart),
                                                targetType.ToMinimalDisplayString(operation.SemanticModel, operation.Syntax.SpanStart)));
                                        break;
                                    }
                                }
                            }
                        }
                    }
                }, OperationKind.Invocation);
 
                bool enableGeneratedComInterfaceComImportInterop = context.Options.AnalyzerConfigOptionsProvider.GlobalOptions.TryGetValue("build_property.EnableGeneratedComInterfaceComImportInterop", out string enableSourceGeneratedBuiltInInteropOption)
                    && bool.TryParse(enableSourceGeneratedBuiltInInteropOption, out bool enableSourceGeneratedBuiltInInterop)
                    && enableSourceGeneratedBuiltInInterop;
 
                var getObjectForIUnknown = marshalType.GetMembers("GetObjectForIUnknown")[0];
 
                context.RegisterOperationAction(context =>
                {
                    var operation = (IConversionOperation)context.Operation;
 
                    if (operation.Type is INamedTypeSymbol { IsComImport: true } && !enableGeneratedComInterfaceComImportInterop)
                    {
                        IOperation operand = operation.Operand;
                        if (operand is IConversionOperation { Type.SpecialType: SpecialType.System_Object } objConversion)
                        {
                            operand = objConversion.Operand;
                        }
                        if (operand.Type is null)
                        {
                            // Some operations like the "null" literal expression don't have a type.
                            // These expressions definitely aren't a source-generated COM type, so we can skip them.
                            return;
                        }
                        foreach (var recognizer in sourceGeneratedComRecognizers)
                        {
                            if (recognizer(operand.Type))
                            {
                                context.ReportDiagnostic(
                                    Diagnostic.Create(
                                        CastsBetweenRuntimeComAndSourceGeneratedComNotSupported,
                                        operation.Syntax.GetLocation()));
                                break;
                            }
                        }
                    }
 
                    foreach (var recognizer in sourceGeneratedComRecognizers)
                    {
                        if (recognizer(operation.Type))
                        {
                            IOperation operand = operation.Operand;
                            if (operand is IConversionOperation { Type.SpecialType: SpecialType.System_Object } objConversion)
                            {
                                operand = objConversion.Operand;
                            }
                            if (operand.Type is INamedTypeSymbol { IsComImport: true } && !enableGeneratedComInterfaceComImportInterop)
                            {
                                context.ReportDiagnostic(
                                    Diagnostic.Create(
                                        CastsBetweenRuntimeComAndSourceGeneratedComNotSupported,
                                            operation.Syntax.GetLocation()));
                                break;
                            }
                            else if (operand is IInvocationOperation invocation && invocation.TargetMethod.Equals(getObjectForIUnknown, SymbolEqualityComparer.Default))
                            {
                                // The returned value from Marshal.GetObjectForIUnknown will always be a built-in COM object, which can't be cast to a source-generated COM type,
                                // even with the interop feature enabled.
                                context.ReportDiagnostic(
                                    Diagnostic.Create(
                                        CastsBetweenRuntimeComAndSourceGeneratedComNotSupported,
                                            operation.Syntax.GetLocation()));
                                break;
                            }
                        }
                    }
                }, OperationKind.Conversion);
 
                static Func<IInvocationOperation, (ITypeSymbol Type, Location location)?> CreateArgumentTypeLookup(int ordinal) => invocation => invocation.GetArgumentByOrdinal(ordinal).Value switch
                {
                    IConversionOperation conversion => (conversion.Operand.Type, conversion.Operand.Syntax.GetLocation()),
                    IOperation op => (op.Type, op.Syntax.GetLocation())
                };
 
                static Func<IInvocationOperation, (ITypeSymbol Type, Location location)?> CreateTypeArgumentTypeLookup(int ordinal) => invocation =>
                {
                    var type = invocation.TargetMethod.TypeArguments[ordinal];
 
                    var invocationSyntax = (InvocationExpressionSyntax)invocation.Syntax;
                    var expression = invocationSyntax.Expression;
 
                    Location? location = null;
 
                    if (expression.IsKind(SyntaxKind.SimpleMemberAccessExpression))
                    {
                        expression = ((MemberAccessExpressionSyntax)expression).Name;
                    }
 
                    if (expression.IsKind(SyntaxKind.GenericName))
                    {
                        location = ((GenericNameSyntax)expression).TypeArgumentList.Arguments[ordinal].GetLocation();
                    }
 
                    if (location is null)
                    {
                        // If we couldn't find the type argument in source, then it was inferred. Don't emit a warning for the inferred type argument.
                        // We'll emit a warning for the argument that was passed in instead.
                        return null;
                    }
 
                    return (type, location);
                };
 
                static Func<IInvocationOperation, (ITypeSymbol Type, Location location)?> CreateTypeOfArgumentTypeLookup(int ordinal) => invocation => invocation.GetArgumentByOrdinal(ordinal).Value switch
                {
                    ITypeOfOperation typeOf => (typeOf.TypeOperand, ((TypeOfExpressionSyntax)typeOf.Syntax).Type.GetLocation()),
                    _ => null
                };
            });
        }
    }
}