File: DataFrameColumns\StringDataFrameColumn.cs
Web Access
Project: src\src\Microsoft.Data.Analysis\Microsoft.Data.Analysis.csproj (Microsoft.Data.Analysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using Microsoft.ML;
using Microsoft.ML.Data;
 
namespace Microsoft.Data.Analysis
{
    /// <summary>
    /// A mutable column to hold strings
    /// </summary>
    /// <remarks> Is NOT Arrow compatible </remarks>
    public partial class StringDataFrameColumn : DataFrameColumn, IEnumerable<string>
    {
        public static int MaxCapacity = ArrayUtility.ArrayMaxSize / Unsafe.SizeOf<IntPtr>(); // Max Size in bytes / size of pointer (8 bytes on x64)
 
        private readonly List<List<string>> _stringBuffers = new List<List<string>>(); // To store more than intMax number of strings
 
        public StringDataFrameColumn(string name, long length = 0) : base(name, length, typeof(string))
        {
            int numberOfBuffersRequired = (int)(length / MaxCapacity + 1);
            for (int i = 0; i < numberOfBuffersRequired; i++)
            {
                long bufferLen = length - _stringBuffers.Count * MaxCapacity;
                List<string> buffer = new List<string>((int)Math.Min(MaxCapacity, bufferLen));
                _stringBuffers.Add(buffer);
                for (int j = 0; j < bufferLen; j++)
                {
                    buffer.Add(default);
                }
            }
            _nullCount = length;
        }
 
        public StringDataFrameColumn(string name, IEnumerable<string> values) : base(name, 0, typeof(string))
        {
            values = values ?? throw new ArgumentNullException(nameof(values));
            if (_stringBuffers.Count == 0)
            {
                _stringBuffers.Add(new List<string>());
            }
            foreach (var value in values)
            {
                Append(value);
            }
        }
 
        private long _nullCount;
        public override long NullCount => _nullCount;
 
        protected internal override void Resize(long length)
        {
            if (length < Length)
                throw new ArgumentException(Strings.CannotResizeDown, nameof(length));
 
            for (long i = Length; i < length; i++)
            {
                Append(null);
            }
        }
 
        public void Append(string value)
        {
            List<string> lastBuffer = _stringBuffers[_stringBuffers.Count - 1];
            if (lastBuffer.Count == MaxCapacity)
            {
                lastBuffer = new List<string>();
                _stringBuffers.Add(lastBuffer);
            }
            lastBuffer.Add(value);
            if (value == null)
                _nullCount++;
            Length++;
        }
 
        /// <summary>
        /// Applies a function to all values in the column, that are not null.
        /// </summary>
        /// <param name="func">The function to apply.</param>
        /// /// <param name="inPlace">A boolean flag to indicate if the operation should be in place.</param>
        /// <returns>A new <see cref="PrimitiveDataFrameColumn{T}"/> if <paramref name="inPlace"/> is not set. Returns this column otherwise.</returns>
        public StringDataFrameColumn Apply(Func<string, string> func, bool inPlace = false)
        {
            var column = inPlace ? this : Clone();
 
            for (long i = 0; i < column.Length; i++)
            {
                var value = column[i];
 
                if (value != null)
                    column[i] = func(value);
            }
 
            return column;
        }
 
        private int GetBufferIndexContainingRowIndex(long rowIndex)
        {
            if (rowIndex >= Length)
            {
                throw new ArgumentOutOfRangeException(Strings.IndexIsGreaterThanColumnLength, nameof(rowIndex));
            }
            return (int)(rowIndex / MaxCapacity);
        }
 
        protected override object GetValue(long rowIndex)
        {
            int bufferIndex = GetBufferIndexContainingRowIndex(rowIndex);
            return _stringBuffers[bufferIndex][(int)(rowIndex % MaxCapacity)];
        }
 
        protected override IReadOnlyList<object> GetValues(long startIndex, int length)
        {
            var ret = new List<object>();
            int bufferIndex = GetBufferIndexContainingRowIndex(startIndex);
            int bufferOffset = (int)(startIndex % MaxCapacity);
            while (ret.Count < length && bufferIndex < _stringBuffers.Count)
            {
                for (int i = bufferOffset; ret.Count < length && i < _stringBuffers[bufferIndex].Count; i++)
                {
                    ret.Add(_stringBuffers[bufferIndex][i]);
                }
                bufferIndex++;
                bufferOffset = 0;
            }
            return ret;
        }
 
        protected override void SetValue(long rowIndex, object value)
        {
            if (value == null || value is string)
            {
                int bufferIndex = GetBufferIndexContainingRowIndex(rowIndex);
                int bufferOffset = (int)(rowIndex % MaxCapacity);
                var oldValue = this[rowIndex];
                _stringBuffers[bufferIndex][bufferOffset] = (string)value;
                if (oldValue != (string)value)
                {
                    if (value == null)
                        _nullCount++;
                    if (oldValue == null && _nullCount > 0)
                        _nullCount--;
                }
            }
            else
            {
                throw new ArgumentException(string.Format(Strings.MismatchedValueType, typeof(string)), nameof(value));
            }
        }
 
        public new string this[long rowIndex]
        {
            get => (string)GetValue(rowIndex);
            set => SetValue(rowIndex, value);
        }
 
        public new List<string> this[long startIndex, int length]
        {
            get
            {
                var ret = new List<string>();
                int bufferIndex = GetBufferIndexContainingRowIndex(startIndex);
                int bufferOffset = (int)(startIndex % MaxCapacity);
                while (ret.Count < length && bufferIndex < _stringBuffers.Count)
                {
                    for (int i = bufferOffset; ret.Count < length && i < _stringBuffers[bufferIndex].Count; i++)
                    {
                        ret.Add(_stringBuffers[bufferIndex][i]);
                    }
                    bufferIndex++;
                    bufferOffset = 0;
                }
                return ret;
            }
        }
 
        public IEnumerator<string> GetEnumerator()
        {
            foreach (List<string> buffer in _stringBuffers)
            {
                foreach (string value in buffer)
                {
                    yield return value;
                }
            }
        }
 
        protected override IEnumerator GetEnumeratorCore() => GetEnumerator();
 
        public override DataFrameColumn Clamp<U>(U min, U max, bool inPlace = false) => throw new NotSupportedException();
 
        public override DataFrameColumn Filter<U>(U min, U max) => throw new NotSupportedException();
 
        public new StringDataFrameColumn Sort(bool ascending = true, bool putNullValuesLast = true)
        {
            return (StringDataFrameColumn)base.Sort(ascending, putNullValuesLast);
        }
 
        protected internal override PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending, bool putNullValuesLast)
        {
            var comparer = Comparer<string>.Default;
 
            List<int[]> bufferSortIndices = new List<int[]>(_stringBuffers.Count);
            var columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
            long nullIndicesSlot = 0;
            foreach (List<string> buffer in _stringBuffers)
            {
                var sortIndices = new int[buffer.Count];
                for (int i = 0; i < buffer.Count; i++)
                {
                    sortIndices[i] = i;
                    if (buffer[i] == null)
                    {
                        columnNullIndices[nullIndicesSlot] = i + bufferSortIndices.Count * MaxCapacity;
                        nullIndicesSlot++;
                    }
                }
                // TODO: Refactor the sort routine to also work with IList?
                string[] array = buffer.ToArray();
                IntrospectiveSort(array, array.Length, sortIndices, comparer);
                bufferSortIndices.Add(sortIndices);
            }
            // Simple merge sort to build the full column's sort indices
            ValueTuple<string, int> GetFirstNonNullValueStartingAtIndex(int stringBufferIndex, int startIndex)
            {
                string value = _stringBuffers[stringBufferIndex][bufferSortIndices[stringBufferIndex][startIndex]];
                while (value == null && ++startIndex < bufferSortIndices[stringBufferIndex].Length)
                {
                    value = _stringBuffers[stringBufferIndex][bufferSortIndices[stringBufferIndex][startIndex]];
                }
                return (value, startIndex);
            }
 
            SortedDictionary<string, List<ValueTuple<int, int>>> heapOfValueAndListOfTupleOfSortAndBufferIndex = new SortedDictionary<string, List<ValueTuple<int, int>>>(comparer);
            List<List<string>> buffers = _stringBuffers;
            for (int i = 0; i < buffers.Count; i++)
            {
                List<string> buffer = buffers[i];
                ValueTuple<string, int> valueAndBufferSortIndex = GetFirstNonNullValueStartingAtIndex(i, 0);
                if (valueAndBufferSortIndex.Item1 == null)
                {
                    // All nulls
                    continue;
                }
                if (heapOfValueAndListOfTupleOfSortAndBufferIndex.ContainsKey(valueAndBufferSortIndex.Item1))
                {
                    heapOfValueAndListOfTupleOfSortAndBufferIndex[valueAndBufferSortIndex.Item1].Add((valueAndBufferSortIndex.Item2, i));
                }
                else
                {
                    heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferSortIndex.Item1, new List<ValueTuple<int, int>>() { (valueAndBufferSortIndex.Item2, i) });
                }
            }
            var columnSortIndices = new Int64DataFrameColumn("SortIndices", Length);
 
            GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Length);
            GetValueAndBufferSortIndexAtBuffer<string> getValueAtBuffer = new GetValueAndBufferSortIndexAtBuffer<string>((int bufferIndex, int sortIndex) => GetFirstNonNullValueStartingAtIndex(bufferIndex, sortIndex));
            GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Length);
 
            PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex,
                columnSortIndices,
                columnNullIndices,
                ascending,
                putNullValuesLast,
                getBufferSortIndex,
                getValueAtBuffer,
                getBufferLengthAtIndex);
 
            return columnSortIndices;
        }
 
        public new StringDataFrameColumn Clone(DataFrameColumn mapIndices, bool invertMapIndices, long numberOfNullsToAppend)
        {
            return (StringDataFrameColumn)CloneImplementation(mapIndices, invertMapIndices, numberOfNullsToAppend);
        }
 
        public new StringDataFrameColumn Clone(long numberOfNullsToAppend = 0)
        {
            return (StringDataFrameColumn)CloneImplementation(numberOfNullsToAppend);
        }
 
        protected override DataFrameColumn CloneImplementation(long numberOfNullsToAppend)
        {
            StringDataFrameColumn ret = new StringDataFrameColumn(Name, Length);
            for (long i = 0; i < Length; i++)
                ret[i] = this[i];
 
            for (long i = 0; i < numberOfNullsToAppend; i++)
                ret.Append(null);
 
            return ret;
        }
 
        protected override DataFrameColumn CloneImplementation(DataFrameColumn mapIndices, bool invertMapIndices = false, long numberOfNullsToAppend = 0)
        {
            StringDataFrameColumn clone;
            if (!(mapIndices is null))
            {
                Type dataType = mapIndices.DataType;
                if (dataType != typeof(long) && dataType != typeof(int) && dataType != typeof(bool))
                    throw new ArgumentException(String.Format(Strings.MultipleMismatchedValueType, typeof(long), typeof(int), typeof(bool)), nameof(mapIndices));
                if (mapIndices.DataType == typeof(long))
                    clone = CloneImplementation(mapIndices as PrimitiveDataFrameColumn<long>, invertMapIndices);
                else if (dataType == typeof(int))
                    clone = CloneImplementation(mapIndices as PrimitiveDataFrameColumn<int>, invertMapIndices);
                else
                    clone = CloneImplementation(mapIndices as PrimitiveDataFrameColumn<bool>);
 
                for (long i = 0; i < numberOfNullsToAppend; i++)
                    clone.Append(null);
            }
            else
            {
                clone = Clone(numberOfNullsToAppend);
            }
 
            return clone;
        }
 
        private StringDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn<bool> boolColumn)
        {
            if (boolColumn.Length > Length)
                throw new ArgumentException(Strings.MapIndicesExceedsColumnLength, nameof(boolColumn));
            StringDataFrameColumn ret = new StringDataFrameColumn(Name, 0);
            for (long i = 0; i < boolColumn.Length; i++)
            {
                bool? value = boolColumn[i];
                if (value.HasValue && value.Value)
                    ret.Append(this[i]);
            }
            return ret;
        }
 
        private StringDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn<int> mapIndices, bool invertMapIndices = false)
        {
            mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices));
            var ret = new StringDataFrameColumn(Name, mapIndices.Length);
 
            long rowIndex = 0;
            for (int b = 0; b < mapIndices.ColumnContainer.Buffers.Count; b++)
            {
                var span = mapIndices.ColumnContainer.Buffers[b].ReadOnlySpan;
                var validitySpan = mapIndices.ColumnContainer.NullBitMapBuffers[b].ReadOnlySpan;
 
                for (int i = 0; i < span.Length; i++)
                {
                    long index = invertMapIndices ? mapIndices.Length - 1 - rowIndex : rowIndex;
                    ret[index] = BitUtility.IsValid(validitySpan, i) ? this[span[i]] : null;
                    rowIndex++;
                }
            }
 
            return ret;
        }
 
        private StringDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn<long> mapIndices, bool invertMapIndices = false)
        {
            mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices));
            var ret = new StringDataFrameColumn(Name, mapIndices.Length);
 
            long rowIndex = 0;
            for (int b = 0; b < mapIndices.ColumnContainer.Buffers.Count; b++)
            {
                var span = mapIndices.ColumnContainer.Buffers[b].ReadOnlySpan;
                var validitySpan = mapIndices.ColumnContainer.NullBitMapBuffers[b].ReadOnlySpan;
 
                for (int i = 0; i < span.Length; i++)
                {
                    long index = invertMapIndices ? mapIndices.Length - 1 - rowIndex : rowIndex;
                    ret[index] = BitUtility.IsValid(validitySpan, i) ? this[span[i]] : null;
                    rowIndex++;
                }
            }
 
            return ret;
        }
 
        internal static DataFrame ValueCountsImplementation(Dictionary<string, ICollection<long>> groupedValues)
        {
            StringDataFrameColumn keys = new StringDataFrameColumn("Values", 0);
            PrimitiveDataFrameColumn<long> counts = new PrimitiveDataFrameColumn<long>("Counts");
            foreach (KeyValuePair<string, ICollection<long>> keyValuePair in groupedValues)
            {
                keys.Append(keyValuePair.Key);
                counts.Append(keyValuePair.Value.Count);
            }
            return new DataFrame(new List<DataFrameColumn> { keys, counts });
        }
 
        public override DataFrame ValueCounts()
        {
            Dictionary<string, ICollection<long>> groupedValues = GroupColumnValues<string>(out HashSet<long> _);
            return ValueCountsImplementation(groupedValues);
        }
 
        public override GroupBy GroupBy(int columnIndex, DataFrame parent)
        {
            Dictionary<string, ICollection<long>> dictionary = GroupColumnValues<string>(out HashSet<long> _);
            return new GroupBy<string>(parent, columnIndex, dictionary);
        }
 
        public override Dictionary<TKey, ICollection<long>> GroupColumnValues<TKey>(out HashSet<long> nullIndices)
        {
            if (typeof(TKey) == typeof(string))
            {
                Dictionary<string, ICollection<long>> multimap = new Dictionary<string, ICollection<long>>(EqualityComparer<string>.Default);
                nullIndices = new HashSet<long>();
                for (long i = 0; i < Length; i++)
                {
                    string str = this[i];
                    if (str != null)
                    {
                        bool containsKey = multimap.TryGetValue(str, out ICollection<long> values);
                        if (containsKey)
                        {
                            values.Add(i);
                        }
                        else
                        {
                            multimap.Add(str, new List<long>() { i });
                        }
                    }
                    else
                    {
                        nullIndices.Add(i);
                    }
                }
                return multimap as Dictionary<TKey, ICollection<long>>;
            }
            else
            {
                throw new NotImplementedException(nameof(TKey));
            }
        }
 
        /// <summary>
        /// Returns a new column with <see langword="null" /> elements replaced by <paramref name="value"/>.
        /// </summary>
        /// <remarks>Tries to convert value to the column's DataType</remarks>
        /// <param name="value"></param>
        /// <param name="inPlace">Indicates if the operation should be performed in place</param>
        public StringDataFrameColumn FillNulls(string value, bool inPlace = false)
        {
            if (value == null)
                throw new ArgumentNullException(nameof(value));
            StringDataFrameColumn column = inPlace ? this : Clone();
 
            for (long i = 0; i < column.Length; i++)
            {
                if (column[i] == null)
                    column[i] = value;
            }
            return column;
        }
 
        protected override DataFrameColumn FillNullsImplementation(object value, bool inPlace)
        {
            if (value is string valueString)
                return FillNulls(valueString, inPlace);
            else
                throw new ArgumentException(String.Format(Strings.MismatchedValueType, typeof(string)), nameof(value));
        }
 
        /// <inheritdoc/>
        public new StringDataFrameColumn DropNulls()
        {
            return (StringDataFrameColumn)DropNullsImplementation();
        }
 
        protected override DataFrameColumn DropNullsImplementation()
        {
            var ret = new StringDataFrameColumn(Name, Length - NullCount);
 
            long j = 0;
            for (long i = 0; i < Length; i++)
            {
                var value = this[i];
 
                if (value != null)
                {
                    ret[j++] = value;
                }
            }
 
            return ret;
        }
 
        protected internal override void AddDataViewColumn(DataViewSchema.Builder builder)
        {
            builder.AddColumn(Name, TextDataViewType.Instance);
        }
 
        protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor)
        {
            return CreateValueGetterDelegate(cursor);
        }
 
        private ValueGetter<ReadOnlyMemory<char>> CreateValueGetterDelegate(DataViewRowCursor cursor) =>
            (ref ReadOnlyMemory<char> value) => value = this[cursor.Position].AsMemory();
 
        protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter)
        {
            long row = cursor.Position;
            ReadOnlyMemory<char> value = default;
            Debug.Assert(getter != null, "Excepted getter to be valid");
 
            (getter as ValueGetter<ReadOnlyMemory<char>>)(ref value);
 
            if (Length > row)
            {
                this[row] = value.ToString();
            }
            else if (Length == row)
            {
                Append(value.ToString());
            }
            else
            {
                throw new IndexOutOfRangeException(nameof(row));
            }
        }
 
        protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn)
        {
            return cursor.GetGetter<ReadOnlyMemory<char>>(schemaColumn);
        }
 
        public override Dictionary<long, ICollection<long>> GetGroupedOccurrences(DataFrameColumn other, out HashSet<long> otherColumnNullIndices)
        {
            return GetGroupedOccurrences<string>(other, out otherColumnNullIndices);
        }
    }
}