vibesql_executor/select/executor/validation/
aggregates.rs

1//! Aggregate function validation
2//!
3//! Validates aggregate function usage including:
4//! - Argument count validation
5//! - Nested aggregate detection
6//! - Aliased aggregate misuse in HAVING clause
7
8use std::collections::HashSet;
9
10use vibesql_ast::{Expression, SelectItem};
11
12use crate::{errors::ExecutorError, schema::CombinedSchema};
13
14/// Check if a function name is an aggregate function
15pub fn is_aggregate_function(name: &str) -> bool {
16    let upper = name.to_uppercase();
17    matches!(upper.as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "TOTAL" | "GROUP_CONCAT")
18}
19
20/// Check if an aggregate function has wrong number of arguments
21/// Returns Some((function_name, arg_count)) if there's an error, None otherwise
22pub fn check_aggregate_arg_count(expr: &Expression) -> Option<String> {
23    match expr {
24        Expression::AggregateFunction { name, args, distinct, .. } => {
25            let upper = name.to_uppercase();
26            let arg_count = args.len();
27
28            // Check for wildcard in non-COUNT aggregates
29            let has_wildcard = args.iter().any(|arg| {
30                let is_wildcard = matches!(arg, Expression::Wildcard);
31                let is_star_ref = matches!(
32                    arg,
33                    Expression::ColumnRef(col_id) if col_id.schema_canonical().is_none() && col_id.table_canonical().is_none() && col_id.column_canonical() == "*"
34                );
35                is_wildcard || is_star_ref
36            });
37
38            match upper.as_str() {
39                "COUNT" => {
40                    // Multi-arg COUNT without DISTINCT is an error
41                    // SQLite: "wrong number of arguments to function count()"
42                    if arg_count > 1 && !*distinct {
43                        Some(name.display().to_string())
44                    } else {
45                        None
46                    }
47                }
48                "MIN" | "MAX" => {
49                    if has_wildcard || arg_count == 0 {
50                        Some(name.display().to_string())
51                    } else {
52                        None
53                    }
54                }
55                "SUM" | "AVG" | "TOTAL" => {
56                    if has_wildcard || arg_count == 0 || arg_count > 1 {
57                        Some(name.display().to_string())
58                    } else {
59                        None
60                    }
61                }
62                "GROUP_CONCAT" => {
63                    if arg_count == 0 || arg_count > 2 {
64                        Some(name.display().to_string())
65                    } else {
66                        None
67                    }
68                }
69                _ => None,
70            }
71        }
72        Expression::Function { name, args, .. } => {
73            // Check if this is an aggregate function with wrong args
74            if is_aggregate_function(name.as_str()) {
75                let upper = name.to_uppercase();
76                let arg_count = args.len();
77
78                // Check for wildcard
79                let has_wildcard = args.iter().any(|arg| {
80                    matches!(arg, Expression::Wildcard)
81                        || matches!(
82                            arg,
83                            Expression::ColumnRef(col_id) if col_id.schema_canonical().is_none() && col_id.table_canonical().is_none() && col_id.column_canonical() == "*"
84                        )
85                });
86
87                match upper.as_str() {
88                    "COUNT" => {
89                        // count(a, b) without DISTINCT is wrong
90                        // Regular count without DISTINCT can only have 0-1 args
91                        if arg_count > 1 {
92                            Some(name.display().to_string())
93                        } else {
94                            None
95                        }
96                    }
97                    "MIN" | "MAX" => {
98                        // Multi-arg min/max are scalar, so only check single arg case
99                        if arg_count <= 1 && (has_wildcard || arg_count == 0) {
100                            Some(name.display().to_string())
101                        } else {
102                            None
103                        }
104                    }
105                    "SUM" | "AVG" | "TOTAL" => {
106                        if has_wildcard || arg_count == 0 || arg_count > 1 {
107                            Some(name.display().to_string())
108                        } else {
109                            None
110                        }
111                    }
112                    "GROUP_CONCAT" => {
113                        if arg_count == 0 || arg_count > 2 {
114                            Some(name.display().to_string())
115                        } else {
116                            None
117                        }
118                    }
119                    _ => None,
120                }
121            } else {
122                // Check function arguments recursively
123                for arg in args {
124                    if let Some(found) = check_aggregate_arg_count(arg) {
125                        return Some(found);
126                    }
127                }
128                None
129            }
130        }
131        Expression::BinaryOp { left, right, .. } => {
132            check_aggregate_arg_count(left).or_else(|| check_aggregate_arg_count(right))
133        }
134        Expression::UnaryOp { expr, .. } => check_aggregate_arg_count(expr),
135        Expression::Case { operand, when_clauses, else_result } => {
136            if let Some(op) = operand {
137                if let Some(found) = check_aggregate_arg_count(op) {
138                    return Some(found);
139                }
140            }
141            for case_when in when_clauses {
142                for cond in &case_when.conditions {
143                    if let Some(found) = check_aggregate_arg_count(cond) {
144                        return Some(found);
145                    }
146                }
147                if let Some(found) = check_aggregate_arg_count(&case_when.result) {
148                    return Some(found);
149                }
150            }
151            if let Some(else_expr) = else_result {
152                check_aggregate_arg_count(else_expr)
153            } else {
154                None
155            }
156        }
157        Expression::IsNull { expr, .. } => check_aggregate_arg_count(expr),
158        Expression::Cast { expr, .. } => check_aggregate_arg_count(expr),
159        Expression::Conjunction(children) | Expression::Disjunction(children) => {
160            for child in children {
161                if let Some(found) = check_aggregate_arg_count(child) {
162                    return Some(found);
163                }
164            }
165            None
166        }
167        _ => None,
168    }
169}
170
171/// Find the first aggregate function in an expression
172/// Returns the function name (original case preserved) if found, None otherwise
173pub fn find_aggregate_in_expression(expr: &Expression) -> Option<String> {
174    match expr {
175        Expression::AggregateFunction { name, .. } => Some(name.to_string()), /* Preserve original case */
176        Expression::Function { name, args, .. } => {
177            // Check if this function is a built-in aggregate
178            // Note: MIN/MAX with multiple args are scalar functions in SQLite
179            if is_aggregate_function(name.as_str()) {
180                let upper = name.to_uppercase();
181                if matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1 {
182                    // Multi-arg min/max are scalar, not aggregate
183                    None
184                } else {
185                    Some(name.to_string()) // Preserve original case
186                }
187            } else {
188                // Check function arguments recursively
189                for arg in args {
190                    if let Some(found) = find_aggregate_in_expression(arg) {
191                        return Some(found);
192                    }
193                }
194                None
195            }
196        }
197        Expression::BinaryOp { left, right, .. } => {
198            find_aggregate_in_expression(left).or_else(|| find_aggregate_in_expression(right))
199        }
200        Expression::UnaryOp { expr, .. } => find_aggregate_in_expression(expr),
201        Expression::Case { operand, when_clauses, else_result } => {
202            if let Some(op) = operand {
203                if let Some(found) = find_aggregate_in_expression(op) {
204                    return Some(found);
205                }
206            }
207            for case_when in when_clauses {
208                for cond in &case_when.conditions {
209                    if let Some(found) = find_aggregate_in_expression(cond) {
210                        return Some(found);
211                    }
212                }
213                if let Some(found) = find_aggregate_in_expression(&case_when.result) {
214                    return Some(found);
215                }
216            }
217            if let Some(else_expr) = else_result {
218                find_aggregate_in_expression(else_expr)
219            } else {
220                None
221            }
222        }
223        Expression::IsNull { expr, .. } => find_aggregate_in_expression(expr),
224        Expression::IsDistinctFrom { left, right, .. } => {
225            find_aggregate_in_expression(left).or_else(|| find_aggregate_in_expression(right))
226        }
227        Expression::IsTruthValue { expr, .. } => find_aggregate_in_expression(expr),
228        Expression::Between { expr, low, high, .. } => find_aggregate_in_expression(expr)
229            .or_else(|| find_aggregate_in_expression(low))
230            .or_else(|| find_aggregate_in_expression(high)),
231        Expression::InList { expr, values, .. } => {
232            if let Some(found) = find_aggregate_in_expression(expr) {
233                return Some(found);
234            }
235            for val in values {
236                if let Some(found) = find_aggregate_in_expression(val) {
237                    return Some(found);
238                }
239            }
240            None
241        }
242        Expression::In { expr, .. } => find_aggregate_in_expression(expr),
243        Expression::Exists { .. } => None, // EXISTS subqueries have their own scope
244        Expression::Cast { expr, .. } => find_aggregate_in_expression(expr),
245        Expression::Like { expr, pattern, .. } => {
246            find_aggregate_in_expression(expr).or_else(|| find_aggregate_in_expression(pattern))
247        }
248        Expression::Position { substring, string, .. } => {
249            find_aggregate_in_expression(substring).or_else(|| find_aggregate_in_expression(string))
250        }
251        Expression::Trim { removal_char, string, .. } => {
252            if let Some(char_expr) = removal_char {
253                if let Some(found) = find_aggregate_in_expression(char_expr) {
254                    return Some(found);
255                }
256            }
257            find_aggregate_in_expression(string)
258        }
259        Expression::Extract { expr, .. } => find_aggregate_in_expression(expr),
260        Expression::ScalarSubquery(_) => None, // Subqueries have their own scope
261        Expression::QuantifiedComparison { expr, .. } => find_aggregate_in_expression(expr),
262        Expression::Interval { value, .. } => find_aggregate_in_expression(value),
263        Expression::WindowFunction { .. } => None, // Window functions are not regular aggregates
264        Expression::MatchAgainst { search_modifier, .. } => {
265            find_aggregate_in_expression(search_modifier)
266        }
267        Expression::Conjunction(children) | Expression::Disjunction(children) => {
268            for child in children {
269                if let Some(found) = find_aggregate_in_expression(child) {
270                    return Some(found);
271                }
272            }
273            None
274        }
275        _ => None,
276    }
277}
278
279/// Find nested aggregate function in an expression
280///
281/// A nested aggregate is when one aggregate's arguments contain another aggregate,
282/// e.g., `SUM(MIN(x))`. This is invalid in SQL.
283///
284/// Returns Some(inner_aggregate_name) if found, None otherwise.
285pub fn find_nested_aggregate(expr: &Expression) -> Option<String> {
286    match expr {
287        Expression::AggregateFunction { args, order_by, .. } => {
288            // Check if any argument contains an aggregate function
289            for arg in args {
290                if let Some(inner_name) = find_aggregate_in_expression(arg) {
291                    return Some(inner_name);
292                }
293            }
294            // Also check ORDER BY expressions for aggregate functions
295            // e.g., group_concat(a ORDER BY max(d)) is invalid
296            if let Some(order_items) = order_by {
297                for item in order_items {
298                    if let Some(inner_name) = find_aggregate_in_expression(&item.expr) {
299                        return Some(inner_name);
300                    }
301                }
302            }
303            None
304        }
305        Expression::Function { name, args, .. } => {
306            // Check if this function is a built-in aggregate with nested aggregate args
307            if is_aggregate_function(name.as_str()) {
308                let upper = name.to_uppercase();
309                // Multi-arg MIN/MAX are scalar functions, not aggregates
310                let is_scalar_minmax = matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1;
311                if !is_scalar_minmax {
312                    // This is an aggregate - check for nested aggregates in args
313                    for arg in args {
314                        if let Some(inner_name) = find_aggregate_in_expression(arg) {
315                            return Some(inner_name);
316                        }
317                    }
318                }
319            }
320            // Check arguments recursively (for non-aggregate functions)
321            for arg in args {
322                if let Some(found) = find_nested_aggregate(arg) {
323                    return Some(found);
324                }
325            }
326            None
327        }
328        Expression::BinaryOp { left, right, .. } => {
329            find_nested_aggregate(left).or_else(|| find_nested_aggregate(right))
330        }
331        Expression::UnaryOp { expr, .. } => find_nested_aggregate(expr),
332        Expression::Cast { expr, .. } => find_nested_aggregate(expr),
333        Expression::Case { operand, when_clauses, else_result } => {
334            if let Some(op) = operand {
335                if let Some(found) = find_nested_aggregate(op) {
336                    return Some(found);
337                }
338            }
339            for case_when in when_clauses {
340                for cond in &case_when.conditions {
341                    if let Some(found) = find_nested_aggregate(cond) {
342                        return Some(found);
343                    }
344                }
345                if let Some(found) = find_nested_aggregate(&case_when.result) {
346                    return Some(found);
347                }
348            }
349            if let Some(else_expr) = else_result {
350                find_nested_aggregate(else_expr)
351            } else {
352                None
353            }
354        }
355        Expression::IsNull { expr, .. } => find_nested_aggregate(expr),
356        Expression::Between { expr, low, high, .. } => find_nested_aggregate(expr)
357            .or_else(|| find_nested_aggregate(low))
358            .or_else(|| find_nested_aggregate(high)),
359        Expression::InList { expr, values, .. } => {
360            if let Some(found) = find_nested_aggregate(expr) {
361                return Some(found);
362            }
363            for val in values {
364                if let Some(found) = find_nested_aggregate(val) {
365                    return Some(found);
366                }
367            }
368            None
369        }
370        Expression::Conjunction(children) | Expression::Disjunction(children) => {
371            for child in children {
372                if let Some(found) = find_nested_aggregate(child) {
373                    return Some(found);
374                }
375            }
376            None
377        }
378        _ => None,
379    }
380}
381
382/// Validate aggregate function argument counts in SELECT list
383///
384/// This validates that aggregate functions have the correct number of arguments:
385/// - MIN, MAX, SUM, AVG, TOTAL require exactly 1 argument (no wildcard)
386/// - COUNT allows 0-1 arguments (supports *)
387/// - GROUP_CONCAT requires 1-2 arguments
388///
389/// Returns an error with SQLite-compatible message if validation fails.
390pub fn validate_aggregate_arguments(select_list: &[SelectItem]) -> Result<(), ExecutorError> {
391    for item in select_list {
392        if let SelectItem::Expression { expr, .. } = item {
393            if let Some(agg_name) = check_aggregate_arg_count(expr) {
394                return Err(ExecutorError::WrongNumberOfArguments { function_name: agg_name });
395            }
396        }
397    }
398    Ok(())
399}
400
401/// Validate that there are no nested aggregate functions in the SELECT list
402///
403/// Nested aggregates like `SUM(MIN(x))` are invalid in SQL.
404/// Returns an error with SQLite-compatible message if nested aggregates are found.
405///
406/// Note: This uses the "misuse of aggregate function X()" format (with "function")
407/// as SQLite detects this during name resolution, not during execution.
408pub fn validate_no_nested_aggregates(select_list: &[SelectItem]) -> Result<(), ExecutorError> {
409    for item in select_list {
410        if let SelectItem::Expression { expr, .. } = item {
411            if let Some(inner_agg_name) = find_nested_aggregate(expr) {
412                return Err(ExecutorError::MisuseOfAggregate { function_name: inner_agg_name });
413            }
414        }
415    }
416    Ok(())
417}
418
419/// Build a set of aggregate alias names from the SELECT list
420///
421/// An "aggregate alias" is an alias that refers to an expression containing
422/// an aggregate function, e.g., `min(f1) AS m` makes `m` an aggregate alias.
423pub fn build_aggregate_aliases(select_list: &[SelectItem]) -> HashSet<String> {
424    let mut aliases = HashSet::new();
425
426    for item in select_list {
427        if let SelectItem::Expression { expr, alias: Some(alias_name), .. } = item {
428            // Check if this expression contains an aggregate
429            if expression_contains_aggregate(expr) {
430                // Store alias in lowercase for case-insensitive matching
431                aliases.insert(alias_name.to_lowercase());
432            }
433        }
434    }
435
436    aliases
437}
438
439/// Check if an expression contains an aggregate function
440pub fn expression_contains_aggregate(expr: &Expression) -> bool {
441    match expr {
442        Expression::AggregateFunction { .. } => true,
443        Expression::Function { name, args, .. } => {
444            // Check if this function is a built-in aggregate
445            if is_aggregate_function(name.as_str()) {
446                let upper = name.to_uppercase();
447                // Multi-arg MIN/MAX are scalar functions
448                if matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1 {
449                    // Still check arguments for nested aggregates
450                    args.iter().any(expression_contains_aggregate)
451                } else {
452                    true
453                }
454            } else {
455                // Check function arguments
456                args.iter().any(expression_contains_aggregate)
457            }
458        }
459        Expression::BinaryOp { left, right, .. } => {
460            expression_contains_aggregate(left) || expression_contains_aggregate(right)
461        }
462        Expression::UnaryOp { expr, .. } => expression_contains_aggregate(expr),
463        Expression::Cast { expr, .. } => expression_contains_aggregate(expr),
464        Expression::Case { operand, when_clauses, else_result } => {
465            operand.as_ref().is_some_and(|e| expression_contains_aggregate(e))
466                || when_clauses.iter().any(|w| {
467                    w.conditions.iter().any(expression_contains_aggregate)
468                        || expression_contains_aggregate(&w.result)
469                })
470                || else_result.as_ref().is_some_and(|e| expression_contains_aggregate(e))
471        }
472        Expression::IsNull { expr, .. } => expression_contains_aggregate(expr),
473        Expression::Between { expr, low, high, .. } => {
474            expression_contains_aggregate(expr)
475                || expression_contains_aggregate(low)
476                || expression_contains_aggregate(high)
477        }
478        Expression::InList { expr, values, .. } => {
479            expression_contains_aggregate(expr) || values.iter().any(expression_contains_aggregate)
480        }
481        Expression::In { expr, .. } => expression_contains_aggregate(expr),
482        Expression::Like { expr, pattern, .. } => {
483            expression_contains_aggregate(expr) || expression_contains_aggregate(pattern)
484        }
485        Expression::Position { substring, string, .. } => {
486            expression_contains_aggregate(substring) || expression_contains_aggregate(string)
487        }
488        Expression::Trim { removal_char, string, .. } => {
489            removal_char.as_ref().is_some_and(|e| expression_contains_aggregate(e))
490                || expression_contains_aggregate(string)
491        }
492        Expression::Extract { expr, .. } => expression_contains_aggregate(expr),
493        Expression::Interval { value, .. } => expression_contains_aggregate(value),
494        Expression::Conjunction(children) | Expression::Disjunction(children) => {
495            children.iter().any(expression_contains_aggregate)
496        }
497        // Subqueries have their own scope
498        Expression::ScalarSubquery(_) | Expression::Exists { .. } => false,
499        // Window functions are not aggregates in this context
500        Expression::WindowFunction { .. } => false,
501        // Other expressions don't contain aggregates
502        _ => false,
503    }
504}
505
506/// Find the first window function in an expression
507///
508/// Window functions can only appear in SELECT list and ORDER BY clauses.
509/// This function finds window functions in expressions where they are not allowed
510/// (e.g., WHERE, HAVING, GROUP BY, aggregate function arguments).
511///
512/// Returns the function name if found, None otherwise.
513pub fn find_window_function_in_expression(expr: &Expression) -> Option<String> {
514    match expr {
515        Expression::WindowFunction { function, .. } => {
516            // Return the function name
517            Some(function.name())
518        }
519        Expression::AggregateFunction { args, order_by, filter, .. } => {
520            // Check aggregate function arguments for window functions
521            for arg in args {
522                if let Some(found) = find_window_function_in_expression(arg) {
523                    return Some(found);
524                }
525            }
526            // Check ORDER BY expressions
527            if let Some(order_items) = order_by {
528                for item in order_items {
529                    if let Some(found) = find_window_function_in_expression(&item.expr) {
530                        return Some(found);
531                    }
532                }
533            }
534            // Check FILTER clause
535            if let Some(filter_expr) = filter {
536                if let Some(found) = find_window_function_in_expression(filter_expr) {
537                    return Some(found);
538                }
539            }
540            None
541        }
542        Expression::Function { args, .. } => {
543            for arg in args {
544                if let Some(found) = find_window_function_in_expression(arg) {
545                    return Some(found);
546                }
547            }
548            None
549        }
550        Expression::BinaryOp { left, right, .. } => {
551            find_window_function_in_expression(left)
552                .or_else(|| find_window_function_in_expression(right))
553        }
554        Expression::UnaryOp { expr, .. } => find_window_function_in_expression(expr),
555        Expression::Case { operand, when_clauses, else_result } => {
556            if let Some(op) = operand {
557                if let Some(found) = find_window_function_in_expression(op) {
558                    return Some(found);
559                }
560            }
561            for case_when in when_clauses {
562                for cond in &case_when.conditions {
563                    if let Some(found) = find_window_function_in_expression(cond) {
564                        return Some(found);
565                    }
566                }
567                if let Some(found) = find_window_function_in_expression(&case_when.result) {
568                    return Some(found);
569                }
570            }
571            if let Some(else_expr) = else_result {
572                find_window_function_in_expression(else_expr)
573            } else {
574                None
575            }
576        }
577        Expression::IsNull { expr, .. } => find_window_function_in_expression(expr),
578        Expression::IsDistinctFrom { left, right, .. } => {
579            find_window_function_in_expression(left)
580                .or_else(|| find_window_function_in_expression(right))
581        }
582        Expression::IsTruthValue { expr, .. } => find_window_function_in_expression(expr),
583        Expression::Between { expr, low, high, .. } => find_window_function_in_expression(expr)
584            .or_else(|| find_window_function_in_expression(low))
585            .or_else(|| find_window_function_in_expression(high)),
586        Expression::InList { expr, values, .. } => {
587            if let Some(found) = find_window_function_in_expression(expr) {
588                return Some(found);
589            }
590            for val in values {
591                if let Some(found) = find_window_function_in_expression(val) {
592                    return Some(found);
593                }
594            }
595            None
596        }
597        Expression::In { expr, .. } => find_window_function_in_expression(expr),
598        Expression::Exists { .. } => None, // EXISTS subqueries have their own scope
599        Expression::Cast { expr, .. } => find_window_function_in_expression(expr),
600        Expression::Like { expr, pattern, .. } => find_window_function_in_expression(expr)
601            .or_else(|| find_window_function_in_expression(pattern)),
602        Expression::Position { substring, string, .. } => {
603            find_window_function_in_expression(substring)
604                .or_else(|| find_window_function_in_expression(string))
605        }
606        Expression::Trim { removal_char, string, .. } => {
607            if let Some(char_expr) = removal_char {
608                if let Some(found) = find_window_function_in_expression(char_expr) {
609                    return Some(found);
610                }
611            }
612            find_window_function_in_expression(string)
613        }
614        Expression::Extract { expr, .. } => find_window_function_in_expression(expr),
615        Expression::ScalarSubquery(_) => None, // Subqueries have their own scope
616        Expression::QuantifiedComparison { expr, .. } => find_window_function_in_expression(expr),
617        Expression::Interval { value, .. } => find_window_function_in_expression(value),
618        Expression::MatchAgainst { search_modifier, .. } => {
619            find_window_function_in_expression(search_modifier)
620        }
621        Expression::Conjunction(children) | Expression::Disjunction(children) => {
622            for child in children {
623                if let Some(found) = find_window_function_in_expression(child) {
624                    return Some(found);
625                }
626            }
627            None
628        }
629        _ => None,
630    }
631}
632
633/// Check for misuse of aliased aggregates in HAVING clause
634///
635/// SQLite error: When an aggregate alias (e.g., `m` from `min(f1) AS m`) is used
636/// inside another aggregate function in the HAVING clause (e.g., `HAVING max(m) < 10`),
637/// it's an error. This function detects such misuse.
638///
639/// The `schema_columns` set contains column names from the actual table schema.
640/// If a column reference matches an actual table column, it's NOT a reference to an alias,
641/// even if an alias with the same name exists.
642///
643/// Returns Some(alias_name) if misuse is found, None otherwise.
644fn find_aliased_aggregate_misuse_in_expression(
645    expr: &Expression,
646    aggregate_aliases: &HashSet<String>,
647    schema_columns: &HashSet<String>,
648    inside_aggregate: bool,
649) -> Option<String> {
650    match expr {
651        // Check if this is an aggregate function - if so, mark that we're inside one
652        Expression::AggregateFunction { args, .. } => {
653            for arg in args {
654                if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
655                    arg,
656                    aggregate_aliases,
657                    schema_columns,
658                    true,
659                ) {
660                    return Some(alias);
661                }
662            }
663            None
664        }
665        Expression::Function { name, args, .. } => {
666            // Check if this function is a built-in aggregate
667            let is_agg = is_aggregate_function(name.as_str());
668            let upper = name.to_uppercase();
669            // Multi-arg MIN/MAX are scalar functions
670            let effectively_aggregate =
671                is_agg && !(matches!(upper.as_str(), "MIN" | "MAX") && args.len() > 1);
672
673            let new_inside_aggregate = inside_aggregate || effectively_aggregate;
674
675            for arg in args {
676                if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
677                    arg,
678                    aggregate_aliases,
679                    schema_columns,
680                    new_inside_aggregate,
681                ) {
682                    return Some(alias);
683                }
684            }
685            None
686        }
687        // Column reference - check if it's an aggregate alias used inside an aggregate
688        Expression::ColumnRef(col_id)
689            if col_id.schema_canonical().is_none() && col_id.table_canonical().is_none() =>
690        {
691            let column = col_id.column_canonical();
692            // If this column exists in the actual table schema, it's a real column reference,
693            // not a reference to a SELECT alias (even if an alias with the same name exists).
694            // Table columns take precedence over aliases in HAVING clause.
695            if schema_columns.contains(&column.to_lowercase()) {
696                return None; // Real column, not an alias reference
697            }
698
699            if inside_aggregate && aggregate_aliases.contains(&column.to_lowercase()) {
700                // Found misuse: aggregate alias used inside another aggregate
701                Some(column.to_string())
702            } else {
703                None
704            }
705        }
706        Expression::ColumnRef(_) => None, // Qualified refs can't be aliases
707        // Recursively check composite expressions
708        Expression::BinaryOp { left, right, .. } => find_aliased_aggregate_misuse_in_expression(
709            left,
710            aggregate_aliases,
711            schema_columns,
712            inside_aggregate,
713        )
714        .or_else(|| {
715            find_aliased_aggregate_misuse_in_expression(
716                right,
717                aggregate_aliases,
718                schema_columns,
719                inside_aggregate,
720            )
721        }),
722        Expression::UnaryOp { expr, .. } => find_aliased_aggregate_misuse_in_expression(
723            expr,
724            aggregate_aliases,
725            schema_columns,
726            inside_aggregate,
727        ),
728        Expression::Cast { expr, .. } => find_aliased_aggregate_misuse_in_expression(
729            expr,
730            aggregate_aliases,
731            schema_columns,
732            inside_aggregate,
733        ),
734        Expression::Case { operand, when_clauses, else_result } => {
735            if let Some(op) = operand {
736                if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
737                    op,
738                    aggregate_aliases,
739                    schema_columns,
740                    inside_aggregate,
741                ) {
742                    return Some(alias);
743                }
744            }
745            for when_clause in when_clauses {
746                for cond in &when_clause.conditions {
747                    if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
748                        cond,
749                        aggregate_aliases,
750                        schema_columns,
751                        inside_aggregate,
752                    ) {
753                        return Some(alias);
754                    }
755                }
756                if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
757                    &when_clause.result,
758                    aggregate_aliases,
759                    schema_columns,
760                    inside_aggregate,
761                ) {
762                    return Some(alias);
763                }
764            }
765            if let Some(else_expr) = else_result {
766                return find_aliased_aggregate_misuse_in_expression(
767                    else_expr,
768                    aggregate_aliases,
769                    schema_columns,
770                    inside_aggregate,
771                );
772            }
773            None
774        }
775        Expression::IsNull { expr, .. } => find_aliased_aggregate_misuse_in_expression(
776            expr,
777            aggregate_aliases,
778            schema_columns,
779            inside_aggregate,
780        ),
781        Expression::Between { expr, low, high, .. } => find_aliased_aggregate_misuse_in_expression(
782            expr,
783            aggregate_aliases,
784            schema_columns,
785            inside_aggregate,
786        )
787        .or_else(|| {
788            find_aliased_aggregate_misuse_in_expression(
789                low,
790                aggregate_aliases,
791                schema_columns,
792                inside_aggregate,
793            )
794        })
795        .or_else(|| {
796            find_aliased_aggregate_misuse_in_expression(
797                high,
798                aggregate_aliases,
799                schema_columns,
800                inside_aggregate,
801            )
802        }),
803        Expression::InList { expr, values, .. } => {
804            if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
805                expr,
806                aggregate_aliases,
807                schema_columns,
808                inside_aggregate,
809            ) {
810                return Some(alias);
811            }
812            for val in values {
813                if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
814                    val,
815                    aggregate_aliases,
816                    schema_columns,
817                    inside_aggregate,
818                ) {
819                    return Some(alias);
820                }
821            }
822            None
823        }
824        Expression::In { expr, .. } => find_aliased_aggregate_misuse_in_expression(
825            expr,
826            aggregate_aliases,
827            schema_columns,
828            inside_aggregate,
829        ),
830        Expression::Like { expr, pattern, .. } => find_aliased_aggregate_misuse_in_expression(
831            expr,
832            aggregate_aliases,
833            schema_columns,
834            inside_aggregate,
835        )
836        .or_else(|| {
837            find_aliased_aggregate_misuse_in_expression(
838                pattern,
839                aggregate_aliases,
840                schema_columns,
841                inside_aggregate,
842            )
843        }),
844        Expression::Position { substring, string, .. } => {
845            find_aliased_aggregate_misuse_in_expression(
846                substring,
847                aggregate_aliases,
848                schema_columns,
849                inside_aggregate,
850            )
851            .or_else(|| {
852                find_aliased_aggregate_misuse_in_expression(
853                    string,
854                    aggregate_aliases,
855                    schema_columns,
856                    inside_aggregate,
857                )
858            })
859        }
860        Expression::Trim { removal_char, string, .. } => {
861            if let Some(rc) = removal_char {
862                if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
863                    rc,
864                    aggregate_aliases,
865                    schema_columns,
866                    inside_aggregate,
867                ) {
868                    return Some(alias);
869                }
870            }
871            find_aliased_aggregate_misuse_in_expression(
872                string,
873                aggregate_aliases,
874                schema_columns,
875                inside_aggregate,
876            )
877        }
878        Expression::Extract { expr, .. } => find_aliased_aggregate_misuse_in_expression(
879            expr,
880            aggregate_aliases,
881            schema_columns,
882            inside_aggregate,
883        ),
884        Expression::Interval { value, .. } => find_aliased_aggregate_misuse_in_expression(
885            value,
886            aggregate_aliases,
887            schema_columns,
888            inside_aggregate,
889        ),
890        Expression::Conjunction(children) | Expression::Disjunction(children) => {
891            for child in children {
892                if let Some(alias) = find_aliased_aggregate_misuse_in_expression(
893                    child,
894                    aggregate_aliases,
895                    schema_columns,
896                    inside_aggregate,
897                ) {
898                    return Some(alias);
899                }
900            }
901            None
902        }
903        // Subqueries have their own scope
904        Expression::ScalarSubquery(_) | Expression::Exists { .. } => None,
905        Expression::QuantifiedComparison { expr, .. } => {
906            find_aliased_aggregate_misuse_in_expression(
907                expr,
908                aggregate_aliases,
909                schema_columns,
910                inside_aggregate,
911            )
912        }
913        Expression::IsDistinctFrom { left, right, .. } => {
914            find_aliased_aggregate_misuse_in_expression(
915                left,
916                aggregate_aliases,
917                schema_columns,
918                inside_aggregate,
919            )
920            .or_else(|| {
921                find_aliased_aggregate_misuse_in_expression(
922                    right,
923                    aggregate_aliases,
924                    schema_columns,
925                    inside_aggregate,
926                )
927            })
928        }
929        Expression::IsTruthValue { expr, .. } => find_aliased_aggregate_misuse_in_expression(
930            expr,
931            aggregate_aliases,
932            schema_columns,
933            inside_aggregate,
934        ),
935        // Other expressions don't contain column refs that could be aggregate aliases
936        _ => None,
937    }
938}
939
940/// Validate HAVING clause for misuse of aliased aggregates
941///
942/// This should be called after building the aggregate aliases from the SELECT list.
943/// Returns an error if an aggregate alias is used inside another aggregate in HAVING.
944///
945/// The `schema` parameter provides the actual table columns. If a column reference
946/// matches an actual table column, it's not considered an alias reference, even if
947/// an alias with the same name exists in the SELECT list.
948pub fn validate_having_aliased_aggregates(
949    having_clause: Option<&Expression>,
950    select_list: &[SelectItem],
951    schema: &CombinedSchema,
952) -> Result<(), ExecutorError> {
953    let Some(having_expr) = having_clause else {
954        return Ok(());
955    };
956
957    // Build the set of aggregate aliases
958    let aggregate_aliases = build_aggregate_aliases(select_list);
959
960    if aggregate_aliases.is_empty() {
961        return Ok(()); // No aggregate aliases to check
962    }
963
964    // Build the set of actual table column names (lowercase for case-insensitive matching)
965    let schema_columns: HashSet<String> = schema
966        .table_schemas
967        .values()
968        .flat_map(|(_, table_schema)| table_schema.columns.iter().map(|c| c.name.to_lowercase()))
969        .collect();
970
971    // Check for misuse in HAVING clause
972    if let Some(alias_name) = find_aliased_aggregate_misuse_in_expression(
973        having_expr,
974        &aggregate_aliases,
975        &schema_columns,
976        false,
977    ) {
978        return Err(ExecutorError::MisuseOfAliasedAggregate { alias_name });
979    }
980
981    Ok(())
982}
983
984#[cfg(test)]
985mod tests {
986    use vibesql_ast::{BinaryOperator, ColumnIdentifier, FunctionIdentifier, UnaryOperator};
987    use vibesql_catalog::{ColumnSchema, TableSchema};
988    use vibesql_types::{DataType, SqlValue};
989
990    use super::*;
991
992    /// Create a schema with f1 and f2 columns (for aliased aggregate tests)
993    fn make_f1_f2_schema() -> CombinedSchema {
994        let columns = vec![
995            ColumnSchema {
996                name: "f1".to_string(),
997                data_type: DataType::Integer,
998                nullable: true,
999                default_value: None,
1000                generated_expr: None,
1001                is_exact_integer_type: false,
1002                collation: None,
1003            },
1004            ColumnSchema {
1005                name: "f2".to_string(),
1006                data_type: DataType::Integer,
1007                nullable: true,
1008                default_value: None,
1009                generated_expr: None,
1010                is_exact_integer_type: false,
1011                collation: None,
1012            },
1013        ];
1014        let table_schema = TableSchema::new("test1".to_string(), columns);
1015        CombinedSchema::from_table("test1".to_string(), table_schema)
1016    }
1017
1018    #[test]
1019    fn test_min_star_invalid() {
1020        // MIN(*) should be invalid - returns error with function name (preserving original case)
1021        let expr = Expression::AggregateFunction {
1022            name: FunctionIdentifier::new("MIN"),
1023            distinct: false,
1024            args: vec![Expression::ColumnRef(ColumnIdentifier::simple("*", false))],
1025            order_by: None,
1026            filter: None,
1027        };
1028        let result = check_aggregate_arg_count(&expr);
1029        assert!(result.is_some(), "MIN(*) should be invalid");
1030        assert_eq!(result.unwrap(), "MIN"); // Preserves original case
1031    }
1032
1033    #[test]
1034    fn test_max_star_invalid() {
1035        // MAX(*) should be invalid
1036        let expr = Expression::AggregateFunction {
1037            name: FunctionIdentifier::new("MAX"),
1038            distinct: false,
1039            args: vec![Expression::ColumnRef(ColumnIdentifier::simple("*", false))],
1040            order_by: None,
1041            filter: None,
1042        };
1043        let result = check_aggregate_arg_count(&expr);
1044        assert!(result.is_some(), "MAX(*) should be invalid");
1045        assert_eq!(result.unwrap(), "MAX"); // Preserves original case
1046    }
1047
1048    #[test]
1049    fn test_min_no_args_invalid() {
1050        // MIN() with no arguments should be invalid
1051        let expr = Expression::AggregateFunction {
1052            name: FunctionIdentifier::new("MIN"),
1053            distinct: false,
1054            args: vec![],
1055            order_by: None,
1056            filter: None,
1057        };
1058        let result = check_aggregate_arg_count(&expr);
1059        assert!(result.is_some(), "MIN() should be invalid");
1060        assert_eq!(result.unwrap(), "MIN"); // Preserves original case
1061    }
1062
1063    #[test]
1064    fn test_validate_aggregate_arguments() {
1065        // Test the public function
1066        let select_list = vec![SelectItem::Expression {
1067            expr: Expression::AggregateFunction {
1068                name: FunctionIdentifier::new("MIN"),
1069                distinct: false,
1070                args: vec![Expression::ColumnRef(ColumnIdentifier::simple("*", false))],
1071                order_by: None,
1072                filter: None,
1073            },
1074            alias: None,
1075            source_text: None,
1076        }];
1077        let result = validate_aggregate_arguments(&select_list);
1078        assert!(result.is_err());
1079    }
1080
1081    #[test]
1082    fn test_having_with_aliased_aggregate_inside_aggregate() {
1083        // SELECT min(f1) AS m FROM test1 GROUP BY f1 HAVING max(m+5)<10
1084        // The alias 'm' refers to an aggregate and is used inside max() - should error
1085        // Note: 'm' is NOT a column in the table, so it's treated as an alias reference
1086        let select_list = vec![SelectItem::Expression {
1087            expr: Expression::AggregateFunction {
1088                name: FunctionIdentifier::new("min"),
1089                distinct: false,
1090                args: vec![Expression::ColumnRef(ColumnIdentifier::simple("f1", false))],
1091                order_by: None,
1092                filter: None,
1093            },
1094            alias: Some("m".to_string()),
1095            source_text: None,
1096        }];
1097
1098        // HAVING max(m+5)<10
1099        let having_expr = Expression::BinaryOp {
1100            op: BinaryOperator::LessThan,
1101            left: Box::new(Expression::AggregateFunction {
1102                name: FunctionIdentifier::new("max"),
1103                distinct: false,
1104                args: vec![Expression::BinaryOp {
1105                    op: BinaryOperator::Plus,
1106                    left: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("m", false))),
1107                    right: Box::new(Expression::Literal(SqlValue::Integer(5))),
1108                }],
1109                order_by: None,
1110                filter: None,
1111            }),
1112            right: Box::new(Expression::Literal(SqlValue::Integer(10))),
1113        };
1114
1115        // Use schema with f1, f2 - 'm' is not a column, so it's an alias reference
1116        let schema = make_f1_f2_schema();
1117        let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1118        assert!(result.is_err());
1119        match result {
1120            Err(ExecutorError::MisuseOfAliasedAggregate { alias_name }) => {
1121                assert_eq!(alias_name, "m");
1122            }
1123            _ => panic!("Expected MisuseOfAliasedAggregate error"),
1124        }
1125    }
1126
1127    #[test]
1128    fn test_having_with_aggregate_alias_not_inside_aggregate() {
1129        // SELECT min(f1) AS m FROM test1 GROUP BY f1 HAVING m>0
1130        // The alias 'm' refers to an aggregate but is NOT used inside another aggregate
1131        // This is a gray area - SQLite actually treats this as an error too,
1132        // because the alias cannot be resolved in HAVING context at all.
1133        // For now, we only detect the case where it's inside an aggregate.
1134        let select_list = vec![SelectItem::Expression {
1135            expr: Expression::AggregateFunction {
1136                name: FunctionIdentifier::new("min"),
1137                distinct: false,
1138                args: vec![Expression::ColumnRef(ColumnIdentifier::simple("f1", false))],
1139                order_by: None,
1140                filter: None,
1141            },
1142            alias: Some("m".to_string()),
1143            source_text: None,
1144        }];
1145
1146        // HAVING m>0 - alias used directly, not inside an aggregate
1147        let having_expr = Expression::BinaryOp {
1148            op: BinaryOperator::GreaterThan,
1149            left: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("m", false))),
1150            right: Box::new(Expression::Literal(SqlValue::Integer(0))),
1151        };
1152
1153        // This should pass our current validation (alias not inside aggregate)
1154        // SQLite would error on this too, but we'll catch it later during evaluation
1155        let schema = make_f1_f2_schema();
1156        let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1157        assert!(result.is_ok());
1158    }
1159
1160    #[test]
1161    fn test_having_without_aggregate_alias() {
1162        // SELECT count(*) FROM test1 GROUP BY f1 HAVING f1>0
1163        // No aliased aggregate, should pass
1164        let select_list = vec![SelectItem::Expression {
1165            expr: Expression::AggregateFunction {
1166                name: FunctionIdentifier::new("count"),
1167                distinct: false,
1168                args: vec![Expression::Wildcard],
1169                order_by: None,
1170                filter: None,
1171            },
1172            alias: None, // No alias
1173            source_text: None,
1174        }];
1175
1176        let having_expr = Expression::BinaryOp {
1177            op: BinaryOperator::GreaterThan,
1178            left: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("f1", false))),
1179            right: Box::new(Expression::Literal(SqlValue::Integer(0))),
1180        };
1181
1182        let schema = make_f1_f2_schema();
1183        let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1184        assert!(result.is_ok());
1185    }
1186
1187    #[test]
1188    fn test_having_with_non_aggregate_alias() {
1189        // SELECT f1 AS x, count(*) FROM test1 GROUP BY f1 HAVING max(x)<10
1190        // 'x' is an alias for f1, not an aggregate - should pass
1191        let select_list = vec![
1192            SelectItem::Expression {
1193                expr: Expression::ColumnRef(ColumnIdentifier::simple("f1", false)),
1194                alias: Some("x".to_string()),
1195                source_text: None,
1196            },
1197            SelectItem::Expression {
1198                expr: Expression::AggregateFunction {
1199                    name: FunctionIdentifier::new("count"),
1200                    distinct: false,
1201                    args: vec![Expression::Wildcard],
1202                    order_by: None,
1203                    filter: None,
1204                },
1205                alias: None,
1206                source_text: None,
1207            },
1208        ];
1209
1210        // HAVING max(x)<10 - 'x' is not an aggregate alias
1211        let having_expr = Expression::BinaryOp {
1212            op: BinaryOperator::LessThan,
1213            left: Box::new(Expression::AggregateFunction {
1214                name: FunctionIdentifier::new("max"),
1215                distinct: false,
1216                args: vec![Expression::ColumnRef(ColumnIdentifier::simple("x", false))],
1217                order_by: None,
1218                filter: None,
1219            }),
1220            right: Box::new(Expression::Literal(SqlValue::Integer(10))),
1221        };
1222
1223        let schema = make_f1_f2_schema();
1224        let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1225        assert!(result.is_ok());
1226    }
1227
1228    #[test]
1229    fn test_having_alias_shadows_column_uses_column() {
1230        // SELECT - col2 * - AVG(-col2) AS col0 FROM tab0 GROUP BY col2 HAVING AVG(col0) IS NULL
1231        // The alias 'col0' happens to match the aggregate expression, but 'col0' is also
1232        // a real column in the table. In this case, the HAVING clause refers to the
1233        // actual column col0, NOT the alias. This should NOT be an error.
1234        let select_list = vec![SelectItem::Expression {
1235            expr: Expression::BinaryOp {
1236                op: BinaryOperator::Multiply,
1237                left: Box::new(Expression::UnaryOp {
1238                    op: UnaryOperator::Minus,
1239                    expr: Box::new(Expression::ColumnRef(ColumnIdentifier::simple("col2", false))),
1240                }),
1241                right: Box::new(Expression::UnaryOp {
1242                    op: UnaryOperator::Minus,
1243                    expr: Box::new(Expression::AggregateFunction {
1244                        name: FunctionIdentifier::new("AVG"),
1245                        distinct: false,
1246                        args: vec![Expression::UnaryOp {
1247                            op: UnaryOperator::Minus,
1248                            expr: Box::new(Expression::ColumnRef(ColumnIdentifier::simple(
1249                                "col2", false,
1250                            ))),
1251                        }],
1252                        order_by: None,
1253                        filter: None,
1254                    }),
1255                }),
1256            },
1257            alias: Some("col0".to_string()), // Alias matches a column name!
1258            source_text: None,
1259        }];
1260
1261        // HAVING AVG(col0) IS NULL - col0 is a real column, not the alias
1262        let having_expr = Expression::IsNull {
1263            expr: Box::new(Expression::AggregateFunction {
1264                name: FunctionIdentifier::new("AVG"),
1265                distinct: false,
1266                args: vec![Expression::ColumnRef(ColumnIdentifier::simple("col0", false))],
1267                order_by: None,
1268                filter: None,
1269            }),
1270            negated: false,
1271        };
1272
1273        // Schema with col0, col1, col2 - col0 exists as an actual column
1274        let columns = vec![
1275            ColumnSchema {
1276                name: "col0".to_string(),
1277                data_type: DataType::Integer,
1278                nullable: true,
1279                default_value: None,
1280                generated_expr: None,
1281                is_exact_integer_type: false,
1282                collation: None,
1283            },
1284            ColumnSchema {
1285                name: "col1".to_string(),
1286                data_type: DataType::Integer,
1287                nullable: true,
1288                default_value: None,
1289                generated_expr: None,
1290                is_exact_integer_type: false,
1291                collation: None,
1292            },
1293            ColumnSchema {
1294                name: "col2".to_string(),
1295                data_type: DataType::Integer,
1296                nullable: true,
1297                default_value: None,
1298                generated_expr: None,
1299                is_exact_integer_type: false,
1300                collation: None,
1301            },
1302        ];
1303        let table_schema = TableSchema::new("tab0".to_string(), columns);
1304        let schema = CombinedSchema::from_table("tab0".to_string(), table_schema);
1305
1306        // This should pass - col0 refers to the real column, not the alias
1307        let result = validate_having_aliased_aggregates(Some(&having_expr), &select_list, &schema);
1308        assert!(result.is_ok(), "Expected Ok but got {:?}", result);
1309    }
1310
1311    #[test]
1312    fn test_build_aggregate_aliases() {
1313        // Test the helper function
1314        let select_list = vec![
1315            SelectItem::Expression {
1316                expr: Expression::AggregateFunction {
1317                    name: FunctionIdentifier::new("min"),
1318                    distinct: false,
1319                    args: vec![Expression::ColumnRef(ColumnIdentifier::simple("f1", false))],
1320                    order_by: None,
1321                    filter: None,
1322                },
1323                alias: Some("m".to_string()),
1324                source_text: None,
1325            },
1326            SelectItem::Expression {
1327                expr: Expression::ColumnRef(ColumnIdentifier::simple("f2", false)),
1328                alias: Some("col2".to_string()),
1329                source_text: None,
1330            },
1331            SelectItem::Expression {
1332                // coalesce(min(f1)+5, 11) AS m2
1333                expr: Expression::Function {
1334                    name: FunctionIdentifier::new("coalesce"),
1335                    args: vec![
1336                        Expression::BinaryOp {
1337                            op: BinaryOperator::Plus,
1338                            left: Box::new(Expression::AggregateFunction {
1339                                name: FunctionIdentifier::new("min"),
1340                                distinct: false,
1341                                args: vec![Expression::ColumnRef(ColumnIdentifier::simple(
1342                                    "f1", false,
1343                                ))],
1344                                order_by: None,
1345                                filter: None,
1346                            }),
1347                            right: Box::new(Expression::Literal(SqlValue::Integer(5))),
1348                        },
1349                        Expression::Literal(SqlValue::Integer(11)),
1350                    ],
1351                    character_unit: None,
1352                },
1353                alias: Some("m2".to_string()),
1354                source_text: None,
1355            },
1356        ];
1357
1358        let aliases = build_aggregate_aliases(&select_list);
1359        assert!(aliases.contains("m")); // min(f1) AS m is an aggregate alias
1360        assert!(!aliases.contains("col2")); // f2 AS col2 is NOT an aggregate alias
1361        assert!(aliases.contains("m2")); // coalesce(min(f1)+5, 11) contains an aggregate
1362    }
1363}