File: Editing\SymbolEditorExtensions.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
 
namespace Microsoft.CodeAnalysis.Editing;
 
public static class SymbolEditorExtensions
{
    /// <summary>
    /// Gets the reference to the declaration of the base or interface type as part of the symbol's declaration. 
    /// </summary>
    public static async Task<SyntaxNode> GetBaseOrInterfaceDeclarationReferenceAsync(
        this SymbolEditor editor,
        ISymbol symbol,
        ITypeSymbol baseOrInterfaceType,
        CancellationToken cancellationToken = default)
    {
        if (baseOrInterfaceType == null)
        {
            throw new ArgumentNullException(nameof(baseOrInterfaceType));
        }
 
        if (baseOrInterfaceType.TypeKind != TypeKind.Error)
        {
            baseOrInterfaceType = (ITypeSymbol)(await editor.GetCurrentSymbolAsync(baseOrInterfaceType, cancellationToken).ConfigureAwait(false));
        }
 
        // look for the base or interface declaration in all declarations of the symbol
        var currentDecls = await editor.GetCurrentDeclarationsAsync(symbol, cancellationToken).ConfigureAwait(false);
 
        foreach (var decl in currentDecls)
        {
            var doc = editor.OriginalSolution.GetDocument(decl.SyntaxTree);
            var model = await doc.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var gen = SyntaxGenerator.GetGenerator(doc);
 
            var typeRef = gen.GetBaseAndInterfaceTypes(decl).FirstOrDefault(r => model.GetTypeInfo(r, cancellationToken).Type.Equals(baseOrInterfaceType));
            if (typeRef != null)
            {
                return typeRef;
            }
        }
 
        return null;
    }
 
    /// <summary>
    /// Changes the base type of the symbol.
    /// </summary>
    public static async Task<ISymbol> SetBaseTypeAsync(
        this SymbolEditor editor,
        INamedTypeSymbol symbol,
        Func<SyntaxGenerator, SyntaxNode> getNewBaseType,
        CancellationToken cancellationToken = default)
    {
        var baseType = symbol.BaseType;
 
        if (baseType != null)
        {
            // find existing declaration of the base type
            var typeRef = await editor.GetBaseOrInterfaceDeclarationReferenceAsync(symbol, baseType, cancellationToken).ConfigureAwait(false);
            if (typeRef != null)
            {
                return await editor.EditOneDeclarationAsync(
                    symbol,
                    typeRef.GetLocation(),
                    (e, d) => e.ReplaceNode(typeRef, getNewBaseType(e.Generator)),
                    cancellationToken).ConfigureAwait(false);
            }
        }
 
        // couldn't find the existing reference to change, so add it to one of the declarations
        return await editor.EditOneDeclarationAsync(symbol, (e, decl) =>
        {
            var newBaseType = getNewBaseType(e.Generator);
            if (newBaseType != null)
            {
                e.ReplaceNode(decl, (d, g) => g.AddBaseType(d, newBaseType));
            }
        }, cancellationToken).ConfigureAwait(false);
    }
 
    /// <summary>
    /// Changes the base type of the symbol.
    /// </summary>
    public static Task<ISymbol> SetBaseTypeAsync(
        this SymbolEditor editor,
        INamedTypeSymbol symbol,
        ITypeSymbol newBaseType,
        CancellationToken cancellationToken = default)
    {
        return editor.SetBaseTypeAsync(symbol, g => newBaseType != null ? g.TypeExpression(newBaseType) : null, cancellationToken);
    }
}