|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Text;
namespace Microsoft.Extensions.ExtraAnalyzers;
[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(LegacyLoggingFixer))]
[Shared]
public sealed partial class LegacyLoggingFixer : CodeFixProvider
{
// mimics the definition from Microsoft.Extensions.Logging.Abstractions
internal enum LogLevel
{
Trace = 0,
Debug = 1,
Information = 2,
Warning = 3,
Error = 4,
Critical = 5,
None = 6,
}
// function pointers that can be patched by test code to exercise obscure failure paths
internal Func<Document, CancellationToken, Task<SyntaxNode?>> GetSyntaxRootAsync = (d, t) => d.GetSyntaxRootAsync(t);
internal Func<Document, CancellationToken, Task<SemanticModel?>> GetSemanticModelAsync = (d, t) => d.GetSemanticModelAsync(t);
internal Func<SemanticModel, SyntaxNode, CancellationToken, IOperation?> GetOperation = (sm, sn, t) => sm.GetOperation(sn, t);
internal Func<Compilation, string, INamedTypeSymbol?> GetTypeByMetadataName1 = (c, n) => c.GetTypeByMetadataName(n);
internal Func<Compilation, string, INamedTypeSymbol?> GetTypeByMetadataName2 = (c, n) => c.GetTypeByMetadataName(n);
internal Func<Compilation, string, INamedTypeSymbol?> GetTypeByMetadataName3 = (c, n) => c.GetTypeByMetadataName(n);
internal Func<SemanticModel, BaseMethodDeclarationSyntax, CancellationToken, IMethodSymbol?> GetDeclaredSymbol = (sm, m, t) => sm.GetDeclaredSymbol(m, t);
private const string LoggerMessageAttribute = "Microsoft.Extensions.Logging.LoggerMessageAttribute";
/// <inheritdoc/>
public override ImmutableArray<string> FixableDiagnosticIds => ImmutableArray.Create(DiagDescriptors.LegacyLogging.Id);
/// <inheritdoc/>
public override FixAllProvider? GetFixAllProvider() => null;
/// <inheritdoc/>
public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var (invocationExpression, details) = await CheckIfCanFixAsync(context.Document, context.Span, context.CancellationToken).ConfigureAwait(false);
if (invocationExpression != null && details != null)
{
context.RegisterCodeFix(
CodeAction.Create(
title: Resources.GenerateStronglyTypedLoggingMethod,
createChangedSolution: cancellationToken => ApplyFixAsync(context.Document, invocationExpression, details, cancellationToken),
equivalenceKey: nameof(Resources.GenerateStronglyTypedLoggingMethod)),
context.Diagnostics);
}
}
internal async Task<(ExpressionSyntax? invocationExpression, FixDetails? details)>
CheckIfCanFixAsync(Document invocationDoc, TextSpan span, CancellationToken cancellationToken)
{
var root = await GetSyntaxRootAsync(invocationDoc, cancellationToken).ConfigureAwait(false);
if (root?.FindNode(span) is not ExpressionSyntax invocationExpression)
{
// shouldn't happen, we only get called for invocations
return (null, null);
}
var sm = await GetSemanticModelAsync(invocationDoc, cancellationToken).ConfigureAwait(false);
if (sm == null)
{
// shouldn't happen
return (null, null);
}
var comp = sm.Compilation;
var loggerExtensions = GetTypeByMetadataName1(comp, "Microsoft.Extensions.Logging.LoggerExtensions");
if (loggerExtensions == null)
{
// shouldn't happen, we only get called for methods on this type
return (null, null);
}
var invocationOp = GetOperation(sm, invocationExpression, cancellationToken) as IInvocationOperation;
if (invocationOp == null)
{
// shouldn't happen, we're dealing with an invocation expression
return (null, null);
}
var method = invocationOp.TargetMethod;
var details = new FixDetails(method, invocationOp, invocationDoc.Project.DefaultNamespace, invocationDoc.Project.Documents);
if (string.IsNullOrWhiteSpace(details.Message))
{
// can't auto-generate without a valid message string
return (null, null);
}
if (details.EventIdParamIndex >= 0)
{
// can't auto-generate the variants using event id
return (null, null);
}
if (string.IsNullOrWhiteSpace(details.Level))
{
// can't auto-generate without a valid level
return (null, null);
}
return (invocationExpression, details);
}
/// <summary>
/// Get the final name of the target method. If there's an existing method with the right
/// message, level, and argument types, we just use that. Otherwise, we create a new method.
/// </summary>
internal async Task<(string methodName, bool existing)> GetFinalTargetMethodNameAsync(
Document targetDoc,
ClassDeclarationSyntax targetClass,
Document invocationDoc,
ExpressionSyntax invocationExpression,
FixDetails details,
CancellationToken cancellationToken)
{
var invocationSM = (await invocationDoc.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false))!;
var invocationOp = (invocationSM.GetOperation(invocationExpression, cancellationToken) as IInvocationOperation)!;
var docEditor = await DocumentEditor.CreateAsync(targetDoc, cancellationToken).ConfigureAwait(false);
var sm = docEditor.SemanticModel;
var comp = sm.Compilation;
var logMethodAttribute = GetTypeByMetadataName2(comp, LoggerMessageAttribute);
if (logMethodAttribute is null)
{
// strange that we can't find the attribute, but supply a potential useful value instead
return (details.TargetMethodName, false);
}
var invocationArgList = MakeArgumentList(details, invocationOp);
var conflict = false;
var count = 2;
string methodName;
do
{
methodName = details.TargetMethodName;
if (conflict)
{
methodName = $"{methodName}{count}";
count++;
conflict = false;
}
foreach (var method in targetClass.Members.Where(m => m.IsKind(SyntaxKind.MethodDeclaration)).OfType<MethodDeclarationSyntax>())
{
var methodSymbol = GetDeclaredSymbol(sm, method, cancellationToken);
if (methodSymbol == null)
{
// hmmm, this shouldn't happen should it?
continue;
}
var matchName = method.Identifier.ToString() == methodName;
var matchParams = invocationArgList.Count == methodSymbol.Parameters.Length;
if (matchParams)
{
for (int i = 0; i < invocationArgList.Count; i++)
{
matchParams = invocationArgList[i].Equals(methodSymbol.Parameters[i].Type, SymbolEqualityComparer.Default);
if (!matchParams)
{
break;
}
}
}
if (matchName && matchParams)
{
conflict = true;
}
foreach (var methodAttr in methodSymbol.GetAttributes())
{
if (SymbolEqualityComparer.Default.Equals(methodAttr.AttributeClass, logMethodAttribute) &&
methodAttr.AttributeConstructor is not null)
{
var argLevel = GetLogMethodAttributeParameter(methodAttr.AttributeConstructor.Parameters, methodAttr.ConstructorArguments, "level");
if (!argLevel.HasValue)
{
break;
}
var level = (LogLevel)argLevel.Value.Value!;
var argMessage = GetLogMethodAttributeParameter(methodAttr.AttributeConstructor.Parameters, methodAttr.ConstructorArguments, "message");
if (!argMessage.HasValue)
{
break;
}
var message = argMessage.Value.Value!.ToString();
var matchMessage = message == details.Message;
var matchLevel = FixDetails.GetLogLevelName(level) == details.Level;
if (matchLevel && matchMessage && matchParams)
{
// found a match, use this one
return (method.Identifier.ToString(), true);
}
break;
}
}
}
}
while (conflict);
return (methodName, false);
}
private static TypedConstant? GetLogMethodAttributeParameter(
ImmutableArray<IParameterSymbol> attributeCtorParams,
ImmutableArray<TypedConstant> constructorArguments,
string paramName)
{
foreach (var param in attributeCtorParams)
{
if (param.Name == paramName)
{
foreach (var ctorArg in constructorArguments)
{
if (SymbolEqualityComparer.Default.Equals(ctorArg.Type, param.Type))
{
return ctorArg;
}
}
}
}
return null;
}
/// <summary>
/// Finds the class into which to create the logging method signature, or creates it if it doesn't exist.
/// </summary>
private static async Task<(Solution solution, ClassDeclarationSyntax declarationSyntax, Document document)>
GetOrMakeTargetClassAsync(Project proj, FixDetails details, CancellationToken cancellationToken)
{
while (true)
{
var comp = (await proj.GetCompilationAsync(cancellationToken).ConfigureAwait(false))!;
var allNodes = comp.SyntaxTrees.SelectMany(s => s.GetRoot().DescendantNodes());
var allClasses = allNodes.Where(d => d.IsKind(SyntaxKind.ClassDeclaration)).OfType<ClassDeclarationSyntax>();
foreach (var cl in allClasses)
{
var nspace = GetNamespace(cl);
if (nspace != details.TargetNamespace)
{
continue;
}
if (cl.Identifier.Text == details.TargetClassName)
{
return (proj.Solution, cl, proj.GetDocument(cl.SyntaxTree)!);
}
}
var text = $@"
#pragma warning disable CS8019
using Microsoft.Extensions.Logging;
using System;
#pragma warning restore CS8019
static partial class {details.TargetClassName}
{{
}}
";
if (!string.IsNullOrEmpty(details.TargetNamespace))
{
text = $@"
namespace {details.TargetNamespace}
{{
#pragma warning disable CS8019
using Microsoft.Extensions.Logging;
using System;
#pragma warning restore CS8019
static partial class {details.TargetClassName}
{{
}}
}}
";
}
proj = proj.AddDocument(details.TargetFilename, text).Project;
}
}
/// <summary>
/// Remaps an invocation expression to a new doc.
/// </summary>
private static async Task<(Document document, ExpressionSyntax expressionSyntax)>
RemapAsync(Solution sol, DocumentId docId, ExpressionSyntax invocationExpression)
{
var doc = sol.GetDocument(docId)!;
var root = await doc.GetSyntaxRootAsync().ConfigureAwait(false);
return (doc, (root!.FindNode(invocationExpression.Span) as ExpressionSyntax)!);
}
private static string GetNamespace(ClassDeclarationSyntax cl)
{
var ns = cl.Parent as BaseNamespaceDeclarationSyntax;
if (ns == null)
{
if (cl.Parent is not CompilationUnitSyntax)
{
// nested type, we don't do those
return "<+Invalid Namespace+>";
}
return string.Empty;
}
var nspace = ns.Name.ToString();
while (true)
{
ns = ns.Parent as BaseNamespaceDeclarationSyntax;
if (ns == null)
{
break;
}
nspace = $"{ns.Name}.{nspace}";
}
return nspace;
}
/// <summary>
/// Given a LoggerExtensions method invocation, produce a parameter list for the corresponding generated logging method.
/// </summary>
private static List<SyntaxNode> MakeParameterList(
FixDetails details,
IInvocationOperation invocationOp,
SyntaxGenerator gen)
{
var t = invocationOp.Arguments[0].Value.Type!;
var loggerType = gen.TypeExpression(t);
if (invocationOp.Parent?.Kind == OperationKind.ConditionalAccess)
{
loggerType = gen.TypeExpression(t.WithNullableAnnotation(NullableAnnotation.Annotated));
}
var loggerParam = gen.ParameterDeclaration("logger", loggerType);
if (loggerParam is ParameterSyntax parameterSyntax)
{
loggerParam = parameterSyntax.WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.ThisKeyword)));
}
var parameters = new List<SyntaxNode>
{
loggerParam
};
if (details.ExceptionParamIndex >= 0)
{
parameters.Add(gen.ParameterDeclaration("exception", gen.TypeExpression(invocationOp.Arguments[details.ExceptionParamIndex].Value.Type!)));
}
var index = 0;
if (details.InterpolationArgs != null)
{
foreach (var o in details.InterpolationArgs)
{
parameters.Add(gen.ParameterDeclaration(details.MessageArgs[index++], gen.TypeExpression(o.Type!)));
}
}
var paramsArg = invocationOp.Arguments[details.ArgsParamIndex];
if (paramsArg != null)
{
var arrayCreation = (IArrayCreationOperation)paramsArg.Value;
foreach (var e in arrayCreation.Initializer!.ElementValues)
{
var type = e.SemanticModel?.GetTypeInfo(e.Syntax).Type!;
string name;
if (index < details.MessageArgs.Count)
{
name = details.MessageArgs[index];
}
else
{
name = $"arg{index}";
}
parameters.Add(gen.ParameterDeclaration(name, gen.TypeExpression(type)));
index++;
}
}
return parameters;
}
/// <summary>
/// Given a LoggerExtensions method invocation, produce an argument list in the shape of a corresponding generated logging method.
/// </summary>
private static List<ITypeSymbol> MakeArgumentList(FixDetails details, IInvocationOperation invocationOp)
{
var args = new List<ITypeSymbol>
{
invocationOp.Arguments[0].Value.Type!
};
if (details.ExceptionParamIndex >= 0)
{
args.Add(invocationOp.Arguments[details.ExceptionParamIndex].Value.Type!);
}
if (details.InterpolationArgs != null)
{
foreach (var a in details.InterpolationArgs)
{
args.Add(a.Type!);
}
}
var paramsArg = invocationOp.Arguments[details.ArgsParamIndex];
if (paramsArg != null)
{
var arrayCreation = (IArrayCreationOperation)paramsArg.Value;
foreach (var e in arrayCreation.Initializer!.ElementValues)
{
foreach (var d in e.Descendants())
{
args.Add(d.Type!);
}
}
}
return args;
}
private static async Task<Solution> RewriteLoggingCallAsync(
Document doc,
ExpressionSyntax invocationExpression,
FixDetails details,
string methodName,
CancellationToken cancellationToken)
{
var solEditor = new SolutionEditor(doc.Project.Solution);
var docEditor = await solEditor.GetDocumentEditorAsync(doc.Id, cancellationToken).ConfigureAwait(false);
var sm = docEditor.SemanticModel;
var comp = sm.Compilation;
var gen = docEditor.Generator;
var invocation = sm.GetOperation(invocationExpression, cancellationToken) as IInvocationOperation;
var argList = new List<SyntaxNode>();
int index = 0;
SyntaxNode loggerSyntaxNode = null!;
foreach (var arg in invocation!.Arguments)
{
if ((index == details.MessageParamIndex) || (index == details.LogLevelParamIndex))
{
index++;
continue;
}
index++;
if (index == 1)
{
loggerSyntaxNode = arg.Syntax;
}
else
{
if (arg.ArgumentKind == ArgumentKind.ParamArray)
{
if (details.InterpolationArgs != null)
{
foreach (var a in details.InterpolationArgs)
{
argList.Add(a.Syntax.WithoutTrivia());
}
}
var arrayCreation = (IArrayCreationOperation)arg.Value;
foreach (var e in arrayCreation.Initializer!.ElementValues)
{
argList.Add(e.Syntax.WithoutTrivia());
}
}
else
{
argList.Add(arg.Syntax.WithoutTrivia());
}
}
}
var memberAccessExpression = gen.MemberAccessExpression(loggerSyntaxNode!, methodName);
var call = gen.InvocationExpression(memberAccessExpression, argList).WithTriviaFrom(invocationExpression);
if (invocationExpression.Parent!.IsKind(SyntaxKind.ConditionalAccessExpression))
{
invocationExpression = (ExpressionSyntax)invocationExpression.Parent;
}
docEditor.ReplaceNode(invocationExpression, call);
return solEditor.GetChangedSolution();
}
/// <summary>
/// Orchestrate all the work needed to fix an issue.
/// </summary>
private async Task<Solution> ApplyFixAsync(Document invocationDoc, ExpressionSyntax invocationExpression, FixDetails details, CancellationToken cancellationToken)
{
ClassDeclarationSyntax targetClass;
Document targetDoc;
Solution sol;
// stable id surviving across solution generations
var invocationDocId = invocationDoc.Id;
// get a reference to the class where to insert the logging method, creating it if necessary
(sol, targetClass, targetDoc) = await GetOrMakeTargetClassAsync(invocationDoc.Project, details, cancellationToken).ConfigureAwait(false);
// find the doc and invocation in the current solution
(invocationDoc, invocationExpression) = await RemapAsync(sol, invocationDocId, invocationExpression).ConfigureAwait(false);
// determine the final name of the logging method and whether we need to generate it or not
var (methodName, existing) = await GetFinalTargetMethodNameAsync(targetDoc, targetClass, invocationDoc, invocationExpression, details, cancellationToken).ConfigureAwait(false);
// if the target method doesn't already exist, go make it
if (!existing)
{
// generate the logging method signature in the target class
sol = await InsertLoggingMethodSignatureAsync(targetDoc, targetClass, invocationDoc, invocationExpression, details, cancellationToken).ConfigureAwait(false);
// find the doc and invocation in the current solution
(invocationDoc, invocationExpression) = await RemapAsync(sol, invocationDocId, invocationExpression).ConfigureAwait(false);
}
// rewrite the call site to invoke the generated logging method
sol = await RewriteLoggingCallAsync(invocationDoc, invocationExpression, details, methodName, cancellationToken).ConfigureAwait(false);
return sol;
}
private async Task<Solution> InsertLoggingMethodSignatureAsync(
Document targetDoc,
ClassDeclarationSyntax targetClass,
Document invocationDoc,
ExpressionSyntax invocationExpression,
FixDetails details,
CancellationToken cancellationToken)
{
var invocationSM = (await invocationDoc.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false))!;
var invocationOp = (invocationSM.GetOperation(invocationExpression, cancellationToken) as IInvocationOperation)!;
var solEditor = new SolutionEditor(targetDoc.Project.Solution);
var docEditor = await solEditor.GetDocumentEditorAsync(targetDoc.Id, cancellationToken).ConfigureAwait(false);
var sm = docEditor.SemanticModel;
var comp = sm.Compilation;
var gen = docEditor.Generator;
var logMethod = gen.MethodDeclaration(
details.TargetMethodName,
MakeParameterList(details, invocationOp, gen),
accessibility: Accessibility.Internal,
modifiers: DeclarationModifiers.Partial | DeclarationModifiers.Static);
var attrArgs = new[]
{
gen.LiteralExpression(CalcEventId(comp, targetClass, cancellationToken)),
gen.MemberAccessExpression(gen.TypeExpression(comp.GetTypeByMetadataName("Microsoft.Extensions.Logging.LogLevel")!), details.Level),
gen.LiteralExpression(details.Message),
};
var attr = gen.Attribute(LoggerMessageAttribute, attrArgs);
logMethod = gen.AddAttributes(logMethod, attr);
var line = SyntaxFactory.ParseLeadingTrivia($@"
");
logMethod = logMethod.WithLeadingTrivia(line);
docEditor.AddMember(targetClass, logMethod);
return solEditor.GetChangedSolution();
}
/// <summary>
/// Iterate through the existing methods in the target class
/// and look at any method annotated with [LoggerMessage],
/// get their event ids, and then return 1 larger than any event id
/// found.
/// </summary>
private int CalcEventId(Compilation comp, ClassDeclarationSyntax targetClass, CancellationToken cancellationToken)
{
var logMethodAttribute = GetTypeByMetadataName3(comp, LoggerMessageAttribute);
if (logMethodAttribute is null)
{
// strange we can't find the attribute, but supply a potential useful value instead
return targetClass.Members.Count + 1;
}
var max = 0;
var semanticModel = comp.GetSemanticModel(targetClass.SyntaxTree);
var targetClassSymbol = semanticModel.GetDeclaredSymbol(targetClass, cancellationToken);
if (targetClassSymbol is null || targetClassSymbol is IErrorTypeSymbol)
{
return max;
}
foreach (var methodSymbol in targetClassSymbol.GetMembers().Where(m => m.Kind == SymbolKind.Method).OfType<IMethodSymbol>())
{
foreach (var methodAttr in methodSymbol.GetAttributes())
{
if (SymbolEqualityComparer.Default.Equals(methodAttr.AttributeClass, logMethodAttribute) &&
methodAttr.AttributeConstructor is not null)
{
var arg = GetLogMethodAttributeParameter(methodAttr.AttributeConstructor.Parameters, methodAttr.ConstructorArguments, "eventId");
if (!arg.HasValue)
{
continue;
}
var eventId = (int)arg.Value.Value!;
if (eventId >= max)
{
max = eventId + 1;
}
}
}
}
return max;
}
}
|