|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Simplification;
using Roslyn.Utilities;
#if CODE_STYLE
using DeclarationModifiers = Microsoft.CodeAnalysis.Internal.Editing.DeclarationModifiers;
#else
using DeclarationModifiers = Microsoft.CodeAnalysis.Editing.DeclarationModifiers;
#endif
namespace Microsoft.CodeAnalysis.Shared.Extensions;
internal static partial class SyntaxGeneratorExtensions
{
private const string EqualsName = "Equals";
private const string DefaultName = "Default";
private const string ObjName = "obj";
public const string OtherName = "other";
public static SyntaxNode CreateThrowNotImplementedStatement(
this SyntaxGenerator codeDefinitionFactory, Compilation compilation)
{
return codeDefinitionFactory.ThrowStatement(
CreateNewNotImplementedException(codeDefinitionFactory, compilation));
}
public static SyntaxNode CreateThrowNotImplementedExpression(
this SyntaxGenerator codeDefinitionFactory, Compilation compilation)
{
return codeDefinitionFactory.ThrowExpression(
CreateNewNotImplementedException(codeDefinitionFactory, compilation));
}
private static SyntaxNode CreateNewNotImplementedException(SyntaxGenerator codeDefinitionFactory, Compilation compilation)
{
var notImplementedExceptionTypeSyntax = compilation.NotImplementedExceptionType() is INamedTypeSymbol symbol
? codeDefinitionFactory.TypeExpression(symbol, addImport: false)
: codeDefinitionFactory.QualifiedName(codeDefinitionFactory.IdentifierName(nameof(System)), codeDefinitionFactory.IdentifierName(nameof(NotImplementedException)));
return codeDefinitionFactory.ObjectCreationExpression(
notImplementedExceptionTypeSyntax,
arguments: []);
}
public static ImmutableArray<SyntaxNode> CreateThrowNotImplementedStatementBlock(
this SyntaxGenerator codeDefinitionFactory, Compilation compilation)
=> [CreateThrowNotImplementedStatement(codeDefinitionFactory, compilation)];
public static ImmutableArray<SyntaxNode> CreateArguments(
this SyntaxGenerator factory,
ImmutableArray<IParameterSymbol> parameters)
{
return parameters.SelectAsArray(p => CreateArgument(factory, p));
}
private static SyntaxNode CreateArgument(
this SyntaxGenerator factory,
IParameterSymbol parameter)
{
return factory.Argument(parameter.RefKind, factory.IdentifierName(parameter.Name));
}
public static SyntaxNode GetDefaultEqualityComparer(
this SyntaxGenerator factory,
SyntaxGeneratorInternal generatorInternal,
Compilation compilation,
ITypeSymbol type)
{
var equalityComparerType = compilation.EqualityComparerOfTType();
var typeExpression = equalityComparerType == null
? factory.GenericName(nameof(EqualityComparer<int>), type)
: generatorInternal.Type(equalityComparerType.Construct(type), typeContext: false);
return factory.MemberAccessExpression(typeExpression, factory.IdentifierName(DefaultName));
}
private static ITypeSymbol GetType(Compilation compilation, ISymbol symbol)
=> symbol switch
{
IFieldSymbol field => field.Type,
IPropertySymbol property => property.Type,
_ => compilation.GetSpecialType(SpecialType.System_Object),
};
public static SyntaxNode IsPatternExpression(this SyntaxGeneratorInternal generator, SyntaxNode expression, SyntaxNode pattern)
=> generator.IsPatternExpression(expression, isToken: default, pattern);
/// <summary>
/// Generates a call to a method *through* an existing field or property symbol.
/// </summary>
/// <returns></returns>
public static SyntaxNode GenerateDelegateThroughMemberStatement(
this SyntaxGenerator generator, IMethodSymbol method, ISymbol throughMember)
{
var through = generator.MemberAccessExpression(
CreateDelegateThroughExpression(generator, method, throughMember),
method.IsGenericMethod
? generator.GenericName(method.Name, method.TypeArguments)
: generator.IdentifierName(method.Name));
var invocationExpression = generator.InvocationExpression(through, generator.CreateArguments(method.Parameters));
return method.ReturnsVoid
? generator.ExpressionStatement(invocationExpression)
: generator.ReturnStatement(invocationExpression);
}
public static SyntaxNode CreateDelegateThroughExpression(
this SyntaxGenerator generator, ISymbol member, ISymbol throughMember)
{
var name = generator.IdentifierName(throughMember.Name);
var through = throughMember.IsStatic
? GenerateContainerName(generator, throughMember)
// If we're delegating through a primary constructor parameter, we cannot qualify the name at all.
: throughMember is IParameterSymbol
? null
: generator.ThisExpression();
through = through is null ? name : generator.MemberAccessExpression(through, name);
var throughMemberType = throughMember.GetMemberType();
if (throughMemberType != null &&
member.ContainingType is { TypeKind: TypeKind.Interface } interfaceBeingImplemented)
{
// In the case of 'implement interface through field / property', we need to know what
// interface we are implementing so that we can insert casts to this interface on every
// usage of the field in the generated code. Without these casts we would end up generating
// code that fails compilation in certain situations.
//
// For example consider the following code.
// class C : IReadOnlyList<int> { int[] field; }
// When applying the 'implement interface through field' code fix in the above example,
// we need to generate the following code to implement the Count property on IReadOnlyList<int>
// class C : IReadOnlyList<int> { int[] field; int Count { get { ((IReadOnlyList<int>)field).Count; } ...}
// as opposed to the following code which will fail to compile (because the array field
// doesn't have a property named .Count) -
// class C : IReadOnlyList<int> { int[] field; int Count { get { field.Count; } ...}
//
// The 'InterfaceTypes' property on the state object always contains only one item
// in the case of C# i.e. it will contain exactly the interface we are trying to implement.
// This is also the case most of the time in the case of VB, except in certain error conditions
// (recursive / circular cases) where the span of the squiggle for the corresponding
// diagnostic (BC30149) changes and 'InterfaceTypes' ends up including all interfaces
// in the Implements clause. For the purposes of inserting the above cast, we ignore the
// uncommon case and optimize for the common one - in other words, we only apply the cast
// in cases where we can unambiguously figure out which interface we are trying to implement.
if (!throughMemberType.Equals(interfaceBeingImplemented))
{
through = generator.CastExpression(interfaceBeingImplemented,
through.WithAdditionalAnnotations(Simplifier.Annotation));
}
else if (throughMember is IPropertySymbol { IsStatic: false, ExplicitInterfaceImplementations: [var explicitlyImplementedProperty, ..] })
{
// If we are implementing through an explicitly implemented property, we need to cast 'this' to
// the explicitly implemented interface type before calling the member, as in:
// ((IA)this).Prop.Member();
//
var explicitImplementationCast = generator.CastExpression(
explicitlyImplementedProperty.ContainingType,
generator.ThisExpression());
through = generator.MemberAccessExpression(explicitImplementationCast,
generator.IdentifierName(explicitlyImplementedProperty.Name));
through = through.WithAdditionalAnnotations(Simplifier.Annotation);
}
}
return through.WithAdditionalAnnotations(Simplifier.Annotation);
// local functions
static SyntaxNode GenerateContainerName(SyntaxGenerator factory, ISymbol throughMember)
{
var classOrStructType = throughMember.ContainingType;
return classOrStructType.IsGenericType
? factory.GenericName(classOrStructType.Name, classOrStructType.TypeArguments)
: factory.IdentifierName(classOrStructType.Name);
}
}
public static ImmutableArray<SyntaxNode> GetGetAccessorStatements(
this SyntaxGenerator generator, Compilation compilation,
IPropertySymbol property, ISymbol? throughMember, bool preferAutoProperties)
{
if (throughMember != null)
{
var throughExpression = CreateDelegateThroughExpression(generator, property, throughMember);
var expression = property.IsIndexer
? throughExpression
: generator.MemberAccessExpression(
throughExpression, generator.IdentifierName(property.Name));
if (property.Parameters.Length > 0)
{
var arguments = generator.CreateArguments(property.Parameters);
expression = generator.ElementAccessExpression(expression, arguments);
}
return [generator.ReturnStatement(expression)];
}
return preferAutoProperties ? default : generator.CreateThrowNotImplementedStatementBlock(compilation);
}
public static ImmutableArray<SyntaxNode> GetSetAccessorStatements(
this SyntaxGenerator generator, Compilation compilation,
IPropertySymbol property, ISymbol? throughMember, bool preferAutoProperties)
{
if (throughMember != null)
{
var throughExpression = CreateDelegateThroughExpression(generator, property, throughMember);
var expression = property.IsIndexer
? throughExpression
: generator.MemberAccessExpression(
throughExpression, generator.IdentifierName(property.Name));
if (property.Parameters.Length > 0)
{
var arguments = generator.CreateArguments(property.Parameters);
expression = generator.ElementAccessExpression(expression, arguments);
}
expression = generator.AssignmentStatement(expression, generator.IdentifierName("value"));
return [generator.ExpressionStatement(expression)];
}
return preferAutoProperties
? default
: generator.CreateThrowNotImplementedStatementBlock(compilation);
}
private static bool TryGetValue(IDictionary<string, string>? dictionary, string key, [NotNullWhen(true)] out string? value)
{
value = null;
return
dictionary != null &&
dictionary.TryGetValue(key, out value);
}
private static bool TryGetValue(IDictionary<string, ISymbol>? dictionary, string key, [NotNullWhen(true)] out string? value)
{
value = null;
if (dictionary != null && dictionary.TryGetValue(key, out var symbol))
{
value = symbol.Name;
return true;
}
return false;
}
public static ImmutableArray<ISymbol> CreateFieldsForParameters(
ImmutableArray<IParameterSymbol> parameters, ImmutableDictionary<string, string>? parameterToNewFieldMap, bool isContainedInUnsafeType)
{
using var _ = ArrayBuilder<ISymbol>.GetInstance(out var result);
foreach (var parameter in parameters)
{
// For non-out parameters, create a field and assign the parameter to it.
if (parameter.RefKind != RefKind.Out &&
TryGetValue(parameterToNewFieldMap, parameter.Name, out var fieldName))
{
result.Add(CodeGenerationSymbolFactory.CreateFieldSymbol(
attributes: default,
accessibility: Accessibility.Private,
modifiers: new DeclarationModifiers(isUnsafe: !isContainedInUnsafeType && parameter.RequiresUnsafeModifier()),
type: parameter.Type,
name: fieldName));
}
}
return result.ToImmutableAndClear();
}
public static ImmutableArray<ISymbol> CreatePropertiesForParameters(
ImmutableArray<IParameterSymbol> parameters, ImmutableDictionary<string, string>? parameterToNewPropertyMap, bool isContainedInUnsafeType)
{
using var _ = ArrayBuilder<ISymbol>.GetInstance(out var result);
foreach (var parameter in parameters)
{
// For non-out parameters, create a property and assign the parameter to it.
if (parameter.RefKind != RefKind.Out &&
TryGetValue(parameterToNewPropertyMap, parameter.Name, out var propertyName))
{
result.Add(CodeGenerationSymbolFactory.CreatePropertySymbol(
attributes: default,
accessibility: Accessibility.Public,
modifiers: new DeclarationModifiers(isUnsafe: !isContainedInUnsafeType && parameter.RequiresUnsafeModifier()),
type: parameter.Type,
refKind: RefKind.None,
explicitInterfaceImplementations: [],
name: propertyName,
parameters: [],
getMethod: CodeGenerationSymbolFactory.CreateAccessorSymbol(
attributes: default,
accessibility: default,
statements: default),
setMethod: null));
}
}
return result.ToImmutableAndClear();
}
public static ImmutableArray<SyntaxNode> CreateAssignmentStatements(
this SyntaxGenerator factory,
SyntaxGeneratorInternal generatorInternal,
SemanticModel semanticModel,
ImmutableArray<IParameterSymbol> parameters,
IDictionary<string, ISymbol>? parameterToExistingFieldMap,
IDictionary<string, string>? parameterToNewFieldMap,
bool addNullChecks,
bool preferThrowExpression)
{
var nullCheckStatements = ArrayBuilder<SyntaxNode>.GetInstance();
var assignStatements = ArrayBuilder<SyntaxNode>.GetInstance();
foreach (var parameter in parameters)
{
var refKind = parameter.RefKind;
var parameterType = parameter.Type;
var parameterName = parameter.Name;
if (refKind == RefKind.Out)
{
// If it's an out param, then don't create a field for it. Instead, assign
// the default value for that type (i.e. "default(...)") to it.
var assignExpression = factory.AssignmentStatement(
factory.IdentifierName(parameterName),
factory.DefaultExpression(parameterType));
var statement = factory.ExpressionStatement(assignExpression);
assignStatements.Add(statement);
}
else
{
// For non-out parameters, create a field and assign the parameter to it.
// TODO: I'm not sure that's what we really want for ref parameters.
if (TryGetValue(parameterToExistingFieldMap, parameterName, out var fieldName) ||
TryGetValue(parameterToNewFieldMap, parameterName, out fieldName))
{
var fieldAccess = factory.MemberAccessExpression(factory.ThisExpression(), factory.IdentifierName(fieldName))
.WithAdditionalAnnotations(Simplifier.Annotation);
factory.AddAssignmentStatements(
generatorInternal,
semanticModel, parameter, fieldAccess,
addNullChecks, preferThrowExpression,
nullCheckStatements, assignStatements);
}
}
}
return nullCheckStatements.ToImmutableAndFree().Concat(assignStatements.ToImmutableAndFree());
}
public static void AddAssignmentStatements(
this SyntaxGenerator factory,
SyntaxGeneratorInternal generatorInternal,
SemanticModel semanticModel,
IParameterSymbol parameter,
SyntaxNode fieldAccess,
bool addNullChecks,
bool preferThrowExpression,
ArrayBuilder<SyntaxNode> nullCheckStatements,
ArrayBuilder<SyntaxNode> assignStatements)
{
// Don't want to add a null check for something of the form `int?`. The type was
// already declared as nullable to indicate that null is ok. Adding a null check
// just disallows something that should be allowed.
var shouldAddNullCheck = addNullChecks && parameter.Type.CanAddNullCheck() && !parameter.Type.IsNullable();
if (shouldAddNullCheck && preferThrowExpression && generatorInternal.SupportsThrowExpression())
{
// Generate: this.x = x ?? throw ...
assignStatements.Add(CreateAssignWithNullCheckStatement(
factory, semanticModel.Compilation, parameter, fieldAccess));
}
else
{
if (shouldAddNullCheck)
{
// generate: if (x == null) throw ...
nullCheckStatements.Add(
factory.CreateNullCheckAndThrowStatement(generatorInternal, semanticModel, parameter));
}
// generate: this.x = x;
assignStatements.Add(
factory.ExpressionStatement(
factory.AssignmentStatement(
fieldAccess,
factory.IdentifierName(parameter.Name))));
}
}
public static SyntaxNode CreateAssignWithNullCheckStatement(
this SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter, SyntaxNode fieldAccess)
{
return factory.ExpressionStatement(factory.AssignmentStatement(
fieldAccess,
factory.CoalesceExpression(
factory.IdentifierName(parameter.Name),
factory.CreateThrowArgumentNullExpression(compilation, parameter))));
}
public static SyntaxNode CreateThrowArgumentNullExpression(this SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter)
=> factory.ThrowExpression(CreateNewArgumentNullException(factory, compilation, parameter));
private static SyntaxNode CreateNewArgumentNullException(SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter)
{
var type = compilation.GetTypeByMetadataName(typeof(ArgumentNullException).FullName!);
Contract.ThrowIfNull(type);
return factory.ObjectCreationExpression(type,
factory.NameOfExpression(
factory.IdentifierName(parameter.Name))).WithAdditionalAnnotations(Simplifier.AddImportsAnnotation);
}
public static SyntaxNode CreateNullCheckAndThrowStatement(
this SyntaxGenerator factory,
SyntaxGeneratorInternal generatorInternal,
SemanticModel semanticModel,
IParameterSymbol parameter)
{
var condition = factory.CreateNullCheckExpression(generatorInternal, semanticModel, parameter.Name);
var throwStatement = factory.CreateThrowArgumentNullExceptionStatement(semanticModel.Compilation, parameter);
// generates: if (s is null) { throw new ArgumentNullException(nameof(s)); }
return factory.IfStatement(condition, [throwStatement]);
}
public static SyntaxNode CreateNullCheckExpression(
this SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, SemanticModel semanticModel, string identifierName)
{
var identifier = factory.IdentifierName(identifierName);
var nullExpr = factory.NullLiteralExpression();
var condition = generatorInternal.SupportsPatterns(semanticModel.SyntaxTree.Options)
? generatorInternal.IsPatternExpression(identifier, generatorInternal.ConstantPattern(nullExpr))
: factory.ReferenceEqualsExpression(identifier, nullExpr);
return condition;
}
public static SyntaxNode CreateThrowArgumentNullExceptionStatement(this SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter)
=> factory.ThrowStatement(CreateNewArgumentNullException(factory, compilation, parameter));
}
|