|
#pragma warning disable CA1063 // Implement IDisposable Correctly
#pragma warning disable CA1000 // Do not declare static members on generic types
#pragma warning disable IDE0016 // Use 'throw' expression
#pragma warning disable IDE0018 // Inline variable declaration
#pragma warning disable IDE0019 // Use pattern matching
#pragma warning disable IDE0028 // Simplify collection initialization
#pragma warning disable IDE0034 // Simplify 'default' expression
#pragma warning disable IDE0040 // Add accessibility modifiers
#pragma warning disable IDE0046 // Convert to conditional expression
#pragma warning disable IDE0058 // Expression value is never used
#pragma warning disable IDE0063 // Use simple 'using' statement
#pragma warning disable IDE0074 // Use compound assignment
#pragma warning disable IDE0090 // Use 'new(...)'
#pragma warning disable IDE0161 // Convert to file-scoped namespace
#pragma warning disable IDE0290 // Use primary constructor
#pragma warning disable IDE0300 // Simplify collection initialization
#if XUNIT_NULLABLE
#nullable enable
#else
// In case this is source-imported with global nullable enabled but no XUNIT_NULLABLE
#pragma warning disable CS8601
#pragma warning disable CS8603
#pragma warning disable CS8604
#pragma warning disable CS8605
#pragma warning disable CS8618
#endif
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
#if XUNIT_NULLABLE
using System.Diagnostics.CodeAnalysis;
#endif
namespace Xunit.Sdk
{
/// <summary>
/// Base class for generic <see cref="CollectionTracker{T}"/>, which also includes some public
/// static functionality.
/// </summary>
#if XUNIT_VISIBILITY_INTERNAL
internal
#else
public
#endif
abstract class CollectionTracker : IDisposable
{
/// <summary>
/// Initializes a new instance of the <see cref="CollectionTracker"/> class.
/// </summary>
/// <param name="innerEnumerable"></param>
/// <exception cref="ArgumentNullException"></exception>
protected CollectionTracker(IEnumerable innerEnumerable)
{
#if NET6_0_OR_GREATER
ArgumentNullException.ThrowIfNull(innerEnumerable);
#else
if (innerEnumerable == null)
throw new ArgumentNullException(nameof(innerEnumerable));
#endif
InnerEnumerable = innerEnumerable;
}
static readonly MethodInfo openGenericCompareTypedSetsMethod =
typeof(CollectionTracker)
.GetRuntimeMethods()
.Single(m => m.Name == nameof(CompareTypedSets));
/// <summary>
/// Gets the inner enumerable that this collection track is wrapping. This is mostly
/// provided for simplifying other APIs which require both the tracker and the collection
/// (for example, <see cref="AreCollectionsEqual"/>).
/// </summary>
protected internal IEnumerable InnerEnumerable { get; protected set; }
/// <summary>
/// Determine if two enumerable collections are equal. It contains logic that varies depending
/// on the collection type (supporting arrays, dictionaries, sets, and generic enumerables).
/// </summary>
/// <param name="x">First value to compare</param>
/// <param name="y">Second value to comare</param>
/// <param name="itemComparer">The comparer used for individual item comparisons</param>
/// <param name="isDefaultItemComparer">Pass <c>true</c> if the <paramref name="itemComparer"/> is the default item
/// comparer from <see cref="AssertEqualityComparer{T}"/>; pass <c>false</c>, otherwise.</param>
/// <param name="mismatchedIndex">The output mismatched item index when the collections are not equal</param>
/// <returns>Returns <c>true</c> if the collections are equal; <c>false</c>, otherwise.</returns>
public static bool AreCollectionsEqual(
#if XUNIT_NULLABLE
CollectionTracker? x,
CollectionTracker? y,
#else
CollectionTracker x,
CollectionTracker y,
#endif
IEqualityComparer itemComparer,
bool isDefaultItemComparer,
out int? mismatchedIndex)
{
Assert.GuardArgumentNotNull(nameof(itemComparer), itemComparer);
mismatchedIndex = null;
return
#if XUNIT_AOT
CheckIfDictionariesAreEqual(x, y, itemComparer) ??
#else
CheckIfDictionariesAreEqual(x, y) ??
#endif
CheckIfSetsAreEqual(x, y, isDefaultItemComparer ? null : itemComparer) ??
CheckIfArraysAreEqual(x, y, itemComparer, isDefaultItemComparer, out mismatchedIndex) ??
CheckIfEnumerablesAreEqual(x, y, itemComparer, isDefaultItemComparer, out mismatchedIndex);
}
static bool? CheckIfArraysAreEqual(
#if XUNIT_NULLABLE
CollectionTracker? x,
CollectionTracker? y,
#else
CollectionTracker x,
CollectionTracker y,
#endif
IEqualityComparer itemComparer,
bool isDefaultItemComparer,
out int? mismatchedIndex)
{
mismatchedIndex = null;
if (x == null || y == null)
return null;
var expectedArray = x.InnerEnumerable as Array;
var actualArray = y.InnerEnumerable as Array;
if (expectedArray == null || actualArray == null)
return null;
// If we have single-dimensional zero-based arrays, then we delegate to the enumerable
// version, since that's uses the trackers and gets us the mismatch pointer.
if (expectedArray.Rank == 1 && expectedArray.GetLowerBound(0) == 0 &&
actualArray.Rank == 1 && actualArray.GetLowerBound(0) == 0)
return CheckIfEnumerablesAreEqual(x, y, itemComparer, isDefaultItemComparer, out mismatchedIndex);
if (expectedArray.Rank != actualArray.Rank)
return false;
// Differing bounds, aka object[2,1] vs. object[1,2]
// You can also have non-zero-based arrays, so we don't just check lengths
for (var rank = 0; rank < expectedArray.Rank; rank++)
if (expectedArray.GetLowerBound(rank) != actualArray.GetLowerBound(rank) || expectedArray.GetUpperBound(rank) != actualArray.GetUpperBound(rank))
return false;
// Enumeration will flatten everything identically, so just enumerate at this point
var expectedEnumerator = x.GetSafeEnumerator();
var actualEnumerator = y.GetSafeEnumerator();
while (true)
{
var hasExpected = expectedEnumerator.MoveNext();
var hasActual = actualEnumerator.MoveNext();
if (!hasExpected || !hasActual)
return hasExpected == hasActual;
if (!itemComparer.Equals(expectedEnumerator.Current, actualEnumerator.Current))
return false;
}
}
static bool? CheckIfDictionariesAreEqual(
#if XUNIT_NULLABLE
CollectionTracker? x,
CollectionTracker? y
#else
CollectionTracker x,
CollectionTracker y
#endif
#if XUNIT_AOT
, IEqualityComparer itemComparer)
#else
)
#endif
{
if (x == null || y == null)
return null;
var dictionaryX = x.InnerEnumerable as IDictionary;
var dictionaryY = y.InnerEnumerable as IDictionary;
if (dictionaryX == null || dictionaryY == null)
return null;
if (dictionaryX.Count != dictionaryY.Count)
return false;
var dictionaryYKeys = new HashSet<object>(dictionaryY.Keys.Cast<object>());
#if !XUNIT_AOT
// We don't pass along the itemComparer from AreCollectionsEqual because we aren't directly
// comparing the KeyValuePair<> objects. Instead we rely on Contains() on the dictionary to
// match up keys, and then create type-appropriate comparers for the values.
#endif
foreach (var key in dictionaryX.Keys.Cast<object>())
{
if (!dictionaryYKeys.Contains(key))
return false;
var valueX = dictionaryX[key];
var valueY = dictionaryY[key];
if (valueX == null)
{
if (valueY != null)
return false;
}
else if (valueY == null)
return false;
else
{
var valueXType = valueX.GetType();
var valueYType = valueY.GetType();
#if XUNIT_AOT
var comparer = itemComparer;
#else
var comparer = AssertEqualityComparer.GetDefaultComparer(valueXType == valueYType ? valueXType : typeof(object));
#endif
if (!comparer.Equals(valueX, valueY))
return false;
}
dictionaryYKeys.Remove(key);
}
return dictionaryYKeys.Count == 0;
}
static bool CheckIfEnumerablesAreEqual(
#if XUNIT_NULLABLE
CollectionTracker? x,
CollectionTracker? y,
#else
CollectionTracker x,
CollectionTracker y,
#endif
IEqualityComparer itemComparer,
bool isDefaultItemComparer,
out int? mismatchIndex)
{
mismatchIndex = null;
if (x == null)
return y == null;
if (y == null)
return false;
var enumeratorX = x.GetSafeEnumerator();
var enumeratorY = y.GetSafeEnumerator();
mismatchIndex = 0;
while (true)
{
var hasNextX = enumeratorX.MoveNext();
var hasNextY = enumeratorY.MoveNext();
if (!hasNextX || !hasNextY)
{
if (hasNextX == hasNextY)
{
mismatchIndex = null;
return true;
}
return false;
}
var xCurrent = enumeratorX.Current;
var yCurrent = enumeratorY.Current;
using (var xCurrentTracker = isDefaultItemComparer ? xCurrent.AsNonStringTracker() : null)
using (var yCurrentTracker = isDefaultItemComparer ? yCurrent.AsNonStringTracker() : null)
{
if (xCurrentTracker != null && yCurrentTracker != null)
{
int? _;
var innerCompare = AreCollectionsEqual(xCurrentTracker, yCurrentTracker, AssertEqualityComparer<object>.DefaultInnerComparer, true, out _);
if (!innerCompare)
return false;
}
else if (!itemComparer.Equals(xCurrent, yCurrent))
return false;
mismatchIndex++;
}
}
}
static bool? CheckIfSetsAreEqual(
#if XUNIT_NULLABLE
CollectionTracker? x,
CollectionTracker? y,
IEqualityComparer? itemComparer)
#else
CollectionTracker x,
CollectionTracker y,
IEqualityComparer itemComparer)
#endif
{
if (x == null || y == null)
return null;
var elementTypeX = ArgumentFormatter.GetSetElementType(x.InnerEnumerable);
var elementTypeY = ArgumentFormatter.GetSetElementType(y.InnerEnumerable);
if (elementTypeX == null || elementTypeY == null)
return null;
if (elementTypeX != elementTypeY)
return false;
#if XUNIT_AOT
// Can't use MakeGenericType in AOT
return CompareUntypedSets(x.InnerEnumerable, y.InnerEnumerable);
#else
var genericCompareMethod = openGenericCompareTypedSetsMethod.MakeGenericMethod(elementTypeX);
#if XUNIT_NULLABLE
return (bool)genericCompareMethod.Invoke(null, new object?[] { x.InnerEnumerable, y.InnerEnumerable, itemComparer })!;
#else
return (bool)genericCompareMethod.Invoke(null, new object[] { x.InnerEnumerable, y.InnerEnumerable, itemComparer });
#endif
#endif // XUNIT_AOT
}
static bool CompareUntypedSets(
IEnumerable enumX,
IEnumerable enumY)
{
var setX = new HashSet<object>(enumX.Cast<object>());
var setY = new HashSet<object>(enumY.Cast<object>());
return setX.SetEquals(setY);
}
static bool CompareTypedSets<T>(
ISet<T> setX,
ISet<T> setY,
#if XUNIT_NULLABLE
IEqualityComparer<T>? itemComparer)
#else
IEqualityComparer<T> itemComparer)
#endif
{
if (setX.Count != setY.Count)
return false;
if (itemComparer != null)
{
setX = new HashSet<T>(setX, itemComparer);
setY = new HashSet<T>(setY, itemComparer);
}
return setX.SetEquals(setY);
}
/// <inheritdoc/>
public abstract void Dispose();
/// <summary>
/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
/// collection surrounded by the mismatched item.
/// </summary>
/// <param name="mismatchedIndex">The index of the mismatched item</param>
/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public abstract string FormatIndexedMismatch(
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1);
/// <summary>
/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
/// collection from <paramref name="startIndex"/> to <paramref name="endIndex"/>. These indices are usually
/// obtained by calling <see cref="GetMismatchExtents"/>.
/// </summary>
/// <param name="startIndex">The start index of the collection to print</param>
/// <param name="endIndex">The end index of the collection to print</param>
/// <param name="mismatchedIndex">The mismatched item index</param>
/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public abstract string FormatIndexedMismatch(
int startIndex,
int endIndex,
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1);
/// <summary>
/// Formats the beginning part of the collection.
/// </summary>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public abstract string FormatStart(int depth = 1);
/// <summary>
/// Gets the extents to print when you find a mismatched index, in the form of
/// a <paramref name="startIndex"/> and <paramref name="endIndex"/>. If the mismatched
/// index is <c>null</c>, the extents will start at index 0.
/// </summary>
/// <param name="mismatchedIndex">The mismatched item index</param>
/// <param name="startIndex">The start index that should be used for printing</param>
/// <param name="endIndex">The end index that should be used for printing</param>
public abstract void GetMismatchExtents(
int? mismatchedIndex,
out int startIndex,
out int endIndex);
/// <summary>
/// Gets a safe version of <see cref="IEnumerator"/> that prevents double enumeration and does all
/// the necessary tracking required for collection formatting. Should should be the same value
/// returned by <see cref="CollectionTracker{T}.GetEnumerator"/>, except non-generic.
/// </summary>
protected internal abstract IEnumerator GetSafeEnumerator();
/// <summary>
/// Gets the full name of the type of the element at the given index, if known.
/// Since this uses the item cache produced by enumeration, it may return <c>null</c>
/// when we haven't enumerated enough to see the given element, or if we enumerated
/// so much that the item has left the cache, or if the item at the given index
/// is <c>null</c>. It will also return <c>null</c> when the <paramref name="index"/>
/// is <c>null</c>.
/// </summary>
/// <param name="index">The item index</param>
#if XUNIT_NULLABLE
public abstract string? TypeAt(int? index);
#else
public abstract string TypeAt(int? index);
#endif
/// <summary>
/// Wraps an untyped enumerable in an object-based <see cref="CollectionTracker{T}"/>.
/// </summary>
/// <param name="enumerable">The untyped enumerable to wrap</param>
public static CollectionTracker<object> Wrap(IEnumerable enumerable) =>
new CollectionTracker<object>(enumerable, enumerable.Cast<object>());
}
/// <summary>
/// A utility class that can be used to wrap enumerables to prevent double enumeration.
/// It offers the ability to safely print parts of the collection when failures are
/// encountered, as well as some static versions of the printing functionality.
/// </summary>
#if XUNIT_VISIBILITY_INTERNAL
internal
#else
public
#endif
sealed class CollectionTracker<[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicFields
| DynamicallyAccessedMemberTypes.NonPublicFields
| DynamicallyAccessedMemberTypes.PublicProperties
| DynamicallyAccessedMemberTypes.NonPublicProperties
| DynamicallyAccessedMemberTypes.PublicMethods)] T> : CollectionTracker, IEnumerable<T>
{
const int MAX_ENUMERABLE_LENGTH_HALF = ArgumentFormatter.MAX_ENUMERABLE_LENGTH / 2;
readonly IEnumerable<T> collection;
#pragma warning disable CA2213 // We move disposal to DisposeInternal, due to https://github.com/xunit/xunit/issues/2762
#if XUNIT_NULLABLE
Enumerator? enumerator;
#else
Enumerator enumerator;
#endif
#pragma warning restore CA2213
/// <summary>
/// INTERNAL CONSTRUCTOR. DO NOT CALL.
/// </summary>
internal CollectionTracker(
IEnumerable collection,
IEnumerable<T> castCollection) :
base(collection)
{
#if NET6_0_OR_GREATER
ArgumentNullException.ThrowIfNull(castCollection);
#else
if (castCollection == null)
throw new ArgumentNullException(nameof(castCollection));
#endif
this.collection = castCollection;
}
CollectionTracker(IEnumerable<T> collection) :
base(collection)
{
this.collection = collection;
}
/// <summary>
/// Gets the number of iterations that have happened so far.
/// </summary>
public int IterationCount =>
enumerator == null ? 0 : enumerator.CurrentIndex + 1;
/// <inheritdoc/>
public override void Dispose() =>
enumerator?.DisposeInternal();
/// <inheritdoc/>
public override string FormatIndexedMismatch(
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1)
{
if (depth == ArgumentFormatter.MAX_DEPTH)
{
pointerIndent = 1;
return ArgumentFormatter.EllipsisInBrackets;
}
int startIndex;
int endIndex;
GetMismatchExtents(mismatchedIndex, out startIndex, out endIndex);
return FormatIndexedMismatch(
#if XUNIT_NULLABLE
enumerator!.CurrentItems,
#else
enumerator.CurrentItems,
#endif
enumerator.MoveNext,
startIndex,
endIndex,
mismatchedIndex,
out pointerIndent,
depth
);
}
/// <inheritdoc/>
public override string FormatIndexedMismatch(
int startIndex,
int endIndex,
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1)
{
if (enumerator == null)
throw new InvalidOperationException("Called FormatIndexedMismatch with indices without calling GetMismatchExtents first");
return FormatIndexedMismatch(
enumerator.CurrentItems,
enumerator.MoveNext,
startIndex,
endIndex,
mismatchedIndex,
out pointerIndent,
depth
);
}
#if XUNIT_SPAN
/// <summary>
/// Formats a span with a mismatched index.
/// </summary>
/// <param name="span">The span to be formatted</param>
/// <param name="mismatchedIndex">The mismatched index point</param>
/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted span</returns>
public static string FormatIndexedMismatch(
ReadOnlySpan<T> span,
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1)
{
if (depth == ArgumentFormatter.MAX_DEPTH)
{
pointerIndent = 1;
return ArgumentFormatter.EllipsisInBrackets;
}
var startIndex = Math.Max(0, (mismatchedIndex ?? 0) - MAX_ENUMERABLE_LENGTH_HALF);
var endIndex = Math.Min(span.Length - 1, startIndex + ArgumentFormatter.MAX_ENUMERABLE_LENGTH - 1);
startIndex = Math.Max(0, endIndex - ArgumentFormatter.MAX_ENUMERABLE_LENGTH + 1);
var moreItemsPastEndIndex = endIndex < span.Length - 1;
var items = new Dictionary<int, T>();
for (var idx = startIndex; idx <= endIndex; ++idx)
items[idx] = span[idx];
return FormatIndexedMismatch(
items,
() => moreItemsPastEndIndex,
startIndex,
endIndex,
mismatchedIndex,
out pointerIndent,
depth
);
}
#endif
static string FormatIndexedMismatch(
Dictionary<int, T> items,
Func<bool> moreItemsPastEndIndex,
int startIndex,
int endIndex,
int? mismatchedIndex,
out int? pointerIndent,
int depth)
{
pointerIndent = null;
var printedValues = new StringBuilder("[");
if (startIndex != 0)
printedValues.Append(ArgumentFormatter.Ellipsis + ", ");
for (var idx = startIndex; idx <= endIndex; ++idx)
{
if (idx != startIndex)
printedValues.Append(", ");
if (idx == mismatchedIndex)
pointerIndent = printedValues.Length;
printedValues.Append(ArgumentFormatter.Format(items[idx], depth));
}
if (moreItemsPastEndIndex())
printedValues.Append(", " + ArgumentFormatter.Ellipsis);
printedValues.Append(']');
return printedValues.ToString();
}
/// <inheritdoc/>
public override string FormatStart(int depth = 1)
{
if (depth == ArgumentFormatter.MAX_DEPTH)
return ArgumentFormatter.EllipsisInBrackets;
if (enumerator == null)
enumerator = new Enumerator(collection.GetEnumerator());
// Ensure we have already seen enough data to format
while (enumerator.CurrentIndex <= ArgumentFormatter.MAX_ENUMERABLE_LENGTH)
if (!enumerator.MoveNext())
break;
return FormatStart(enumerator.StartItems, enumerator.CurrentIndex, depth);
}
/// <summary>
/// Formats the beginning part of a collection.
/// </summary>
/// <param name="collection">The collection to be formatted</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public static string FormatStart(
IEnumerable<T> collection,
int depth = 1)
{
Assert.GuardArgumentNotNull(nameof(collection), collection);
if (depth == ArgumentFormatter.MAX_DEPTH)
return ArgumentFormatter.EllipsisInBrackets;
var startItems = new List<T>();
var currentIndex = -1;
var spanEnumerator = collection.GetEnumerator();
// Ensure we have already seen enough data to format
while (currentIndex <= ArgumentFormatter.MAX_ENUMERABLE_LENGTH)
{
if (!spanEnumerator.MoveNext())
break;
startItems.Add(spanEnumerator.Current);
++currentIndex;
}
return FormatStart(startItems, currentIndex, depth);
}
#if XUNIT_SPAN
/// <summary>
/// Formats the beginning part of a span.
/// </summary>
/// <param name="span">The span to be formatted</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted span</returns>
public static string FormatStart(
ReadOnlySpan<T> span,
int depth = 1)
{
if (depth == ArgumentFormatter.MAX_DEPTH)
return ArgumentFormatter.EllipsisInBrackets;
var startItems = new List<T>();
var currentIndex = -1;
var spanEnumerator = span.GetEnumerator();
// Ensure we have already seen enough data to format
while (currentIndex <= ArgumentFormatter.MAX_ENUMERABLE_LENGTH)
{
if (!spanEnumerator.MoveNext())
break;
startItems.Add(spanEnumerator.Current);
++currentIndex;
}
return FormatStart(startItems, currentIndex, depth);
}
#endif
static string FormatStart(
List<T> items,
int currentIndex,
int depth)
{
var printedValues = new StringBuilder("[");
var printLength = Math.Min(currentIndex + 1, ArgumentFormatter.MAX_ENUMERABLE_LENGTH);
for (var idx = 0; idx < printLength; ++idx)
{
if (idx != 0)
printedValues.Append(", ");
printedValues.Append(ArgumentFormatter.Format(items[idx], depth));
}
if (currentIndex >= ArgumentFormatter.MAX_ENUMERABLE_LENGTH)
printedValues.Append(", " + ArgumentFormatter.Ellipsis);
printedValues.Append(']');
return printedValues.ToString();
}
/// <inheritdoc/>
public IEnumerator<T> GetEnumerator()
{
if (enumerator != null)
throw new InvalidOperationException("Multiple enumeration is not supported");
enumerator = new Enumerator(collection.GetEnumerator());
return enumerator;
}
IEnumerator IEnumerable.GetEnumerator() =>
GetEnumerator();
/// <inheritdoc/>
protected internal override IEnumerator GetSafeEnumerator() =>
GetEnumerator();
/// <inheritdoc/>
public override void GetMismatchExtents(
int? mismatchedIndex,
out int startIndex,
out int endIndex)
{
if (enumerator == null)
enumerator = new Enumerator(collection.GetEnumerator());
startIndex = Math.Max(0, (mismatchedIndex ?? 0) - MAX_ENUMERABLE_LENGTH_HALF);
endIndex = startIndex + ArgumentFormatter.MAX_ENUMERABLE_LENGTH - 1;
// Make sure our window starts with startIndex and ends with endIndex, as appropriate
while (enumerator.CurrentIndex < endIndex)
if (!enumerator.MoveNext())
break;
endIndex = enumerator.CurrentIndex;
startIndex = Math.Max(0, endIndex - ArgumentFormatter.MAX_ENUMERABLE_LENGTH + 1);
}
/// <inheritdoc/>
#if XUNIT_NULLABLE
public override string? TypeAt(int? index)
#else
public override string TypeAt(int? index)
#endif
{
if (enumerator == null || !index.HasValue)
return null;
#if XUNIT_NULLABLE
T? item;
#else
T item;
#endif
if (!enumerator.TryGetCurrentItemAt(index.Value, out item))
return null;
return item?.GetType().FullName;
}
/// <summary>
/// Wraps the given collection inside of a <see cref="CollectionTracker{T}"/>.
/// </summary>
/// <param name="collection">The collection to be wrapped</param>
public static CollectionTracker<T> Wrap(IEnumerable<T> collection) =>
new CollectionTracker<T>(collection);
sealed class Enumerator : IEnumerator<T>
{
int currentItemsLastInsertionIndex = -1;
readonly T[] currentItemsRingBuffer = new T[ArgumentFormatter.MAX_ENUMERABLE_LENGTH];
readonly IEnumerator<T> innerEnumerator;
public Enumerator(IEnumerator<T> innerEnumerator)
{
this.innerEnumerator = innerEnumerator;
}
public T Current =>
innerEnumerator.Current;
#if XUNIT_NULLABLE
object? IEnumerator.Current =>
#else
object IEnumerator.Current =>
#endif
Current;
public int CurrentIndex { get; private set; } = -1;
public Dictionary<int, T> CurrentItems
{
get
{
var result = new Dictionary<int, T>();
if (CurrentIndex > -1)
{
var itemIndex = Math.Max(0, CurrentIndex - ArgumentFormatter.MAX_ENUMERABLE_LENGTH + 1);
var indexInRingBuffer = (currentItemsLastInsertionIndex - CurrentIndex + itemIndex) % ArgumentFormatter.MAX_ENUMERABLE_LENGTH;
if (indexInRingBuffer < 0)
indexInRingBuffer += ArgumentFormatter.MAX_ENUMERABLE_LENGTH;
while (itemIndex <= CurrentIndex)
{
result[itemIndex] = currentItemsRingBuffer[indexInRingBuffer];
++itemIndex;
indexInRingBuffer = (indexInRingBuffer + 1) % ArgumentFormatter.MAX_ENUMERABLE_LENGTH;
}
}
return result;
}
}
public List<T> StartItems { get; } = new List<T>();
public void Dispose()
{ }
public void DisposeInternal() =>
innerEnumerator.Dispose();
public bool MoveNext()
{
if (!innerEnumerator.MoveNext())
return false;
CurrentIndex++;
var current = innerEnumerator.Current;
// Keep (MAX_ENUMERABLE_LENGTH + 1) items here, so we can
// print the start of the collection when lengths differ
if (CurrentIndex <= ArgumentFormatter.MAX_ENUMERABLE_LENGTH)
StartItems.Add(current);
// Keep a ring buffer filled with the most recent MAX_ENUMERABLE_LENGTH items
// so we can print out the items when we've found a bad index
currentItemsLastInsertionIndex = (currentItemsLastInsertionIndex + 1) % ArgumentFormatter.MAX_ENUMERABLE_LENGTH;
currentItemsRingBuffer[currentItemsLastInsertionIndex] = current;
return true;
}
public void Reset()
{
innerEnumerator.Reset();
CurrentIndex = -1;
currentItemsLastInsertionIndex = -1;
StartItems.Clear();
}
public bool TryGetCurrentItemAt(
int index,
#if XUNIT_NULLABLE
[MaybeNullWhen(false)] out T item)
#else
out T item)
#endif
{
item = default(T);
if (index < 0 || index <= CurrentIndex - ArgumentFormatter.MAX_ENUMERABLE_LENGTH || index > CurrentIndex)
return false;
var indexInRingBuffer = (currentItemsLastInsertionIndex - CurrentIndex + index) % ArgumentFormatter.MAX_ENUMERABLE_LENGTH;
if (indexInRingBuffer < 0)
indexInRingBuffer += ArgumentFormatter.MAX_ENUMERABLE_LENGTH;
item = currentItemsRingBuffer[indexInRingBuffer];
return true;
}
}
}
}
|