vibesql_executor/evaluator/window/
aggregates.rs1use std::cmp::Ordering;
8
9use vibesql_ast::Expression;
10use vibesql_storage::Row;
11use vibesql_types::SqlValue;
12
13use super::{partitioning::Partition, sorting::compare_values};
14
15#[inline]
18fn passes_filter<F>(filter: Option<&Expression>, row: &Row, eval_fn: &F) -> bool
19where
20 F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
21{
22 if let Some(filter_expr) = filter {
23 if let Ok(filter_result) = eval_fn(filter_expr, row) {
24 matches!(filter_result, SqlValue::Boolean(true))
25 } else {
26 false }
28 } else {
29 true }
31}
32
33pub fn evaluate_count_window<F, I>(
44 partition: &Partition,
45 frame_indices: I,
46 arg_expr: Option<&Expression>,
47 filter: Option<&Expression>,
48 eval_fn: F,
49) -> SqlValue
50where
51 F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
52 I: IntoIterator<Item = usize>,
53{
54 let mut count = 0i64;
55
56 for idx in frame_indices {
57 if idx >= partition.len() {
58 continue;
59 }
60
61 let row = &partition.rows[idx];
62
63 if !passes_filter(filter, row, &eval_fn) {
65 continue;
66 }
67
68 if arg_expr.is_none() {
70 count += 1;
71 continue;
72 }
73
74 if let Some(expr) = arg_expr {
76 if let Ok(val) = eval_fn(expr, row) {
77 if !matches!(val, SqlValue::Null) {
78 count += 1;
79 }
80 }
81 }
82 }
83
84 SqlValue::Integer(count)
85}
86
87pub fn evaluate_sum_window<F, I>(
96 partition: &Partition,
97 frame_indices: I,
98 arg_expr: &Expression,
99 filter: Option<&Expression>,
100 eval_fn: F,
101) -> SqlValue
102where
103 F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
104 I: IntoIterator<Item = usize>,
105{
106 let mut sum = 0.0f64;
107 let mut has_value = false;
108
109 for idx in frame_indices {
110 if idx >= partition.len() {
111 continue;
112 }
113
114 let row = &partition.rows[idx];
115
116 if !passes_filter(filter, row, &eval_fn) {
118 continue;
119 }
120
121 if let Ok(val) = eval_fn(arg_expr, row) {
122 match val {
123 SqlValue::Integer(n) => {
124 sum += n as f64;
125 has_value = true;
126 }
127 SqlValue::Smallint(n) => {
128 sum += n as f64;
129 has_value = true;
130 }
131 SqlValue::Bigint(n) => {
132 sum += n as f64;
133 has_value = true;
134 }
135 SqlValue::Numeric(n) => {
136 sum += n;
137 has_value = true;
138 }
139 SqlValue::Float(n) => {
140 sum += n as f64;
141 has_value = true;
142 }
143 SqlValue::Real(n) => {
144 sum += n as f64;
145 has_value = true;
146 }
147 SqlValue::Double(n) => {
148 sum += n;
149 has_value = true;
150 }
151 SqlValue::Null => {} _ => {} }
154 }
155 }
156
157 if has_value {
158 if sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
160 SqlValue::Integer(sum as i64)
161 } else {
162 SqlValue::Numeric(sum)
163 }
164 } else {
165 SqlValue::Null
166 }
167}
168
169pub fn evaluate_avg_window<F, I>(
179 partition: &Partition,
180 frame_indices: I,
181 arg_expr: &Expression,
182 filter: Option<&Expression>,
183 eval_fn: F,
184) -> SqlValue
185where
186 F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
187 I: IntoIterator<Item = usize>,
188{
189 let mut sum = 0.0f64;
190 let mut count = 0i64;
191
192 for idx in frame_indices {
193 if idx >= partition.len() {
194 continue;
195 }
196
197 let row = &partition.rows[idx];
198
199 if !passes_filter(filter, row, &eval_fn) {
201 continue;
202 }
203
204 if let Ok(val) = eval_fn(arg_expr, row) {
205 match val {
206 SqlValue::Integer(n) => {
207 sum += n as f64;
208 count += 1;
209 }
210 SqlValue::Smallint(n) => {
211 sum += n as f64;
212 count += 1;
213 }
214 SqlValue::Bigint(n) => {
215 sum += n as f64;
216 count += 1;
217 }
218 SqlValue::Numeric(n) => {
219 sum += n;
220 count += 1;
221 }
222 SqlValue::Float(n) => {
223 sum += n as f64;
224 count += 1;
225 }
226 SqlValue::Real(n) => {
227 sum += n as f64;
228 count += 1;
229 }
230 SqlValue::Double(n) => {
231 sum += n;
232 count += 1;
233 }
234 SqlValue::Null => {} _ => {} }
237 }
238 }
239
240 if count > 0 {
241 SqlValue::Numeric(sum / count as f64)
242 } else {
243 SqlValue::Null
244 }
245}
246
247pub fn evaluate_min_window<F, I>(
256 partition: &Partition,
257 frame_indices: I,
258 arg_expr: &Expression,
259 filter: Option<&Expression>,
260 eval_fn: F,
261) -> SqlValue
262where
263 F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
264 I: IntoIterator<Item = usize>,
265{
266 let mut min_val: Option<SqlValue> = None;
267
268 for idx in frame_indices {
269 if idx >= partition.len() {
270 continue;
271 }
272
273 let row = &partition.rows[idx];
274
275 if !passes_filter(filter, row, &eval_fn) {
277 continue;
278 }
279
280 if let Ok(val) = eval_fn(arg_expr, row) {
281 if matches!(val, SqlValue::Null) {
282 continue; }
284
285 if let Some(ref current_min) = min_val {
286 if compare_values(&val, current_min) == Ordering::Less {
287 min_val = Some(val);
288 }
289 } else {
290 min_val = Some(val);
291 }
292 }
293 }
294
295 min_val.unwrap_or(SqlValue::Null)
296}
297
298pub fn evaluate_max_window<F, I>(
307 partition: &Partition,
308 frame_indices: I,
309 arg_expr: &Expression,
310 filter: Option<&Expression>,
311 eval_fn: F,
312) -> SqlValue
313where
314 F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
315 I: IntoIterator<Item = usize>,
316{
317 let mut max_val: Option<SqlValue> = None;
318
319 for idx in frame_indices {
320 if idx >= partition.len() {
321 continue;
322 }
323
324 let row = &partition.rows[idx];
325
326 if !passes_filter(filter, row, &eval_fn) {
328 continue;
329 }
330
331 if let Ok(val) = eval_fn(arg_expr, row) {
332 if matches!(val, SqlValue::Null) {
333 continue; }
335
336 if let Some(ref current_max) = max_val {
337 if compare_values(&val, current_max) == Ordering::Greater {
338 max_val = Some(val);
339 }
340 } else {
341 max_val = Some(val);
342 }
343 }
344 }
345
346 max_val.unwrap_or(SqlValue::Null)
347}
348
349pub fn evaluate_group_concat_window<F, I>(
358 partition: &Partition,
359 frame_indices: I,
360 arg_expr: &Expression,
361 separator: &str,
362 filter: Option<&Expression>,
363 eval_fn: F,
364) -> SqlValue
365where
366 F: Fn(&Expression, &Row) -> Result<SqlValue, String>,
367 I: IntoIterator<Item = usize>,
368{
369 let mut values: Vec<String> = Vec::new();
370
371 for idx in frame_indices {
372 if idx >= partition.len() {
373 continue;
374 }
375
376 let row = &partition.rows[idx];
377
378 if !passes_filter(filter, row, &eval_fn) {
380 continue;
381 }
382
383 if let Ok(val) = eval_fn(arg_expr, row) {
384 match val {
385 SqlValue::Null => {} SqlValue::Varchar(s) | SqlValue::Character(s) => {
387 values.push(s.to_string());
388 }
389 SqlValue::Integer(n) => {
390 values.push(n.to_string());
391 }
392 SqlValue::Bigint(n) => {
393 values.push(n.to_string());
394 }
395 SqlValue::Smallint(n) => {
396 values.push(n.to_string());
397 }
398 SqlValue::Numeric(n) => {
399 if n.fract() == 0.0 {
401 values.push((n as i64).to_string());
402 } else {
403 values.push(n.to_string());
404 }
405 }
406 SqlValue::Float(n) => {
407 if n.fract() == 0.0 {
408 values.push((n as i64).to_string());
409 } else {
410 values.push(n.to_string());
411 }
412 }
413 SqlValue::Real(n) => {
414 if n.fract() == 0.0 {
415 values.push((n as i64).to_string());
416 } else {
417 values.push(n.to_string());
418 }
419 }
420 SqlValue::Double(n) => {
421 if n.fract() == 0.0 {
422 values.push((n as i64).to_string());
423 } else {
424 values.push(n.to_string());
425 }
426 }
427 other => {
428 values.push(format!("{}", other));
430 }
431 }
432 }
433 }
434
435 if values.is_empty() {
436 SqlValue::Null
437 } else {
438 SqlValue::Varchar(values.join(separator).into())
439 }
440}