vibesql_executor/evaluator/window/
aggregates.rs

1//! Aggregate window functions
2//!
3//! Implements COUNT, SUM, AVG, MIN, MAX with frame support.
4//! SQL:2003 FILTER clause support for conditional aggregation in window functions.
5//! SQL:2011 EXCLUDE clause support for frame exclusion.
6
7use std::cmp::Ordering;
8
9use vibesql_ast::Expression;
10use vibesql_storage::Row;
11use vibesql_types::SqlValue;
12
13use super::{partitioning::Partition, sorting::compare_values};
14
15/// Helper function to check if a row passes the FILTER condition
16/// Returns true if there's no filter, or if the filter evaluates to TRUE
17#[inline]
18fn passes_filter<F>(filter: Option<&Expression>, row: &Row, eval_fn: &F) -> bool
19where
20    F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
21{
22    if let Some(filter_expr) = filter {
23        if let Ok(filter_result) = eval_fn(filter_expr, row) {
24            matches!(filter_result, SqlValue::Boolean(true))
25        } else {
26            false // Evaluation error = skip row
27        }
28    } else {
29        true // No filter = include all rows
30    }
31}
32
33/// Evaluate COUNT aggregate window function over a frame
34///
35/// Counts rows in the frame. Two variants:
36/// - COUNT(*): counts all rows in frame
37/// - COUNT(expr): counts rows where expr is not NULL
38///
39/// Supports FILTER clause for conditional aggregation.
40/// Supports EXCLUDE clause via frame_indices iterator.
41///
42/// Example: COUNT(*) FILTER (WHERE x > 0) OVER (ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)
43pub fn evaluate_count_window<F, I>(
44    partition: &Partition,
45    frame_indices: I,
46    arg_expr: Option<&Expression>,
47    filter: Option<&Expression>,
48    eval_fn: F,
49) -> SqlValue
50where
51    F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
52    I: IntoIterator<Item = usize>,
53{
54    let mut count = 0i64;
55
56    for idx in frame_indices {
57        if idx >= partition.len() {
58            continue;
59        }
60
61        let row = &partition.rows[idx];
62
63        // Check FILTER condition first
64        if !passes_filter(filter, row, &eval_fn) {
65            continue;
66        }
67
68        // COUNT(*) - count all rows
69        if arg_expr.is_none() {
70            count += 1;
71            continue;
72        }
73
74        // COUNT(expr) - count non-NULL values
75        if let Some(expr) = arg_expr {
76            if let Ok(val) = eval_fn(expr, row) {
77                if !matches!(val, SqlValue::Null) {
78                    count += 1;
79                }
80            }
81        }
82    }
83
84    SqlValue::Integer(count)
85}
86
87/// Evaluate SUM aggregate window function over a frame
88///
89/// Sums numeric values in the frame, ignoring NULLs.
90/// Returns NULL if all values are NULL or frame is empty.
91/// Supports FILTER clause for conditional aggregation.
92/// Supports EXCLUDE clause via frame_indices iterator.
93///
94/// Example: SUM(amount) FILTER (WHERE status = 'paid') OVER (ORDER BY date) for filtered running totals
95pub fn evaluate_sum_window<F, I>(
96    partition: &Partition,
97    frame_indices: I,
98    arg_expr: &Expression,
99    filter: Option<&Expression>,
100    eval_fn: F,
101) -> SqlValue
102where
103    F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
104    I: IntoIterator<Item = usize>,
105{
106    let mut sum = 0.0f64;
107    let mut has_value = false;
108
109    for idx in frame_indices {
110        if idx >= partition.len() {
111            continue;
112        }
113
114        let row = &partition.rows[idx];
115
116        // Check FILTER condition first
117        if !passes_filter(filter, row, &eval_fn) {
118            continue;
119        }
120
121        if let Ok(val) = eval_fn(arg_expr, row) {
122            match val {
123                SqlValue::Integer(n) => {
124                    sum += n as f64;
125                    has_value = true;
126                }
127                SqlValue::Smallint(n) => {
128                    sum += n as f64;
129                    has_value = true;
130                }
131                SqlValue::Bigint(n) => {
132                    sum += n as f64;
133                    has_value = true;
134                }
135                SqlValue::Numeric(n) => {
136                    sum += n;
137                    has_value = true;
138                }
139                SqlValue::Float(n) => {
140                    sum += n as f64;
141                    has_value = true;
142                }
143                SqlValue::Real(n) => {
144                    sum += n as f64;
145                    has_value = true;
146                }
147                SqlValue::Double(n) => {
148                    sum += n;
149                    has_value = true;
150                }
151                SqlValue::Null => {} // Ignore NULL
152                _ => {}              // Ignore non-numeric values
153            }
154        }
155    }
156
157    if has_value {
158        // Return Integer if sum is a whole number, otherwise Numeric
159        if sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
160            SqlValue::Integer(sum as i64)
161        } else {
162            SqlValue::Numeric(sum)
163        }
164    } else {
165        SqlValue::Null
166    }
167}
168
169/// Evaluate AVG aggregate window function over a frame
170///
171/// Computes average of numeric values in the frame, ignoring NULLs.
172/// Returns NULL if all values are NULL or frame is empty.
173/// Supports FILTER clause for conditional aggregation.
174/// Supports EXCLUDE clause via frame_indices iterator.
175///
176/// Example: AVG(temperature) FILTER (WHERE valid = 1) OVER (ORDER BY date ROWS BETWEEN 6 PRECEDING AND CURRENT ROW)
177/// for 7-day moving average of valid readings
178pub fn evaluate_avg_window<F, I>(
179    partition: &Partition,
180    frame_indices: I,
181    arg_expr: &Expression,
182    filter: Option<&Expression>,
183    eval_fn: F,
184) -> SqlValue
185where
186    F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
187    I: IntoIterator<Item = usize>,
188{
189    let mut sum = 0.0f64;
190    let mut count = 0i64;
191
192    for idx in frame_indices {
193        if idx >= partition.len() {
194            continue;
195        }
196
197        let row = &partition.rows[idx];
198
199        // Check FILTER condition first
200        if !passes_filter(filter, row, &eval_fn) {
201            continue;
202        }
203
204        if let Ok(val) = eval_fn(arg_expr, row) {
205            match val {
206                SqlValue::Integer(n) => {
207                    sum += n as f64;
208                    count += 1;
209                }
210                SqlValue::Smallint(n) => {
211                    sum += n as f64;
212                    count += 1;
213                }
214                SqlValue::Bigint(n) => {
215                    sum += n as f64;
216                    count += 1;
217                }
218                SqlValue::Numeric(n) => {
219                    sum += n;
220                    count += 1;
221                }
222                SqlValue::Float(n) => {
223                    sum += n as f64;
224                    count += 1;
225                }
226                SqlValue::Real(n) => {
227                    sum += n as f64;
228                    count += 1;
229                }
230                SqlValue::Double(n) => {
231                    sum += n;
232                    count += 1;
233                }
234                SqlValue::Null => {} // Ignore NULL
235                _ => {}              // Ignore non-numeric values
236            }
237        }
238    }
239
240    if count > 0 {
241        SqlValue::Numeric(sum / count as f64)
242    } else {
243        SqlValue::Null
244    }
245}
246
247/// Evaluate MIN aggregate window function over a frame
248///
249/// Finds minimum value in the frame, ignoring NULLs.
250/// Returns NULL if all values are NULL or frame is empty.
251/// Supports FILTER clause for conditional aggregation.
252/// Supports EXCLUDE clause via frame_indices iterator.
253///
254/// Example: MIN(salary) FILTER (WHERE active = 1) OVER (PARTITION BY department)
255pub fn evaluate_min_window<F, I>(
256    partition: &Partition,
257    frame_indices: I,
258    arg_expr: &Expression,
259    filter: Option<&Expression>,
260    eval_fn: F,
261) -> SqlValue
262where
263    F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
264    I: IntoIterator<Item = usize>,
265{
266    let mut min_val: Option<SqlValue> = None;
267
268    for idx in frame_indices {
269        if idx >= partition.len() {
270            continue;
271        }
272
273        let row = &partition.rows[idx];
274
275        // Check FILTER condition first
276        if !passes_filter(filter, row, &eval_fn) {
277            continue;
278        }
279
280        if let Ok(val) = eval_fn(arg_expr, row) {
281            if matches!(val, SqlValue::Null) {
282                continue; // Ignore NULL
283            }
284
285            if let Some(ref current_min) = min_val {
286                if compare_values(&val, current_min) == Ordering::Less {
287                    min_val = Some(val);
288                }
289            } else {
290                min_val = Some(val);
291            }
292        }
293    }
294
295    min_val.unwrap_or(SqlValue::Null)
296}
297
298/// Evaluate MAX aggregate window function over a frame
299///
300/// Finds maximum value in the frame, ignoring NULLs.
301/// Returns NULL if all values are NULL or frame is empty.
302/// Supports FILTER clause for conditional aggregation.
303/// Supports EXCLUDE clause via frame_indices iterator.
304///
305/// Example: MAX(salary) FILTER (WHERE active = 1) OVER (PARTITION BY department)
306pub fn evaluate_max_window<F, I>(
307    partition: &Partition,
308    frame_indices: I,
309    arg_expr: &Expression,
310    filter: Option<&Expression>,
311    eval_fn: F,
312) -> SqlValue
313where
314    F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
315    I: IntoIterator<Item = usize>,
316{
317    let mut max_val: Option<SqlValue> = None;
318
319    for idx in frame_indices {
320        if idx >= partition.len() {
321            continue;
322        }
323
324        let row = &partition.rows[idx];
325
326        // Check FILTER condition first
327        if !passes_filter(filter, row, &eval_fn) {
328            continue;
329        }
330
331        if let Ok(val) = eval_fn(arg_expr, row) {
332            if matches!(val, SqlValue::Null) {
333                continue; // Ignore NULL
334            }
335
336            if let Some(ref current_max) = max_val {
337                if compare_values(&val, current_max) == Ordering::Greater {
338                    max_val = Some(val);
339                }
340            } else {
341                max_val = Some(val);
342            }
343        }
344    }
345
346    max_val.unwrap_or(SqlValue::Null)
347}
348
349/// Evaluate GROUP_CONCAT aggregate window function over a frame
350///
351/// Concatenates string values in the frame using the specified separator.
352/// Returns NULL if all values are NULL or frame is empty.
353/// Supports FILTER clause for conditional aggregation.
354/// Supports EXCLUDE clause via frame_indices iterator.
355///
356/// Example: GROUP_CONCAT(name, ',') FILTER (WHERE active = 1) OVER (ORDER BY date)
357pub fn evaluate_group_concat_window<F, I>(
358    partition: &Partition,
359    frame_indices: I,
360    arg_expr: &Expression,
361    separator: &str,
362    filter: Option<&Expression>,
363    eval_fn: F,
364) -> SqlValue
365where
366    F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
367    I: IntoIterator<Item = usize>,
368{
369    let mut values: Vec<String> = Vec::new();
370
371    for idx in frame_indices {
372        if idx >= partition.len() {
373            continue;
374        }
375
376        let row = &partition.rows[idx];
377
378        // Check FILTER condition first
379        if !passes_filter(filter, row, &eval_fn) {
380            continue;
381        }
382
383        if let Ok(val) = eval_fn(arg_expr, row) {
384            match val {
385                SqlValue::Null => {} // Ignore NULL
386                SqlValue::Varchar(s) | SqlValue::Character(s) => {
387                    values.push(s.to_string());
388                }
389                SqlValue::Integer(n) => {
390                    values.push(n.to_string());
391                }
392                SqlValue::Bigint(n) => {
393                    values.push(n.to_string());
394                }
395                SqlValue::Smallint(n) => {
396                    values.push(n.to_string());
397                }
398                SqlValue::Numeric(n) => {
399                    // Format as integer if whole number
400                    if n.fract() == 0.0 {
401                        values.push((n as i64).to_string());
402                    } else {
403                        values.push(n.to_string());
404                    }
405                }
406                SqlValue::Float(n) => {
407                    if n.fract() == 0.0 {
408                        values.push((n as i64).to_string());
409                    } else {
410                        values.push(n.to_string());
411                    }
412                }
413                SqlValue::Real(n) => {
414                    if n.fract() == 0.0 {
415                        values.push((n as i64).to_string());
416                    } else {
417                        values.push(n.to_string());
418                    }
419                }
420                SqlValue::Double(n) => {
421                    if n.fract() == 0.0 {
422                        values.push((n as i64).to_string());
423                    } else {
424                        values.push(n.to_string());
425                    }
426                }
427                other => {
428                    // Convert other types to string
429                    values.push(format!("{}", other));
430                }
431            }
432        }
433    }
434
435    if values.is_empty() {
436        SqlValue::Null
437    } else {
438        SqlValue::Varchar(values.join(separator).into())
439    }
440}