sql_cli/data/
group_by_expressions.rs

1// GROUP BY expression evaluation support
2
3use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{ColumnRef, SelectItem, SqlExpression};
14use tracing::debug;
15
16/// Key for grouping rows - contains the evaluated expression values
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub struct GroupKey(pub Vec<DataValue>);
19
20/// Extension methods for QueryEngine to handle GROUP BY expressions
21pub trait GroupByExpressions {
22    /// Group rows by evaluating expressions for each row
23    fn group_by_expressions(
24        &self,
25        view: DataView,
26        group_by_exprs: &[SqlExpression],
27    ) -> Result<FxHashMap<GroupKey, DataView>>;
28
29    /// Apply GROUP BY with expressions to the view
30    fn apply_group_by_expressions(
31        &self,
32        view: DataView,
33        group_by_exprs: &[SqlExpression],
34        select_items: &[SelectItem],
35        having: Option<&SqlExpression>,
36        _case_insensitive: bool,
37        date_notation: String,
38    ) -> Result<DataView>;
39}
40
41impl GroupByExpressions for QueryEngine {
42    fn group_by_expressions(
43        &self,
44        view: DataView,
45        group_by_exprs: &[SqlExpression],
46    ) -> Result<FxHashMap<GroupKey, DataView>> {
47        // Estimate cardinality for pre-sizing
48        let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
49        let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
50        let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
51            FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
52
53        // Process each visible row
54        for row_idx in view.get_visible_rows().iter().copied() {
55            // Evaluate GROUP BY expressions for this row
56            let mut key_values = Vec::new();
57            for expr in group_by_exprs {
58                let mut evaluator = ArithmeticEvaluator::new(view.source());
59                let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
60                key_values.push(value);
61            }
62
63            let key = GroupKey(key_values);
64            group_rows.entry(key).or_default().push(row_idx);
65        }
66
67        // Create DataViews for each group
68        for (key, rows) in group_rows {
69            let mut group_view = DataView::new(view.source_arc());
70            group_view = group_view.with_rows(rows);
71            groups.insert(key, group_view);
72        }
73
74        Ok(groups)
75    }
76
77    fn apply_group_by_expressions(
78        &self,
79        view: DataView,
80        group_by_exprs: &[SqlExpression],
81        select_items: &[SelectItem],
82        having: Option<&SqlExpression>,
83        _case_insensitive: bool,
84        date_notation: String,
85    ) -> Result<DataView> {
86        debug!(
87            "apply_group_by_expressions - grouping by {} expressions",
88            group_by_exprs.len()
89        );
90
91        // Build groups by evaluating expressions for each row
92        let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
93        debug!(
94            "apply_group_by_expressions - created {} groups",
95            groups.len()
96        );
97
98        // Create a result table for the grouped data
99        let mut result_table = DataTable::new("grouped_result");
100
101        // First, scan SELECT items to find non-aggregate expressions and their aliases
102        let mut aggregate_columns = Vec::new();
103        let mut non_aggregate_exprs = Vec::new();
104        let mut group_by_aliases = Vec::new();
105
106        // Map GROUP BY expressions to their aliases from SELECT items
107        for (i, group_expr) in group_by_exprs.iter().enumerate() {
108            let mut found_alias = None;
109
110            // Look for a matching SELECT item with an alias
111            for item in select_items {
112                if let SelectItem::Expression { expr, alias } = item {
113                    if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
114                        found_alias = Some(alias.clone());
115                        break;
116                    }
117                }
118            }
119
120            // Use found alias or generate a default one
121            let alias = found_alias.unwrap_or_else(|| match group_expr {
122                SqlExpression::Column(column_ref) => column_ref.name.clone(),
123                _ => format!("group_expr_{}", i + 1),
124            });
125
126            result_table.add_column(DataColumn::new(&alias));
127            group_by_aliases.push(alias);
128        }
129
130        // Now process SELECT items to find aggregates and validate non-aggregates
131        for item in select_items {
132            match item {
133                SelectItem::Expression { expr, alias } => {
134                    if contains_aggregate(expr) {
135                        // Aggregate expression
136                        result_table.add_column(DataColumn::new(alias));
137                        aggregate_columns.push((expr.clone(), alias.clone()));
138                    } else {
139                        // Non-aggregate expression - must match a GROUP BY expression
140                        let mut found = false;
141                        for group_expr in group_by_exprs {
142                            if expressions_match(expr, group_expr) {
143                                found = true;
144                                non_aggregate_exprs.push((expr.clone(), alias.clone()));
145                                break;
146                            }
147                        }
148                        if !found {
149                            // Check if it's a simple column that's in GROUP BY
150                            if let SqlExpression::Column(col) = expr {
151                                // Check if this column is referenced in any GROUP BY expression
152                                let referenced = group_by_exprs
153                                    .iter()
154                                    .any(|ge| expression_references_column(ge, &col.name));
155                                if !referenced {
156                                    return Err(anyhow!(
157                                        "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
158                                        alias
159                                    ));
160                                }
161                            } else {
162                                return Err(anyhow!(
163                                    "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
164                                    alias
165                                ));
166                            }
167                        }
168                    }
169                }
170                SelectItem::Column(col_ref) => {
171                    // Check if this column is in a GROUP BY expression
172                    let in_group_by = group_by_exprs.iter().any(
173                        |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
174                    );
175
176                    if !in_group_by {
177                        return Err(anyhow!(
178                            "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
179                            col_ref.name
180                        ));
181                    }
182                }
183                SelectItem::Star => {
184                    // For GROUP BY queries, * includes GROUP BY columns
185                    // Already handled by adding group_by_aliases columns
186                }
187            }
188        }
189
190        // Process each group
191        for (group_key, group_view) in groups {
192            let mut row_values = Vec::new();
193
194            // Add GROUP BY expression values
195            for value in &group_key.0 {
196                row_values.push(value.clone());
197            }
198
199            // Calculate aggregate values for this group
200            for (expr, _col_name) in &aggregate_columns {
201                let group_rows = group_view.get_visible_rows();
202                let mut evaluator = ArithmeticEvaluator::with_date_notation(
203                    group_view.source(),
204                    date_notation.clone(),
205                )
206                .with_visible_rows(group_rows.clone());
207
208                let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
209                    evaluator
210                        .evaluate(expr, group_rows[0])
211                        .unwrap_or(DataValue::Null)
212                } else {
213                    DataValue::Null
214                };
215
216                row_values.push(value);
217            }
218
219            // Evaluate HAVING clause if present
220            if let Some(having_expr) = having {
221                // Create a temporary table with one row containing the group values
222                let mut temp_table = DataTable::new("having_eval");
223
224                // Add columns for GROUP BY expressions
225                for alias in &group_by_aliases {
226                    temp_table.add_column(DataColumn::new(alias));
227                }
228
229                // Add columns for aggregates
230                for (_, alias) in &aggregate_columns {
231                    temp_table.add_column(DataColumn::new(alias));
232                }
233
234                temp_table
235                    .add_row(DataRow::new(row_values.clone()))
236                    .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
237
238                // Evaluate HAVING expression
239                let mut evaluator =
240                    ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
241                let having_result = evaluator.evaluate(having_expr, 0)?;
242
243                // Skip this group if HAVING condition is not met
244                if !is_truthy(&having_result) {
245                    continue;
246                }
247            }
248
249            // Add the row to the result table
250            result_table
251                .add_row(DataRow::new(row_values))
252                .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
253        }
254
255        Ok(DataView::new(Arc::new(result_table)))
256    }
257}
258
259/// Check if two expressions are equivalent (for GROUP BY validation)
260fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
261    // Simple equality check for now
262    // Could be enhanced to handle semantic equivalence
263    format!("{:?}", expr1) == format!("{:?}", expr2)
264}
265
266/// Check if an expression references a column
267fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
268    match expr {
269        SqlExpression::Column(name) => name == column,
270        SqlExpression::BinaryOp { left, right, .. } => {
271            expression_references_column(left, column)
272                || expression_references_column(right, column)
273        }
274        SqlExpression::FunctionCall { args, .. } => args
275            .iter()
276            .any(|arg| expression_references_column(arg, column)),
277        SqlExpression::Between { expr, lower, upper } => {
278            expression_references_column(expr, column)
279                || expression_references_column(lower, column)
280                || expression_references_column(upper, column)
281        }
282        _ => false,
283    }
284}
285
286/// Check if a DataValue is truthy (for HAVING evaluation)
287fn is_truthy(value: &DataValue) -> bool {
288    match value {
289        DataValue::Boolean(b) => *b,
290        DataValue::Integer(i) => *i != 0,
291        DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
292        DataValue::Null => false,
293        _ => true,
294    }
295}