|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
namespace Microsoft.CodeAnalysis.CSharp
{
/// <summary>
/// Rewrite special Union type matching with appropriate BoundRecursivePatterns
/// against an IUnion.Value property.
///
/// The rewrite happens bottom-up. Nodes that require special Union treatment are represented with
/// a <see cref="BoundPatternWithUnionMatching"/> node created during the rewrite. Rewriter keeps a
/// <see cref="BoundPatternWithUnionMatching"/> node at the top of the result until we reach a point
/// when we are ready to perform its transformation, we call
/// <see cref="RewritePatternWithUnionMatchingToPropertyPattern(BoundPattern)"/> helper at that point.
/// Generally, the transformation must be performed for a pattern when, and only when, we know that no
/// more conjunctions coming where the pattern could be the left hand side.
///
/// Assuming that '^' marks a Union matching pattern:
///
/// A pattern 'unionTypeInstance is int^' is transformed to 'unionTypeInstance is { Value: int }'.
///
/// A pattern 'unionTypeInstance is not^(int or string)' is transformed to 'unionTypeInstance is { Value: not (int or string) }'.
///
/// A pattern 'unionTypeInstance is int^ or string^' is transformed to 'unionTypeInstance is { Value: int } or { Value: string }'.
///
/// A pattern 'unionTypeInstance is int^ and 15 or string^' is transformed to 'unionTypeInstance is { Value: int and 15 } or { Value: string }'.
/// </summary>
sealed class UnionMatchingRewriter : BoundTreeRewriter
{
private readonly CSharpCompilation _compilation;
private UnionMatchingRewriter(CSharpCompilation compilation)
{
_compilation = compilation;
}
public static BoundPattern Rewrite(CSharpCompilation compilation, BoundPattern pattern)
{
var result = new UnionMatchingRewriter(compilation).Visit(pattern);
Debug.Assert(result != pattern);
return RewritePatternWithUnionMatchingToPropertyPattern((BoundPattern)result);
}
protected override BoundNode? VisitExpressionOrPatternWithoutStackGuard(BoundNode node)
{
return Visit(node);
}
private NamedTypeSymbol ObjectType => _compilation.GetSpecialType(SpecialType.System_Object);
private static BoundPatternWithUnionMatching CreatePatternWithUnionMatching(NamedTypeSymbol unionMatchingInputType, BoundPattern innerPattern)
{
Debug.Assert(unionMatchingInputType.IsSubjectForUnionMatching);
Debug.Assert(innerPattern.InputType.IsObjectType());
PropertySymbol? valueProperty = Binder.GetUnionTypeValuePropertyNoUseSiteDiagnostics((NamedTypeSymbol)unionMatchingInputType.StrippedType());
var member = new BoundPropertySubpatternMember(innerPattern.Syntax, receiver: null, valueProperty, type: innerPattern.InputType, hasErrors: valueProperty is null).MakeCompilerGenerated();
return new BoundPatternWithUnionMatching(
syntax: innerPattern.Syntax,
unionMatchingInputType,
member,
innerPattern,
inputType: unionMatchingInputType).MakeCompilerGenerated();
}
public override BoundNode? VisitConstantPattern(BoundConstantPattern node)
{
node = (BoundConstantPattern)base.VisitConstantPattern(node)!;
if (node.IsUnionMatching)
{
Debug.Assert(node.InputType.IsSubjectForUnionMatching);
if (Binder.IsClassOrNullableValueTypeUnionNullPatternMatching((NamedTypeSymbol)node.InputType, node.ConstantValue) && node.NarrowedType.Equals(node.InputType, TypeCompareKind.AllIgnoreOptions))
{
// Special case of a null test for a class Union. Its meaning is equivalent to: (<union instance> is null or <union instance>.Value is null)
// Or a special case of a null test for a Nullable<Union>. Its meaning is equivalent to: (<input value> is null or <input value>.GetValueOrDefault().Value is null)
BoundPatternWithUnionMatching underlyingValueMatching = CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.Value, node.ConstantValue, isUnionMatching: false, inputType: ObjectType, narrowedType: ObjectType));
return new BoundBinaryPattern(
node.Syntax, disjunction: true,
left: node.Update(node.Value, node.ConstantValue, isUnionMatching: false, node.InputType, node.InputType),
right: RewritePatternWithUnionMatchingToPropertyPattern(underlyingValueMatching),
inputType: node.InputType,
narrowedType: node.InputType)
{ WasCompilerGenerated = true };
}
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.Value, node.ConstantValue, isUnionMatching: false, inputType: ObjectType, narrowedType: node.NarrowedType));
}
return node;
}
public override BoundNode? VisitRecursivePattern(BoundRecursivePattern node)
{
node = (BoundRecursivePattern)base.VisitRecursivePattern(node)!;
if (node.IsUnionMatching)
{
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(
node.DeclaredType, node.DeconstructMethod, node.Deconstruction, node.Properties, node.IsExplicitNotNullTest, node.Variable, node.VariableAccess,
isUnionMatching: false, inputType: ObjectType, narrowedType: node.NarrowedType));
}
return node;
}
public override BoundNode? VisitListPattern(BoundListPattern node)
{
Symbol? variable = node.Variable;
ImmutableArray<BoundPattern> subpatterns = this.VisitList(node.Subpatterns).SelectAsArray(RewritePatternWithUnionMatchingToPropertyPattern);
BoundExpression? lengthAccess = node.LengthAccess;
BoundExpression? indexerAccess = node.IndexerAccess;
BoundListPatternReceiverPlaceholder? receiverPlaceholder = node.ReceiverPlaceholder;
BoundListPatternIndexPlaceholder? argumentPlaceholder = node.ArgumentPlaceholder;
BoundExpression? variableAccess = node.VariableAccess;
TypeSymbol? inputType = node.InputType;
TypeSymbol? narrowedType = node.NarrowedType;
if (node.IsUnionMatching)
{
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(subpatterns, node.HasSlice, lengthAccess, indexerAccess, receiverPlaceholder, argumentPlaceholder, variable, variableAccess,
isUnionMatching: false, inputType: ObjectType, narrowedType));
}
return node.Update(subpatterns, node.HasSlice, lengthAccess, indexerAccess, receiverPlaceholder, argumentPlaceholder, variable, variableAccess, isUnionMatching: false, inputType, narrowedType);
}
public override BoundNode? VisitITuplePattern(BoundITuplePattern node)
{
node = (BoundITuplePattern)base.VisitITuplePattern(node)!;
if (node.IsUnionMatching)
{
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.GetLengthMethod, node.GetItemMethod, node.Subpatterns,
isUnionMatching: false, inputType: ObjectType, narrowedType: node.NarrowedType));
}
return node;
}
public override BoundNode? VisitDeclarationPattern(BoundDeclarationPattern node)
{
node = (BoundDeclarationPattern)base.VisitDeclarationPattern(node)!;
if (node.IsUnionMatching)
{
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.DeclaredType, node.IsVar, node.Variable, node.VariableAccess,
isUnionMatching: false, inputType: ObjectType, narrowedType: node.NarrowedType));
}
return node;
}
public override BoundNode? VisitTypePattern(BoundTypePattern node)
{
node = (BoundTypePattern)base.VisitTypePattern(node)!;
if (node.IsUnionMatching)
{
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.DeclaredType, node.IsExplicitNotNullTest, isUnionMatching: false, inputType: ObjectType, narrowedType: node.NarrowedType));
}
return node;
}
public override BoundNode? VisitRelationalPattern(BoundRelationalPattern node)
{
node = (BoundRelationalPattern)base.VisitRelationalPattern(node)!;
if (node.IsUnionMatching)
{
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.Relation, node.Value, node.ConstantValue, isUnionMatching: false, inputType: ObjectType, narrowedType: node.NarrowedType));
}
return node;
}
public override BoundNode? VisitNegatedPattern(BoundNegatedPattern node)
{
BoundPattern negated = RewritePatternWithUnionMatchingToPropertyPattern((BoundPattern)this.Visit(node.Negated));
TypeSymbol? inputType = node.InputType;
TypeSymbol? narrowedType = node.NarrowedType;
if (node.IsUnionMatching)
{
return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(negated, isUnionMatching: false, inputType: ObjectType, narrowedType));
}
return node.Update(negated, isUnionMatching: false, inputType, narrowedType);
}
public override BoundNode? VisitSlicePattern(BoundSlicePattern node)
{
BoundPattern? pattern = RewritePatternWithUnionMatchingToPropertyPattern((BoundPattern)this.Visit(node.Pattern));
BoundExpression? indexerAccess = node.IndexerAccess;
BoundSlicePatternReceiverPlaceholder? receiverPlaceholder = node.ReceiverPlaceholder;
BoundSlicePatternRangePlaceholder? argumentPlaceholder = node.ArgumentPlaceholder;
TypeSymbol? inputType = node.InputType;
TypeSymbol? narrowedType = node.NarrowedType;
return node.Update(pattern, indexerAccess, receiverPlaceholder, argumentPlaceholder, inputType, narrowedType);
}
public override BoundNode? VisitPositionalSubpattern(BoundPositionalSubpattern node)
{
Symbol? symbol = node.Symbol;
BoundPattern pattern = RewritePatternWithUnionMatchingToPropertyPattern((BoundPattern)this.Visit(node.Pattern));
return node.Update(symbol, pattern);
}
public override BoundNode? VisitPropertySubpattern(BoundPropertySubpattern node)
{
BoundPropertySubpatternMember? member = node.Member;
BoundPattern pattern = RewritePatternWithUnionMatchingToPropertyPattern((BoundPattern)this.Visit(node.Pattern));
return node.Update(member, node.IsLengthOrCount, pattern);
}
public override BoundNode? VisitBinaryPattern(BoundBinaryPattern node)
{
var binaryPatternStack = ArrayBuilder<BoundBinaryPattern>.GetInstance();
BoundBinaryPattern? currentNode = node;
do
{
binaryPatternStack.Push(currentNode);
currentNode = currentNode.Left as BoundBinaryPattern;
} while (currentNode != null);
Debug.Assert(binaryPatternStack.Count > 0);
var binaryPattern = binaryPatternStack.Pop();
BoundPattern result = (BoundPattern)Visit(binaryPattern.Left);
#if DEBUG
var narrowedTypeCandidates = ArrayBuilder<TypeSymbol>.GetInstance(2);
if (result is BoundPatternWithUnionMatching unionPattern)
{
narrowedTypeCandidates.Add(getDisjunctionType(unionPattern));
}
else
{
Binder.CollectDisjunctionTypes(result, narrowedTypeCandidates, hasUnionMatching: false);
}
#endif
do
{
result = rewriteBinaryPattern(
this,
result,
binaryPattern
#if DEBUG
, narrowedTypeCandidates
#endif
);
}
while (binaryPatternStack.TryPop(out binaryPattern));
binaryPatternStack.Free();
#if DEBUG
narrowedTypeCandidates.Free();
#endif
return result;
static BoundPattern rewriteBinaryPattern(
UnionMatchingRewriter rewriter,
BoundPattern preboundLeft,
BoundBinaryPattern node
#if DEBUG
, ArrayBuilder<TypeSymbol> narrowedTypeCandidates
#endif
)
{
if (node.Disjunction)
{
preboundLeft = RewritePatternWithUnionMatchingToPropertyPattern(preboundLeft);
var right = RewritePatternWithUnionMatchingToPropertyPattern((BoundPattern)rewriter.Visit(node.Right));
#if DEBUG
// Here we are verifying that the narrowed type computed during the initial binding phase in
// 'Binder.BindBinaryPattern.bindBinaryPattern' matches what we compute here
// with all recursive patterns in place. So, if the algorithm changes there, we might need to
// update it here as well. However, we are trying to share the same helpers in both places as mach
// as possible.
// Compute the common type. This algorithm is quadratic, but disjunctive patterns are unlikely to be huge
Binder.CollectDisjunctionTypes(right, narrowedTypeCandidates, hasUnionMatching: false);
var discardedSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
TypeSymbol? leastSpecific = Binder.LeastSpecificType(narrowedTypeCandidates, rewriter._compilation.Conversions, ref discardedSiteInfo);
Debug.Assert(node.NarrowedType.Equals(leastSpecific ?? node.InputType, TypeCompareKind.ConsiderEverything));
#endif
return node.Update(disjunction: true, preboundLeft, right, inputType: node.InputType, narrowedType: node.NarrowedType);
}
else
{
var right = (BoundPattern)rewriter.Visit(node.Right);
BoundPattern result = makeConjunction(node.Syntax, preboundLeft, right, makeCompilerGenerated: node.WasCompilerGenerated);
#if DEBUG
narrowedTypeCandidates.Clear();
narrowedTypeCandidates.Add(result is BoundPatternWithUnionMatching unionResult ? getDisjunctionType(unionResult) : result.NarrowedType);
#endif
return result;
}
// If left and right are not BoundPatternWithUnionMatching, simply produce a regular BoundBinaryPattern representing the conjunction.
// Otherwise, create a BoundPatternWithUnionMatching representing the conjunction with the following properties:
// - There are no BoundPatternWithUnionMatching nodes under any ValuePattern.
// - The last union matching in evaluation order is always at the top
// - A preceding BoundPatternWithUnionMatching in evaluation order, if any, is the BoundPatternWithUnionMatching.LeftOfPendingConjunction.
static BoundPattern makeConjunction(SyntaxNode node, BoundPattern left, BoundPattern? right, bool makeCompilerGenerated)
{
if (right is BoundPatternWithUnionMatching rightUnionPattern)
{
// Update LeftOfPendingConjunction with the conjunction of left and LeftOfPendingConjunction
// The code below unwraps the following recursive operation:
// return new BoundPatternWithUnionMatching(
// syntax: node,
// rightUnionPattern.UnionType,
// makeConjunction(node, left, rightUnionPattern.LeftOfPendingConjunction, makeCompilerGenerated: true),
// rightUnionPattern.ValueProperty,
// rightUnionPattern.ValuePattern,
// inputType: left.InputType).MakeCompilerGenerated();
var stack = ArrayBuilder<BoundPatternWithUnionMatching>.GetInstance();
stack.Push(rightUnionPattern);
while (rightUnionPattern.LeftOfPendingConjunction is BoundPatternWithUnionMatching other)
{
stack.Push(other);
rightUnionPattern = other;
}
Debug.Assert(rightUnionPattern.LeftOfPendingConjunction is not BoundPatternWithUnionMatching);
var leftOfPendingConjunction = makeConjunction(node, left, rightUnionPattern.LeftOfPendingConjunction, makeCompilerGenerated: true);
do
{
rightUnionPattern = stack.Pop();
leftOfPendingConjunction = new BoundPatternWithUnionMatching(
syntax: node,
rightUnionPattern.UnionMatchingInputType,
leftOfPendingConjunction,
rightUnionPattern.ValueProperty,
rightUnionPattern.ValuePattern,
inputType: left.InputType).MakeCompilerGenerated();
}
while (!stack.IsEmpty);
stack.Free();
return leftOfPendingConjunction;
}
else if (right is { })
{
if (left is BoundPatternWithUnionMatching leftUnionPattern)
{
// The right is just a continuation of the ValuePattern.
// Update ValuePattern with the conjunction of ValuePattern and right,
// since neither of them contain union patterns, we can simply create a BoundBinaryPattern for that.
return new BoundPatternWithUnionMatching(
syntax: node,
leftUnionPattern.UnionMatchingInputType,
leftUnionPattern.LeftOfPendingConjunction,
leftUnionPattern.ValueProperty,
MakeBinaryAnd(node, leftUnionPattern.ValuePattern, right, makeCompilerGenerated),
inputType: leftUnionPattern.InputType).MakeCompilerGenerated();
}
else
{
// Neither left nor right contain union patterns, create a BoundBinaryPattern for that.
return MakeBinaryAnd(node, left, right, makeCompilerGenerated);
}
}
else
{
return left;
}
}
}
#if DEBUG
static TypeSymbol getDisjunctionType(BoundPatternWithUnionMatching unionPattern)
{
// Disjunction type is the UnionType for the first BoundPatternWithUnionMatching in evaluation order.
// That type won't be narrowed more for the purposes of a possible upcoming disjunction, since
// everything after that goes into a subputtern of a recursive pattern.
while (unionPattern.LeftOfPendingConjunction is BoundPatternWithUnionMatching leftUnionPattern)
{
unionPattern = leftUnionPattern;
}
return unionPattern.UnionMatchingInputType;
}
#endif
}
private static BoundBinaryPattern MakeBinaryAnd(SyntaxNode node, BoundPattern left, BoundPattern right, bool makeCompilerGenerated)
{
return new BoundBinaryPattern(node, disjunction: false, left, right, inputType: left.InputType, narrowedType: right.NarrowedType) { WasCompilerGenerated = makeCompilerGenerated };
}
private static BoundPattern RewritePatternWithUnionMatchingToPropertyPattern(BoundPattern pattern)
{
// If pattern contains BoundPatternWithUnionMatching pending a rewrite, we should have BoundPatternWithUnionMatching
// at the top.
if (pattern is BoundPatternWithUnionMatching unionPattern)
{
// If this method is called, we are sure that no more conjunctions will follow this pattern immediately.
// Therefore, no additional patterns are coming for the top most Value property. We can start rewriting
// BoundPatternWithUnionMatching from the top down, converting them to appropriate BoundRecursivePatterns and nesting
// them as we go down the chain. Effectively, we will end up with a chain of BoundRecursivePatterns in
// reversed order, i.e. BoundRecursivePatterns corresponding to the top-most BoundPatternWithUnionMatching will be
// at the bottom, and BoundRecursivePatterns corresponding to the bottom-most BoundPatternWithUnionMatching will be
// at the top.
TypeSymbol unionMatchingInputType = unionPattern.UnionMatchingInputType;
BoundPropertySubpatternMember valueProperty = unionPattern.ValueProperty;
BoundPattern? leftOfPendingConjunction = unionPattern.LeftOfPendingConjunction;
BoundPattern valuePattern = unionPattern.ValuePattern;
while (true)
{
var unionType = unionMatchingInputType.StrippedType();
BoundPattern result = new BoundRecursivePattern(
syntax: valuePattern.Syntax,
declaredType: null,
deconstructMethod: null,
deconstruction: default,
properties: [new BoundPropertySubpattern(valuePattern.Syntax, valueProperty, isLengthOrCount: false, valuePattern).MakeCompilerGenerated()],
variable: null,
variableAccess: null,
isExplicitNotNullTest: false,
isUnionMatching: false,
inputType: unionType,
narrowedType: unionType).MakeCompilerGenerated();
if (unionMatchingInputType.IsNullableType())
{
// Prepend the 'Value' property pattern with a type pattern unwrapping the nullable value.
result = MakeBinaryAnd(
result.Syntax,
new BoundTypePattern(
result.Syntax,
declaredType: new BoundTypeExpression(result.Syntax, aliasOpt: null, unionType).MakeCompilerGenerated(),
isExplicitNotNullTest: false, // https://github.com/dotnet/roslyn/issues/82636: Is passing 'true' going to make a difference?
isUnionMatching: false,
inputType: unionMatchingInputType,
narrowedType: unionType).MakeCompilerGenerated(),
result,
makeCompilerGenerated: true);
}
if (leftOfPendingConjunction is BoundPatternWithUnionMatching leftUnionPattern)
{
unionMatchingInputType = leftUnionPattern.UnionMatchingInputType;
valueProperty = leftUnionPattern.ValueProperty;
leftOfPendingConjunction = leftUnionPattern.LeftOfPendingConjunction;
valuePattern = MakeBinaryAnd(pattern.Syntax, leftUnionPattern.ValuePattern, result, makeCompilerGenerated: true);
continue;
}
else if (leftOfPendingConjunction is { } left)
{
result = MakeBinaryAnd(pattern.Syntax, left, result, makeCompilerGenerated: true);
}
Debug.Assert(result.InputType.Equals(pattern.InputType, TypeCompareKind.AllIgnoreOptions));
return result;
}
}
return pattern;
}
}
}
|