File: src\Analyzers\CSharp\CodeFixes\Iterator\CSharpChangeToIEnumerableCodeFixProvider.cs
Web Access
Project: src\src\CodeStyle\CSharp\CodeFixes\Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes.csproj (Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CodeFixes.Iterator;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Iterator;
 
[ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.ChangeReturnType), Shared]
[method: ImportingConstructor]
[method: SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
internal sealed class CSharpChangeToIEnumerableCodeFixProvider() : AbstractIteratorCodeFixProvider
{
    /// <summary>
    /// CS1624: The body of 'x' cannot be an iterator block because 'y' is not an iterator interface type
    /// </summary>
    private const string CS1624 = nameof(CS1624);
 
    public override ImmutableArray<string> FixableDiagnosticIds => [CS1624];
 
    protected override async Task<CodeAction?> GetCodeFixAsync(SyntaxNode root, SyntaxNode node, Document document, Diagnostic diagnostics, CancellationToken cancellationToken)
    {
        var model = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var methodSymbol = model.GetDeclaredSymbol(node, cancellationToken) as IMethodSymbol;
        // IMethod symbol can either be a regular method or an accessor
        if (methodSymbol?.ReturnType == null || methodSymbol.ReturnsVoid)
        {
            return null;
        }
 
        var type = methodSymbol.ReturnType;
        if (!TryGetIEnumerableSymbols(model, out var ienumerableSymbol, out var ienumerableGenericSymbol))
        {
            return null;
        }
 
        if (type.InheritsFromOrEquals(ienumerableSymbol, includeInterfaces: true))
        {
            var arity = type.GetArity();
            if (arity == 1)
            {
                var typeArg = type.GetTypeArguments().First();
                ienumerableGenericSymbol = ienumerableGenericSymbol.Construct(typeArg);
            }
            else if (arity == 0 && type is IArrayTypeSymbol arrayType)
            {
                ienumerableGenericSymbol = ienumerableGenericSymbol.Construct(arrayType.ElementType);
            }
            else
            {
                return null;
            }
        }
        else
        {
            ienumerableGenericSymbol = ienumerableGenericSymbol.Construct(type);
        }
 
        var newReturnType = ienumerableGenericSymbol.GenerateTypeSyntax();
        Document? newDocument = null;
        var newMethodDeclarationSyntax = (node as MethodDeclarationSyntax)?.WithReturnType(newReturnType);
        if (newMethodDeclarationSyntax != null)
        {
            newDocument = document.WithSyntaxRoot(root.ReplaceNode(node, newMethodDeclarationSyntax));
        }
 
        var newOperator = (node as OperatorDeclarationSyntax)?.WithReturnType(newReturnType);
        if (newOperator != null)
        {
            newDocument = document.WithSyntaxRoot(root.ReplaceNode(node, newOperator));
        }
 
        var oldAccessor = node.Parent?.Parent as PropertyDeclarationSyntax;
        if (oldAccessor != null)
        {
            newDocument = document.WithSyntaxRoot(root.ReplaceNode(oldAccessor, oldAccessor.WithType(newReturnType)));
        }
 
        var oldIndexer = node.Parent?.Parent as IndexerDeclarationSyntax;
        if (oldIndexer != null)
        {
            newDocument = document.WithSyntaxRoot(root.ReplaceNode(oldIndexer, oldIndexer.WithType(newReturnType)));
        }
 
        if (newDocument == null)
        {
            return null;
        }
 
        var title = string.Format(CSharpCodeFixesResources.Change_return_type_from_0_to_1,
            type.ToMinimalDisplayString(model, node.SpanStart),
            ienumerableGenericSymbol.ToMinimalDisplayString(model, node.SpanStart));
 
        return CodeAction.Create(title, _ => Task.FromResult(newDocument), title);
    }
 
    private static bool TryGetIEnumerableSymbols(
        SemanticModel model,
        [NotNullWhen(true)] out INamedTypeSymbol? ienumerableSymbol,
        [NotNullWhen(true)] out INamedTypeSymbol? ienumerableGenericSymbol)
    {
        ienumerableSymbol = model.Compilation.GetTypeByMetadataName(typeof(IEnumerable).FullName!);
        ienumerableGenericSymbol = model.Compilation.GetTypeByMetadataName(typeof(IEnumerable<>).FullName!);
 
        return ienumerableGenericSymbol != null && ienumerableSymbol != null;
    }
}