File: System\Data\Filter\AggregateNode.cs
Web Access
Project: src\src\libraries\System.Data.Common\src\System.Data.Common.csproj (System.Data.Common)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
 
namespace System.Data
{
    internal enum Aggregate
    {
        None = FunctionId.none,
        Sum = FunctionId.Sum,
        Avg = FunctionId.Avg,
        Min = FunctionId.Min,
        Max = FunctionId.Max,
        Count = FunctionId.Count,
        StDev = FunctionId.StDev,   // Statistical standard deviation
        Var = FunctionId.Var,       // Statistical variance
    }
 
    internal sealed class AggregateNode : ExpressionNode
    {
        private readonly AggregateType _type;
        private readonly Aggregate _aggregate;
        private readonly bool _local;     // set to true if the aggregate calculated locally (for the current table)
 
        private readonly string? _relationName;
        private readonly string _columnName;
 
        // CONSIDER PERF: keep the objects, not names.
        // ? try to drop a column
        private DataTable? _childTable;
        private DataColumn? _column;
        private DataRelation? _relation;
 
        internal AggregateNode(DataTable? table, FunctionId aggregateType, string columnName) :
            this(table, aggregateType, columnName, true, null)
        {
        }
 
        internal AggregateNode(DataTable? table, FunctionId aggregateType, string columnName, bool local, string? relationName) : base(table)
        {
            Debug.Assert(columnName != null, "Invalid parameter column name (null).");
            _aggregate = (Aggregate)(int)aggregateType;
 
            if (aggregateType == FunctionId.Sum)
                _type = AggregateType.Sum;
            else if (aggregateType == FunctionId.Avg)
                _type = AggregateType.Mean;
            else if (aggregateType == FunctionId.Min)
                _type = AggregateType.Min;
            else if (aggregateType == FunctionId.Max)
                _type = AggregateType.Max;
            else if (aggregateType == FunctionId.Count)
                _type = AggregateType.Count;
            else if (aggregateType == FunctionId.Var)
                _type = AggregateType.Var;
            else if (aggregateType == FunctionId.StDev)
                _type = AggregateType.StDev;
            else
            {
                throw ExprException.UndefinedFunction(Function.s_functionName[(int)aggregateType]);
            }
 
            _local = local;
            _relationName = relationName;
            _columnName = columnName;
        }
        internal override void Bind(DataTable table, List<DataColumn> list)
        {
            BindTable(table);
            if (table == null)
                throw ExprException.AggregateUnbound(ToString()!);
 
            if (_local)
            {
                _relation = null;
            }
            else
            {
                DataRelationCollection relations;
                relations = table.ChildRelations;
 
                if (_relationName == null)
                {
                    // must have one and only one relation
 
                    if (relations.Count > 1)
                    {
                        throw ExprException.UnresolvedRelation(table.TableName, ToString()!);
                    }
                    if (relations.Count == 1)
                    {
                        _relation = relations[0];
                    }
                    else
                    {
                        throw ExprException.AggregateUnbound(ToString()!);
                    }
                }
                else
                {
                    _relation = relations[_relationName];
                }
            }
 
            _childTable = (_relation == null) ? table : _relation.ChildTable;
 
            _column = _childTable.Columns[_columnName];
 
            if (_column == null)
                throw ExprException.UnboundName(_columnName);
 
            // add column to the dependency list, do not add duplicate columns
 
            Debug.Assert(_column != null, $"Failed to bind column {_columnName}");
 
            int i;
            for (i = 0; i < list.Count; i++)
            {
                // walk the list, check if the current column already on the list
                DataColumn dataColumn = list[i];
                if (_column == dataColumn)
                {
                    break;
                }
            }
            if (i >= list.Count)
            {
                list.Add(_column);
            }
 
            AggregateNode.Bind(_relation, list);
        }
 
        internal static void Bind(DataRelation? relation, List<DataColumn> list)
        {
            if (null != relation)
            {
                // add the ends of the relationship the expression depends on
                foreach (DataColumn c in relation.ChildColumnsReference)
                {
                    if (!list.Contains(c))
                    {
                        list.Add(c);
                    }
                }
                foreach (DataColumn c in relation.ParentColumnsReference)
                {
                    if (!list.Contains(c))
                    {
                        list.Add(c);
                    }
                }
            }
        }
 
        [RequiresUnreferencedCode(DataSet.RequiresUnreferencedCodeMessage)]
        internal override object Eval()
        {
            return Eval(null, DataRowVersion.Default);
        }
 
        [RequiresUnreferencedCode(DataSet.RequiresUnreferencedCodeMessage)]
        internal override object Eval(DataRow? row, DataRowVersion version)
        {
            if (_childTable == null)
                throw ExprException.AggregateUnbound(ToString()!);
 
            DataRow[] rows;
 
            if (_local)
            {
                rows = new DataRow[_childTable.Rows.Count];
                _childTable.Rows.CopyTo(rows, 0);
            }
            else
            {
                if (row == null)
                {
                    throw ExprException.EvalNoContext();
                }
                if (_relation == null)
                {
                    throw ExprException.AggregateUnbound(ToString()!);
                }
                rows = row.GetChildRows(_relation, version);
            }
 
            int[] records;
            if (version == DataRowVersion.Proposed)
            {
                version = DataRowVersion.Default;
            }
 
            List<int> recordList = new List<int>();
 
            for (int i = 0; i < rows.Length; i++)
            {
                if (rows[i].RowState == DataRowState.Deleted)
                {
                    if (DataRowAction.Rollback != rows[i]._action)
                    {
                        continue;
                    }
                    Debug.Assert(DataRowVersion.Original == version, "wrong version");
                    version = DataRowVersion.Original;
                }
                else if ((DataRowAction.Rollback == rows[i]._action) && (rows[i].RowState == DataRowState.Added))
                {
                    continue;
                }
                if (version == DataRowVersion.Original && rows[i]._oldRecord == -1)
                {
                    continue;
                }
                recordList.Add(rows[i].GetRecordFromVersion(version));
            }
            records = recordList.ToArray();
 
            return _column!.GetAggregateValue(records, _type);
        }
 
        // Helper for the DataTable.Compute method
        [RequiresUnreferencedCode(DataSet.RequiresUnreferencedCodeMessage)]
        internal override object Eval(int[] records)
        {
            if (_childTable == null)
                throw ExprException.AggregateUnbound(ToString()!);
            if (!_local)
            {
                throw ExprException.ComputeNotAggregate(ToString()!);
            }
            return _column!.GetAggregateValue(records, _type);
        }
 
        internal override bool IsConstant()
        {
            return false;
        }
 
        internal override bool IsTableConstant()
        {
            return _local;
        }
 
        internal override bool HasLocalAggregate()
        {
            return _local;
        }
 
        internal override bool HasRemoteAggregate()
        {
            return !_local;
        }
 
        internal override bool DependsOn(DataColumn column)
        {
            if (_column == column)
            {
                return true;
            }
            if (_column!.Computed)
            {
                return _column.DataExpression!.DependsOn(column);
            }
            return false;
        }
 
        internal override ExpressionNode Optimize()
        {
            return this;
        }
    }
}