1use std::collections::HashSet;
9
10use vibesql_ast::{Expression, SelectItem};
11
12use crate::{errors::ExecutorError, schema::CombinedSchema};
13
14pub fn is_aggregate_function(name: &str) -> bool {
16 let upper = name.to_uppercase();
17 matches!(upper.as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "TOTAL" | "GROUP_CONCAT")
18}
19
20pub fn check_aggregate_arg_count(expr: &Expression) -> Option<String> {
23 match expr {
24 Expression::AggregateFunction { name, args, distinct, .. } => {
25 let upper = name.to_uppercase();
26 let arg_count = args.len();
27
28 let has_wildcard = args.iter().any(|arg| {
30 let is_wildcard = matches!(arg, Expression::Wildcard);
31 let is_star_ref = matches!(
32 arg,
33 Expression::ColumnRef(col_id) if col_id.schema_canonical().is_none() && col_id.table_canonical().is_none() && col_id.column_canonical() == "*"
34 );
35 is_wildcard || is_star_ref
36 });
37
38 match upper.as_str() {
39 "COUNT" => {
40 if arg_count > 1 && !*distinct {
43 Some(name.display().to_string())
44 } else {
45 None
46 }
47 }
48 "MIN" | "MAX" => {
49 if has_wildcard || arg_count == 0 {
50 Some(name.display().to_string())
51 } else {
52 None
53 }
54 }
55 "SUM" | "AVG" | "TOTAL" => {
56 if has_wildcard || arg_count == 0 || arg_count > 1 {
57 Some(name.display().to_string())
58 } else {
59 None
60 }
61 }
62 "GROUP_CONCAT" => {
63 if arg_count == 0 || arg_count > 2 {
64 Some(name.display().to_string())
65 } else {
66 None
67 }
68 }
69 _ => None,
70 }
71 }
72 Expression::Function { name, args, .. } => {
73 if is_aggregate_function(name.as_str()) {
75 let upper = name.to_uppercase();
76 let arg_count = args.len();
77
78 let has_wildcard = args.iter().any(|arg| {
80 matches!(arg, Expression::Wildcard)
81 || matches!(
82 arg,
83 Expression::ColumnRef(col_id) if col_id.schema_canonical().is_none() && col_id.table_canonical().is_none() && col_id.column_canonical() == "*"
84 )
85 });
86
87 match upper.as_str() {
88 "COUNT" => {
89 if arg_count > 1 {
92 Some(name.display().to_string())
93 } else {
94 None
95 }
96 }
97 "MIN" | "MAX" => {
98 if arg_count <= 1 && (has_wildcard || arg_count == 0) {
100 Some(name.display().to_string())
101 } else {
102 None
103 }
104 }
105 "SUM" | "AVG" | "TOTAL" => {
106 if has_wildcard || arg_count == 0 || arg_count > 1 {
107 Some(name.display().to_string())
108 } else {
109 None
110 }
111 }
112 "GROUP_CONCAT" => {
113 if arg_count == 0 || arg_count > 2 {
114 Some(name.display().to_string())
115 } else {
116 None
117 }
118 }
119 _ => None,
120 }
121 } else {
122 for arg in args {
124 if let Some(found) = check_aggregate_arg_count(arg) {
125 return Some(found);
126 }
127 }
128 None
129 }
130 }
131 Expression::BinaryOp { left, right, .. } => {
132 check_aggregate_arg_count(left).or_else(|| check_aggregate_arg_count(right))
133 }
134 Expression::UnaryOp { expr, .. } => check_aggregate_arg_count(expr),
135 Expression::Case { operand, when_clauses, else_result } => {
136 if let Some(op) = operand {
137 if let Some(found) = check_aggregate_arg_count(op) {
138 return Some(found);
139 }
140 }
141 for case_when in when_clauses {
142 for cond in &case_when.conditions {
143 if let Some(found) = check_aggregate_arg_count(cond) {
144 return Some(found);
145 }
146 }
147 if let Some(found) = check_aggregate_arg_count(&case_when.result) {
148 return Some(found);
149 }
150 }
151 if let Some(else_expr) = else_result {
152 check_aggregate_arg_count(else_expr)
153 } else {
154 None
155 }
156 }
157 Expression::IsNull { expr, .. } => check_aggregate_arg_count(expr),
158 Expression::Cast { expr, .. } => check_aggregate_arg_count(expr),
159 Expression::Conjunction(children) | Expression::Disjunction(children) => {
160 for child in children {
161 if let Some(found) = check_aggregate_arg_count(child) {
162 return Some(found);
163 }
164 }
165 None
166 }
167 _ => None,
168 }
169}
170
171pub fn find_aggregate_in_expression(expr: &Expression) -> Option<String> {
174 match expr {
175 Expression::AggregateFunction { name, .. } => Some(name.to_string()), Expression::Function { name, args, .. } => {
177 if is_aggregate_function(name.as_str()) {
180 let upper = name.to_uppercase();
181 if matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1 {
182 None
184 } else {
185 Some(name.to_string()) }
187 } else {
188 for arg in args {
190 if let Some(found) = find_aggregate_in_expression(arg) {
191 return Some(found);
192 }
193 }
194 None
195 }
196 }
197 Expression::BinaryOp { left, right, .. } => {
198 find_aggregate_in_expression(left).or_else(|| find_aggregate_in_expression(right))
199 }
200 Expression::UnaryOp { expr, .. } => find_aggregate_in_expression(expr),
201 Expression::Case { operand, when_clauses, else_result } => {
202 if let Some(op) = operand {
203 if let Some(found) = find_aggregate_in_expression(op) {
204 return Some(found);
205 }
206 }
207 for case_when in when_clauses {
208 for cond in &case_when.conditions {
209 if let Some(found) = find_aggregate_in_expression(cond) {
210 return Some(found);
211 }
212 }
213 if let Some(found) = find_aggregate_in_expression(&case_when.result) {
214 return Some(found);
215 }
216 }
217 if let Some(else_expr) = else_result {
218 find_aggregate_in_expression(else_expr)
219 } else {
220 None
221 }
222 }
223 Expression::IsNull { expr, .. } => find_aggregate_in_expression(expr),
224 Expression::IsDistinctFrom { left, right, .. } => {
225 find_aggregate_in_expression(left).or_else(|| find_aggregate_in_expression(right))
226 }
227 Expression::IsTruthValue { expr, .. } => find_aggregate_in_expression(expr),
228 Expression::Between { expr, low, high, .. } => find_aggregate_in_expression(expr)
229 .or_else(|| find_aggregate_in_expression(low))
230 .or_else(|| find_aggregate_in_expression(high)),
231 Expression::InList { expr, values, .. } => {
232 if let Some(found) = find_aggregate_in_expression(expr) {
233 return Some(found);
234 }
235 for val in values {
236 if let Some(found) = find_aggregate_in_expression(val) {
237 return Some(found);
238 }
239 }
240 None
241 }
242 Expression::In { expr, .. } => find_aggregate_in_expression(expr),
243 Expression::Exists { .. } => None, Expression::Cast { expr, .. } => find_aggregate_in_expression(expr),
245 Expression::Like { expr, pattern, .. } => {
246 find_aggregate_in_expression(expr).or_else(|| find_aggregate_in_expression(pattern))
247 }
248 Expression::Position { substring, string, .. } => {
249 find_aggregate_in_expression(substring).or_else(|| find_aggregate_in_expression(string))
250 }
251 Expression::Trim { removal_char, string, .. } => {
252 if let Some(char_expr) = removal_char {
253 if let Some(found) = find_aggregate_in_expression(char_expr) {
254 return Some(found);
255 }
256 }
257 find_aggregate_in_expression(string)
258 }
259 Expression::Extract { expr, .. } => find_aggregate_in_expression(expr),
260 Expression::ScalarSubquery(_) => None, Expression::QuantifiedComparison { expr, .. } => find_aggregate_in_expression(expr),
262 Expression::Interval { value, .. } => find_aggregate_in_expression(value),
263 Expression::WindowFunction { .. } => None, Expression::MatchAgainst { search_modifier, .. } => {
265 find_aggregate_in_expression(search_modifier)
266 }
267 Expression::Conjunction(children) | Expression::Disjunction(children) => {
268 for child in children {
269 if let Some(found) = find_aggregate_in_expression(child) {
270 return Some(found);
271 }
272 }
273 None
274 }
275 _ => None,
276 }
277}
278
279pub fn find_nested_aggregate(expr: &Expression) -> Option<String> {
286 match expr {
287 Expression::AggregateFunction { args, order_by, .. } => {
288 for arg in args {
290 if let Some(inner_name) = find_aggregate_in_expression(arg) {
291 return Some(inner_name);
292 }
293 }
294 if let Some(order_items) = order_by {
297 for item in order_items {
298 if let Some(inner_name) = find_aggregate_in_expression(&item.expr) {
299 return Some(inner_name);
300 }
301 }
302 }
303 None
304 }
305 Expression::Function { name, args, .. } => {
306 if is_aggregate_function(name.as_str()) {
308 let upper = name.to_uppercase();
309 let is_scalar_minmax = matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1;
311 if !is_scalar_minmax {
312 for arg in args {
314 if let Some(inner_name) = find_aggregate_in_expression(arg) {
315 return Some(inner_name);
316 }
317 }
318 }
319 }
320 for arg in args {
322 if let Some(found) = find_nested_aggregate(arg) {
323 return Some(found);
324 }
325 }
326 None
327 }
328 Expression::BinaryOp { left, right, .. } => {
329 find_nested_aggregate(left).or_else(|| find_nested_aggregate(right))
330 }
331 Expression::UnaryOp { expr, .. } => find_nested_aggregate(expr),
332 Expression::Cast { expr, .. } => find_nested_aggregate(expr),
333 Expression::Case { operand, when_clauses, else_result } => {
334 if let Some(op) = operand {
335 if let Some(found) = find_nested_aggregate(op) {
336 return Some(found);
337 }
338 }
339 for case_when in when_clauses {
340 for cond in &case_when.conditions {
341 if let Some(found) = find_nested_aggregate(cond) {
342 return Some(found);
343 }
344 }
345 if let Some(found) = find_nested_aggregate(&case_when.result) {
346 return Some(found);
347 }
348 }
349 if let Some(else_expr) = else_result {
350 find_nested_aggregate(else_expr)
351 } else {
352 None
353 }
354 }
355 Expression::IsNull { expr, .. } => find_nested_aggregate(expr),
356 Expression::Between { expr, low, high, .. } => find_nested_aggregate(expr)
357 .or_else(|| find_nested_aggregate(low))
358 .or_else(|| find_nested_aggregate(high)),
359 Expression::InList { expr, values, .. } => {
360 if let Some(found) = find_nested_aggregate(expr) {
361 return Some(found);
362 }
363 for val in values {
364 if let Some(found) = find_nested_aggregate(val) {
365 return Some(found);
366 }
367 }
368 None
369 }
370 Expression::Conjunction(children) | Expression::Disjunction(children) => {
371 for child in children {
372 if let Some(found) = find_nested_aggregate(child) {
373 return Some(found);
374 }
375 }
376 None
377 }
378 _ => None,
379 }
380}
381
382pub fn validate_aggregate_arguments(select_list: &[SelectItem]) -> Result<(), ExecutorError> {
391 for item in select_list {
392 if let SelectItem::Expression { expr, .. } = item {
393 if let Some(agg_name) = check_aggregate_arg_count(expr) {
394 return Err(ExecutorError::WrongNumberOfArguments { function_name: agg_name });
395 }
396 }
397 }
398 Ok(())
399}
400
401pub fn validate_no_nested_aggregates(select_list: &[SelectItem]) -> Result<(), ExecutorError> {
409 for item in select_list {
410 if let SelectItem::Expression { expr, .. } = item {
411 if let Some(inner_agg_name) = find_nested_aggregate(expr) {
412 return Err(ExecutorError::MisuseOfAggregate { function_name: inner_agg_name });
413 }
414 }
415 }
416 Ok(())
417}
418
419pub fn build_aggregate_aliases(select_list: &[SelectItem]) -> HashSet<String> {
424 let mut aliases = HashSet::new();
425
426 for item in select_list {
427 if let SelectItem::Expression { expr, alias: Some(alias_name), .. } = item {
428 if expression_contains_aggregate(expr) {
430 aliases.insert(alias_name.to_lowercase());
432 }
433 }
434 }
435
436 aliases
437}
438
439pub fn expression_contains_aggregate(expr: &Expression) -> bool {
441 match expr {
442 Expression::AggregateFunction { .. } => true,
443 Expression::Function { name, args, .. } => {
444 if is_aggregate_function(name.as_str()) {
446 let upper = name.to_uppercase();
447 if matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1 {
449 args.iter().any(expression_contains_aggregate)
451 } else {
452 true
453 }
454 } else {
455 args.iter().any(expression_contains_aggregate)
457 }
458 }
459 Expression::BinaryOp { left, right, .. } => {
460 expression_contains_aggregate(left) || expression_contains_aggregate(right)
461 }
462 Expression::UnaryOp { expr, .. } => expression_contains_aggregate(expr),
463 Expression::Cast { expr, .. } => expression_contains_aggregate(expr),
464 Expression::Case { operand, when_clauses, else_result } => {
465 operand.as_ref().is_some_and(|e| expression_contains_aggregate(e))
466 || when_clauses.iter().any(|w| {
467 w.conditions.iter().any(expression_contains_aggregate)
468 || expression_contains_aggregate(&w.result)
469 })
470 || else_result.as_ref().is_some_and(|e| expression_contains_aggregate(e))
471 }
472 Expression::IsNull { expr, .. } => expression_contains_aggregate(expr),
473 Expression::Between { expr, low, high, .. } => {
474 expression_contains_aggregate(expr)
475 || expression_contains_aggregate(low)
476 || expression_contains_aggregate(high)
477 }
478 Expression::InList { expr, values, .. } => {
479 expression_contains_aggregate(expr) || values.iter().any(expression_contains_aggregate)
480 }
481 Expression::In { expr, .. } => expression_contains_aggregate(expr),
482 Expression::Like { expr, pattern, .. } => {
483 expression_contains_aggregate(expr) || expression_contains_aggregate(pattern)
484 }
485 Expression::Position { substring, string, .. } => {
486 expression_contains_aggregate(substring) || expression_contains_aggregate(string)
487 }
488 Expression::Trim { removal_char, string, .. } => {
489 removal_char.as_ref().is_some_and(|e| expression_contains_aggregate(e))
490 || expression_contains_aggregate(string)
491 }
492 Expression::Extract { expr, .. } => expression_contains_aggregate(expr),
493 Expression::Interval { value, .. } => expression_contains_aggregate(value),
494 Expression::Conjunction(children) | Expression::Disjunction(children) => {
495 children.iter().any(expression_contains_aggregate)
496 }
497 Expression::ScalarSubquery(_) | Expression::Exists { .. } => false,
499 Expression::WindowFunction { .. } => false,
501 _ => false,
503 }
504}
505
506pub fn find_window_function_in_expression(expr: &Expression) -> Option<String> {
514 match expr {
515 Expression::WindowFunction { function, .. } => {
516 Some(function.name())
518 }
519 Expression::AggregateFunction { args, order_by, filter, .. } => {
520 for arg in args {
522 if let Some(found) = find_window_function_in_expression(arg) {
523 return Some(found);
524 }
525 }
526 if let Some(order_items) = order_by {
528 for item in order_items {
529 if let Some(found) = find_window_function_in_expression(&item.expr) {
530 return Some(found);
531 }
532 }
533 }
534 if let Some(filter_expr) = filter {
536 if let Some(found) = find_window_function_in_expression(filter_expr) {
537 return Some(found);
538 }
539 }
540 None
541 }
542 Expression::Function { args, .. } => {
543 for arg in args {
544 if let Some(found) = find_window_function_in_expression(arg) {
545 return Some(found);
546 }
547 }
548 None
549 }
550 Expression::BinaryOp { left, right, .. } => {
551 find_window_function_in_expression(left)
552 .or_else(|| find_window_function_in_expression(right))
553 }
554 Expression::UnaryOp { expr, .. } => find_window_function_in_expression(expr),
555 Expression::Case { operand, when_clauses, else_result } => {
556 if let Some(op) = operand {
557 if let Some(found) = find_window_function_in_expression(op) {
558 return Some(found);
559 }
560 }
561 for case_when in when_clauses {
562 for cond in &case_when.conditions {
563 if let Some(found) = find_window_function_in_expression(cond) {
564 return Some(found);
565 }
566 }
567 if let Some(found) = find_window_function_in_expression(&case_when.result) {
568 return Some(found);
569 }
570 }
571 if let Some(else_expr) = else_result {
572 find_window_function_in_expression(else_expr)
573 } else {
574 None
575 }
576 }
577 Expression::IsNull { expr, .. } => find_window_function_in_expression(expr),
578 Expression::IsDistinctFrom { left, right, .. } => {
579 find_window_function_in_expression(left)
580 .or_else(|| find_window_function_in_expression(right))
581 }
582 Expression::IsTruthValue { expr, .. } => find_window_function_in_expression(expr),
583 Expression::Between { expr, low, high, .. } => find_window_function_in_expression(expr)
584 .or_else(|| find_window_function_in_expression(low))
585 .or_else(|| find_window_function_in_expression(high)),
586 Expression::InList { expr, values, .. } => {
587 if let Some(found) = find_window_function_in_expression(expr) {
588 return Some(found);
589 }
590 for val in values {
591 if let Some(found) = find_window_function_in_expression(val) {
592 return Some(found);
593 }
594 }
595 None
596 }
597 Expression::In { expr, .. } => find_window_function_in_expression(expr),
598 Expression::Exists { .. } => None, Expression::Cast { expr, .. } => find_window_function_in_expression(expr),
600 Expression::Like { expr, pattern, .. } => find_window_function_in_expression(expr)
601 .or_else(|| find_window_function_in_expression(pattern)),
602 Expression::Position { substring, string, .. } => {
603 find_window_function_in_expression(substring)
604 .or_else(|| find_window_function_in_expression(string))
605 }
606 Expression::Trim { removal_char, string, .. } => {
607 if let Some(char_expr) = removal_char {
608 if let Some(found) = find_window_function_in_expression(char_expr) {
609 return Some(found);
610 }
611 }
612 find_window_function_in_expression(string)
613 }
614 Expression::Extract { expr, .. } => find_window_function_in_expression(expr),
615 Expression::ScalarSubquery(_) => None, Expression::QuantifiedComparison { expr, .. } => find_window_function_in_expression(expr),
617 Expression::Interval { value, .. } => find_window_function_in_expression(value),
618 Expression::MatchAgainst { search_modifier, .. } => {
619 find_window_function_in_expression(search_modifier)
620 }
621 Expression::Conjunction(children) | Expression::Disjunction(children) => {
622 for child in children {
623 if let Some(found) = find_window_function_in_expression(child) {
624 return Some(found);
625 }
626 }
627 None
628 }
629 _ => None,
630 }
631}
632
633fn find_aliased_aggregate_misuse_in_expression(
645 expr: &Expression,
646 aggregate_aliases: &HashSet<String>,
647 schema_columns: &HashSet<String>,
648 inside_aggregate: bool,
649) -> Option<String> {
650 match expr {
651 Expression::AggregateFunction { args, .. } => {
653 for arg in args {
654 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
655 arg,
656 aggregate_aliases,
657 schema_columns,
658 true,
659 ) {
660 return Some(alias);
661 }
662 }
663 None
664 }
665 Expression::Function { name, args, .. } => {
666 let is_agg = is_aggregate_function(name.as_str());
668 let upper = name.to_uppercase();
669 let effectively_aggregate =
671 is_agg && !(matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1);
672
673 let new_inside_aggregate = inside_aggregate || effectively_aggregate;
674
675 for arg in args {
676 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
677 arg,
678 aggregate_aliases,
679 schema_columns,
680 new_inside_aggregate,
681 ) {
682 return Some(alias);
683 }
684 }
685 None
686 }
687 Expression::ColumnRef(col_id)
689 if col_id.schema_canonical().is_none() && col_id.table_canonical().is_none() =>
690 {
691 let column = col_id.column_canonical();
692 if schema_columns.contains(&column.to_lowercase()) {
696 return None; }
698
699 if inside_aggregate && aggregate_aliases.contains(&column.to_lowercase()) {
700 Some(column.to_string())
702 } else {
703 None
704 }
705 }
706 Expression::ColumnRef(_) => None, Expression::BinaryOp { left, right, .. } => find_aliased_aggregate_misuse_in_expression(
709 left,
710 aggregate_aliases,
711 schema_columns,
712 inside_aggregate,
713 )
714 .or_else(|| {
715 find_aliased_aggregate_misuse_in_expression(
716 right,
717 aggregate_aliases,
718 schema_columns,
719 inside_aggregate,
720 )
721 }),
722 Expression::UnaryOp { expr, .. } => find_aliased_aggregate_misuse_in_expression(
723 expr,
724 aggregate_aliases,
725 schema_columns,
726 inside_aggregate,
727 ),
728 Expression::Cast { expr, .. } => find_aliased_aggregate_misuse_in_expression(
729 expr,
730 aggregate_aliases,
731 schema_columns,
732 inside_aggregate,
733 ),
734 Expression::Case { operand, when_clauses, else_result } => {
735 if let Some(op) = operand {
736 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
737 op,
738 aggregate_aliases,
739 schema_columns,
740 inside_aggregate,
741 ) {
742 return Some(alias);
743 }
744 }
745 for when_clause in when_clauses {
746 for cond in &when_clause.conditions {
747 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
748 cond,
749 aggregate_aliases,
750 schema_columns,
751 inside_aggregate,
752 ) {
753 return Some(alias);
754 }
755 }
756 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
757 &when_clause.result,
758 aggregate_aliases,
759 schema_columns,
760 inside_aggregate,
761 ) {
762 return Some(alias);
763 }
764 }
765 if let Some(else_expr) = else_result {
766 return find_aliased_aggregate_misuse_in_expression(
767 else_expr,
768 aggregate_aliases,
769 schema_columns,
770 inside_aggregate,
771 );
772 }
773 None
774 }
775 Expression::IsNull { expr, .. } => find_aliased_aggregate_misuse_in_expression(
776 expr,
777 aggregate_aliases,
778 schema_columns,
779 inside_aggregate,
780 ),
781 Expression::Between { expr, low, high, .. } => find_aliased_aggregate_misuse_in_expression(
782 expr,
783 aggregate_aliases,
784 schema_columns,
785 inside_aggregate,
786 )
787 .or_else(|| {
788 find_aliased_aggregate_misuse_in_expression(
789 low,
790 aggregate_aliases,
791 schema_columns,
792 inside_aggregate,
793 )
794 })
795 .or_else(|| {
796 find_aliased_aggregate_misuse_in_expression(
797 high,
798 aggregate_aliases,
799 schema_columns,
800 inside_aggregate,
801 )
802 }),
803 Expression::InList { expr, values, .. } => {
804 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
805 expr,
806 aggregate_aliases,
807 schema_columns,
808 inside_aggregate,
809 ) {
810 return Some(alias);
811 }
812 for val in values {
813 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
814 val,
815 aggregate_aliases,
816 schema_columns,
817 inside_aggregate,
818 ) {
819 return Some(alias);
820 }
821 }
822 None
823 }
824 Expression::In { expr, .. } => find_aliased_aggregate_misuse_in_expression(
825 expr,
826 aggregate_aliases,
827 schema_columns,
828 inside_aggregate,
829 ),
830 Expression::Like { expr, pattern, .. } => find_aliased_aggregate_misuse_in_expression(
831 expr,
832 aggregate_aliases,
833 schema_columns,
834 inside_aggregate,
835 )
836 .or_else(|| {
837 find_aliased_aggregate_misuse_in_expression(
838 pattern,
839 aggregate_aliases,
840 schema_columns,
841 inside_aggregate,
842 )
843 }),
844 Expression::Position { substring, string, .. } => {
845 find_aliased_aggregate_misuse_in_expression(
846 substring,
847 aggregate_aliases,
848 schema_columns,
849 inside_aggregate,
850 )
851 .or_else(|| {
852 find_aliased_aggregate_misuse_in_expression(
853 string,
854 aggregate_aliases,
855 schema_columns,
856 inside_aggregate,
857 )
858 })
859 }
860 Expression::Trim { removal_char, string, .. } => {
861 if let Some(rc) = removal_char {
862 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
863 rc,
864 aggregate_aliases,
865 schema_columns,
866 inside_aggregate,
867 ) {
868 return Some(alias);
869 }
870 }
871 find_aliased_aggregate_misuse_in_expression(
872 string,
873 aggregate_aliases,
874 schema_columns,
875 inside_aggregate,
876 )
877 }
878 Expression::Extract { expr, .. } => find_aliased_aggregate_misuse_in_expression(
879 expr,
880 aggregate_aliases,
881 schema_columns,
882 inside_aggregate,
883 ),
884 Expression::Interval { value, .. } => find_aliased_aggregate_misuse_in_expression(
885 value,
886 aggregate_aliases,
887 schema_columns,
888 inside_aggregate,
889 ),
890 Expression::Conjunction(children) | Expression::Disjunction(children) => {
891 for child in children {
892 if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
893 child,
894 aggregate_aliases,
895 schema_columns,
896 inside_aggregate,
897 ) {
898 return Some(alias);
899 }
900 }
901 None
902 }
903 Expression::ScalarSubquery(_) | Expression::Exists { .. } => None,
905 Expression::QuantifiedComparison { expr, .. } => {
906 find_aliased_aggregate_misuse_in_expression(
907 expr,
908 aggregate_aliases,
909 schema_columns,
910 inside_aggregate,
911 )
912 }
913 Expression::IsDistinctFrom { left, right, .. } => {
914 find_aliased_aggregate_misuse_in_expression(
915 left,
916 aggregate_aliases,
917 schema_columns,
918 inside_aggregate,
919 )
920 .or_else(|| {
921 find_aliased_aggregate_misuse_in_expression(
922 right,
923 aggregate_aliases,
924 schema_columns,
925 inside_aggregate,
926 )
927 })
928 }
929 Expression::IsTruthValue { expr, .. } => find_aliased_aggregate_misuse_in_expression(
930 expr,
931 aggregate_aliases,
932 schema_columns,
933 inside_aggregate,
934 ),
935 _ => None,
937 }
938}
939
940pub fn validate_having_aliased_aggregates(
949 having_clause: Option<&Expression>,
950 select_list: &[SelectItem],
951 schema: &CombinedSchema,
952) -> Result<(), ExecutorError> {
953 let Some(having_expr) = having_clause else {
954 return Ok(());
955 };
956
957 let aggregate_aliases = build_aggregate_aliases(select_list);
959
960 if aggregate_aliases.is_empty() {
961 return Ok(()); }
963
964 let schema_columns: HashSet<String> = schema
966 .table_schemas
967 .values()
968 .flat_map(|(_, table_schema)| table_schema.columns.iter().map(|c| c.name.to_lowercase()))
969 .collect();
970
971 if let Some(alias_name) = find_aliased_aggregate_misuse_in_expression(
973 having_expr,
974 &aggregate_aliases,
975 &schema_columns,
976 false,
977 ) {
978 return Err(ExecutorError::MisuseOfAliasedAggregate { alias_name });
979 }
980
981 Ok(())
982}
983
984#[cfg(test)]
985mod tests {
986 use vibesql_ast::{BinaryOperator, ColumnIdentifier, FunctionIdentifier, UnaryOperator};
987 use vibesql_catalog::{ColumnSchema, TableSchema};
988 use vibesql_types::{DataType, SqlValue};
989
990 use super::*;
991
992 fn make_f1_f2_schema() -> CombinedSchema {
994 let columns = vec![
995 ColumnSchema {
996 name: "f1".to_string(),
997 data_type: DataType::Integer,
998 nullable: true,
999 default_value: None,
1000 generated_expr: None,
1001 is_exact_integer_type: false,
1002 collation: None,
1003 },
1004 ColumnSchema {
1005 name: "f2".to_string(),
1006 data_type: DataType::Integer,
1007 nullable: true,
1008 default_value: None,
1009 generated_expr: None,
1010 is_exact_integer_type: false,
1011 collation: None,
1012 },
1013 ];
1014 let table_schema = TableSchema::new("test1".to_string(), columns);
1015 CombinedSchema::from_table("test1".to_string(), table_schema)
1016 }
1017
1018 #[test]
1019 fn test_min_star_invalid() {
1020 let expr = Expression::AggregateFunction {
1022 name: FunctionIdentifier::new("MIN"),
1023 distinct: false,
1024 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("*", false))],
1025 order_by: None,
1026 filter: None,
1027 };
1028 let result = check_aggregate_arg_count(&expr);
1029 assert!(result.is_some(), "MIN(*) should be invalid");
1030 assert_eq!(result.unwrap(), "MIN"); }
1032
1033 #[test]
1034 fn test_max_star_invalid() {
1035 let expr = Expression::AggregateFunction {
1037 name: FunctionIdentifier::new("MAX"),
1038 distinct: false,
1039 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("*", false))],
1040 order_by: None,
1041 filter: None,
1042 };
1043 let result = check_aggregate_arg_count(&expr);
1044 assert!(result.is_some(), "MAX(*) should be invalid");
1045 assert_eq!(result.unwrap(), "MAX"); }
1047
1048 #[test]
1049 fn test_min_no_args_invalid() {
1050 let expr = Expression::AggregateFunction {
1052 name: FunctionIdentifier::new("MIN"),
1053 distinct: false,
1054 args: vec![],
1055 order_by: None,
1056 filter: None,
1057 };
1058 let result = check_aggregate_arg_count(&expr);
1059 assert!(result.is_some(), "MIN() should be invalid");
1060 assert_eq!(result.unwrap(), "MIN"); }
1062
1063 #[test]
1064 fn test_validate_aggregate_arguments() {
1065 let select_list = vec![SelectItem::Expression {
1067 expr: Expression::AggregateFunction {
1068 name: FunctionIdentifier::new("MIN"),
1069 distinct: false,
1070 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("*", false))],
1071 order_by: None,
1072 filter: None,
1073 },
1074 alias: None,
1075 source_text: None,
1076 }];
1077 let result = validate_aggregate_arguments(&select_list);
1078 assert!(result.is_err());
1079 }
1080
1081 #[test]
1082 fn test_having_with_aliased_aggregate_inside_aggregate() {
1083 let select_list = vec![SelectItem::Expression {
1087 expr: Expression::AggregateFunction {
1088 name: FunctionIdentifier::new("min"),
1089 distinct: false,
1090 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("f1", false))],
1091 order_by: None,
1092 filter: None,
1093 },
1094 alias: Some("m".to_string()),
1095 source_text: None,
1096 }];
1097
1098 let having_expr = Expression::BinaryOp {
1100 op: BinaryOperator::LessThan,
1101 left: Box::new(Expression::AggregateFunction {
1102 name: FunctionIdentifier::new("max"),
1103 distinct: false,
1104 args: vec![Expression::BinaryOp {
1105 op: BinaryOperator::Plus,
1106 left: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("m", false))),
1107 right: Box::new(Expression::Literal(SqlValue::Integer(5))),
1108 }],
1109 order_by: None,
1110 filter: None,
1111 }),
1112 right: Box::new(Expression::Literal(SqlValue::Integer(10))),
1113 };
1114
1115 let schema = make_f1_f2_schema();
1117 let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1118 assert!(result.is_err());
1119 match result {
1120 Err(ExecutorError::MisuseOfAliasedAggregate { alias_name }) => {
1121 assert_eq!(alias_name, "m");
1122 }
1123 _ => panic!("Expected MisuseOfAliasedAggregate error"),
1124 }
1125 }
1126
1127 #[test]
1128 fn test_having_with_aggregate_alias_not_inside_aggregate() {
1129 let select_list = vec![SelectItem::Expression {
1135 expr: Expression::AggregateFunction {
1136 name: FunctionIdentifier::new("min"),
1137 distinct: false,
1138 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("f1", false))],
1139 order_by: None,
1140 filter: None,
1141 },
1142 alias: Some("m".to_string()),
1143 source_text: None,
1144 }];
1145
1146 let having_expr = Expression::BinaryOp {
1148 op: BinaryOperator::GreaterThan,
1149 left: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("m", false))),
1150 right: Box::new(Expression::Literal(SqlValue::Integer(0))),
1151 };
1152
1153 let schema = make_f1_f2_schema();
1156 let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1157 assert!(result.is_ok());
1158 }
1159
1160 #[test]
1161 fn test_having_without_aggregate_alias() {
1162 let select_list = vec![SelectItem::Expression {
1165 expr: Expression::AggregateFunction {
1166 name: FunctionIdentifier::new("count"),
1167 distinct: false,
1168 args: vec![Expression::Wildcard],
1169 order_by: None,
1170 filter: None,
1171 },
1172 alias: None, source_text: None,
1174 }];
1175
1176 let having_expr = Expression::BinaryOp {
1177 op: BinaryOperator::GreaterThan,
1178 left: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("f1", false))),
1179 right: Box::new(Expression::Literal(SqlValue::Integer(0))),
1180 };
1181
1182 let schema = make_f1_f2_schema();
1183 let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1184 assert!(result.is_ok());
1185 }
1186
1187 #[test]
1188 fn test_having_with_non_aggregate_alias() {
1189 let select_list = vec![
1192 SelectItem::Expression {
1193 expr: Expression::ColumnRef(ColumnIdentifier::simple("f1", false)),
1194 alias: Some("x".to_string()),
1195 source_text: None,
1196 },
1197 SelectItem::Expression {
1198 expr: Expression::AggregateFunction {
1199 name: FunctionIdentifier::new("count"),
1200 distinct: false,
1201 args: vec![Expression::Wildcard],
1202 order_by: None,
1203 filter: None,
1204 },
1205 alias: None,
1206 source_text: None,
1207 },
1208 ];
1209
1210 let having_expr = Expression::BinaryOp {
1212 op: BinaryOperator::LessThan,
1213 left: Box::new(Expression::AggregateFunction {
1214 name: FunctionIdentifier::new("max"),
1215 distinct: false,
1216 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("x", false))],
1217 order_by: None,
1218 filter: None,
1219 }),
1220 right: Box::new(Expression::Literal(SqlValue::Integer(10))),
1221 };
1222
1223 let schema = make_f1_f2_schema();
1224 let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1225 assert!(result.is_ok());
1226 }
1227
1228 #[test]
1229 fn test_having_alias_shadows_column_uses_column() {
1230 let select_list = vec![SelectItem::Expression {
1235 expr: Expression::BinaryOp {
1236 op: BinaryOperator::Multiply,
1237 left: Box::new(Expression::UnaryOp {
1238 op: UnaryOperator::Minus,
1239 expr: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("col2", false))),
1240 }),
1241 right: Box::new(Expression::UnaryOp {
1242 op: UnaryOperator::Minus,
1243 expr: Box::new(Expression::AggregateFunction {
1244 name: FunctionIdentifier::new("AVG"),
1245 distinct: false,
1246 args: vec![Expression::UnaryOp {
1247 op: UnaryOperator::Minus,
1248 expr: Box::new(Expression::ColumnRef(ColumnIdentifier::simple(
1249 "col2", false,
1250 ))),
1251 }],
1252 order_by: None,
1253 filter: None,
1254 }),
1255 }),
1256 },
1257 alias: Some("col0".to_string()), source_text: None,
1259 }];
1260
1261 let having_expr = Expression::IsNull {
1263 expr: Box::new(Expression::AggregateFunction {
1264 name: FunctionIdentifier::new("AVG"),
1265 distinct: false,
1266 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("col0", false))],
1267 order_by: None,
1268 filter: None,
1269 }),
1270 negated: false,
1271 };
1272
1273 let columns = vec![
1275 ColumnSchema {
1276 name: "col0".to_string(),
1277 data_type: DataType::Integer,
1278 nullable: true,
1279 default_value: None,
1280 generated_expr: None,
1281 is_exact_integer_type: false,
1282 collation: None,
1283 },
1284 ColumnSchema {
1285 name: "col1".to_string(),
1286 data_type: DataType::Integer,
1287 nullable: true,
1288 default_value: None,
1289 generated_expr: None,
1290 is_exact_integer_type: false,
1291 collation: None,
1292 },
1293 ColumnSchema {
1294 name: "col2".to_string(),
1295 data_type: DataType::Integer,
1296 nullable: true,
1297 default_value: None,
1298 generated_expr: None,
1299 is_exact_integer_type: false,
1300 collation: None,
1301 },
1302 ];
1303 let table_schema = TableSchema::new("tab0".to_string(), columns);
1304 let schema = CombinedSchema::from_table("tab0".to_string(), table_schema);
1305
1306 let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1308 assert!(result.is_ok(), "Expected Ok but got {:?}", result);
1309 }
1310
1311 #[test]
1312 fn test_build_aggregate_aliases() {
1313 let select_list = vec![
1315 SelectItem::Expression {
1316 expr: Expression::AggregateFunction {
1317 name: FunctionIdentifier::new("min"),
1318 distinct: false,
1319 args: vec![Expression::ColumnRef(ColumnIdentifier::simple("f1", false))],
1320 order_by: None,
1321 filter: None,
1322 },
1323 alias: Some("m".to_string()),
1324 source_text: None,
1325 },
1326 SelectItem::Expression {
1327 expr: Expression::ColumnRef(ColumnIdentifier::simple("f2", false)),
1328 alias: Some("col2".to_string()),
1329 source_text: None,
1330 },
1331 SelectItem::Expression {
1332 expr: Expression::Function {
1334 name: FunctionIdentifier::new("coalesce"),
1335 args: vec![
1336 Expression::BinaryOp {
1337 op: BinaryOperator::Plus,
1338 left: Box::new(Expression::AggregateFunction {
1339 name: FunctionIdentifier::new("min"),
1340 distinct: false,
1341 args: vec![Expression::ColumnRef(ColumnIdentifier::simple(
1342 "f1", false,
1343 ))],
1344 order_by: None,
1345 filter: None,
1346 }),
1347 right: Box::new(Expression::Literal(SqlValue::Integer(5))),
1348 },
1349 Expression::Literal(SqlValue::Integer(11)),
1350 ],
1351 character_unit: None,
1352 },
1353 alias: Some("m2".to_string()),
1354 source_text: None,
1355 },
1356 ];
1357
1358 let aliases = build_aggregate_aliases(&select_list);
1359 assert!(aliases.contains("m")); assert!(!aliases.contains("col2")); assert!(aliases.contains("m2")); }
1363}