1use anyhow::{anyhow, Result};
4use fxhash::FxHashMap;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
11use crate::data::query_engine::QueryEngine;
12use crate::sql::aggregates::contains_aggregate;
13use crate::sql::parser::ast::{ColumnRef, SelectItem, SqlExpression};
14use tracing::debug;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub struct GroupKey(pub Vec<DataValue>);
19
20pub trait GroupByExpressions {
22 fn group_by_expressions(
24 &self,
25 view: DataView,
26 group_by_exprs: &[SqlExpression],
27 ) -> Result<FxHashMap<GroupKey, DataView>>;
28
29 fn apply_group_by_expressions(
31 &self,
32 view: DataView,
33 group_by_exprs: &[SqlExpression],
34 select_items: &[SelectItem],
35 having: Option<&SqlExpression>,
36 _case_insensitive: bool,
37 date_notation: String,
38 ) -> Result<DataView>;
39}
40
41impl GroupByExpressions for QueryEngine {
42 fn group_by_expressions(
43 &self,
44 view: DataView,
45 group_by_exprs: &[SqlExpression],
46 ) -> Result<FxHashMap<GroupKey, DataView>> {
47 let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
49 let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
50 let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
51 FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
52
53 for row_idx in view.get_visible_rows().iter().copied() {
55 let mut key_values = Vec::new();
57 for expr in group_by_exprs {
58 let mut evaluator = ArithmeticEvaluator::new(view.source());
59 let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
60 key_values.push(value);
61 }
62
63 let key = GroupKey(key_values);
64 group_rows.entry(key).or_default().push(row_idx);
65 }
66
67 for (key, rows) in group_rows {
69 let mut group_view = DataView::new(view.source_arc());
70 group_view = group_view.with_rows(rows);
71 groups.insert(key, group_view);
72 }
73
74 Ok(groups)
75 }
76
77 fn apply_group_by_expressions(
78 &self,
79 view: DataView,
80 group_by_exprs: &[SqlExpression],
81 select_items: &[SelectItem],
82 having: Option<&SqlExpression>,
83 _case_insensitive: bool,
84 date_notation: String,
85 ) -> Result<DataView> {
86 debug!(
87 "apply_group_by_expressions - grouping by {} expressions",
88 group_by_exprs.len()
89 );
90
91 let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
93 debug!(
94 "apply_group_by_expressions - created {} groups",
95 groups.len()
96 );
97
98 let mut result_table = DataTable::new("grouped_result");
100
101 let mut aggregate_columns = Vec::new();
103 let mut non_aggregate_exprs = Vec::new();
104 let mut group_by_aliases = Vec::new();
105
106 for (i, group_expr) in group_by_exprs.iter().enumerate() {
108 let mut found_alias = None;
109
110 for item in select_items {
112 if let SelectItem::Expression { expr, alias } = item {
113 if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
114 found_alias = Some(alias.clone());
115 break;
116 }
117 }
118 }
119
120 let alias = found_alias.unwrap_or_else(|| match group_expr {
122 SqlExpression::Column(column_ref) => column_ref.name.clone(),
123 _ => format!("group_expr_{}", i + 1),
124 });
125
126 result_table.add_column(DataColumn::new(&alias));
127 group_by_aliases.push(alias);
128 }
129
130 for item in select_items {
132 match item {
133 SelectItem::Expression { expr, alias } => {
134 if contains_aggregate(expr) {
135 result_table.add_column(DataColumn::new(alias));
137 aggregate_columns.push((expr.clone(), alias.clone()));
138 } else {
139 let mut found = false;
141 for group_expr in group_by_exprs {
142 if expressions_match(expr, group_expr) {
143 found = true;
144 non_aggregate_exprs.push((expr.clone(), alias.clone()));
145 break;
146 }
147 }
148 if !found {
149 if let SqlExpression::Column(col) = expr {
151 let referenced = group_by_exprs
153 .iter()
154 .any(|ge| expression_references_column(ge, &col.name));
155 if !referenced {
156 return Err(anyhow!(
157 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
158 alias
159 ));
160 }
161 } else {
162 return Err(anyhow!(
163 "Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
164 alias
165 ));
166 }
167 }
168 }
169 }
170 SelectItem::Column(col_ref) => {
171 let in_group_by = group_by_exprs.iter().any(
173 |expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
174 );
175
176 if !in_group_by {
177 return Err(anyhow!(
178 "Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
179 col_ref.name
180 ));
181 }
182 }
183 SelectItem::Star => {
184 }
187 }
188 }
189
190 for (group_key, group_view) in groups {
192 let mut row_values = Vec::new();
193
194 for value in &group_key.0 {
196 row_values.push(value.clone());
197 }
198
199 for (expr, _col_name) in &aggregate_columns {
201 let group_rows = group_view.get_visible_rows();
202 let mut evaluator = ArithmeticEvaluator::with_date_notation(
203 group_view.source(),
204 date_notation.clone(),
205 )
206 .with_visible_rows(group_rows.clone());
207
208 let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
209 evaluator
210 .evaluate(expr, group_rows[0])
211 .unwrap_or(DataValue::Null)
212 } else {
213 DataValue::Null
214 };
215
216 row_values.push(value);
217 }
218
219 if let Some(having_expr) = having {
221 let mut temp_table = DataTable::new("having_eval");
223
224 for alias in &group_by_aliases {
226 temp_table.add_column(DataColumn::new(alias));
227 }
228
229 for (_, alias) in &aggregate_columns {
231 temp_table.add_column(DataColumn::new(alias));
232 }
233
234 temp_table
235 .add_row(DataRow::new(row_values.clone()))
236 .map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
237
238 let mut evaluator =
240 ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
241 let having_result = evaluator.evaluate(having_expr, 0)?;
242
243 if !is_truthy(&having_result) {
245 continue;
246 }
247 }
248
249 result_table
251 .add_row(DataRow::new(row_values))
252 .map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
253 }
254
255 Ok(DataView::new(Arc::new(result_table)))
256 }
257}
258
259fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
261 format!("{:?}", expr1) == format!("{:?}", expr2)
264}
265
266fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
268 match expr {
269 SqlExpression::Column(name) => name == column,
270 SqlExpression::BinaryOp { left, right, .. } => {
271 expression_references_column(left, column)
272 || expression_references_column(right, column)
273 }
274 SqlExpression::FunctionCall { args, .. } => args
275 .iter()
276 .any(|arg| expression_references_column(arg, column)),
277 SqlExpression::Between { expr, lower, upper } => {
278 expression_references_column(expr, column)
279 || expression_references_column(lower, column)
280 || expression_references_column(upper, column)
281 }
282 _ => false,
283 }
284}
285
286fn is_truthy(value: &DataValue) -> bool {
288 match value {
289 DataValue::Boolean(b) => *b,
290 DataValue::Integer(i) => *i != 0,
291 DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
292 DataValue::Null => false,
293 _ => true,
294 }
295}