File: Sdk\CollectionTracker.cs
Web Access
Project: src\src\Microsoft.DotNet.XUnitAssert\src\Microsoft.DotNet.XUnitAssert.csproj (xunit.assert)
#pragma warning disable CA1000 // Do not declare static members on generic types
#pragma warning disable CA1031 // Do not catch general exception types
#pragma warning disable CA1508 // Avoid dead conditional code
#pragma warning disable CA2213 // We move disposal to DisposeInternal, due to https://github.com/xunit/xunit/issues/2762
#pragma warning disable IDE0019 // Use pattern matching
#pragma warning disable IDE0028 // Simplify collection initialization
#pragma warning disable IDE0063 // Use simple 'using' statement
#pragma warning disable IDE0074 // Use compound assignment
#pragma warning disable IDE0090 // Use 'new(...)'
#pragma warning disable IDE0290 // Use primary constructor
#pragma warning disable IDE0300 // Simplify collection initialization

#if XUNIT_NULLABLE
#nullable enable
#else
// In case this is source-imported with global nullable enabled but no XUNIT_NULLABLE
#pragma warning disable CS8601
#pragma warning disable CS8603
#pragma warning disable CS8604
#pragma warning disable CS8618
#pragma warning disable CS8625
#endif

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;

#if XUNIT_NULLABLE
using System.Diagnostics.CodeAnalysis;
#endif

namespace Xunit.Sdk
{
	/// <summary>
	/// Base class for generic <see cref="CollectionTracker{T}"/>, which also includes some public
	/// static functionality.
	/// </summary>
#if XUNIT_VISIBILITY_INTERNAL
	internal
#else
	public
#endif
	abstract partial class CollectionTracker : IDisposable
	{
		/// <summary>
		/// Initializes a new instance of the <see cref="CollectionTracker"/> class.
		/// </summary>
		/// <param name="innerEnumerable"></param>
		/// <exception cref="ArgumentNullException"></exception>
		protected CollectionTracker(IEnumerable innerEnumerable) =>
			InnerEnumerable = innerEnumerable ?? throw new ArgumentNullException(nameof(innerEnumerable));

		/// <summary>
		/// Gets the inner enumerable that this collection track is wrapping. This is mostly
		/// provided for simplifying other APIs which require both the tracker and the collection
		/// (for example, <see cref="AreCollectionsEqual(CollectionTracker?, CollectionTracker?, IEqualityComparer, bool)"/>).
		/// </summary>
		protected internal IEnumerable InnerEnumerable { get; protected set; }

		/// <summary>
		/// Determine if two enumerable collections are equal. It contains logic that varies depending
		/// on the collection type (supporting arrays, dictionaries, sets, and generic enumerables).
		/// </summary>
		/// <param name="x">First value to compare</param>
		/// <param name="y">Second value to comare</param>
		/// <param name="itemComparer">The comparer used for individual item comparisons</param>
		/// <param name="isDefaultItemComparer">Pass <see langword="true"/> if the <paramref name="itemComparer"/> is the default item
		/// comparer from <see cref="AssertEqualityComparer{T}"/>; pass <see langword="false"/>, otherwise.</param>
		/// <returns>Returns <see langword="true"/> if the collections are equal; <see langword="false"/>, otherwise.</returns>
		public static AssertEqualityResult AreCollectionsEqual(
#if XUNIT_NULLABLE
			CollectionTracker? x,
			CollectionTracker? y,
#else
			CollectionTracker x,
			CollectionTracker y,
#endif
			IEqualityComparer itemComparer,
			bool isDefaultItemComparer)
		{
			Assert.GuardArgumentNotNull(nameof(itemComparer), itemComparer);

			try
			{
				return
					CheckIfDictionariesAreEqual(x, y) ??
					CheckIfSetsAreEqual(x, y, isDefaultItemComparer ? null : itemComparer) ??
					CheckIfArraysAreEqual(x, y, itemComparer, isDefaultItemComparer) ??
					CheckIfEnumerablesAreEqual(x, y, itemComparer, isDefaultItemComparer);
			}
			catch (Exception ex)
			{
				return AssertEqualityResult.ForResult(false, x?.InnerEnumerable, y?.InnerEnumerable, ex);
			}
		}

#if XUNIT_NULLABLE
		static AssertEqualityResult? CheckIfArraysAreEqual(
			CollectionTracker? x,
			CollectionTracker? y,
#else
		static AssertEqualityResult CheckIfArraysAreEqual(
			CollectionTracker x,
			CollectionTracker y,
#endif
			IEqualityComparer itemComparer,
			bool isDefaultItemComparer)
		{
			if (x == null || y == null)
				return null;

			var expectedArray = x.InnerEnumerable as Array;
			var actualArray = y.InnerEnumerable as Array;

			if (expectedArray == null || actualArray == null)
				return null;

			// If we have single-dimensional zero-based arrays, then we delegate to the enumerable
			// version, since that's uses the trackers and gets us the mismatch pointer.
			if (expectedArray.Rank == 1 && expectedArray.GetLowerBound(0) == 0 &&
				actualArray.Rank == 1 && actualArray.GetLowerBound(0) == 0)
				return CheckIfEnumerablesAreEqual(x, y, itemComparer, isDefaultItemComparer);

			if (expectedArray.Rank != actualArray.Rank)
				return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);

			// Differing bounds, aka object[2,1] vs. object[1,2]
			// You can also have non-zero-based arrays, so we don't just check lengths
			for (var rank = 0; rank < expectedArray.Rank; rank++)
				if (expectedArray.GetLowerBound(rank) != actualArray.GetLowerBound(rank) || expectedArray.GetUpperBound(rank) != actualArray.GetUpperBound(rank))
					return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);

			// Enumeration will flatten everything identically, so just enumerate at this point
			var expectedEnumerator = x.GetSafeEnumerator();
			var actualEnumerator = y.GetSafeEnumerator();

			while (true)
			{
				var hasExpected = expectedEnumerator.MoveNext();
				var hasActual = actualEnumerator.MoveNext();

				if (!hasExpected || !hasActual)
					return AssertEqualityResult.ForResult(hasExpected == hasActual, x.InnerEnumerable, y.InnerEnumerable);

				if (!itemComparer.Equals(expectedEnumerator.Current, actualEnumerator.Current))
					return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);
			}
		}

#if XUNIT_NULLABLE
		static AssertEqualityResult? CheckIfDictionariesAreEqual(
			CollectionTracker? x,
			CollectionTracker? y)
#else
		static AssertEqualityResult CheckIfDictionariesAreEqual(
			CollectionTracker x,
			CollectionTracker y)
#endif
		{
			if (x == null || y == null)
				return null;

			var dictionaryX = x.InnerEnumerable as IDictionary;
			var dictionaryY = y.InnerEnumerable as IDictionary;

			if (dictionaryX == null || dictionaryY == null)
				return null;

			if (dictionaryX.Count != dictionaryY.Count)
				return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);

			var dictionaryYKeys = new HashSet<object>(dictionaryY.Keys.Cast<object>());

			// We don't pass along the itemComparer from AreCollectionsEqual because we aren't directly
			// comparing the KeyValuePair<> objects. Instead we rely on Contains() on the dictionary to
			// match up keys, and then create type-appropriate comparers for the values.
			foreach (var key in dictionaryX.Keys.Cast<object>())
			{
				if (!dictionaryYKeys.Contains(key))
					return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);

				var valueX = dictionaryX[key];
				var valueY = dictionaryY[key];

				if (valueX == null)
				{
					if (valueY != null)
						return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);
				}
				else if (valueY == null)
					return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);
				else
				{
					var valueXType = valueX.GetType();
					var valueYType = valueY.GetType();

					var comparer = AssertEqualityComparer.GetDefaultComparer(valueXType == valueYType ? valueXType : typeof(object));
					if (!comparer.Equals(valueX, valueY))
						return AssertEqualityResult.ForResult(false, x.InnerEnumerable, y.InnerEnumerable);
				}

				dictionaryYKeys.Remove(key);
			}

			return AssertEqualityResult.ForResult(dictionaryYKeys.Count == 0, x.InnerEnumerable, y.InnerEnumerable);
		}

		static AssertEqualityResult CheckIfEnumerablesAreEqual(
#if XUNIT_NULLABLE
			CollectionTracker? x,
			CollectionTracker? y,
#else
			CollectionTracker x,
			CollectionTracker y,
#endif
			IEqualityComparer itemComparer,
			bool isDefaultItemComparer)
		{
			if (x == null)
				return AssertEqualityResult.ForResult(y == null, null, y?.InnerEnumerable);
			if (y == null)
				return AssertEqualityResult.ForResult(false, x.InnerEnumerable, null);

			var (comparisonType, equalsMethod) = GetAssertEqualityComparerMetadata(itemComparer);
			var enumeratorX = x.GetSafeEnumerator();
			var enumeratorY = y.GetSafeEnumerator();
			var mismatchIndex = 0;

			while (true)
			{
				var hasNextX = enumeratorX.MoveNext();
				var hasNextY = enumeratorY.MoveNext();

				if (!hasNextX || !hasNextY)
					return hasNextX == hasNextY
						? AssertEqualityResult.ForResult(true, x.InnerEnumerable, y.InnerEnumerable)
						: AssertEqualityResult.ForMismatch(x.InnerEnumerable, y.InnerEnumerable, mismatchIndex);

				var xCurrent = enumeratorX.Current;
				var yCurrent = enumeratorY.Current;

				using (var xCurrentTracker = isDefaultItemComparer ? xCurrent.AsNonStringTracker() : null)
				using (var yCurrentTracker = isDefaultItemComparer ? yCurrent.AsNonStringTracker() : null)
				{
					try
					{
						if (xCurrentTracker != null && yCurrentTracker != null)
						{
							var innerCompare = AreCollectionsEqual(xCurrentTracker, yCurrentTracker, AssertEqualityComparer<object>.DefaultInnerComparer, true);
							if (!innerCompare.Equal)
								return AssertEqualityResult.ForMismatch(x.InnerEnumerable, y.InnerEnumerable, mismatchIndex, innerResult: innerCompare);
						}
						else
						{
							var assertEqualityResult = default(AssertEqualityResult);
							if (comparisonType?.IsAssignableFrom(xCurrent?.GetType()) == true && comparisonType?.IsAssignableFrom(yCurrent?.GetType()) == true)
								assertEqualityResult = equalsMethod?.Invoke(itemComparer, new[] { xCurrent, null, yCurrent, null }) as AssertEqualityResult;

							if (assertEqualityResult != null)
							{
								if (!assertEqualityResult.Equal)
									return AssertEqualityResult.ForMismatch(x.InnerEnumerable, y.InnerEnumerable, mismatchIndex, innerResult: assertEqualityResult);
							}
							else if (!itemComparer.Equals(xCurrent, yCurrent))
								return AssertEqualityResult.ForMismatch(x.InnerEnumerable, y.InnerEnumerable, mismatchIndex);
						}
					}
					catch (Exception ex)
					{
						return AssertEqualityResult.ForMismatch(x.InnerEnumerable, y.InnerEnumerable, mismatchIndex, ex);
					}

					mismatchIndex++;
				}
			}
		}

		static bool CompareTypedSets<T>(
			ISet<T> setX,
			ISet<T> setY,
#if XUNIT_NULLABLE
			IEqualityComparer<T>? itemComparer)
#else
			IEqualityComparer<T> itemComparer)
#endif
		{
			if (setX.Count != setY.Count)
				return false;

			if (itemComparer != null)
			{
				setX = new HashSet<T>(setX, itemComparer);
				setY = new HashSet<T>(setY, itemComparer);
			}

			return setX.SetEquals(setY);
		}

		/// <inheritdoc/>
		public void Dispose()
		{
			Dispose(true);

			GC.SuppressFinalize(this);
		}

		/// <summary>
		/// Override to provide an implementation of <see cref="IDisposable.Dispose"/>.
		/// </summary>
		/// <param name="disposing"></param>
		protected abstract void Dispose(bool disposing);

		/// <summary>
		/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
		/// collection surrounded by the mismatched item.
		/// </summary>
		/// <param name="mismatchedIndex">The index of the mismatched item</param>
		/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
		/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
		/// <returns>The formatted collection</returns>
		public abstract string FormatIndexedMismatch(
			int? mismatchedIndex,
			out int? pointerIndent,
			int depth = 1);

		/// <summary>
		/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
		/// collection from <paramref name="startIndex"/> to <paramref name="endIndex"/>. These indices are usually
		/// obtained by calling <see cref="GetMismatchExtents"/>.
		/// </summary>
		/// <param name="startIndex">The start index of the collection to print</param>
		/// <param name="endIndex">The end index of the collection to print</param>
		/// <param name="mismatchedIndex">The mismatched item index</param>
		/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
		/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
		/// <returns>The formatted collection</returns>
		public abstract string FormatIndexedMismatch(
			int startIndex,
			int endIndex,
			int? mismatchedIndex,
			out int? pointerIndent,
			int depth = 1);

		/// <summary>
		/// Formats the beginning part of the collection.
		/// </summary>
		/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
		/// <returns>The formatted collection</returns>
		public abstract string FormatStart(int depth = 1);

		/// <summary>
		/// Gets the extents to print when you find a mismatched index, in the form of
		/// a <paramref name="startIndex"/> and <paramref name="endIndex"/>. If the mismatched
		/// index is <see langword="null"/>, the extents will start at index 0.
		/// </summary>
		/// <param name="mismatchedIndex">The mismatched item index</param>
		/// <param name="startIndex">The start index that should be used for printing</param>
		/// <param name="endIndex">The end index that should be used for printing</param>
		public abstract void GetMismatchExtents(
			int? mismatchedIndex,
			out int startIndex,
			out int endIndex);

		/// <summary>
		/// Gets a safe version of <see cref="IEnumerator"/> that prevents double enumeration and does all
		/// the necessary tracking required for collection formatting. Should should be the same value
		/// returned by <see cref="CollectionTracker{T}.GetEnumerator"/>, except non-generic.
		/// </summary>
		protected internal abstract IEnumerator GetSafeEnumerator();

		/// <summary>
		/// Gets the full name of the type of the element at the given index, if known.
		/// Since this uses the item cache produced by enumeration, it may return <see langword="null"/>
		/// when we haven't enumerated enough to see the given element, or if we enumerated
		/// so much that the item has left the cache, or if the item at the given index
		/// is <see langword="null"/>. It will also return <see langword="null"/> when the <paramref name="index"/>
		/// is <see langword="null"/>.
		/// </summary>
		/// <param name="index">The item index</param>
#if XUNIT_NULLABLE
		public abstract string? TypeAt(int? index);
#else
		public abstract string TypeAt(int? index);
#endif

		/// <summary>
		/// Wraps an untyped enumerable in an object-based <see cref="CollectionTracker{T}"/>.
		/// </summary>
		/// <param name="enumerable">The untyped enumerable to wrap</param>
		public static CollectionTracker<object> Wrap(IEnumerable enumerable) =>
			new CollectionTracker<object>(enumerable, enumerable.Cast<object>());
	}

	/// <summary>
	/// A utility class that can be used to wrap enumerables to prevent double enumeration.
	/// It offers the ability to safely print parts of the collection when failures are
	/// encountered, as well as some static versions of the printing functionality.
	/// </summary>
#if XUNIT_VISIBILITY_INTERNAL
	internal
#else
	public
#endif
	sealed class CollectionTracker<T> : CollectionTracker, IEnumerable<T>
	{
		readonly IEnumerable<T> collection;
#if XUNIT_NULLABLE
		BufferedEnumerator? enumerator;
#else
		BufferedEnumerator enumerator;
#endif

		/// <summary>
		/// INTERNAL CONSTRUCTOR. DO NOT CALL.
		/// </summary>
		internal CollectionTracker(
			IEnumerable collection,
			IEnumerable<T> castCollection) :
				base(collection) =>
					this.collection = castCollection ?? throw new ArgumentNullException(nameof(castCollection));

		CollectionTracker(IEnumerable<T> collection) :
			base(collection) =>
				this.collection = collection;

		/// <summary>
		/// Gets the number of iterations that have happened so far.
		/// </summary>
		public int IterationCount =>
			enumerator == null ? 0 : enumerator.CurrentIndex + 1;

		/// <inheritdoc/>
		protected override void Dispose(bool disposing) =>
			enumerator?.DisposeInternal();

		/// <inheritdoc/>
		public override string FormatIndexedMismatch(
			int? mismatchedIndex,
			out int? pointerIndent,
			int depth = 1)
		{
			if (depth > ArgumentFormatter.MaxEnumerableLength)
			{
				pointerIndent = 1;
				return ArgumentFormatter.EllipsisInBrackets;
			}

			GetMismatchExtents(mismatchedIndex, out var startIndex, out var endIndex);

			return FormatIndexedMismatch(
#if XUNIT_NULLABLE
				enumerator!.CurrentItemsIndexer,
#else
				enumerator.CurrentItemsIndexer,
#endif
				enumerator.MoveNext,
				startIndex,
				endIndex,
				mismatchedIndex,
				out pointerIndent,
				depth
			);
		}

		/// <inheritdoc/>
		public override string FormatIndexedMismatch(
			int startIndex,
			int endIndex,
			int? mismatchedIndex,
			out int? pointerIndent,
			int depth = 1)
		{
			if (enumerator == null)
				throw new InvalidOperationException("Called FormatIndexedMismatch with indices without calling GetMismatchExtents first");

			return FormatIndexedMismatch(
				enumerator.CurrentItemsIndexer,
				enumerator.MoveNext,
				startIndex,
				endIndex,
				mismatchedIndex,
				out pointerIndent,
				depth
			);
		}

		/// <summary>
		/// Formats a span with a mismatched index.
		/// </summary>
		/// <param name="span">The span to be formatted</param>
		/// <param name="mismatchedIndex">The mismatched index point</param>
		/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
		/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
		/// <returns>The formatted span</returns>
		public static string FormatIndexedMismatch(
			ReadOnlySpan<T> span,
			int? mismatchedIndex,
			out int? pointerIndent,
			int depth = 1)
		{
			if (depth > ArgumentFormatter.MaxEnumerableLength)
			{
				pointerIndent = 1;
				return ArgumentFormatter.EllipsisInBrackets;
			}

			int startIndex, endIndex;

			if (ArgumentFormatter.MaxEnumerableLength == int.MaxValue)
			{
				startIndex = 0;
				endIndex = span.Length - 1;
			}
			else
			{
				startIndex = Math.Max(0, (mismatchedIndex ?? 0) - ArgumentFormatter.MaxEnumerableLength / 2);
				endIndex = Math.Min(span.Length - 1, startIndex + ArgumentFormatter.MaxEnumerableLength - 1);
				startIndex = Math.Max(0, endIndex - ArgumentFormatter.MaxEnumerableLength + 1);
			}

			var moreItemsPastEndIndex = endIndex < span.Length - 1;
			var items = new Dictionary<int, T>();

			for (var idx = startIndex; idx <= endIndex; ++idx)
				items[idx] = span[idx];

			return FormatIndexedMismatch(
				idx => items[idx],
				() => moreItemsPastEndIndex,
				startIndex,
				endIndex,
				mismatchedIndex,
				out pointerIndent,
				depth
			);
		}

		static string FormatIndexedMismatch(
			Func<int, T> indexer,
			Func<bool> moreItemsPastEndIndex,
			int startIndex,
			int endIndex,
			int? mismatchedIndex,
			out int? pointerIndent,
			int depth)
		{
			pointerIndent = null;

			var printedValues = new StringBuilder("[");
			if (startIndex != 0)
				printedValues.Append(ArgumentFormatter.Ellipsis + ", ");

			for (var idx = startIndex; idx <= endIndex; ++idx)
			{
				if (idx != startIndex)
					printedValues.Append(", ");

				if (idx == mismatchedIndex)
					pointerIndent = printedValues.Length;

				printedValues.Append(ArgumentFormatter.Format(indexer(idx), depth));
			}

			if (moreItemsPastEndIndex())
				printedValues.Append(", " + ArgumentFormatter.Ellipsis);

			printedValues.Append(']');
			return printedValues.ToString();
		}

		/// <inheritdoc/>
		public override string FormatStart(int depth = 1)
		{
			if (depth > ArgumentFormatter.MaxEnumerableLength)
				return ArgumentFormatter.EllipsisInBrackets;

			if (enumerator == null)
				enumerator = BufferedEnumerator.Create(collection.GetEnumerator());

			// Ensure we have already seen enough data to format
			while (enumerator.CurrentIndex <= ArgumentFormatter.MaxEnumerableLength)
				if (!enumerator.MoveNext())
					break;

			return FormatStart(enumerator.StartItemsIndexer, enumerator.CurrentIndex, depth);
		}

		/// <summary>
		/// Formats the beginning part of a collection.
		/// </summary>
		/// <param name="collection">The collection to be formatted</param>
		/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
		/// <returns>The formatted collection</returns>
		public static string FormatStart(
			IEnumerable<T> collection,
			int depth = 1)
		{
			Assert.GuardArgumentNotNull(nameof(collection), collection);

			if (depth > ArgumentFormatter.MaxEnumerableLength)
				return ArgumentFormatter.EllipsisInBrackets;

			var startItems = new List<T>();
			var currentIndex = -1;
			var spanEnumerator = collection.GetEnumerator();

			// Ensure we have already seen enough data to format
			while (currentIndex <= ArgumentFormatter.MaxEnumerableLength)
			{
				if (!spanEnumerator.MoveNext())
					break;

				startItems.Add(spanEnumerator.Current);
				++currentIndex;
			}

			return FormatStart(idx => startItems[idx], currentIndex, depth);
		}

		/// <summary>
		/// Formats the beginning part of a span.
		/// </summary>
		/// <param name="span">The span to be formatted</param>
		/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
		/// <returns>The formatted span</returns>
		public static string FormatStart(
			ReadOnlySpan<T> span,
			int depth = 1)
		{
			if (depth > ArgumentFormatter.MaxEnumerableLength)
				return ArgumentFormatter.EllipsisInBrackets;

			var startItems = new List<T>();
			var currentIndex = -1;
			var spanEnumerator = span.GetEnumerator();

			// Ensure we have already seen enough data to format
			while (currentIndex <= ArgumentFormatter.MaxEnumerableLength)
			{
				if (!spanEnumerator.MoveNext())
					break;

				startItems.Add(spanEnumerator.Current);
				++currentIndex;
			}

			return FormatStart(idx => startItems[idx], currentIndex, depth);
		}

		static string FormatStart(
			Func<int, T> indexer,
			int currentIndex,
			int depth)
		{
			var printedValues = new StringBuilder("[");
			var printLength = Math.Min(currentIndex + 1, ArgumentFormatter.MaxEnumerableLength);

			for (var idx = 0; idx < printLength; ++idx)
			{
				if (idx != 0)
					printedValues.Append(", ");

				printedValues.Append(ArgumentFormatter.Format(indexer(idx), depth));
			}

			if (currentIndex >= ArgumentFormatter.MaxEnumerableLength)
				printedValues.Append(", " + ArgumentFormatter.Ellipsis);

			printedValues.Append(']');
			return printedValues.ToString();
		}

		/// <inheritdoc/>
		public IEnumerator<T> GetEnumerator()
		{
			if (enumerator != null)
				throw new InvalidOperationException("Multiple enumeration is not supported");

			enumerator = BufferedEnumerator.Create(collection.GetEnumerator());
			return enumerator;
		}

		IEnumerator IEnumerable.GetEnumerator() =>
			GetEnumerator();

		/// <inheritdoc/>
		protected internal override IEnumerator GetSafeEnumerator() =>
			GetEnumerator();

		/// <inheritdoc/>
		public override void GetMismatchExtents(
			int? mismatchedIndex,
			out int startIndex,
			out int endIndex)
		{
			if (enumerator == null)
				enumerator = BufferedEnumerator.Create(collection.GetEnumerator());

			if (ArgumentFormatter.MaxEnumerableLength == int.MaxValue)
			{
				startIndex = 0;
				endIndex = int.MaxValue;
			}
			else
			{
				startIndex = Math.Max(0, (mismatchedIndex ?? 0) - ArgumentFormatter.MaxEnumerableLength / 2);
				endIndex = startIndex + ArgumentFormatter.MaxEnumerableLength - 1;
			}

			// Make sure our window starts with startIndex and ends with endIndex, as appropriate
			while (enumerator.CurrentIndex < endIndex)
				if (!enumerator.MoveNext())
					break;

			endIndex = enumerator.CurrentIndex;

			if (ArgumentFormatter.MaxEnumerableLength != int.MaxValue)
				startIndex = Math.Max(0, endIndex - ArgumentFormatter.MaxEnumerableLength + 1);
		}

		/// <inheritdoc/>
#if XUNIT_NULLABLE
		public override string? TypeAt(int? index)
#else
		public override string TypeAt(int? index)
#endif
		{
			if (enumerator == null || !index.HasValue)
				return null;

			if (!enumerator.TryGetCurrentItemAt(index.Value, out var item))
				return null;

			return item?.GetType().FullName;
		}

		/// <summary>
		/// Wraps the given collection inside of a <see cref="CollectionTracker{T}"/>.
		/// </summary>
		/// <param name="collection">The collection to be wrapped</param>
		public static CollectionTracker<T> Wrap(IEnumerable<T> collection) =>
			new CollectionTracker<T>(collection);

		abstract class BufferedEnumerator : IEnumerator<T>
		{
			protected BufferedEnumerator(IEnumerator<T> innerEnumerator) =>
				InnerEnumerator = innerEnumerator;

			public T Current =>
				InnerEnumerator.Current;

#if XUNIT_NULLABLE
			object? IEnumerator.Current =>
#else
			object IEnumerator.Current =>
#endif
				Current;

			public int CurrentIndex { get; private set; } = -1;

			public abstract Func<int, T> CurrentItemsIndexer { get; }

			protected IEnumerator<T> InnerEnumerator { get; }

			public abstract Func<int, T> StartItemsIndexer { get; }

			public static BufferedEnumerator Create(IEnumerator<T> innerEnumerator) =>
				ArgumentFormatter.MaxEnumerableLength == int.MaxValue
					? (BufferedEnumerator)new ListBufferedEnumerator(innerEnumerator)
					: new RingBufferedEnumerator(innerEnumerator);

			public void Dispose() { }

			public void DisposeInternal() =>
				InnerEnumerator.Dispose();

			public virtual bool MoveNext()
			{
				if (!InnerEnumerator.MoveNext())
					return false;

				CurrentIndex++;
				return true;
			}

			public virtual void Reset()
			{
				InnerEnumerator.Reset();

				CurrentIndex = -1;
			}

			public abstract bool TryGetCurrentItemAt(
				int index,
#if XUNIT_NULLABLE
				[MaybeNullWhen(false)] out T item);
#else
				out T item);
#endif

			// Used when ArgumentFormatter.MaxEnumerableLength is unlimited (int.MaxValue)
			sealed class ListBufferedEnumerator : BufferedEnumerator
			{
				readonly List<T> buffer = new List<T>();

				public ListBufferedEnumerator(IEnumerator<T> innerEnumerator) :
					base(innerEnumerator)
				{ }

				public override Func<int, T> CurrentItemsIndexer =>
					idx => buffer[idx];

				public override Func<int, T> StartItemsIndexer =>
					idx => buffer[idx];

				public override bool MoveNext()
				{
					if (!base.MoveNext())
						return false;

					buffer.Add(InnerEnumerator.Current);
					return true;
				}

				public override void Reset()
				{
					base.Reset();

					buffer.Clear();
				}

				public override bool TryGetCurrentItemAt(
					int index,
#if XUNIT_NULLABLE
					[MaybeNullWhen(false)] out T item)
#else
					out T item)
#endif
				{
					if (index < 0 || index > CurrentIndex)
					{
						item = default;
						return false;
					}

					item = buffer[index];
					return true;
				}
			}

			// Used when ArgumentFormatter.MaxEnumerableLength is not unlimited
			sealed class RingBufferedEnumerator : BufferedEnumerator
			{
				int currentItemsLastInsertionIndex = -1;
				readonly T[] currentItemsRingBuffer = new T[ArgumentFormatter.MaxEnumerableLength];
				readonly List<T> startItems = new List<T>();

				public override Func<int, T> CurrentItemsIndexer
				{
					get
					{
						var result = new Dictionary<int, T>();

						if (CurrentIndex > -1)
						{
							var itemIndex = Math.Max(0, CurrentIndex - ArgumentFormatter.MaxEnumerableLength + 1);

							var indexInRingBuffer = (currentItemsLastInsertionIndex - CurrentIndex + itemIndex) % ArgumentFormatter.MaxEnumerableLength;
							if (indexInRingBuffer < 0)
								indexInRingBuffer += ArgumentFormatter.MaxEnumerableLength;

							while (itemIndex <= CurrentIndex)
							{
								result[itemIndex] = currentItemsRingBuffer[indexInRingBuffer];

								++itemIndex;
								indexInRingBuffer = (indexInRingBuffer + 1) % ArgumentFormatter.MaxEnumerableLength;
							}
						}

						return idx => result[idx];
					}
				}

				public override Func<int, T> StartItemsIndexer =>
					idx => startItems[idx];

				public RingBufferedEnumerator(IEnumerator<T> innerEnumerator) :
					base(innerEnumerator)
				{ }

				public override bool MoveNext()
				{
					if (!base.MoveNext())
						return false;

					var current = InnerEnumerator.Current;

					// Keep (MAX_ENUMERABLE_LENGTH + 1) items here, so we can
					// print the start of the collection when lengths differ
					if (CurrentIndex <= ArgumentFormatter.MaxEnumerableLength)
						startItems.Add(current);

					// Keep a ring buffer filled with the most recent MAX_ENUMERABLE_LENGTH items
					// so we can print out the items when we've found a bad index
					currentItemsLastInsertionIndex = (currentItemsLastInsertionIndex + 1) % ArgumentFormatter.MaxEnumerableLength;
					currentItemsRingBuffer[currentItemsLastInsertionIndex] = current;

					return true;
				}

				public override void Reset()
				{
					base.Reset();

					currentItemsLastInsertionIndex = -1;
					startItems.Clear();
				}

				public override bool TryGetCurrentItemAt(
					int index,
#if XUNIT_NULLABLE
					[MaybeNullWhen(false)] out T item)
#else
					out T item)
#endif
				{
					if (index < 0 || index <= CurrentIndex - ArgumentFormatter.MaxEnumerableLength || index > CurrentIndex)
					{
						item = default;
						return false;
					}

					var indexInRingBuffer = (currentItemsLastInsertionIndex - CurrentIndex + index) % ArgumentFormatter.MaxEnumerableLength;
					if (indexInRingBuffer < 0)
						indexInRingBuffer += ArgumentFormatter.MaxEnumerableLength;

					item = currentItemsRingBuffer[indexInRingBuffer];
					return true;
				}
			}
		}
	}
}