// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis.CodeGeneration; using Microsoft.CodeAnalysis.Editing; using Microsoft.CodeAnalysis.PooledObjects; using Microsoft.CodeAnalysis.Simplification; namespace Microsoft.CodeAnalysis.Shared.Extensions; internal static partial class SyntaxGeneratorExtensions { private const string EqualsName = "Equals"; private const string DefaultName = "Default"; private const string ObjName = "obj"; public const string OtherName = "other"; public static SyntaxNode CreateThrowNotImplementedStatement( this SyntaxGenerator codeDefinitionFactory, Compilation compilation) { return codeDefinitionFactory.ThrowStatement( CreateNewNotImplementedException(codeDefinitionFactory, compilation)); } public static SyntaxNode CreateThrowNotImplementedExpression( this SyntaxGenerator codeDefinitionFactory, Compilation compilation) { return codeDefinitionFactory.ThrowExpression( CreateNewNotImplementedException(codeDefinitionFactory, compilation)); } private static SyntaxNode CreateNewNotImplementedException(SyntaxGenerator codeDefinitionFactory, Compilation compilation) { var notImplementedExceptionTypeSyntax = compilation.NotImplementedExceptionType() is INamedTypeSymbol symbol ? codeDefinitionFactory.TypeExpression(symbol, addImport: false) : codeDefinitionFactory.QualifiedName(codeDefinitionFactory.IdentifierName(nameof(System)), codeDefinitionFactory.IdentifierName(nameof(NotImplementedException))); return codeDefinitionFactory.ObjectCreationExpression( notImplementedExceptionTypeSyntax, arguments: []); } public static ImmutableArray<SyntaxNode> CreateThrowNotImplementedStatementBlock( this SyntaxGenerator codeDefinitionFactory, Compilation compilation) => [CreateThrowNotImplementedStatement(codeDefinitionFactory, compilation)]; public static ImmutableArray<SyntaxNode> CreateArguments( this SyntaxGenerator factory, ImmutableArray<IParameterSymbol> parameters) { return parameters.SelectAsArray(p => CreateArgument(factory, p)); } private static SyntaxNode CreateArgument( this SyntaxGenerator factory, IParameterSymbol parameter) { return factory.Argument(parameter.RefKind, factory.IdentifierName(parameter.Name)); } public static SyntaxNode GetDefaultEqualityComparer( this SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, Compilation compilation, ITypeSymbol type) { var equalityComparerType = compilation.EqualityComparerOfTType(); var typeExpression = equalityComparerType == null ? factory.GenericName(nameof(EqualityComparer<int>), type) : generatorInternal.Type(equalityComparerType.Construct(type), typeContext: false); return factory.MemberAccessExpression(typeExpression, factory.IdentifierName(DefaultName)); } private static ITypeSymbol GetType(Compilation compilation, ISymbol symbol) => symbol switch { IFieldSymbol field => field.Type, IPropertySymbol property => property.Type, _ => compilation.GetSpecialType(SpecialType.System_Object), }; public static SyntaxNode IsPatternExpression(this SyntaxGeneratorInternal generator, SyntaxNode expression, SyntaxNode pattern) => generator.IsPatternExpression(expression, isToken: default, pattern); /// <summary> /// Generates a call to a method *through* an existing field or property symbol. /// </summary> public static SyntaxNode GenerateDelegateThroughMemberStatement( this SyntaxGenerator generator, IMethodSymbol method, ISymbol throughMember) { var through = generator.MemberAccessExpression( CreateDelegateThroughExpression(generator, method, throughMember), method.IsGenericMethod ? generator.GenericName(method.Name, method.TypeArguments) : generator.IdentifierName(method.Name)); var invocationExpression = generator.InvocationExpression(through, generator.CreateArguments(method.Parameters)); return method.ReturnsVoid ? generator.ExpressionStatement(invocationExpression) : generator.ReturnStatement(invocationExpression); } public static SyntaxNode CreateDelegateThroughExpression( this SyntaxGenerator generator, ISymbol member, ISymbol throughMember) { var name = generator.IdentifierName(throughMember.Name); var through = throughMember.IsStatic ? GenerateContainerName(generator, throughMember) // If we're delegating through a primary constructor parameter, we cannot qualify the name at all. : throughMember is IParameterSymbol ? null : generator.ThisExpression(); through = through is null ? name : generator.MemberAccessExpression(through, name); var throughMemberType = throughMember.GetMemberType(); if (throughMemberType != null && member.ContainingType is { TypeKind: TypeKind.Interface } interfaceBeingImplemented) { // In the case of 'implement interface through field / property', we need to know what // interface we are implementing so that we can insert casts to this interface on every // usage of the field in the generated code. Without these casts we would end up generating // code that fails compilation in certain situations. // // For example consider the following code. // class C : IReadOnlyList<int> { int[] field; } // When applying the 'implement interface through field' code fix in the above example, // we need to generate the following code to implement the Count property on IReadOnlyList<int> // class C : IReadOnlyList<int> { int[] field; int Count { get { ((IReadOnlyList<int>)field).Count; } ...} // as opposed to the following code which will fail to compile (because the array field // doesn't have a property named .Count) - // class C : IReadOnlyList<int> { int[] field; int Count { get { field.Count; } ...} // // The 'InterfaceTypes' property on the state object always contains only one item // in the case of C# i.e. it will contain exactly the interface we are trying to implement. // This is also the case most of the time in the case of VB, except in certain error conditions // (recursive / circular cases) where the span of the squiggle for the corresponding // diagnostic (BC30149) changes and 'InterfaceTypes' ends up including all interfaces // in the Implements clause. For the purposes of inserting the above cast, we ignore the // uncommon case and optimize for the common one - in other words, we only apply the cast // in cases where we can unambiguously figure out which interface we are trying to implement. if (!throughMemberType.Equals(interfaceBeingImplemented)) { through = generator.CastExpression(interfaceBeingImplemented, through.WithAdditionalAnnotations(Simplifier.Annotation)); } else if (throughMember is IPropertySymbol { IsStatic: false, ExplicitInterfaceImplementations: [var explicitlyImplementedProperty, ..] }) { // If we are implementing through an explicitly implemented property, we need to cast 'this' to // the explicitly implemented interface type before calling the member, as in: // ((IA)this).Prop.Member(); // var explicitImplementationCast = generator.CastExpression( explicitlyImplementedProperty.ContainingType, generator.ThisExpression()); through = generator.MemberAccessExpression(explicitImplementationCast, generator.IdentifierName(explicitlyImplementedProperty.Name)); through = through.WithAdditionalAnnotations(Simplifier.Annotation); } } return through.WithAdditionalAnnotations(Simplifier.Annotation); // local functions static SyntaxNode GenerateContainerName(SyntaxGenerator factory, ISymbol throughMember) { var classOrStructType = throughMember.ContainingType; return classOrStructType.IsGenericType ? factory.GenericName(classOrStructType.Name, classOrStructType.TypeArguments) : factory.IdentifierName(classOrStructType.Name); } } public static ImmutableArray<SyntaxNode> GetGetAccessorStatements( this SyntaxGenerator generator, Compilation compilation, IPropertySymbol property, IPropertySymbol? conflictingProperty, ISymbol? throughMember, bool preferAutoProperties) { if (throughMember != null) { var throughExpression = CreateDelegateThroughExpression(generator, property, throughMember); var expression = property.IsIndexer ? throughExpression : generator.MemberAccessExpression( throughExpression, generator.IdentifierName(property.Name)); if (property.Parameters.Length > 0) { var arguments = generator.CreateArguments(property.Parameters); expression = generator.ElementAccessExpression(expression, arguments); } return [generator.ReturnStatement(expression)]; } if (preferAutoProperties) return default; // Forward from the explicit property we're creating to the existing property it conflicts with if possible. if (conflictingProperty is { GetMethod: not null, Parameters.Length: 0 } && property is { GetMethod: not null, Parameters.Length: 0 }) { if (compilation.ClassifyCommonConversion(conflictingProperty.Type, property.Type) is { Exists: true, IsImplicit: true }) return [generator.ReturnStatement(generator.MemberAccessExpression(generator.ThisExpression(), property.Name))]; } return generator.CreateThrowNotImplementedStatementBlock(compilation); } public static ImmutableArray<SyntaxNode> GetSetAccessorStatements( this SyntaxGenerator generator, Compilation compilation, IPropertySymbol property, IPropertySymbol? conflictingProperty, ISymbol? throughMember, bool preferAutoProperties) { if (throughMember != null) { var throughExpression = CreateDelegateThroughExpression(generator, property, throughMember); var expression = property.IsIndexer ? throughExpression : generator.MemberAccessExpression( throughExpression, generator.IdentifierName(property.Name)); if (property.Parameters.Length > 0) { var arguments = generator.CreateArguments(property.Parameters); expression = generator.ElementAccessExpression(expression, arguments); } expression = generator.AssignmentStatement(expression, generator.IdentifierName("value")); return [generator.ExpressionStatement(expression)]; } if (preferAutoProperties) return default; // Forward from the explicit property we're creating to the existing property it conflicts with if possible. if (conflictingProperty is { SetMethod.Parameters.Length: 1 } && property is { SetMethod.Parameters: [var parameter] }) { if (compilation.ClassifyCommonConversion(property.Type, conflictingProperty.Type) is { Exists: true, IsImplicit: true }) return [generator.ExpressionStatement(generator.AssignmentStatement(generator.MemberAccessExpression(generator.ThisExpression(), property.Name), generator.IdentifierName(parameter.Name)))]; } return generator.CreateThrowNotImplementedStatementBlock(compilation); } private static bool TryGetValue(IDictionary<string, string>? dictionary, string key, [NotNullWhen(true)] out string? value) { value = null; return dictionary != null && dictionary.TryGetValue(key, out value); } private static bool TryGetValue(IDictionary<string, ISymbol>? dictionary, string key, [NotNullWhen(true)] out string? value) { value = null; if (dictionary != null && dictionary.TryGetValue(key, out var symbol)) { value = symbol.Name; return true; } return false; } public static ImmutableArray<ISymbol> CreateFieldsForParameters( ImmutableArray<IParameterSymbol> parameters, ImmutableDictionary<string, string>? parameterToNewFieldMap, bool isContainedInUnsafeType) { using var _ = ArrayBuilder<ISymbol>.GetInstance(out var result); foreach (var parameter in parameters) { // For non-out parameters, create a field and assign the parameter to it. if (parameter.RefKind != RefKind.Out && TryGetValue(parameterToNewFieldMap, parameter.Name, out var fieldName)) { result.Add(CodeGenerationSymbolFactory.CreateFieldSymbol( attributes: default, accessibility: Accessibility.Private, modifiers: DeclarationModifiers.None.WithIsUnsafe(!isContainedInUnsafeType && parameter.RequiresUnsafeModifier()), type: parameter.Type, name: fieldName)); } } return result.ToImmutableAndClear(); } public static ImmutableArray<ISymbol> CreatePropertiesForParameters( ImmutableArray<IParameterSymbol> parameters, ImmutableDictionary<string, string>? parameterToNewPropertyMap, bool isContainedInUnsafeType) { using var _ = ArrayBuilder<ISymbol>.GetInstance(out var result); foreach (var parameter in parameters) { // For non-out parameters, create a property and assign the parameter to it. if (parameter.RefKind != RefKind.Out && TryGetValue(parameterToNewPropertyMap, parameter.Name, out var propertyName)) { result.Add(CodeGenerationSymbolFactory.CreatePropertySymbol( attributes: default, accessibility: Accessibility.Public, modifiers: DeclarationModifiers.None.WithIsUnsafe(!isContainedInUnsafeType && parameter.RequiresUnsafeModifier()), type: parameter.Type, refKind: RefKind.None, explicitInterfaceImplementations: [], name: propertyName, parameters: [], getMethod: CodeGenerationSymbolFactory.CreateAccessorSymbol( attributes: default, accessibility: default, statements: default), setMethod: null)); } } return result.ToImmutableAndClear(); } public static ImmutableArray<SyntaxNode> CreateAssignmentStatements( this SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, SemanticModel semanticModel, ImmutableArray<IParameterSymbol> parameters, IDictionary<string, ISymbol>? parameterToExistingFieldMap, IDictionary<string, string>? parameterToNewFieldMap, bool addNullChecks, bool preferThrowExpression) { using var _1 = ArrayBuilder<SyntaxNode>.GetInstance(out var nullCheckStatements); using var _2 = ArrayBuilder<SyntaxNode>.GetInstance(out var assignStatements); foreach (var parameter in parameters) { var refKind = parameter.RefKind; var parameterType = parameter.Type; var parameterName = parameter.Name; if (refKind == RefKind.Out) { // If it's an out param, then don't create a field for it. Instead, assign // the default value for that type (i.e. "default(...)") to it. var assignExpression = factory.AssignmentStatement( factory.IdentifierName(parameterName), factory.DefaultExpression(parameterType)); var statement = factory.ExpressionStatement(assignExpression); assignStatements.Add(statement); } else { // For non-out parameters, create a field and assign the parameter to it. // TODO: I'm not sure that's what we really want for ref parameters. if (TryGetValue(parameterToExistingFieldMap, parameterName, out var fieldName) || TryGetValue(parameterToNewFieldMap, parameterName, out fieldName)) { var fieldAccess = factory.MemberAccessExpression(factory.ThisExpression(), factory.IdentifierName(fieldName)) .WithAdditionalAnnotations(Simplifier.Annotation); factory.AddAssignmentStatements( generatorInternal, semanticModel, parameter, fieldAccess, addNullChecks, preferThrowExpression, nullCheckStatements, assignStatements); } } } return [.. nullCheckStatements, .. assignStatements]; } public static void AddAssignmentStatements( this SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, SemanticModel semanticModel, IParameterSymbol parameter, SyntaxNode fieldAccess, bool addNullChecks, bool preferThrowExpression, ArrayBuilder<SyntaxNode> nullCheckStatements, ArrayBuilder<SyntaxNode> assignStatements) { // Don't want to add a null check for something of the form `int?`. The type was // already declared as nullable to indicate that null is ok. Adding a null check // just disallows something that should be allowed. var shouldAddNullCheck = addNullChecks && parameter.Type.CanAddNullCheck() && !parameter.Type.IsNullable(); if (shouldAddNullCheck && preferThrowExpression && generatorInternal.SupportsThrowExpression()) { // Generate: this.x = x ?? throw ... assignStatements.Add(CreateAssignWithNullCheckStatement( factory, semanticModel.Compilation, parameter, fieldAccess)); } else { if (shouldAddNullCheck) { // generate: if (x == null) throw ... nullCheckStatements.Add( factory.CreateNullCheckAndThrowStatement(generatorInternal, semanticModel, parameter)); } // generate: this.x = x; assignStatements.Add( factory.ExpressionStatement( factory.AssignmentStatement( fieldAccess, factory.IdentifierName(parameter.Name)))); } } public static SyntaxNode CreateAssignWithNullCheckStatement( this SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter, SyntaxNode fieldAccess) { return factory.ExpressionStatement(factory.AssignmentStatement( fieldAccess, factory.CoalesceExpression( factory.IdentifierName(parameter.Name), factory.CreateThrowArgumentNullExpression(compilation, parameter)))); } public static SyntaxNode CreateThrowArgumentNullExpression(this SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter) => factory.ThrowExpression(CreateNewArgumentNullException(factory, compilation, parameter)); private static SyntaxNode CreateNewArgumentNullException(SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter) { var type = compilation.GetTypeByMetadataName(typeof(ArgumentNullException).FullName!); Contract.ThrowIfNull(type); return factory.ObjectCreationExpression(type, factory.NameOfExpression( factory.IdentifierName(parameter.Name))).WithAdditionalAnnotations(Simplifier.AddImportsAnnotation); } public static SyntaxNode CreateNullCheckAndThrowStatement( this SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, SemanticModel semanticModel, IParameterSymbol parameter) { var condition = factory.CreateNullCheckExpression(generatorInternal, semanticModel, parameter.Name); var throwStatement = factory.CreateThrowArgumentNullExceptionStatement(semanticModel.Compilation, parameter); // generates: if (s is null) { throw new ArgumentNullException(nameof(s)); } return factory.IfStatement(condition, [throwStatement]); } public static SyntaxNode CreateNullCheckExpression( this SyntaxGenerator factory, SyntaxGeneratorInternal generatorInternal, SemanticModel semanticModel, string identifierName) { var identifier = factory.IdentifierName(identifierName); var nullExpr = factory.NullLiteralExpression(); var condition = generatorInternal.SupportsPatterns(semanticModel.SyntaxTree.Options) ? generatorInternal.IsPatternExpression(identifier, generatorInternal.ConstantPattern(nullExpr)) : factory.ReferenceEqualsExpression(identifier, nullExpr); return condition; } public static SyntaxNode CreateThrowArgumentNullExceptionStatement(this SyntaxGenerator factory, Compilation compilation, IParameterSymbol parameter) => factory.ThrowStatement(CreateNewArgumentNullException(factory, compilation, parameter)); } |