sql_cli/data/
group_by_expressions.rs

1// GROUP BY expression evaluation support
2
3use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
8use crate::data::data_view::DataView;
9use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
10use crate::data::query_engine::QueryEngine;
11use crate::sql::aggregates::contains_aggregate;
12use crate::sql::parser::ast::{SelectItem, SqlExpression};
13use tracing::debug;
14
15/// Key for grouping rows - contains the evaluated expression values
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct GroupKey(pub Vec<DataValue>);
18
19/// Extension methods for QueryEngine to handle GROUP BY expressions
20pub trait GroupByExpressions {
21    /// Group rows by evaluating expressions for each row
22    fn group_by_expressions(
23        &self,
24        view: DataView,
25        group_by_exprs: &[SqlExpression],
26    ) -> Result<HashMap<GroupKey, DataView>>;
27
28    /// Apply GROUP BY with expressions to the view
29    fn apply_group_by_expressions(
30        &self,
31        view: DataView,
32        group_by_exprs: &[SqlExpression],
33        select_items: &[SelectItem],
34        having: Option<&SqlExpression>,
35        _case_insensitive: bool,
36        date_notation: String,
37    ) -> Result<DataView>;
38}
39
40impl GroupByExpressions for QueryEngine {
41    fn group_by_expressions(
42        &self,
43        view: DataView,
44        group_by_exprs: &[SqlExpression],
45    ) -> Result<HashMap<GroupKey, DataView>> {
46        let mut groups = HashMap::new();
47        let mut group_rows: HashMap<GroupKey, Vec<usize>> = HashMap::new();
48
49        // Process each visible row
50        for row_idx in view.get_visible_rows().iter().copied() {
51            // Evaluate GROUP BY expressions for this row
52            let mut key_values = Vec::new();
53            for expr in group_by_exprs {
54                let mut evaluator = ArithmeticEvaluator::new(view.source());
55                let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
56                key_values.push(value);
57            }
58
59            let key = GroupKey(key_values);
60            group_rows.entry(key).or_default().push(row_idx);
61        }
62
63        // Create DataViews for each group
64        for (key, rows) in group_rows {
65            let mut group_view = DataView::new(view.source_arc());
66            group_view = group_view.with_rows(rows);
67            groups.insert(key, group_view);
68        }
69
70        Ok(groups)
71    }
72
73    fn apply_group_by_expressions(
74        &self,
75        view: DataView,
76        group_by_exprs: &[SqlExpression],
77        select_items: &[SelectItem],
78        having: Option<&SqlExpression>,
79        _case_insensitive: bool,
80        date_notation: String,
81    ) -> Result<DataView> {
82        debug!(
83            "apply_group_by_expressions - grouping by {} expressions",
84            group_by_exprs.len()
85        );
86
87        // Build groups by evaluating expressions for each row
88        let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
89        debug!(
90            "apply_group_by_expressions - created {} groups",
91            groups.len()
92        );
93
94        // Create a result table for the grouped data
95        let mut result_table = DataTable::new("grouped_result");
96
97        // First, scan SELECT items to find non-aggregate expressions and their aliases
98        let mut aggregate_columns = Vec::new();
99        let mut non_aggregate_exprs = Vec::new();
100        let mut group_by_aliases = Vec::new();
101
102        // Map GROUP BY expressions to their aliases from SELECT items
103        for (i, group_expr) in group_by_exprs.iter().enumerate() {
104            let mut found_alias = None;
105
106            // Look for a matching SELECT item with an alias
107            for item in select_items {
108                if let SelectItem::Expression { expr, alias } = item {
109                    if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
110                        found_alias = Some(alias.clone());
111                        break;
112                    }
113                }
114            }
115
116            // Use found alias or generate a default one
117            let alias = found_alias.unwrap_or_else(|| match group_expr {
118                SqlExpression::Column(name) => name.clone(),
119                _ => format!("group_expr_{}", i + 1),
120            });
121
122            result_table.add_column(DataColumn::new(&alias));
123            group_by_aliases.push(alias);
124        }
125
126        // Now process SELECT items to find aggregates and validate non-aggregates
127        for item in select_items {
128            match item {
129                SelectItem::Expression { expr, alias } => {
130                    if contains_aggregate(expr) {
131                        // Aggregate expression
132                        result_table.add_column(DataColumn::new(alias));
133                        aggregate_columns.push((expr.clone(), alias.clone()));
134                    } else {
135                        // Non-aggregate expression - must match a GROUP BY expression
136                        let mut found = false;
137                        for group_expr in group_by_exprs {
138                            if expressions_match(expr, group_expr) {
139                                found = true;
140                                non_aggregate_exprs.push((expr.clone(), alias.clone()));
141                                break;
142                            }
143                        }
144                        if !found {
145                            // Check if it's a simple column that's in GROUP BY
146                            if let SqlExpression::Column(col) = expr {
147                                // Check if this column is referenced in any GROUP BY expression
148                                let referenced = group_by_exprs
149                                    .iter()
150                                    .any(|ge| expression_references_column(ge, col));
151                                if !referenced {
152                                    return Err(anyhow!(
153                                        "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
154                                        alias
155                                    ));
156                                }
157                            } else {
158                                return Err(anyhow!(
159                                    "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
160                                    alias
161                                ));
162                            }
163                        }
164                    }
165                }
166                SelectItem::Column(col_name) => {
167                    // Check if this column is in a GROUP BY expression
168                    let in_group_by = group_by_exprs.iter().any(
169                        |expr| matches!(expr, SqlExpression::Column(name) if name == col_name),
170                    );
171
172                    if !in_group_by {
173                        return Err(anyhow!(
174                            "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
175                            col_name
176                        ));
177                    }
178                }
179                SelectItem::Star => {
180                    // For GROUP BY queries, * includes GROUP BY columns
181                    // Already handled by adding group_by_aliases columns
182                }
183            }
184        }
185
186        // Process each group
187        for (group_key, group_view) in groups {
188            let mut row_values = Vec::new();
189
190            // Add GROUP BY expression values
191            for value in &group_key.0 {
192                row_values.push(value.clone());
193            }
194
195            // Calculate aggregate values for this group
196            for (expr, _col_name) in &aggregate_columns {
197                let group_rows = group_view.get_visible_rows();
198                let mut evaluator = ArithmeticEvaluator::with_date_notation(
199                    group_view.source(),
200                    date_notation.clone(),
201                )
202                .with_visible_rows(group_rows.clone());
203
204                let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
205                    evaluator
206                        .evaluate(expr, group_rows[0])
207                        .unwrap_or(DataValue::Null)
208                } else {
209                    DataValue::Null
210                };
211
212                row_values.push(value);
213            }
214
215            // Evaluate HAVING clause if present
216            if let Some(having_expr) = having {
217                // Create a temporary table with one row containing the group values
218                let mut temp_table = DataTable::new("having_eval");
219
220                // Add columns for GROUP BY expressions
221                for alias in &group_by_aliases {
222                    temp_table.add_column(DataColumn::new(alias));
223                }
224
225                // Add columns for aggregates
226                for (_, alias) in &aggregate_columns {
227                    temp_table.add_column(DataColumn::new(alias));
228                }
229
230                temp_table
231                    .add_row(DataRow::new(row_values.clone()))
232                    .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
233
234                // Evaluate HAVING expression
235                let mut evaluator =
236                    ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
237                let having_result = evaluator.evaluate(having_expr, 0)?;
238
239                // Skip this group if HAVING condition is not met
240                if !is_truthy(&having_result) {
241                    continue;
242                }
243            }
244
245            // Add the row to the result table
246            result_table
247                .add_row(DataRow::new(row_values))
248                .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
249        }
250
251        Ok(DataView::new(Arc::new(result_table)))
252    }
253}
254
255/// Check if two expressions are equivalent (for GROUP BY validation)
256fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
257    // Simple equality check for now
258    // Could be enhanced to handle semantic equivalence
259    format!("{:?}", expr1) == format!("{:?}", expr2)
260}
261
262/// Check if an expression references a column
263fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
264    match expr {
265        SqlExpression::Column(name) => name == column,
266        SqlExpression::BinaryOp { left, right, .. } => {
267            expression_references_column(left, column)
268                || expression_references_column(right, column)
269        }
270        SqlExpression::FunctionCall { args, .. } => args
271            .iter()
272            .any(|arg| expression_references_column(arg, column)),
273        SqlExpression::Between { expr, lower, upper } => {
274            expression_references_column(expr, column)
275                || expression_references_column(lower, column)
276                || expression_references_column(upper, column)
277        }
278        _ => false,
279    }
280}
281
282/// Check if a DataValue is truthy (for HAVING evaluation)
283fn is_truthy(value: &DataValue) -> bool {
284    match value {
285        DataValue::Boolean(b) => *b,
286        DataValue::Integer(i) => *i != 0,
287        DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
288        DataValue::Null => false,
289        _ => true,
290    }
291}