File: src\libraries\System.Runtime.InteropServices\gen\Common\ConvertToSourceGeneratedInteropFixer.cs
Web Access
Project: src\src\libraries\System.Runtime.InteropServices\gen\LibraryImportGenerator\LibraryImportGenerator.csproj (Microsoft.Interop.LibraryImportGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Runtime.InteropServices;
 
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Editing;
 
namespace Microsoft.Interop.Analyzers
{
    public abstract class ConvertToSourceGeneratedInteropFixer : CodeFixProvider
    {
        public sealed override FixAllProvider GetFixAllProvider() => CustomFixAllProvider.Instance;
 
        protected abstract string BaseEquivalenceKey { get; }
 
        protected virtual IEnumerable<ConvertToSourceGeneratedInteropFix> CreateAllFixesForDiagnosticOptions(SyntaxNode node, ImmutableDictionary<string, Option> options)
        {
            // By default, we only have one fix for the options specified from the diagnostic.
            yield return new ConvertToSourceGeneratedInteropFix(CreateFixForSelectedOptions(node, options), options);
        }
 
        protected abstract Func<DocumentEditor, CancellationToken, Task> CreateFixForSelectedOptions(SyntaxNode node, ImmutableDictionary<string, Option> selectedOptions);
 
        protected abstract string GetDiagnosticTitle(ImmutableDictionary<string, Option> selectedOptions);
 
        /// <summary>
        /// A basic string-serializable option mechanism to enable the same fixer to be used for diagnostics with slightly different properties.
        /// </summary>
        public abstract record Option
        {
            private protected Option()
            {
            }
 
            public const string AllowUnsafe = nameof(AllowUnsafe);
            public const string MayRequireAdditionalWork = nameof(MayRequireAdditionalWork);
 
            public sealed record Bool(bool Value) : Option
            {
                public override string ToString() => $"b:{Value};";
            }
 
            public sealed record String(string Value) : Option
            {
                public override string ToString() => $"s:{Value};";
            }
 
            public static ImmutableDictionary<string, Option> ParseOptionsFromEquivalenceKey(string equivalenceKey)
            {
                ImmutableDictionary<string, Option>.Builder options = ImmutableDictionary.CreateBuilder<string, Option>();
                // The first ';' separates the base equivalence key from the options
                foreach (string option in equivalenceKey.Split(';').Skip(0))
                {
                    string[] optionKeyAndValue = option.Split('=');
                    if (optionKeyAndValue.Length != 2 || optionKeyAndValue[1].Length < 3)
                    {
                        continue;
                    }
 
                    string type = optionKeyAndValue[1].Substring(0, 2);
                    string value = optionKeyAndValue[1].Substring(2);
                    if (type == "b:")
                    {
                        options.Add(optionKeyAndValue[0], new Bool(bool.Parse(value)));
                    }
                    else if (type == "s:")
                    {
                        options.Add(optionKeyAndValue[0], new String(value));
                    }
                }
 
                return options.ToImmutable();
            }
 
            public static string CreateEquivalenceKeyFromOptions(string baseEquivalenceKey, ImmutableDictionary<string, Option> options)
            {
                StringBuilder equivalenceKeyBuilder = new(baseEquivalenceKey);
                foreach (var option in options.OrderBy(item => item.Key))
                {
                    equivalenceKeyBuilder.Append($";{option.Key}={option.Value}");
                }
                return equivalenceKeyBuilder.ToString();
            }
        }
 
        protected abstract ImmutableDictionary<string, Option> ParseOptionsFromDiagnostic(Diagnostic diagnostic);
 
        protected abstract ImmutableDictionary<string, Option> CombineOptions(ImmutableDictionary<string, Option> fixAllOptions, ImmutableDictionary<string, Option> diagnosticOptions);
 
        private ImmutableDictionary<string, Option> GetOptionsForIndividualFix(ImmutableDictionary<string, Option> fixAllOptions, Diagnostic diagnostic)
        {
            return CombineOptions(fixAllOptions, ParseOptionsFromDiagnostic(diagnostic));
        }
 
        private static async Task<Solution> ApplyActionAndEnableUnsafe(Solution solution, DocumentId documentId, Func<DocumentEditor, CancellationToken, Task> documentBasedFix, CancellationToken ct)
        {
            var editor = new SolutionEditor(solution);
            var docEditor = await editor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false);
            await documentBasedFix(docEditor, ct).ConfigureAwait(false);
 
            var docProjectId = documentId.ProjectId;
            var updatedSolution = editor.GetChangedSolution();
            return AddUnsafe(updatedSolution, updatedSolution.GetProject(docProjectId));
        }
 
        public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
        {
            // Get the syntax root and semantic model
            Document doc = context.Document;
            SyntaxNode? root = await doc.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
            if (root == null)
                return;
 
            SyntaxNode node = root.FindNode(context.Span);
 
            bool enableUnsafe = doc.Project.CompilationOptions is CSharpCompilationOptions { AllowUnsafe: false };
            foreach (Diagnostic diagnostic in context.Diagnostics)
            {
                var options = ParseOptionsFromDiagnostic(diagnostic);
                foreach (var fix in CreateAllFixesForDiagnosticOptions(node, options))
                {
                    if (enableUnsafe)
                    {
                        var selectedOptions = fix.SelectedOptions.Add(Option.AllowUnsafe, new Option.Bool(true));
 
                        context.RegisterCodeFix(
                            CodeAction.Create(
                                GetDiagnosticTitle(selectedOptions),
                                ct => ApplyActionAndEnableUnsafe(doc.Project.Solution, doc.Id, fix.ApplyFix, ct),
                                Option.CreateEquivalenceKeyFromOptions(BaseEquivalenceKey, selectedOptions)),
                            diagnostic);
                    }
                    else
                    {
                        context.RegisterCodeFix(
                            CodeAction.Create(
                                GetDiagnosticTitle(fix.SelectedOptions),
                                async ct =>
                                {
                                    DocumentEditor editor = await DocumentEditor.CreateAsync(doc, ct).ConfigureAwait(false);
 
                                    await fix.ApplyFix(editor, ct).ConfigureAwait(false);
 
                                    return editor.GetChangedDocument();
                                },
                                Option.CreateEquivalenceKeyFromOptions(BaseEquivalenceKey, fix.SelectedOptions)),
                            diagnostic);
                    }
                }
            }
        }
 
        protected record struct ConvertToSourceGeneratedInteropFix(Func<DocumentEditor, CancellationToken, Task> ApplyFix, ImmutableDictionary<string, Option> SelectedOptions);
 
        private sealed class CustomFixAllProvider : FixAllProvider
        {
            public static readonly CustomFixAllProvider Instance = new();
            public override async Task<CodeAction?> GetFixAsync(FixAllContext fixAllContext)
            {
                var options = Option.ParseOptionsFromEquivalenceKey(fixAllContext.CodeActionEquivalenceKey);
                var codeFixProvider = (ConvertToSourceGeneratedInteropFixer)fixAllContext.CodeFixProvider;
 
                bool addUnsafe = options.TryGetValue(Option.AllowUnsafe, out Option allowUnsafeOption) && allowUnsafeOption is Option.Bool(true);
 
                bool includeFixesWithAdditionalWork = options.TryGetValue(Option.MayRequireAdditionalWork, out Option includeFixesWithAdditionalWorkOption) && includeFixesWithAdditionalWorkOption is Option.Bool(true);
                ImmutableArray<Diagnostic> diagnosticsInScope = await fixAllContext.GetDiagnosticsInScopeAsync().ConfigureAwait(false);
 
 
                return CodeAction.Create(codeFixProvider.GetDiagnosticTitle(options),
                    async ct =>
                    {
                        HashSet<Project> projectsToAddUnsafe = new();
                        SolutionEditor solutionEditor = new SolutionEditor(fixAllContext.Solution);
 
                        foreach (var diagnostic in diagnosticsInScope)
                        {
                            bool mayRequireAdditionalWork = diagnostic.Properties.TryGetValue(Option.MayRequireAdditionalWork, out string mayRequireAdditionalWorkString)
                                && bool.TryParse(mayRequireAdditionalWorkString, out bool mayRequireAdditionalWorkValue)
                                ? mayRequireAdditionalWorkValue
                                : false;
                            if (mayRequireAdditionalWork && !includeFixesWithAdditionalWork)
                            {
                                // Don't fix any diagnostics that require additional work if the "fix all" command wasn't triggered from a location
                                // that was able to warn the user that additional work may be required.
                                continue;
                            }
                            DocumentId documentId = solutionEditor.OriginalSolution.GetDocumentId(diagnostic.Location.SourceTree)!;
                            DocumentEditor editor = await solutionEditor.GetDocumentEditorAsync(documentId, ct).ConfigureAwait(false);
                            SyntaxNode root = await diagnostic.Location.SourceTree.GetRootAsync(ct).ConfigureAwait(false);
 
                            SyntaxNode node = root.FindNode(diagnostic.Location.SourceSpan);
 
                            var documentBasedFix = codeFixProvider.CreateFixForSelectedOptions(node, codeFixProvider.GetOptionsForIndividualFix(options, diagnostic));
 
                            await documentBasedFix(editor, ct).ConfigureAwait(false);
 
                            // Record this project as a project we need to allow unsafe blocks in.
                            projectsToAddUnsafe.Add(solutionEditor.OriginalSolution.GetDocument(documentId).Project);
                        }
 
                        Solution solutionWithUpdatedSources = solutionEditor.GetChangedSolution();
 
                        if (addUnsafe)
                        {
                            foreach (var project in projectsToAddUnsafe)
                            {
                                solutionWithUpdatedSources = AddUnsafe(solutionWithUpdatedSources, project);
                            }
                        }
                        return solutionWithUpdatedSources;
                    },
                    equivalenceKey: fixAllContext.CodeActionEquivalenceKey);
            }
        }
 
        private static Solution AddUnsafe(Solution solution, Project project)
        {
            return solution.WithProjectCompilationOptions(project.Id, ((CSharpCompilationOptions)project.CompilationOptions).WithAllowUnsafe(true));
        }
 
        protected static void MakeNodeParentsPartial(DocumentEditor editor, SyntaxNode syntax)
        {
            for (SyntaxNode? node = syntax.Parent; node is not null; node = node.Parent)
            {
                editor.ReplaceNode(node, (node, gen) => gen.WithModifiers(node, gen.GetModifiers(node).WithPartial(true)));
            }
        }
 
        protected static SyntaxNode AddExplicitDefaultBoolMarshalling(SyntaxGenerator generator, IMethodSymbol methodSymbol, SyntaxNode generatedDeclaration, string unmanagedTypeMemberIdentifier)
        {
            foreach (IParameterSymbol parameter in methodSymbol.Parameters)
            {
                if (parameter.Type.SpecialType == SpecialType.System_Boolean
                    && !parameter.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_InteropServices_MarshalAsAttribute))
                {
                    SyntaxNode generatedParameterSyntax = generator.GetParameters(generatedDeclaration)[parameter.Ordinal];
                    generatedDeclaration = generator.ReplaceNode(generatedDeclaration, generatedParameterSyntax, generator.AddAttributes(generatedParameterSyntax,
                                    GenerateMarshalAsUnmanagedTypeBoolAttribute(generator, unmanagedTypeMemberIdentifier)));
                }
            }
 
            if (methodSymbol.ReturnType.SpecialType == SpecialType.System_Boolean
                && !methodSymbol.GetReturnTypeAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_InteropServices_MarshalAsAttribute))
            {
                generatedDeclaration = generator.AddReturnAttributes(generatedDeclaration,
                    GenerateMarshalAsUnmanagedTypeBoolAttribute(generator, unmanagedTypeMemberIdentifier));
            }
 
            return generatedDeclaration;
 
 
            static SyntaxNode GenerateMarshalAsUnmanagedTypeBoolAttribute(SyntaxGenerator generator, string unmanagedTypeMemberIdentifier)
                 => generator.Attribute(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute,
                     generator.AttributeArgument(
                         generator.MemberAccessExpression(
                             generator.DottedName(TypeNames.System_Runtime_InteropServices_UnmanagedType),
                             generator.IdentifierName(unmanagedTypeMemberIdentifier))));
        }
 
        protected static SyntaxNode AddHResultStructAsErrorMarshalling(SyntaxGenerator generator, IMethodSymbol methodSymbol, SyntaxNode generatedDeclaration)
        {
            if (methodSymbol.ReturnType is { TypeKind: TypeKind.Struct }
                && IsHResultLikeType(methodSymbol.ReturnType)
                && !methodSymbol.GetReturnTypeAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_InteropServices_MarshalAsAttribute))
            {
                generatedDeclaration = generator.AddReturnAttributes(generatedDeclaration,
                    GeneratedMarshalAsUnmanagedTypeErrorAttribute(generator));
            }
 
            return generatedDeclaration;
 
 
            static bool IsHResultLikeType(ITypeSymbol type)
            {
                string typeName = type.Name;
                return typeName.Equals("hr", StringComparison.OrdinalIgnoreCase)
                    || typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase);
            }
 
            // MarshalAs(UnmanagedType.Error)
            static SyntaxNode GeneratedMarshalAsUnmanagedTypeErrorAttribute(SyntaxGenerator generator)
                 => generator.Attribute(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute,
                     generator.AttributeArgument(
                         generator.MemberAccessExpression(
                             generator.DottedName(TypeNames.System_Runtime_InteropServices_UnmanagedType),
                             generator.IdentifierName(nameof(UnmanagedType.Error)))));
        }
    }
}