1use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
8use crate::data::data_view::DataView;
9use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
10use crate::data::query_engine::QueryEngine;
11use crate::sql::aggregates::contains_aggregate;
12use crate::sql::parser::ast::{SelectItem, SqlExpression};
13use tracing::debug;
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct GroupKey(pub Vec<DataValue>);
18
19pub trait GroupByExpressions {
21 fn group_by_expressions(
23 &self,
24 view: DataView,
25 group_by_exprs: &[SqlExpression],
26 ) -> Result<HashMap<GroupKey, DataView>>;
27
28 fn apply_group_by_expressions(
30 &self,
31 view: DataView,
32 group_by_exprs: &[SqlExpression],
33 select_items: &[SelectItem],
34 having: Option<&SqlExpression>,
35 _case_insensitive: bool,
36 date_notation: String,
37 ) -> Result<DataView>;
38}
39
40impl GroupByExpressions for QueryEngine {
41 fn group_by_expressions(
42 &self,
43 view: DataView,
44 group_by_exprs: &[SqlExpression],
45 ) -> Result<HashMap<GroupKey, DataView>> {
46 let mut groups = HashMap::new();
47 let mut group_rows: HashMap<GroupKey, Vec<usize>> = HashMap::new();
48
49 for row_idx in view.get_visible_rows().iter().copied() {
51 let mut key_values = Vec::new();
53 for expr in group_by_exprs {
54 let mut evaluator = ArithmeticEvaluator::new(view.source());
55 let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
56 key_values.push(value);
57 }
58
59 let key = GroupKey(key_values);
60 group_rows.entry(key).or_default().push(row_idx);
61 }
62
63 for (key, rows) in group_rows {
65 let mut group_view = DataView::new(view.source_arc());
66 group_view = group_view.with_rows(rows);
67 groups.insert(key, group_view);
68 }
69
70 Ok(groups)
71 }
72
73 fn apply_group_by_expressions(
74 &self,
75 view: DataView,
76 group_by_exprs: &[SqlExpression],
77 select_items: &[SelectItem],
78 having: Option<&SqlExpression>,
79 _case_insensitive: bool,
80 date_notation: String,
81 ) -> Result<DataView> {
82 debug!(
83 "apply_group_by_expressions - grouping by {} expressions",
84 group_by_exprs.len()
85 );
86
87 let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
89 debug!(
90 "apply_group_by_expressions - created {} groups",
91 groups.len()
92 );
93
94 let mut result_table = DataTable::new("grouped_result");
96
97 let mut aggregate_columns = Vec::new();
99 let mut non_aggregate_exprs = Vec::new();
100 let mut group_by_aliases = Vec::new();
101
102 for (i, group_expr) in group_by_exprs.iter().enumerate() {
104 let mut found_alias = None;
105
106 for item in select_items {
108 if let SelectItem::Expression { expr, alias } = item {
109 if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
110 found_alias = Some(alias.clone());
111 break;
112 }
113 }
114 }
115
116 let alias = found_alias.unwrap_or_else(|| match group_expr {
118 SqlExpression::Column(name) => name.clone(),
119 _ => format!("group_expr_{}", i + 1),
120 });
121
122 result_table.add_column(DataColumn::new(&alias));
123 group_by_aliases.push(alias);
124 }
125
126 for item in select_items {
128 match item {
129 SelectItem::Expression { expr, alias } => {
130 if contains_aggregate(expr) {
131 result_table.add_column(DataColumn::new(alias));
133 aggregate_columns.push((expr.clone(), alias.clone()));
134 } else {
135 let mut found = false;
137 for group_expr in group_by_exprs {
138 if expressions_match(expr, group_expr) {
139 found = true;
140 non_aggregate_exprs.push((expr.clone(), alias.clone()));
141 break;
142 }
143 }
144 if !found {
145 if let SqlExpression::Column(col) = expr {
147 let referenced = group_by_exprs
149 .iter()
150 .any(|ge| expression_references_column(ge, col));
151 if !referenced {
152 return Err(anyhow!(
153 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
154 alias
155 ));
156 }
157 } else {
158 return Err(anyhow!(
159 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
160 alias
161 ));
162 }
163 }
164 }
165 }
166 SelectItem::Column(col_name) => {
167 let in_group_by = group_by_exprs.iter().any(
169 |expr| matches!(expr, SqlExpression::Column(name) if name == col_name),
170 );
171
172 if !in_group_by {
173 return Err(anyhow!(
174 "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
175 col_name
176 ));
177 }
178 }
179 SelectItem::Star => {
180 }
183 }
184 }
185
186 for (group_key, group_view) in groups {
188 let mut row_values = Vec::new();
189
190 for value in &group_key.0 {
192 row_values.push(value.clone());
193 }
194
195 for (expr, _col_name) in &aggregate_columns {
197 let group_rows = group_view.get_visible_rows();
198 let mut evaluator = ArithmeticEvaluator::with_date_notation(
199 group_view.source(),
200 date_notation.clone(),
201 )
202 .with_visible_rows(group_rows.clone());
203
204 let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
205 evaluator
206 .evaluate(expr, group_rows[0])
207 .unwrap_or(DataValue::Null)
208 } else {
209 DataValue::Null
210 };
211
212 row_values.push(value);
213 }
214
215 if let Some(having_expr) = having {
217 let mut temp_table = DataTable::new("having_eval");
219
220 for alias in &group_by_aliases {
222 temp_table.add_column(DataColumn::new(alias));
223 }
224
225 for (_, alias) in &aggregate_columns {
227 temp_table.add_column(DataColumn::new(alias));
228 }
229
230 temp_table
231 .add_row(DataRow::new(row_values.clone()))
232 .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
233
234 let mut evaluator =
236 ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
237 let having_result = evaluator.evaluate(having_expr, 0)?;
238
239 if !is_truthy(&having_result) {
241 continue;
242 }
243 }
244
245 result_table
247 .add_row(DataRow::new(row_values))
248 .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
249 }
250
251 Ok(DataView::new(Arc::new(result_table)))
252 }
253}
254
255fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
257 format!("{:?}", expr1) == format!("{:?}", expr2)
260}
261
262fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
264 match expr {
265 SqlExpression::Column(name) => name == column,
266 SqlExpression::BinaryOp { left, right, .. } => {
267 expression_references_column(left, column)
268 || expression_references_column(right, column)
269 }
270 SqlExpression::FunctionCall { args, .. } => args
271 .iter()
272 .any(|arg| expression_references_column(arg, column)),
273 SqlExpression::Between { expr, lower, upper } => {
274 expression_references_column(expr, column)
275 || expression_references_column(lower, column)
276 || expression_references_column(upper, column)
277 }
278 _ => false,
279 }
280}
281
282fn is_truthy(value: &DataValue) -> bool {
284 match value {
285 DataValue::Boolean(b) => *b,
286 DataValue::Integer(i) => *i != 0,
287 DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
288 DataValue::Null => false,
289 _ => true,
290 }
291}