|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Composition;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp.Formatting;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.CodeRefactorings.EnableNullable;
using static CSharpSyntaxTokens;
[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = PredefinedCodeRefactoringProviderNames.EnableNullable), Shared]
internal sealed partial class EnableNullableCodeRefactoringProvider : CodeRefactoringProvider
{
private static readonly Func<DirectiveTriviaSyntax, bool> s_isNullableDirectiveTriviaPredicate =
directive => directive.IsKind(SyntaxKind.NullableDirectiveTrivia);
[ImportingConstructor]
[SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
public EnableNullableCodeRefactoringProvider()
{
}
public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
{
var (document, textSpan, cancellationToken) = context;
if (!textSpan.IsEmpty)
return;
if (!ShouldOfferRefactoring(document.Project))
{
return;
}
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var token = root.FindToken(textSpan.Start, findInsideTrivia: true);
if (token.IsKind(SyntaxKind.EndOfDirectiveToken))
token = root.FindToken(textSpan.Start - 1, findInsideTrivia: true);
if (token.Kind() is not (SyntaxKind.EnableKeyword or SyntaxKind.RestoreKeyword or SyntaxKind.DisableKeyword or SyntaxKind.NullableKeyword or SyntaxKind.HashToken) ||
token.Parent is not NullableDirectiveTriviaSyntax nullableDirectiveTrivia)
{
return;
}
context.RegisterRefactoring(new CustomCodeAction(
(purpose, _, cancellationToken) => EnableNullableReferenceTypesAsync(document.Project, purpose, cancellationToken)));
}
private static bool ShouldOfferRefactoring(Project project)
=> project is
{
ParseOptions: CSharpParseOptions { LanguageVersion: >= LanguageVersion.CSharp8 },
CompilationOptions.NullableContextOptions: NullableContextOptions.Disable,
};
private static async Task<Solution> EnableNullableReferenceTypesAsync(
Project project, CodeActionPurpose purpose, CancellationToken cancellationToken)
{
var solution = project.Solution;
var updatedDocumentRoots = await ProducerConsumer<(DocumentId documentId, SyntaxNode newRoot)>.RunParallelAsync(
source: project.Documents,
produceItems: static async (document, callback, _, cancellationToken) =>
{
if (await document.IsGeneratedCodeAsync(cancellationToken).ConfigureAwait(false))
return;
var updatedDocumentRoot = await EnableNullableReferenceTypesAsync(document, cancellationToken).ConfigureAwait(false);
callback((document.Id, updatedDocumentRoot));
},
args: 0,
cancellationToken).ConfigureAwait(false);
solution = solution.WithDocumentSyntaxRoots(updatedDocumentRoots);
if (purpose is CodeActionPurpose.Apply)
{
var compilationOptions = (CSharpCompilationOptions)project.CompilationOptions!;
solution = solution.WithProjectCompilationOptions(project.Id, compilationOptions.WithNullableContextOptions(NullableContextOptions.Enable));
}
return solution;
}
private static async Task<SyntaxNode> EnableNullableReferenceTypesAsync(Document document, CancellationToken cancellationToken)
{
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var firstToken = GetFirstTokenOfInterest(root);
if (firstToken.IsKind(SyntaxKind.None))
{
// The document has no content, so it's fine to change the nullable context
return root;
}
// Update #nullable directives that already exist in the document
(root, firstToken) = RewriteExistingDirectives(root, firstToken);
// Update existing documents to retain their original semantics
//
// * Add '#nullable disable' if the document didn't specify other semantics
// * Remove leading '#nullable restore' (was '#nullable enable' prior to rewrite in the previous step)
// * Otherwise, leave existing '#nullable' directive since it will control the initial semantics for the document
return await DisableNullableReferenceTypesInExistingDocumentIfNecessaryAsync(document, root, firstToken, cancellationToken).ConfigureAwait(false);
}
private static (SyntaxNode root, SyntaxToken firstToken) RewriteExistingDirectives(SyntaxNode root, SyntaxToken firstToken)
{
var firstNonDirectiveToken = root.GetFirstToken();
var firstDirective = root.GetFirstDirective(s_isNullableDirectiveTriviaPredicate);
if (firstNonDirectiveToken.IsKind(SyntaxKind.None) && firstDirective is null)
{
// The document has no semantic content, and also has no nullable directives to update
return (root, firstToken);
}
// Update all prior nullable directives
var directives = new List<NullableDirectiveTriviaSyntax>();
for (var directive = firstDirective; directive is not null; directive = directive.GetNextDirective(s_isNullableDirectiveTriviaPredicate))
{
directives.Add((NullableDirectiveTriviaSyntax)directive);
}
var updatedRoot = root.ReplaceNodes(
directives,
(originalNode, rewrittenNode) =>
{
if (originalNode.SettingToken.IsKind(SyntaxKind.DisableKeyword))
{
// 'disable' keeps its meaning
return rewrittenNode;
}
if (originalNode.SettingToken.IsKind(SyntaxKind.RestoreKeyword))
{
return rewrittenNode.WithSettingToken(DisableKeyword.WithTriviaFrom(rewrittenNode.SettingToken));
}
if (originalNode.SettingToken.IsKind(SyntaxKind.EnableKeyword))
{
return rewrittenNode.WithSettingToken(RestoreKeyword.WithTriviaFrom(rewrittenNode.SettingToken));
}
Debug.Fail("Unexpected state?");
return rewrittenNode;
});
return (updatedRoot, GetFirstTokenOfInterest(updatedRoot));
}
private static async Task<SyntaxNode> DisableNullableReferenceTypesInExistingDocumentIfNecessaryAsync(Document document, SyntaxNode root, SyntaxToken firstToken, CancellationToken cancellationToken)
{
var options = await document.GetCSharpSyntaxFormattingOptionsAsync(cancellationToken).ConfigureAwait(false);
var newLine = SyntaxFactory.EndOfLine(options.NewLine);
// Add a new '#nullable disable' to the top of each file
if (!HasLeadingNullableDirective(root, out var leadingDirective))
{
var nullableDisableTrivia = SyntaxFactory.Trivia(SyntaxFactory.NullableDirectiveTrivia(DisableKeyword.WithPrependedLeadingTrivia(SyntaxFactory.ElasticSpace), isActive: true));
var existingTriviaList = firstToken.LeadingTrivia;
var insertionIndex = GetInsertionPoint(existingTriviaList);
return root.ReplaceToken(firstToken, firstToken.WithLeadingTrivia(existingTriviaList.InsertRange(insertionIndex, [nullableDisableTrivia, newLine, newLine])));
}
else if (leadingDirective.SettingToken.IsKind(SyntaxKind.RestoreKeyword) && leadingDirective.TargetToken.IsKind(SyntaxKind.None))
{
// Remove the leading `#nullable restore` directive because it's redundant. Since there is no
// RemoveTrivia call, we replace the trivia with an empty marker.
return root.ReplaceTrivia(leadingDirective.ParentTrivia, SyntaxFactory.ElasticMarker);
}
else
{
// No need to add a '#nullable disable' directive because the file already starts with an unconditional
// '#nullable' directive that will override it.
return root;
}
}
private static int GetInsertionPoint(SyntaxTriviaList list)
{
var insertionPoint = list.Count;
for (var i = list.Count - 1; i >= 0; i--)
{
switch (list[i].Kind())
{
case SyntaxKind.WhitespaceTrivia:
case SyntaxKind.EndOfLineTrivia:
case SyntaxKind.SingleLineCommentTrivia:
case SyntaxKind.MultiLineCommentTrivia:
continue;
case SyntaxKind.SingleLineDocumentationCommentTrivia:
case SyntaxKind.MultiLineDocumentationCommentTrivia:
// Insert before the documentation comment
insertionPoint = i;
continue;
default:
return insertionPoint;
}
}
return insertionPoint;
}
private static SyntaxToken GetFirstTokenOfInterest(SyntaxNode root)
{
var firstToken = root.GetFirstToken(includeDirectives: true);
if (firstToken.IsKind(SyntaxKind.None))
{
return firstToken;
}
if (firstToken.IsKind(SyntaxKind.HashToken) && firstToken.Parent.IsKind(SyntaxKind.RegionDirectiveTrivia))
{
// If the file starts with a #region/#endregion that contains no semantic content (e.g. just a file
// header), skip it.
var nextToken = firstToken.Parent.GetLastToken(includeDirectives: true).GetNextToken(includeDirectives: true);
if (nextToken.IsKind(SyntaxKind.HashToken) && nextToken.Parent.IsKind(SyntaxKind.EndRegionDirectiveTrivia))
{
firstToken = nextToken.Parent.GetLastToken(includeDirectives: true).GetNextToken(includeDirectives: true);
}
}
return firstToken;
}
private static bool HasLeadingNullableDirective(SyntaxNode root, [NotNullWhen(true)] out NullableDirectiveTriviaSyntax? leadingNullableDirective)
{
// A leading nullable directive is a '#nullable' directive which precedes any conditional directives ('#if')
// or code (non-trivia).
var firstRelevantDirective = root.GetFirstDirective(static directive => directive.Kind() is SyntaxKind.NullableDirectiveTrivia or SyntaxKind.IfDirectiveTrivia);
if (firstRelevantDirective is NullableDirectiveTriviaSyntax nullableDirective
&& nullableDirective.TargetToken.IsKind(SyntaxKind.None))
{
var firstSemanticToken = root.GetFirstToken();
if (firstSemanticToken.IsKind(SyntaxKind.None) || firstSemanticToken.SpanStart > nullableDirective.Span.End)
{
leadingNullableDirective = nullableDirective;
return true;
}
}
leadingNullableDirective = null;
return false;
}
private enum CodeActionPurpose
{
Preview,
Apply,
}
private sealed class CustomCodeAction(
Func<CodeActionPurpose, IProgress<CodeAnalysisProgress>, CancellationToken, Task<Solution>> createChangedSolution)
: CodeAction.SolutionChangeAction(
CSharpFeaturesResources.Enable_nullable_reference_types_in_project,
(progress, cancellationToken) => createChangedSolution(CodeActionPurpose.Apply, progress, cancellationToken),
nameof(CSharpFeaturesResources.Enable_nullable_reference_types_in_project))
{
private readonly Func<CodeActionPurpose, IProgress<CodeAnalysisProgress>, CancellationToken, Task<Solution>> _createChangedSolution = createChangedSolution;
protected override async Task<IEnumerable<CodeActionOperation>> ComputePreviewOperationsAsync(CancellationToken cancellationToken)
{
var changedSolution = await _createChangedSolution(CodeActionPurpose.Preview, CodeAnalysisProgress.None, cancellationToken).ConfigureAwait(false);
if (changedSolution is null)
return [];
return new CodeActionOperation[] { new ApplyChangesOperation(changedSolution) };
}
}
}
|