sql_cli/data/
group_by_expressions.rs

1// GROUP BY expression evaluation support
2
3use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8
9use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
10use crate::data::data_view::DataView;
11use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
12use crate::data::query_engine::QueryEngine;
13use crate::sql::aggregates::contains_aggregate;
14use crate::sql::parser::ast::{ColumnRef, SelectItem, SqlExpression};
15use tracing::debug;
16
17/// Detailed phase information for GROUP BY operations
18#[derive(Debug, Clone)]
19pub struct GroupByPhaseInfo {
20    pub total_rows: usize,
21    pub num_groups: usize,
22    pub num_expressions: usize,
23    pub phase1_cardinality_estimation: Duration,
24    pub phase2_key_building: Duration,
25    pub phase2_expression_evaluation: Duration,
26    pub phase3_dataview_creation: Duration,
27    pub phase4_aggregation: Duration,
28    pub phase4_having_evaluation: Duration,
29    pub groups_filtered_by_having: usize,
30    pub total_time: Duration,
31}
32
33/// Key for grouping rows - contains the evaluated expression values
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct GroupKey(pub Vec<DataValue>);
36
37/// Extension methods for QueryEngine to handle GROUP BY expressions
38pub trait GroupByExpressions {
39    /// Group rows by evaluating expressions for each row
40    fn group_by_expressions(
41        &self,
42        view: DataView,
43        group_by_exprs: &[SqlExpression],
44    ) -> Result<FxHashMap<GroupKey, DataView>>;
45
46    /// Apply GROUP BY with expressions to the view
47    fn apply_group_by_expressions(
48        &self,
49        view: DataView,
50        group_by_exprs: &[SqlExpression],
51        select_items: &[SelectItem],
52        having: Option<&SqlExpression>,
53        _case_insensitive: bool,
54        date_notation: String,
55    ) -> Result<(DataView, GroupByPhaseInfo)>;
56}
57
58impl GroupByExpressions for QueryEngine {
59    fn group_by_expressions(
60        &self,
61        view: DataView,
62        group_by_exprs: &[SqlExpression],
63    ) -> Result<FxHashMap<GroupKey, DataView>> {
64        use std::time::Instant;
65        let start = Instant::now();
66
67        // Phase 1: Estimate cardinality for pre-sizing
68        let phase1_start = Instant::now();
69        let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
70        let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
71        let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
72            FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
73        let phase1_time = phase1_start.elapsed();
74        debug!(
75            "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
76            phase1_time, estimated_groups
77        );
78
79        // Phase 2: Process each visible row and build group keys
80        let phase2_start = Instant::now();
81        let visible_rows = view.get_visible_rows();
82        let total_rows = visible_rows.len();
83        debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
84
85        // OPTIMIZATION: Create evaluator once outside the loop!
86        let mut evaluator = ArithmeticEvaluator::new(view.source());
87
88        // OPTIMIZATION: Pre-allocate key_values vector with the right capacity
89        let mut key_values = Vec::with_capacity(group_by_exprs.len());
90
91        for row_idx in visible_rows.iter().copied() {
92            // Clear and reuse the vector instead of allocating new one
93            key_values.clear();
94
95            // Evaluate GROUP BY expressions for this row
96            for expr in group_by_exprs {
97                let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
98                key_values.push(value);
99            }
100
101            let key = GroupKey(key_values.clone()); // Need to clone here for the key
102            group_rows.entry(key).or_default().push(row_idx);
103        }
104        let phase2_time = phase2_start.elapsed();
105        debug!(
106            "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
107            phase2_time,
108            group_rows.len()
109        );
110
111        // Phase 3: Create DataViews for each group
112        let phase3_start = Instant::now();
113        for (key, rows) in group_rows {
114            let mut group_view = DataView::new(view.source_arc());
115            group_view = group_view.with_rows(rows);
116            groups.insert(key, group_view);
117        }
118        let phase3_time = phase3_start.elapsed();
119        debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
120
121        let total_time = start.elapsed();
122        debug!(
123            "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
124            total_time, phase1_time, phase2_time, phase3_time
125        );
126
127        Ok(groups)
128    }
129
130    fn apply_group_by_expressions(
131        &self,
132        view: DataView,
133        group_by_exprs: &[SqlExpression],
134        select_items: &[SelectItem],
135        having: Option<&SqlExpression>,
136        _case_insensitive: bool,
137        date_notation: String,
138    ) -> Result<(DataView, GroupByPhaseInfo)> {
139        use std::time::Instant;
140        let start = Instant::now();
141
142        debug!(
143            "apply_group_by_expressions - grouping by {} expressions, {} select items",
144            group_by_exprs.len(),
145            select_items.len()
146        );
147
148        // Phase 1: Build groups by evaluating expressions for each row
149        let phase1_start = Instant::now();
150        let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
151        let phase1_time = phase1_start.elapsed();
152        debug!(
153            "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
154            phase1_time,
155            groups.len()
156        );
157
158        // Create a result table for the grouped data
159        let mut result_table = DataTable::new("grouped_result");
160
161        // First, scan SELECT items to find non-aggregate expressions and their aliases
162        let mut aggregate_columns = Vec::new();
163        let mut non_aggregate_exprs = Vec::new();
164        let mut group_by_aliases = Vec::new();
165
166        // Map GROUP BY expressions to their aliases from SELECT items
167        for (i, group_expr) in group_by_exprs.iter().enumerate() {
168            let mut found_alias = None;
169
170            // Look for a matching SELECT item with an alias
171            for item in select_items {
172                if let SelectItem::Expression { expr, alias } = item {
173                    if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
174                        found_alias = Some(alias.clone());
175                        break;
176                    }
177                }
178            }
179
180            // Use found alias or generate a default one
181            let alias = found_alias.unwrap_or_else(|| match group_expr {
182                SqlExpression::Column(column_ref) => column_ref.name.clone(),
183                _ => format!("group_expr_{}", i + 1),
184            });
185
186            result_table.add_column(DataColumn::new(&alias));
187            group_by_aliases.push(alias);
188        }
189
190        // Now process SELECT items to find aggregates and validate non-aggregates
191        for item in select_items {
192            match item {
193                SelectItem::Expression { expr, alias } => {
194                    if contains_aggregate(expr) {
195                        // Aggregate expression
196                        result_table.add_column(DataColumn::new(alias));
197                        aggregate_columns.push((expr.clone(), alias.clone()));
198                    } else {
199                        // Non-aggregate expression - must match a GROUP BY expression
200                        let mut found = false;
201                        for group_expr in group_by_exprs {
202                            if expressions_match(expr, group_expr) {
203                                found = true;
204                                non_aggregate_exprs.push((expr.clone(), alias.clone()));
205                                break;
206                            }
207                        }
208                        if !found {
209                            // Check if it's a simple column that's in GROUP BY
210                            if let SqlExpression::Column(col) = expr {
211                                // Check if this column is referenced in any GROUP BY expression
212                                let referenced = group_by_exprs
213                                    .iter()
214                                    .any(|ge| expression_references_column(ge, &col.name));
215                                if !referenced {
216                                    return Err(anyhow!(
217                                        "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
218                                        alias
219                                    ));
220                                }
221                            } else {
222                                return Err(anyhow!(
223                                    "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
224                                    alias
225                                ));
226                            }
227                        }
228                    }
229                }
230                SelectItem::Column(col_ref) => {
231                    // Check if this column is in a GROUP BY expression
232                    let in_group_by = group_by_exprs.iter().any(
233                        |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
234                    );
235
236                    if !in_group_by {
237                        return Err(anyhow!(
238                            "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
239                            col_ref.name
240                        ));
241                    }
242                }
243                SelectItem::Star => {
244                    // For GROUP BY queries, * includes GROUP BY columns
245                    // Already handled by adding group_by_aliases columns
246                }
247            }
248        }
249
250        // Phase 2: Process each group (aggregate computation)
251        let phase2_start = Instant::now();
252        let mut aggregation_time = std::time::Duration::ZERO;
253        let mut having_time = std::time::Duration::ZERO;
254        let mut groups_processed = 0;
255        let mut groups_filtered = 0;
256
257        for (group_key, group_view) in groups {
258            let mut row_values = Vec::new();
259
260            // Add GROUP BY expression values
261            for value in &group_key.0 {
262                row_values.push(value.clone());
263            }
264
265            // Calculate aggregate values for this group
266            let agg_start = Instant::now();
267            for (expr, _col_name) in &aggregate_columns {
268                let group_rows = group_view.get_visible_rows();
269                let mut evaluator = ArithmeticEvaluator::with_date_notation(
270                    group_view.source(),
271                    date_notation.clone(),
272                )
273                .with_visible_rows(group_rows.clone());
274
275                let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
276                    evaluator
277                        .evaluate(expr, group_rows[0])
278                        .unwrap_or(DataValue::Null)
279                } else {
280                    DataValue::Null
281                };
282
283                row_values.push(value);
284            }
285            aggregation_time += agg_start.elapsed();
286
287            // Evaluate HAVING clause if present
288            let having_start = Instant::now();
289            if let Some(having_expr) = having {
290                // Create a temporary table with one row containing the group values
291                let mut temp_table = DataTable::new("having_eval");
292
293                // Add columns for GROUP BY expressions
294                for alias in &group_by_aliases {
295                    temp_table.add_column(DataColumn::new(alias));
296                }
297
298                // Add columns for aggregates
299                for (_, alias) in &aggregate_columns {
300                    temp_table.add_column(DataColumn::new(alias));
301                }
302
303                temp_table
304                    .add_row(DataRow::new(row_values.clone()))
305                    .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
306
307                // Evaluate HAVING expression
308                let mut evaluator =
309                    ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
310                let having_result = evaluator.evaluate(having_expr, 0)?;
311
312                // Skip this group if HAVING condition is not met
313                if !is_truthy(&having_result) {
314                    groups_filtered += 1;
315                    having_time += having_start.elapsed();
316                    continue;
317                }
318            }
319            having_time += having_start.elapsed();
320
321            groups_processed += 1;
322
323            // Add the row to the result table
324            result_table
325                .add_row(DataRow::new(row_values))
326                .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
327        }
328
329        let phase2_time = phase2_start.elapsed();
330        let total_time = start.elapsed();
331
332        debug!(
333            "apply_group_by_expressions Phase 2 (aggregation): {:?}",
334            phase2_time
335        );
336        debug!("  - Aggregation time: {:?}", aggregation_time);
337        debug!("  - HAVING evaluation time: {:?}", having_time);
338        debug!(
339            "  - Groups processed: {}, filtered by HAVING: {}",
340            groups_processed, groups_filtered
341        );
342        debug!(
343            "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
344            total_time, phase1_time, phase2_time
345        );
346
347        let phase_info = GroupByPhaseInfo {
348            total_rows: view.row_count(),
349            num_groups: groups_processed,
350            num_expressions: group_by_exprs.len(),
351            phase1_cardinality_estimation: Duration::ZERO, // Not tracked separately in phase1
352            phase2_key_building: phase1_time,              // This is actually the grouping phase
353            phase2_expression_evaluation: Duration::ZERO,  // Included in phase2_key_building
354            phase3_dataview_creation: Duration::ZERO,      // Included in phase1_time
355            phase4_aggregation: aggregation_time,
356            phase4_having_evaluation: having_time,
357            groups_filtered_by_having: groups_filtered,
358            total_time,
359        };
360
361        Ok((DataView::new(Arc::new(result_table)), phase_info))
362    }
363}
364
365/// Check if two expressions are equivalent (for GROUP BY validation)
366fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
367    // Simple equality check for now
368    // Could be enhanced to handle semantic equivalence
369    format!("{:?}", expr1) == format!("{:?}", expr2)
370}
371
372/// Check if an expression references a column
373fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
374    match expr {
375        SqlExpression::Column(name) => name == column,
376        SqlExpression::BinaryOp { left, right, .. } => {
377            expression_references_column(left, column)
378                || expression_references_column(right, column)
379        }
380        SqlExpression::FunctionCall { args, .. } => args
381            .iter()
382            .any(|arg| expression_references_column(arg, column)),
383        SqlExpression::Between { expr, lower, upper } => {
384            expression_references_column(expr, column)
385                || expression_references_column(lower, column)
386                || expression_references_column(upper, column)
387        }
388        _ => false,
389    }
390}
391
392/// Check if a DataValue is truthy (for HAVING evaluation)
393fn is_truthy(value: &DataValue) -> bool {
394    match value {
395        DataValue::Boolean(b) => *b,
396        DataValue::Integer(i) => *i != 0,
397        DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
398        DataValue::Null => false,
399        _ => true,
400    }
401}