Skip to main content

sql_cli/data/
group_by_expressions.rs

1// GROUP BY expression evaluation support
2
3use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{SelectItem, SqlExpression};
14use tracing::debug;
15
16/// Detailed phase information for GROUP BY operations
17#[derive(Debug, Clone)]
18pub struct GroupByPhaseInfo {
19    pub total_rows: usize,
20    pub num_groups: usize,
21    pub num_expressions: usize,
22    pub phase1_cardinality_estimation: Duration,
23    pub phase2_key_building: Duration,
24    pub phase2_expression_evaluation: Duration,
25    pub phase3_dataview_creation: Duration,
26    pub phase4_aggregation: Duration,
27    pub phase4_having_evaluation: Duration,
28    pub groups_filtered_by_having: usize,
29    pub total_time: Duration,
30}
31
32/// Key for grouping rows - contains the evaluated expression values
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct GroupKey(pub Vec<DataValue>);
35
36/// Extension methods for QueryEngine to handle GROUP BY expressions
37pub trait GroupByExpressions {
38    /// Group rows by evaluating expressions for each row
39    fn group_by_expressions(
40        &self,
41        view: DataView,
42        group_by_exprs: &[SqlExpression],
43    ) -> Result<FxHashMap<GroupKey, DataView>>;
44
45    /// Apply GROUP BY with expressions to the view
46    fn apply_group_by_expressions(
47        &self,
48        view: DataView,
49        group_by_exprs: &[SqlExpression],
50        select_items: &[SelectItem],
51        having: Option<&SqlExpression>,
52        _case_insensitive: bool,
53        date_notation: String,
54    ) -> Result<(DataView, GroupByPhaseInfo)>;
55}
56
57impl GroupByExpressions for QueryEngine {
58    fn group_by_expressions(
59        &self,
60        view: DataView,
61        group_by_exprs: &[SqlExpression],
62    ) -> Result<FxHashMap<GroupKey, DataView>> {
63        use std::time::Instant;
64        let start = Instant::now();
65
66        // Phase 1: Estimate cardinality for pre-sizing
67        let phase1_start = Instant::now();
68        let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
69        let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
70        let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
71            FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
72        let phase1_time = phase1_start.elapsed();
73        debug!(
74            "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
75            phase1_time, estimated_groups
76        );
77
78        // Phase 2: Process each visible row and build group keys
79        let phase2_start = Instant::now();
80        let visible_rows = view.get_visible_rows();
81        let total_rows = visible_rows.len();
82        debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
83
84        // OPTIMIZATION: Create evaluator once outside the loop!
85        let mut evaluator = ArithmeticEvaluator::new(view.source());
86
87        // OPTIMIZATION: Pre-allocate key_values vector with the right capacity
88        let mut key_values = Vec::with_capacity(group_by_exprs.len());
89
90        for row_idx in visible_rows.iter().copied() {
91            // Clear and reuse the vector instead of allocating new one
92            key_values.clear();
93
94            // Evaluate GROUP BY expressions for this row
95            for expr in group_by_exprs {
96                let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
97                key_values.push(value);
98            }
99
100            let key = GroupKey(key_values.clone()); // Need to clone here for the key
101            group_rows.entry(key).or_default().push(row_idx);
102        }
103        let phase2_time = phase2_start.elapsed();
104        debug!(
105            "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
106            phase2_time,
107            group_rows.len()
108        );
109
110        // Phase 3: Create DataViews for each group
111        let phase3_start = Instant::now();
112        for (key, rows) in group_rows {
113            let mut group_view = DataView::new(view.source_arc());
114            group_view = group_view.with_rows(rows);
115            groups.insert(key, group_view);
116        }
117        let phase3_time = phase3_start.elapsed();
118        debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
119
120        let total_time = start.elapsed();
121        debug!(
122            "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
123            total_time, phase1_time, phase2_time, phase3_time
124        );
125
126        Ok(groups)
127    }
128
129    fn apply_group_by_expressions(
130        &self,
131        view: DataView,
132        group_by_exprs: &[SqlExpression],
133        select_items: &[SelectItem],
134        having: Option<&SqlExpression>,
135        _case_insensitive: bool,
136        date_notation: String,
137    ) -> Result<(DataView, GroupByPhaseInfo)> {
138        use std::time::Instant;
139        let start = Instant::now();
140
141        debug!(
142            "apply_group_by_expressions - grouping by {} expressions, {} select items",
143            group_by_exprs.len(),
144            select_items.len()
145        );
146
147        // Phase 1: Build groups by evaluating expressions for each row
148        let phase1_start = Instant::now();
149        let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
150        let phase1_time = phase1_start.elapsed();
151        debug!(
152            "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
153            phase1_time,
154            groups.len()
155        );
156
157        // Create a result table for the grouped data
158        let mut result_table = DataTable::new("grouped_result");
159
160        // First, scan SELECT items to find non-aggregate expressions and their aliases
161        let mut aggregate_columns = Vec::new();
162        let mut non_aggregate_exprs = Vec::new();
163        // Computed expressions that depend only on GROUP BY columns but aren't
164        // direct matches (e.g. `user_id + 10` when GROUP BY is `user_id`).
165        // These need a result column and per-group evaluation.
166        let mut derived_grouped_exprs: Vec<(SqlExpression, String)> = Vec::new();
167        let mut group_by_aliases = Vec::new();
168
169        // Map GROUP BY expressions to their aliases from SELECT items
170        for (i, group_expr) in group_by_exprs.iter().enumerate() {
171            let mut found_alias = None;
172
173            // Look for a matching SELECT item with an alias
174            for item in select_items {
175                if let SelectItem::Expression { expr, alias, .. } = item {
176                    if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
177                        found_alias = Some(alias.clone());
178                        break;
179                    }
180                }
181            }
182
183            // Use found alias or generate a default one
184            let alias = found_alias.unwrap_or_else(|| match group_expr {
185                SqlExpression::Column(column_ref) => column_ref.name.clone(),
186                _ => format!("group_expr_{}", i + 1),
187            });
188
189            result_table.add_column(DataColumn::new(&alias));
190            group_by_aliases.push(alias);
191        }
192
193        // Now process SELECT items to find aggregates and validate non-aggregates
194        for item in select_items {
195            match item {
196                SelectItem::Expression { expr, alias, .. } => {
197                    if contains_aggregate(expr) {
198                        // Aggregate expression
199                        result_table.add_column(DataColumn::new(alias));
200                        aggregate_columns.push((expr.clone(), alias.clone()));
201                    } else {
202                        // Non-aggregate expression - must match a GROUP BY expression
203                        let mut found = false;
204                        for group_expr in group_by_exprs {
205                            if expressions_match(expr, group_expr) {
206                                found = true;
207                                non_aggregate_exprs.push((expr.clone(), alias.clone()));
208                                break;
209                            }
210                        }
211                        if !found {
212                            // Walk the expression and confirm every column ref
213                            // it touches is grouped (either appears in a GROUP
214                            // BY expression directly, or is referenced by one).
215                            // This permits computed expressions that depend
216                            // only on grouping keys, e.g.:
217                            //   SELECT user_id + 10 FROM t GROUP BY user_id
218                            // which standard SQL allows.
219                            let mut referenced_cols = Vec::new();
220                            collect_column_refs(expr, &mut referenced_cols);
221                            let all_grouped = !referenced_cols.is_empty()
222                                && referenced_cols.iter().all(|col_name| {
223                                    group_by_exprs
224                                        .iter()
225                                        .any(|ge| expression_references_column(ge, col_name))
226                                });
227                            if !all_grouped {
228                                return Err(anyhow!(
229                                    "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
230                                    alias
231                                ));
232                            }
233                            // Add a column for the derived expression and remember
234                            // it so Phase 2 evaluates it on each group's first row.
235                            result_table.add_column(DataColumn::new(alias));
236                            derived_grouped_exprs.push((expr.clone(), alias.clone()));
237                        }
238                    }
239                }
240                SelectItem::Column {
241                    column: col_ref, ..
242                } => {
243                    // Check if this column is in a GROUP BY expression
244                    let in_group_by = group_by_exprs.iter().any(
245                        |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
246                    );
247
248                    if !in_group_by {
249                        return Err(anyhow!(
250                            "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
251                            col_ref.name
252                        ));
253                    }
254                }
255                SelectItem::Star { .. } => {
256                    // For GROUP BY queries, * includes GROUP BY columns
257                    // Already handled by adding group_by_aliases columns
258                }
259                SelectItem::StarExclude { .. } => {
260                    // StarExclude behaves like Star in GROUP BY context
261                    // Expansion happens later in the query execution pipeline
262                }
263            }
264        }
265
266        // Phase 2: Process each group (aggregate computation)
267        let phase2_start = Instant::now();
268        let mut aggregation_time = std::time::Duration::ZERO;
269        let mut having_time = std::time::Duration::ZERO;
270        let mut groups_processed = 0;
271        let mut groups_filtered = 0;
272
273        for (group_key, group_view) in groups {
274            let mut row_values = Vec::new();
275
276            // Add GROUP BY expression values
277            for value in &group_key.0 {
278                row_values.push(value.clone());
279            }
280
281            // Calculate aggregate values for this group
282            let agg_start = Instant::now();
283            for (expr, _col_name) in &aggregate_columns {
284                let group_rows = group_view.get_visible_rows();
285                let mut evaluator = ArithmeticEvaluator::with_date_notation(
286                    group_view.source(),
287                    date_notation.clone(),
288                )
289                .with_visible_rows(group_rows.clone());
290
291                let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
292                    evaluator
293                        .evaluate(expr, group_rows[0])
294                        .unwrap_or(DataValue::Null)
295                } else {
296                    DataValue::Null
297                };
298
299                row_values.push(value);
300            }
301
302            // Evaluate derived expressions that depend only on GROUP BY columns.
303            // Same value for every row in the group, so evaluating on the first
304            // row of the group is correct.
305            for (expr, _alias) in &derived_grouped_exprs {
306                let group_rows = group_view.get_visible_rows();
307                let value = if !group_rows.is_empty() {
308                    let mut evaluator = ArithmeticEvaluator::with_date_notation(
309                        group_view.source(),
310                        date_notation.clone(),
311                    );
312                    evaluator
313                        .evaluate(expr, group_rows[0])
314                        .unwrap_or(DataValue::Null)
315                } else {
316                    DataValue::Null
317                };
318                row_values.push(value);
319            }
320
321            aggregation_time += agg_start.elapsed();
322
323            // Evaluate HAVING clause if present
324            let having_start = Instant::now();
325            if let Some(having_expr) = having {
326                // Create a temporary table with one row containing the group values
327                let mut temp_table = DataTable::new("having_eval");
328
329                // Add columns for GROUP BY expressions
330                for alias in &group_by_aliases {
331                    temp_table.add_column(DataColumn::new(alias));
332                }
333
334                // Add columns for aggregates
335                for (_, alias) in &aggregate_columns {
336                    temp_table.add_column(DataColumn::new(alias));
337                }
338
339                temp_table
340                    .add_row(DataRow::new(row_values.clone()))
341                    .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
342
343                // Evaluate HAVING expression
344                let mut evaluator =
345                    ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
346                let having_result = evaluator.evaluate(having_expr, 0)?;
347
348                // Skip this group if HAVING condition is not met
349                if !is_truthy(&having_result) {
350                    groups_filtered += 1;
351                    having_time += having_start.elapsed();
352                    continue;
353                }
354            }
355            having_time += having_start.elapsed();
356
357            groups_processed += 1;
358
359            // Add the row to the result table
360            result_table
361                .add_row(DataRow::new(row_values))
362                .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
363        }
364
365        let phase2_time = phase2_start.elapsed();
366        let total_time = start.elapsed();
367
368        debug!(
369            "apply_group_by_expressions Phase 2 (aggregation): {:?}",
370            phase2_time
371        );
372        debug!("  - Aggregation time: {:?}", aggregation_time);
373        debug!("  - HAVING evaluation time: {:?}", having_time);
374        debug!(
375            "  - Groups processed: {}, filtered by HAVING: {}",
376            groups_processed, groups_filtered
377        );
378        debug!(
379            "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
380            total_time, phase1_time, phase2_time
381        );
382
383        let phase_info = GroupByPhaseInfo {
384            total_rows: view.row_count(),
385            num_groups: groups_processed,
386            num_expressions: group_by_exprs.len(),
387            phase1_cardinality_estimation: Duration::ZERO, // Not tracked separately in phase1
388            phase2_key_building: phase1_time,              // This is actually the grouping phase
389            phase2_expression_evaluation: Duration::ZERO,  // Included in phase2_key_building
390            phase3_dataview_creation: Duration::ZERO,      // Included in phase1_time
391            phase4_aggregation: aggregation_time,
392            phase4_having_evaluation: having_time,
393            groups_filtered_by_having: groups_filtered,
394            total_time,
395        };
396
397        Ok((DataView::new(Arc::new(result_table)), phase_info))
398    }
399}
400
401/// Check if two expressions are equivalent (for GROUP BY validation)
402fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
403    // Simple equality check for now
404    // Could be enhanced to handle semantic equivalence
405    format!("{:?}", expr1) == format!("{:?}", expr2)
406}
407
408/// Recursively collect every column name referenced inside an expression.
409/// Used to validate that computed SELECT items in a GROUP BY query depend
410/// only on grouped columns.
411fn collect_column_refs(expr: &SqlExpression, out: &mut Vec<String>) {
412    match expr {
413        SqlExpression::Column(col_ref) => out.push(col_ref.name.clone()),
414        SqlExpression::BinaryOp { left, right, .. } => {
415            collect_column_refs(left, out);
416            collect_column_refs(right, out);
417        }
418        SqlExpression::FunctionCall { args, .. } => {
419            for arg in args {
420                collect_column_refs(arg, out);
421            }
422        }
423        SqlExpression::Between { expr, lower, upper } => {
424            collect_column_refs(expr, out);
425            collect_column_refs(lower, out);
426            collect_column_refs(upper, out);
427        }
428        SqlExpression::Not { expr } => collect_column_refs(expr, out),
429        SqlExpression::CaseExpression {
430            when_branches,
431            else_branch,
432        } => {
433            for branch in when_branches {
434                collect_column_refs(&branch.condition, out);
435                collect_column_refs(&branch.result, out);
436            }
437            if let Some(else_expr) = else_branch {
438                collect_column_refs(else_expr, out);
439            }
440        }
441        SqlExpression::SimpleCaseExpression {
442            expr,
443            when_branches,
444            else_branch,
445        } => {
446            collect_column_refs(expr, out);
447            for branch in when_branches {
448                collect_column_refs(&branch.value, out);
449                collect_column_refs(&branch.result, out);
450            }
451            if let Some(else_expr) = else_branch {
452                collect_column_refs(else_expr, out);
453            }
454        }
455        // Literals, NULL, subqueries, etc. — no plain column refs to collect
456        _ => {}
457    }
458}
459
460/// Check if an expression references a column
461fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
462    match expr {
463        SqlExpression::Column(name) => name == column,
464        SqlExpression::BinaryOp { left, right, .. } => {
465            expression_references_column(left, column)
466                || expression_references_column(right, column)
467        }
468        SqlExpression::FunctionCall { args, .. } => args
469            .iter()
470            .any(|arg| expression_references_column(arg, column)),
471        SqlExpression::Between { expr, lower, upper } => {
472            expression_references_column(expr, column)
473                || expression_references_column(lower, column)
474                || expression_references_column(upper, column)
475        }
476        _ => false,
477    }
478}
479
480/// Check if a DataValue is truthy (for HAVING evaluation)
481fn is_truthy(value: &DataValue) -> bool {
482    match value {
483        DataValue::Boolean(b) => *b,
484        DataValue::Integer(i) => *i != 0,
485        DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
486        DataValue::Null => false,
487        _ => true,
488    }
489}