sql_cli/data/
group_by_expressions.rs

1// GROUP BY expression evaluation support
2
3use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{SelectItem, SqlExpression};
14use tracing::debug;
15
16/// Detailed phase information for GROUP BY operations
17#[derive(Debug, Clone)]
18pub struct GroupByPhaseInfo {
19    pub total_rows: usize,
20    pub num_groups: usize,
21    pub num_expressions: usize,
22    pub phase1_cardinality_estimation: Duration,
23    pub phase2_key_building: Duration,
24    pub phase2_expression_evaluation: Duration,
25    pub phase3_dataview_creation: Duration,
26    pub phase4_aggregation: Duration,
27    pub phase4_having_evaluation: Duration,
28    pub groups_filtered_by_having: usize,
29    pub total_time: Duration,
30}
31
32/// Key for grouping rows - contains the evaluated expression values
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct GroupKey(pub Vec<DataValue>);
35
36/// Extension methods for QueryEngine to handle GROUP BY expressions
37pub trait GroupByExpressions {
38    /// Group rows by evaluating expressions for each row
39    fn group_by_expressions(
40        &self,
41        view: DataView,
42        group_by_exprs: &[SqlExpression],
43    ) -> Result<FxHashMap<GroupKey, DataView>>;
44
45    /// Apply GROUP BY with expressions to the view
46    fn apply_group_by_expressions(
47        &self,
48        view: DataView,
49        group_by_exprs: &[SqlExpression],
50        select_items: &[SelectItem],
51        having: Option<&SqlExpression>,
52        _case_insensitive: bool,
53        date_notation: String,
54    ) -> Result<(DataView, GroupByPhaseInfo)>;
55}
56
57impl GroupByExpressions for QueryEngine {
58    fn group_by_expressions(
59        &self,
60        view: DataView,
61        group_by_exprs: &[SqlExpression],
62    ) -> Result<FxHashMap<GroupKey, DataView>> {
63        use std::time::Instant;
64        let start = Instant::now();
65
66        // Phase 1: Estimate cardinality for pre-sizing
67        let phase1_start = Instant::now();
68        let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
69        let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
70        let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
71            FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
72        let phase1_time = phase1_start.elapsed();
73        debug!(
74            "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
75            phase1_time, estimated_groups
76        );
77
78        // Phase 2: Process each visible row and build group keys
79        let phase2_start = Instant::now();
80        let visible_rows = view.get_visible_rows();
81        let total_rows = visible_rows.len();
82        debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
83
84        // OPTIMIZATION: Create evaluator once outside the loop!
85        let mut evaluator = ArithmeticEvaluator::new(view.source());
86
87        // OPTIMIZATION: Pre-allocate key_values vector with the right capacity
88        let mut key_values = Vec::with_capacity(group_by_exprs.len());
89
90        for row_idx in visible_rows.iter().copied() {
91            // Clear and reuse the vector instead of allocating new one
92            key_values.clear();
93
94            // Evaluate GROUP BY expressions for this row
95            for expr in group_by_exprs {
96                let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
97                key_values.push(value);
98            }
99
100            let key = GroupKey(key_values.clone()); // Need to clone here for the key
101            group_rows.entry(key).or_default().push(row_idx);
102        }
103        let phase2_time = phase2_start.elapsed();
104        debug!(
105            "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
106            phase2_time,
107            group_rows.len()
108        );
109
110        // Phase 3: Create DataViews for each group
111        let phase3_start = Instant::now();
112        for (key, rows) in group_rows {
113            let mut group_view = DataView::new(view.source_arc());
114            group_view = group_view.with_rows(rows);
115            groups.insert(key, group_view);
116        }
117        let phase3_time = phase3_start.elapsed();
118        debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
119
120        let total_time = start.elapsed();
121        debug!(
122            "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
123            total_time, phase1_time, phase2_time, phase3_time
124        );
125
126        Ok(groups)
127    }
128
129    fn apply_group_by_expressions(
130        &self,
131        view: DataView,
132        group_by_exprs: &[SqlExpression],
133        select_items: &[SelectItem],
134        having: Option<&SqlExpression>,
135        _case_insensitive: bool,
136        date_notation: String,
137    ) -> Result<(DataView, GroupByPhaseInfo)> {
138        use std::time::Instant;
139        let start = Instant::now();
140
141        debug!(
142            "apply_group_by_expressions - grouping by {} expressions, {} select items",
143            group_by_exprs.len(),
144            select_items.len()
145        );
146
147        // Phase 1: Build groups by evaluating expressions for each row
148        let phase1_start = Instant::now();
149        let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
150        let phase1_time = phase1_start.elapsed();
151        debug!(
152            "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
153            phase1_time,
154            groups.len()
155        );
156
157        // Create a result table for the grouped data
158        let mut result_table = DataTable::new("grouped_result");
159
160        // First, scan SELECT items to find non-aggregate expressions and their aliases
161        let mut aggregate_columns = Vec::new();
162        let mut non_aggregate_exprs = Vec::new();
163        let mut group_by_aliases = Vec::new();
164
165        // Map GROUP BY expressions to their aliases from SELECT items
166        for (i, group_expr) in group_by_exprs.iter().enumerate() {
167            let mut found_alias = None;
168
169            // Look for a matching SELECT item with an alias
170            for item in select_items {
171                if let SelectItem::Expression { expr, alias, .. } = item {
172                    if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
173                        found_alias = Some(alias.clone());
174                        break;
175                    }
176                }
177            }
178
179            // Use found alias or generate a default one
180            let alias = found_alias.unwrap_or_else(|| match group_expr {
181                SqlExpression::Column(column_ref) => column_ref.name.clone(),
182                _ => format!("group_expr_{}", i + 1),
183            });
184
185            result_table.add_column(DataColumn::new(&alias));
186            group_by_aliases.push(alias);
187        }
188
189        // Now process SELECT items to find aggregates and validate non-aggregates
190        for item in select_items {
191            match item {
192                SelectItem::Expression { expr, alias, .. } => {
193                    if contains_aggregate(expr) {
194                        // Aggregate expression
195                        result_table.add_column(DataColumn::new(alias));
196                        aggregate_columns.push((expr.clone(), alias.clone()));
197                    } else {
198                        // Non-aggregate expression - must match a GROUP BY expression
199                        let mut found = false;
200                        for group_expr in group_by_exprs {
201                            if expressions_match(expr, group_expr) {
202                                found = true;
203                                non_aggregate_exprs.push((expr.clone(), alias.clone()));
204                                break;
205                            }
206                        }
207                        if !found {
208                            // Check if it's a simple column that's in GROUP BY
209                            if let SqlExpression::Column(col) = expr {
210                                // Check if this column is referenced in any GROUP BY expression
211                                let referenced = group_by_exprs
212                                    .iter()
213                                    .any(|ge| expression_references_column(ge, &col.name));
214                                if !referenced {
215                                    return Err(anyhow!(
216                                        "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
217                                        alias
218                                    ));
219                                }
220                            } else {
221                                return Err(anyhow!(
222                                    "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
223                                    alias
224                                ));
225                            }
226                        }
227                    }
228                }
229                SelectItem::Column {
230                    column: col_ref, ..
231                } => {
232                    // Check if this column is in a GROUP BY expression
233                    let in_group_by = group_by_exprs.iter().any(
234                        |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
235                    );
236
237                    if !in_group_by {
238                        return Err(anyhow!(
239                            "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
240                            col_ref.name
241                        ));
242                    }
243                }
244                SelectItem::Star { .. } => {
245                    // For GROUP BY queries, * includes GROUP BY columns
246                    // Already handled by adding group_by_aliases columns
247                }
248            }
249        }
250
251        // Phase 2: Process each group (aggregate computation)
252        let phase2_start = Instant::now();
253        let mut aggregation_time = std::time::Duration::ZERO;
254        let mut having_time = std::time::Duration::ZERO;
255        let mut groups_processed = 0;
256        let mut groups_filtered = 0;
257
258        for (group_key, group_view) in groups {
259            let mut row_values = Vec::new();
260
261            // Add GROUP BY expression values
262            for value in &group_key.0 {
263                row_values.push(value.clone());
264            }
265
266            // Calculate aggregate values for this group
267            let agg_start = Instant::now();
268            for (expr, _col_name) in &aggregate_columns {
269                let group_rows = group_view.get_visible_rows();
270                let mut evaluator = ArithmeticEvaluator::with_date_notation(
271                    group_view.source(),
272                    date_notation.clone(),
273                )
274                .with_visible_rows(group_rows.clone());
275
276                let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
277                    evaluator
278                        .evaluate(expr, group_rows[0])
279                        .unwrap_or(DataValue::Null)
280                } else {
281                    DataValue::Null
282                };
283
284                row_values.push(value);
285            }
286            aggregation_time += agg_start.elapsed();
287
288            // Evaluate HAVING clause if present
289            let having_start = Instant::now();
290            if let Some(having_expr) = having {
291                // Create a temporary table with one row containing the group values
292                let mut temp_table = DataTable::new("having_eval");
293
294                // Add columns for GROUP BY expressions
295                for alias in &group_by_aliases {
296                    temp_table.add_column(DataColumn::new(alias));
297                }
298
299                // Add columns for aggregates
300                for (_, alias) in &aggregate_columns {
301                    temp_table.add_column(DataColumn::new(alias));
302                }
303
304                temp_table
305                    .add_row(DataRow::new(row_values.clone()))
306                    .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
307
308                // Evaluate HAVING expression
309                let mut evaluator =
310                    ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
311                let having_result = evaluator.evaluate(having_expr, 0)?;
312
313                // Skip this group if HAVING condition is not met
314                if !is_truthy(&having_result) {
315                    groups_filtered += 1;
316                    having_time += having_start.elapsed();
317                    continue;
318                }
319            }
320            having_time += having_start.elapsed();
321
322            groups_processed += 1;
323
324            // Add the row to the result table
325            result_table
326                .add_row(DataRow::new(row_values))
327                .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
328        }
329
330        let phase2_time = phase2_start.elapsed();
331        let total_time = start.elapsed();
332
333        debug!(
334            "apply_group_by_expressions Phase 2 (aggregation): {:?}",
335            phase2_time
336        );
337        debug!("  - Aggregation time: {:?}", aggregation_time);
338        debug!("  - HAVING evaluation time: {:?}", having_time);
339        debug!(
340            "  - Groups processed: {}, filtered by HAVING: {}",
341            groups_processed, groups_filtered
342        );
343        debug!(
344            "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
345            total_time, phase1_time, phase2_time
346        );
347
348        let phase_info = GroupByPhaseInfo {
349            total_rows: view.row_count(),
350            num_groups: groups_processed,
351            num_expressions: group_by_exprs.len(),
352            phase1_cardinality_estimation: Duration::ZERO, // Not tracked separately in phase1
353            phase2_key_building: phase1_time,              // This is actually the grouping phase
354            phase2_expression_evaluation: Duration::ZERO,  // Included in phase2_key_building
355            phase3_dataview_creation: Duration::ZERO,      // Included in phase1_time
356            phase4_aggregation: aggregation_time,
357            phase4_having_evaluation: having_time,
358            groups_filtered_by_having: groups_filtered,
359            total_time,
360        };
361
362        Ok((DataView::new(Arc::new(result_table)), phase_info))
363    }
364}
365
366/// Check if two expressions are equivalent (for GROUP BY validation)
367fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
368    // Simple equality check for now
369    // Could be enhanced to handle semantic equivalence
370    format!("{:?}", expr1) == format!("{:?}", expr2)
371}
372
373/// Check if an expression references a column
374fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
375    match expr {
376        SqlExpression::Column(name) => name == column,
377        SqlExpression::BinaryOp { left, right, .. } => {
378            expression_references_column(left, column)
379                || expression_references_column(right, column)
380        }
381        SqlExpression::FunctionCall { args, .. } => args
382            .iter()
383            .any(|arg| expression_references_column(arg, column)),
384        SqlExpression::Between { expr, lower, upper } => {
385            expression_references_column(expr, column)
386                || expression_references_column(lower, column)
387                || expression_references_column(upper, column)
388        }
389        _ => false,
390    }
391}
392
393/// Check if a DataValue is truthy (for HAVING evaluation)
394fn is_truthy(value: &DataValue) -> bool {
395    match value {
396        DataValue::Boolean(b) => *b,
397        DataValue::Integer(i) => *i != 0,
398        DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
399        DataValue::Null => false,
400        _ => true,
401    }
402}