sql_cli/data/
group_by_expressions.rs

1// GROUP BY expression evaluation support
2
3use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{SelectItem, SqlExpression};
14use tracing::debug;
15
16/// Detailed phase information for GROUP BY operations
17#[derive(Debug, Clone)]
18pub struct GroupByPhaseInfo {
19    pub total_rows: usize,
20    pub num_groups: usize,
21    pub num_expressions: usize,
22    pub phase1_cardinality_estimation: Duration,
23    pub phase2_key_building: Duration,
24    pub phase2_expression_evaluation: Duration,
25    pub phase3_dataview_creation: Duration,
26    pub phase4_aggregation: Duration,
27    pub phase4_having_evaluation: Duration,
28    pub groups_filtered_by_having: usize,
29    pub total_time: Duration,
30}
31
32/// Key for grouping rows - contains the evaluated expression values
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct GroupKey(pub Vec<DataValue>);
35
36/// Extension methods for QueryEngine to handle GROUP BY expressions
37pub trait GroupByExpressions {
38    /// Group rows by evaluating expressions for each row
39    fn group_by_expressions(
40        &self,
41        view: DataView,
42        group_by_exprs: &[SqlExpression],
43    ) -> Result<FxHashMap<GroupKey, DataView>>;
44
45    /// Apply GROUP BY with expressions to the view
46    fn apply_group_by_expressions(
47        &self,
48        view: DataView,
49        group_by_exprs: &[SqlExpression],
50        select_items: &[SelectItem],
51        having: Option<&SqlExpression>,
52        _case_insensitive: bool,
53        date_notation: String,
54    ) -> Result<(DataView, GroupByPhaseInfo)>;
55}
56
57impl GroupByExpressions for QueryEngine {
58    fn group_by_expressions(
59        &self,
60        view: DataView,
61        group_by_exprs: &[SqlExpression],
62    ) -> Result<FxHashMap<GroupKey, DataView>> {
63        use std::time::Instant;
64        let start = Instant::now();
65
66        // Phase 1: Estimate cardinality for pre-sizing
67        let phase1_start = Instant::now();
68        let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
69        let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
70        let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
71            FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
72        let phase1_time = phase1_start.elapsed();
73        debug!(
74            "GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
75            phase1_time, estimated_groups
76        );
77
78        // Phase 2: Process each visible row and build group keys
79        let phase2_start = Instant::now();
80        let visible_rows = view.get_visible_rows();
81        let total_rows = visible_rows.len();
82        debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
83
84        // OPTIMIZATION: Create evaluator once outside the loop!
85        let mut evaluator = ArithmeticEvaluator::new(view.source());
86
87        // OPTIMIZATION: Pre-allocate key_values vector with the right capacity
88        let mut key_values = Vec::with_capacity(group_by_exprs.len());
89
90        for row_idx in visible_rows.iter().copied() {
91            // Clear and reuse the vector instead of allocating new one
92            key_values.clear();
93
94            // Evaluate GROUP BY expressions for this row
95            for expr in group_by_exprs {
96                let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
97                key_values.push(value);
98            }
99
100            let key = GroupKey(key_values.clone()); // Need to clone here for the key
101            group_rows.entry(key).or_default().push(row_idx);
102        }
103        let phase2_time = phase2_start.elapsed();
104        debug!(
105            "GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
106            phase2_time,
107            group_rows.len()
108        );
109
110        // Phase 3: Create DataViews for each group
111        let phase3_start = Instant::now();
112        for (key, rows) in group_rows {
113            let mut group_view = DataView::new(view.source_arc());
114            group_view = group_view.with_rows(rows);
115            groups.insert(key, group_view);
116        }
117        let phase3_time = phase3_start.elapsed();
118        debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
119
120        let total_time = start.elapsed();
121        debug!(
122            "GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
123            total_time, phase1_time, phase2_time, phase3_time
124        );
125
126        Ok(groups)
127    }
128
129    fn apply_group_by_expressions(
130        &self,
131        view: DataView,
132        group_by_exprs: &[SqlExpression],
133        select_items: &[SelectItem],
134        having: Option<&SqlExpression>,
135        _case_insensitive: bool,
136        date_notation: String,
137    ) -> Result<(DataView, GroupByPhaseInfo)> {
138        use std::time::Instant;
139        let start = Instant::now();
140
141        debug!(
142            "apply_group_by_expressions - grouping by {} expressions, {} select items",
143            group_by_exprs.len(),
144            select_items.len()
145        );
146
147        // Phase 1: Build groups by evaluating expressions for each row
148        let phase1_start = Instant::now();
149        let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
150        let phase1_time = phase1_start.elapsed();
151        debug!(
152            "apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
153            phase1_time,
154            groups.len()
155        );
156
157        // Create a result table for the grouped data
158        let mut result_table = DataTable::new("grouped_result");
159
160        // First, scan SELECT items to find non-aggregate expressions and their aliases
161        let mut aggregate_columns = Vec::new();
162        let mut non_aggregate_exprs = Vec::new();
163        let mut group_by_aliases = Vec::new();
164
165        // Map GROUP BY expressions to their aliases from SELECT items
166        for (i, group_expr) in group_by_exprs.iter().enumerate() {
167            let mut found_alias = None;
168
169            // Look for a matching SELECT item with an alias
170            for item in select_items {
171                if let SelectItem::Expression { expr, alias } = item {
172                    if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
173                        found_alias = Some(alias.clone());
174                        break;
175                    }
176                }
177            }
178
179            // Use found alias or generate a default one
180            let alias = found_alias.unwrap_or_else(|| match group_expr {
181                SqlExpression::Column(column_ref) => column_ref.name.clone(),
182                _ => format!("group_expr_{}", i + 1),
183            });
184
185            result_table.add_column(DataColumn::new(&alias));
186            group_by_aliases.push(alias);
187        }
188
189        // Now process SELECT items to find aggregates and validate non-aggregates
190        for item in select_items {
191            match item {
192                SelectItem::Expression { expr, alias } => {
193                    if contains_aggregate(expr) {
194                        // Aggregate expression
195                        result_table.add_column(DataColumn::new(alias));
196                        aggregate_columns.push((expr.clone(), alias.clone()));
197                    } else {
198                        // Non-aggregate expression - must match a GROUP BY expression
199                        let mut found = false;
200                        for group_expr in group_by_exprs {
201                            if expressions_match(expr, group_expr) {
202                                found = true;
203                                non_aggregate_exprs.push((expr.clone(), alias.clone()));
204                                break;
205                            }
206                        }
207                        if !found {
208                            // Check if it's a simple column that's in GROUP BY
209                            if let SqlExpression::Column(col) = expr {
210                                // Check if this column is referenced in any GROUP BY expression
211                                let referenced = group_by_exprs
212                                    .iter()
213                                    .any(|ge| expression_references_column(ge, &col.name));
214                                if !referenced {
215                                    return Err(anyhow!(
216                                        "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
217                                        alias
218                                    ));
219                                }
220                            } else {
221                                return Err(anyhow!(
222                                    "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
223                                    alias
224                                ));
225                            }
226                        }
227                    }
228                }
229                SelectItem::Column(col_ref) => {
230                    // Check if this column is in a GROUP BY expression
231                    let in_group_by = group_by_exprs.iter().any(
232                        |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
233                    );
234
235                    if !in_group_by {
236                        return Err(anyhow!(
237                            "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
238                            col_ref.name
239                        ));
240                    }
241                }
242                SelectItem::Star => {
243                    // For GROUP BY queries, * includes GROUP BY columns
244                    // Already handled by adding group_by_aliases columns
245                }
246            }
247        }
248
249        // Phase 2: Process each group (aggregate computation)
250        let phase2_start = Instant::now();
251        let mut aggregation_time = std::time::Duration::ZERO;
252        let mut having_time = std::time::Duration::ZERO;
253        let mut groups_processed = 0;
254        let mut groups_filtered = 0;
255
256        for (group_key, group_view) in groups {
257            let mut row_values = Vec::new();
258
259            // Add GROUP BY expression values
260            for value in &group_key.0 {
261                row_values.push(value.clone());
262            }
263
264            // Calculate aggregate values for this group
265            let agg_start = Instant::now();
266            for (expr, _col_name) in &aggregate_columns {
267                let group_rows = group_view.get_visible_rows();
268                let mut evaluator = ArithmeticEvaluator::with_date_notation(
269                    group_view.source(),
270                    date_notation.clone(),
271                )
272                .with_visible_rows(group_rows.clone());
273
274                let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
275                    evaluator
276                        .evaluate(expr, group_rows[0])
277                        .unwrap_or(DataValue::Null)
278                } else {
279                    DataValue::Null
280                };
281
282                row_values.push(value);
283            }
284            aggregation_time += agg_start.elapsed();
285
286            // Evaluate HAVING clause if present
287            let having_start = Instant::now();
288            if let Some(having_expr) = having {
289                // Create a temporary table with one row containing the group values
290                let mut temp_table = DataTable::new("having_eval");
291
292                // Add columns for GROUP BY expressions
293                for alias in &group_by_aliases {
294                    temp_table.add_column(DataColumn::new(alias));
295                }
296
297                // Add columns for aggregates
298                for (_, alias) in &aggregate_columns {
299                    temp_table.add_column(DataColumn::new(alias));
300                }
301
302                temp_table
303                    .add_row(DataRow::new(row_values.clone()))
304                    .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
305
306                // Evaluate HAVING expression
307                let mut evaluator =
308                    ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
309                let having_result = evaluator.evaluate(having_expr, 0)?;
310
311                // Skip this group if HAVING condition is not met
312                if !is_truthy(&having_result) {
313                    groups_filtered += 1;
314                    having_time += having_start.elapsed();
315                    continue;
316                }
317            }
318            having_time += having_start.elapsed();
319
320            groups_processed += 1;
321
322            // Add the row to the result table
323            result_table
324                .add_row(DataRow::new(row_values))
325                .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
326        }
327
328        let phase2_time = phase2_start.elapsed();
329        let total_time = start.elapsed();
330
331        debug!(
332            "apply_group_by_expressions Phase 2 (aggregation): {:?}",
333            phase2_time
334        );
335        debug!("  - Aggregation time: {:?}", aggregation_time);
336        debug!("  - HAVING evaluation time: {:?}", having_time);
337        debug!(
338            "  - Groups processed: {}, filtered by HAVING: {}",
339            groups_processed, groups_filtered
340        );
341        debug!(
342            "apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
343            total_time, phase1_time, phase2_time
344        );
345
346        let phase_info = GroupByPhaseInfo {
347            total_rows: view.row_count(),
348            num_groups: groups_processed,
349            num_expressions: group_by_exprs.len(),
350            phase1_cardinality_estimation: Duration::ZERO, // Not tracked separately in phase1
351            phase2_key_building: phase1_time,              // This is actually the grouping phase
352            phase2_expression_evaluation: Duration::ZERO,  // Included in phase2_key_building
353            phase3_dataview_creation: Duration::ZERO,      // Included in phase1_time
354            phase4_aggregation: aggregation_time,
355            phase4_having_evaluation: having_time,
356            groups_filtered_by_having: groups_filtered,
357            total_time,
358        };
359
360        Ok((DataView::new(Arc::new(result_table)), phase_info))
361    }
362}
363
364/// Check if two expressions are equivalent (for GROUP BY validation)
365fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
366    // Simple equality check for now
367    // Could be enhanced to handle semantic equivalence
368    format!("{:?}", expr1) == format!("{:?}", expr2)
369}
370
371/// Check if an expression references a column
372fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
373    match expr {
374        SqlExpression::Column(name) => name == column,
375        SqlExpression::BinaryOp { left, right, .. } => {
376            expression_references_column(left, column)
377                || expression_references_column(right, column)
378        }
379        SqlExpression::FunctionCall { args, .. } => args
380            .iter()
381            .any(|arg| expression_references_column(arg, column)),
382        SqlExpression::Between { expr, lower, upper } => {
383            expression_references_column(expr, column)
384                || expression_references_column(lower, column)
385                || expression_references_column(upper, column)
386        }
387        _ => false,
388    }
389}
390
391/// Check if a DataValue is truthy (for HAVING evaluation)
392fn is_truthy(value: &DataValue) -> bool {
393    match value {
394        DataValue::Boolean(b) => *b,
395        DataValue::Integer(i) => *i != 0,
396        DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
397        DataValue::Null => false,
398        _ => true,
399    }
400}