1use std::cmp::Ordering;
7use std::ops::Range;
8
9use vibesql_ast::{
10 Expression, FrameBound, FrameExclude, FrameUnit, OrderByItem, UnaryOperator, WindowFrame,
11};
12use vibesql_types::SqlValue;
13
14use super::partitioning::Partition;
15use super::sorting::compare_values;
16use super::utils::evaluate_expression_with_map;
17
18pub fn validate_frame(frame_spec: &Option<WindowFrame>) -> Result<(), String> {
22 let frame = match frame_spec {
23 Some(f) => f,
24 None => return Ok(()),
25 };
26
27 validate_frame_bound(&frame.start, true)?;
29
30 if let Some(ref end_bound) = frame.end {
32 validate_frame_bound(end_bound, false)?;
33 }
34
35 Ok(())
36}
37
38fn validate_frame_bound(bound: &FrameBound, is_start: bool) -> Result<(), String> {
40 match bound {
41 FrameBound::UnboundedPreceding
42 | FrameBound::UnboundedFollowing
43 | FrameBound::CurrentRow => Ok(()),
44
45 FrameBound::Preceding(offset_expr) | FrameBound::Following(offset_expr) => {
46 let bound_type = if is_start { "starting" } else { "ending" };
47 match evaluate_offset_expr(offset_expr)? {
48 Some(offset) if offset < 0.0 => {
49 Err(format!("frame {} offset must be a non-negative number", bound_type))
50 }
51 None => {
52 Err(format!("frame {} offset must be a non-negative number", bound_type))
54 }
55 Some(_) => Ok(()), }
57 }
58 }
59}
60
61fn evaluate_offset_expr(expr: &Expression) -> Result<Option<f64>, String> {
66 match expr {
67 Expression::Literal(SqlValue::Integer(n)) => Ok(Some(*n as f64)),
68 Expression::Literal(SqlValue::Smallint(n)) => Ok(Some(*n as f64)),
69 Expression::Literal(SqlValue::Bigint(n)) => Ok(Some(*n as f64)),
70 Expression::Literal(SqlValue::Unsigned(n)) => Ok(Some(*n as f64)),
71 Expression::Literal(SqlValue::Float(f)) => Ok(Some(*f as f64)),
72 Expression::Literal(SqlValue::Real(f)) => Ok(Some(*f)),
73 Expression::Literal(SqlValue::Double(f)) => Ok(Some(*f)),
74 Expression::Literal(SqlValue::Numeric(f)) => Ok(Some(*f)),
75 Expression::Literal(SqlValue::Null) => {
76 Ok(None)
78 }
79 Expression::Literal(SqlValue::Character(s)) | Expression::Literal(SqlValue::Varchar(s)) => {
80 let s_str = s.as_str();
82 if s_str.is_empty() {
83 Ok(None)
85 } else if let Ok(n) = s_str.parse::<f64>() {
86 Ok(Some(n))
87 } else {
88 Ok(None)
90 }
91 }
92 Expression::Literal(SqlValue::Blob(_)) => {
93 Ok(None)
95 }
96 Expression::UnaryOp { op: UnaryOperator::Minus, expr } => {
97 match evaluate_offset_expr(expr)? {
98 Some(inner) => Ok(Some(-inner)),
99 None => Ok(None),
100 }
101 }
102 Expression::UnaryOp { op: UnaryOperator::Plus, expr } => evaluate_offset_expr(expr),
103 _ => Err("frame offset must be a constant expression".to_string()),
104 }
105}
106
107pub fn calculate_frame(
112 partition: &Partition,
113 current_row_idx: usize,
114 order_by: &Option<Vec<OrderByItem>>,
115 frame_spec: &Option<WindowFrame>,
116) -> Range<usize> {
117 let partition_size = partition.len();
118
119 let frame = match frame_spec {
123 Some(f) => f,
124 None => {
125 let has_order_by = order_by.as_ref().is_some_and(|items| !items.is_empty());
127
128 if has_order_by {
129 return calculate_range_frame(
132 partition,
133 current_row_idx,
134 order_by,
135 &FrameBound::UnboundedPreceding,
136 &Some(FrameBound::CurrentRow),
137 );
138 } else {
139 return 0..partition_size;
141 }
142 }
143 };
144
145 match frame.unit {
147 FrameUnit::Rows => {
148 calculate_rows_frame(partition, current_row_idx, &frame.start, &frame.end)
149 }
150 FrameUnit::Range => {
151 calculate_range_frame(partition, current_row_idx, order_by, &frame.start, &frame.end)
152 }
153 FrameUnit::Groups => {
154 calculate_groups_frame(partition, current_row_idx, order_by, &frame.start, &frame.end)
155 }
156 }
157}
158
159fn calculate_rows_frame(
161 partition: &Partition,
162 current_row_idx: usize,
163 start: &FrameBound,
164 end: &Option<FrameBound>,
165) -> Range<usize> {
166 let partition_size = partition.len();
167
168 let start_idx = calculate_rows_boundary(start, current_row_idx, partition_size, true);
170
171 let end_idx = match end {
173 Some(end_bound) => calculate_rows_boundary(end_bound, current_row_idx, partition_size, false),
174 None => current_row_idx + 1, };
176
177 let start = start_idx.min(partition_size);
179 let end = end_idx.min(partition_size).max(start);
180
181 start..end
182}
183
184fn calculate_range_frame(
192 partition: &Partition,
193 current_row_idx: usize,
194 order_by: &Option<Vec<OrderByItem>>,
195 start: &FrameBound,
196 end: &Option<FrameBound>,
197) -> Range<usize> {
198 let partition_size = partition.len();
199
200 if partition_size == 0 || current_row_idx >= partition_size {
201 return 0..0;
202 }
203
204 let order_items = match order_by {
206 Some(items) if !items.is_empty() => items,
207 _ => return 0..partition_size,
208 };
209
210 let current_row = &partition.rows[current_row_idx];
212 let current_value = evaluate_expression_with_map(&order_items[0].expr, current_row, &partition.column_map).unwrap_or(SqlValue::Null);
213
214 let start_idx = calculate_range_boundary(
216 partition,
217 current_row_idx,
218 order_items,
219 ¤t_value,
220 start,
221 true,
222 );
223
224 let end_idx = match end {
226 Some(end_bound) => calculate_range_boundary(
227 partition,
228 current_row_idx,
229 order_items,
230 ¤t_value,
231 end_bound,
232 false,
233 ),
234 None => {
235 find_last_peer(partition, current_row_idx, order_items) + 1
237 }
238 };
239
240 let start = start_idx.min(partition_size);
242 let end = end_idx.min(partition_size).max(start);
243
244 start..end
245}
246
247fn calculate_range_boundary(
253 partition: &Partition,
254 current_row_idx: usize,
255 order_items: &[OrderByItem],
256 current_value: &SqlValue,
257 bound: &FrameBound,
258 is_start: bool,
259) -> usize {
260 use vibesql_ast::OrderDirection;
261
262 let partition_size = partition.len();
263
264 let is_desc = order_items
266 .first()
267 .is_some_and(|item| matches!(item.direction, OrderDirection::Desc));
268
269 match bound {
270 FrameBound::UnboundedPreceding => 0,
271 FrameBound::UnboundedFollowing => partition_size,
272
273 FrameBound::CurrentRow => {
274 if is_start {
275 find_first_peer(partition, current_row_idx, order_items)
277 } else {
278 find_last_peer(partition, current_row_idx, order_items) + 1
280 }
281 }
282
283 FrameBound::Preceding(offset_expr) => {
284 let offset = get_numeric_offset(offset_expr);
285
286 let target_value = if is_desc {
289 add_to_value(current_value, offset)
290 } else {
291 subtract_from_value(current_value, offset)
292 };
293
294 if is_start {
295 if is_desc {
298 find_first_row_le_desc(partition, order_items, &target_value)
299 } else {
300 find_first_row_ge(partition, order_items, &target_value)
301 }
302 } else {
303 if is_desc {
305 find_last_row_ge_desc(partition, order_items, &target_value) + 1
306 } else {
307 find_last_row_le(partition, order_items, &target_value) + 1
308 }
309 }
310 }
311
312 FrameBound::Following(offset_expr) => {
313 let offset = get_numeric_offset(offset_expr);
314
315 let target_value = if is_desc {
318 subtract_from_value(current_value, offset)
319 } else {
320 add_to_value(current_value, offset)
321 };
322
323 if is_start {
324 if is_desc {
327 find_first_row_le_desc(partition, order_items, &target_value)
328 } else {
329 find_first_row_ge(partition, order_items, &target_value)
330 }
331 } else {
332 if is_desc {
334 find_last_row_ge_desc(partition, order_items, &target_value) + 1
335 } else {
336 find_last_row_le(partition, order_items, &target_value) + 1
337 }
338 }
339 }
340 }
341}
342
343fn calculate_groups_frame(
351 partition: &Partition,
352 current_row_idx: usize,
353 order_by: &Option<Vec<OrderByItem>>,
354 start: &FrameBound,
355 end: &Option<FrameBound>,
356) -> Range<usize> {
357 let partition_size = partition.len();
358
359 if partition_size == 0 || current_row_idx >= partition_size {
360 return 0..0;
361 }
362
363 let order_items = match order_by {
365 Some(items) if !items.is_empty() => items,
366 _ => return 0..partition_size,
367 };
368
369 let group_boundaries = build_group_boundaries(partition, order_items);
371 let current_group = find_group_for_row(&group_boundaries, current_row_idx);
372
373 let start_idx =
375 calculate_groups_boundary(&group_boundaries, current_group, start, true, partition_size);
376
377 let end_idx = match end {
379 Some(end_bound) => calculate_groups_boundary(
380 &group_boundaries,
381 current_group,
382 end_bound,
383 false,
384 partition_size,
385 ),
386 None => {
387 if current_group + 1 < group_boundaries.len() {
389 group_boundaries[current_group + 1]
390 } else {
391 partition_size
392 }
393 }
394 };
395
396 let start = start_idx.min(partition_size);
398 let end = end_idx.min(partition_size).max(start);
399
400 start..end
401}
402
403fn calculate_groups_boundary(
405 group_boundaries: &[usize],
406 current_group: usize,
407 bound: &FrameBound,
408 is_start: bool,
409 partition_size: usize,
410) -> usize {
411 let num_groups = group_boundaries.len();
412
413 match bound {
414 FrameBound::UnboundedPreceding => 0,
415 FrameBound::UnboundedFollowing => partition_size,
416
417 FrameBound::CurrentRow => {
418 if is_start {
419 group_boundaries[current_group]
420 } else {
421 if current_group + 1 < num_groups {
422 group_boundaries[current_group + 1]
423 } else {
424 partition_size
425 }
426 }
427 }
428
429 FrameBound::Preceding(offset_expr) => {
430 let offset = get_numeric_offset(offset_expr) as usize;
431 let target_group = current_group.saturating_sub(offset);
432
433 if is_start {
434 group_boundaries[target_group]
435 } else {
436 if target_group + 1 < num_groups {
437 group_boundaries[target_group + 1]
438 } else {
439 partition_size
440 }
441 }
442 }
443
444 FrameBound::Following(offset_expr) => {
445 let offset = get_numeric_offset(offset_expr) as usize;
446 let target_group = (current_group + offset).min(num_groups.saturating_sub(1));
447
448 if is_start {
449 group_boundaries[target_group]
450 } else {
451 if target_group + 1 < num_groups {
452 group_boundaries[target_group + 1]
453 } else {
454 partition_size
455 }
456 }
457 }
458 }
459}
460
461fn build_group_boundaries(partition: &Partition, order_items: &[OrderByItem]) -> Vec<usize> {
463 let mut boundaries = vec![0];
464
465 for i in 1..partition.len() {
466 let prev_row = &partition.rows[i - 1];
467 let curr_row = &partition.rows[i];
468
469 let mut is_new_group = false;
471 for item in order_items {
472 let prev_val = evaluate_expression_with_map(&item.expr, prev_row, &partition.column_map).unwrap_or(SqlValue::Null);
473 let curr_val = evaluate_expression_with_map(&item.expr, curr_row, &partition.column_map).unwrap_or(SqlValue::Null);
474 if compare_values(&prev_val, &curr_val) != Ordering::Equal {
475 is_new_group = true;
476 break;
477 }
478 }
479
480 if is_new_group {
481 boundaries.push(i);
482 }
483 }
484
485 boundaries
486}
487
488fn find_group_for_row(group_boundaries: &[usize], row_idx: usize) -> usize {
490 for (i, &boundary) in group_boundaries.iter().enumerate().rev() {
491 if row_idx >= boundary {
492 return i;
493 }
494 }
495 0
496}
497
498fn find_first_peer(
500 partition: &Partition,
501 current_row_idx: usize,
502 order_items: &[OrderByItem],
503) -> usize {
504 let current_row = &partition.rows[current_row_idx];
505
506 for i in (0..current_row_idx).rev() {
507 let row = &partition.rows[i];
508 let mut is_peer = true;
509
510 for item in order_items {
511 let curr_val = evaluate_expression_with_map(&item.expr, current_row, &partition.column_map).unwrap_or(SqlValue::Null);
512 let row_val = evaluate_expression_with_map(&item.expr, row, &partition.column_map).unwrap_or(SqlValue::Null);
513 if compare_values(&curr_val, &row_val) != Ordering::Equal {
514 is_peer = false;
515 break;
516 }
517 }
518
519 if !is_peer {
520 return i + 1;
521 }
522 }
523
524 0
525}
526
527fn find_last_peer(
529 partition: &Partition,
530 current_row_idx: usize,
531 order_items: &[OrderByItem],
532) -> usize {
533 let current_row = &partition.rows[current_row_idx];
534
535 for i in (current_row_idx + 1)..partition.len() {
536 let row = &partition.rows[i];
537 let mut is_peer = true;
538
539 for item in order_items {
540 let curr_val = evaluate_expression_with_map(&item.expr, current_row, &partition.column_map).unwrap_or(SqlValue::Null);
541 let row_val = evaluate_expression_with_map(&item.expr, row, &partition.column_map).unwrap_or(SqlValue::Null);
542 if compare_values(&curr_val, &row_val) != Ordering::Equal {
543 is_peer = false;
544 break;
545 }
546 }
547
548 if !is_peer {
549 return i - 1;
550 }
551 }
552
553 partition.len() - 1
554}
555
556fn find_first_row_ge(
558 partition: &Partition,
559 order_items: &[OrderByItem],
560 target: &SqlValue,
561) -> usize {
562 for (i, row) in partition.rows.iter().enumerate() {
563 let val = evaluate_expression_with_map(&order_items[0].expr, row, &partition.column_map).unwrap_or(SqlValue::Null);
564 if compare_values(&val, target) != Ordering::Less {
565 return i;
566 }
567 }
568 partition.len()
569}
570
571fn find_last_row_le(
573 partition: &Partition,
574 order_items: &[OrderByItem],
575 target: &SqlValue,
576) -> usize {
577 for i in (0..partition.len()).rev() {
578 let val = evaluate_expression_with_map(&order_items[0].expr, &partition.rows[i], &partition.column_map).unwrap_or(SqlValue::Null);
579 if compare_values(&val, target) != Ordering::Greater {
580 return i;
581 }
582 }
583 0
584}
585
586fn find_first_row_le_desc(
591 partition: &Partition,
592 order_items: &[OrderByItem],
593 target: &SqlValue,
594) -> usize {
595 for (i, row) in partition.rows.iter().enumerate() {
596 let val = evaluate_expression_with_map(&order_items[0].expr, row, &partition.column_map).unwrap_or(SqlValue::Null);
597 if compare_values(&val, target) != Ordering::Greater {
598 return i;
599 }
600 }
601 partition.len()
602}
603
604fn find_last_row_ge_desc(
609 partition: &Partition,
610 order_items: &[OrderByItem],
611 target: &SqlValue,
612) -> usize {
613 for i in (0..partition.len()).rev() {
614 let val = evaluate_expression_with_map(&order_items[0].expr, &partition.rows[i], &partition.column_map).unwrap_or(SqlValue::Null);
615 if compare_values(&val, target) != Ordering::Less {
616 return i;
617 }
618 }
619 0
620}
621
622fn get_numeric_offset(expr: &Expression) -> f64 {
624 match evaluate_offset_expr(expr) {
625 Ok(Some(n)) => n,
626 _ => 0.0,
627 }
628}
629
630fn subtract_from_value(value: &SqlValue, offset: f64) -> SqlValue {
632 match value {
633 SqlValue::Integer(n) => SqlValue::Real(*n as f64 - offset),
634 SqlValue::Smallint(n) => SqlValue::Real(*n as f64 - offset),
635 SqlValue::Bigint(n) => SqlValue::Real(*n as f64 - offset),
636 SqlValue::Unsigned(n) => SqlValue::Real(*n as f64 - offset),
637 SqlValue::Float(f) => SqlValue::Real(*f as f64 - offset),
638 SqlValue::Real(f) => SqlValue::Real(*f - offset),
639 SqlValue::Double(f) => SqlValue::Real(*f - offset),
640 SqlValue::Numeric(f) => SqlValue::Real(*f - offset),
641 _ => value.clone(), }
643}
644
645fn add_to_value(value: &SqlValue, offset: f64) -> SqlValue {
647 match value {
648 SqlValue::Integer(n) => SqlValue::Real(*n as f64 + offset),
649 SqlValue::Smallint(n) => SqlValue::Real(*n as f64 + offset),
650 SqlValue::Bigint(n) => SqlValue::Real(*n as f64 + offset),
651 SqlValue::Unsigned(n) => SqlValue::Real(*n as f64 + offset),
652 SqlValue::Float(f) => SqlValue::Real(*f as f64 + offset),
653 SqlValue::Real(f) => SqlValue::Real(*f + offset),
654 SqlValue::Double(f) => SqlValue::Real(*f + offset),
655 SqlValue::Numeric(f) => SqlValue::Real(*f + offset),
656 _ => value.clone(), }
658}
659
660fn calculate_rows_boundary(
666 bound: &FrameBound,
667 current_row_idx: usize,
668 partition_size: usize,
669 is_start: bool,
670) -> usize {
671 match bound {
672 FrameBound::UnboundedPreceding => 0,
673
674 FrameBound::UnboundedFollowing => partition_size,
675
676 FrameBound::CurrentRow => {
677 if is_start {
678 current_row_idx
679 } else {
680 current_row_idx + 1 }
682 }
683
684 FrameBound::Preceding(offset_expr) => {
685 let offset = get_numeric_offset(offset_expr) as usize;
686 current_row_idx.saturating_sub(offset)
687 }
688
689 FrameBound::Following(offset_expr) => {
690 let offset = get_numeric_offset(offset_expr) as usize;
691 let result = current_row_idx + offset;
692
693 if is_start {
694 result.min(partition_size)
695 } else {
696 (result + 1).min(partition_size) }
698 }
699 }
700}
701
702#[derive(Debug, Clone)]
704pub struct FrameResult {
705 pub range: Range<usize>,
707 pub exclude: Option<FrameExclude>,
709 pub current_row_idx: usize,
711}
712
713impl FrameResult {
714 pub fn includes(
718 &self,
719 row_idx: usize,
720 partition: &Partition,
721 order_by: &Option<Vec<OrderByItem>>,
722 ) -> bool {
723 if !self.range.contains(&row_idx) {
725 return false;
726 }
727
728 match self.exclude {
730 None | Some(FrameExclude::NoOthers) => true,
731
732 Some(FrameExclude::CurrentRow) => row_idx != self.current_row_idx,
733
734 Some(FrameExclude::Group) => {
735 !is_peer(row_idx, self.current_row_idx, partition, order_by)
737 }
738
739 Some(FrameExclude::Ties) => {
740 row_idx == self.current_row_idx
742 || !is_peer(row_idx, self.current_row_idx, partition, order_by)
743 }
744 }
745 }
746
747 pub fn included_indices<'a>(
751 &'a self,
752 partition: &'a Partition,
753 order_by: &'a Option<Vec<OrderByItem>>,
754 ) -> impl Iterator<Item = usize> + 'a {
755 self.range.clone().filter(move |&idx| self.includes(idx, partition, order_by))
756 }
757}
758
759pub fn calculate_frame_with_exclusion(
763 partition: &Partition,
764 current_row_idx: usize,
765 order_by: &Option<Vec<OrderByItem>>,
766 frame_spec: &Option<WindowFrame>,
767) -> FrameResult {
768 let range = calculate_frame(partition, current_row_idx, order_by, frame_spec);
769 let exclude = frame_spec.as_ref().and_then(|f| f.exclude);
770
771 FrameResult { range, exclude, current_row_idx }
772}
773
774fn is_peer(
779 row_idx_a: usize,
780 row_idx_b: usize,
781 partition: &Partition,
782 order_by: &Option<Vec<OrderByItem>>,
783) -> bool {
784 if row_idx_a == row_idx_b {
786 return true;
787 }
788
789 let order_items = match order_by {
791 Some(items) if !items.is_empty() => items,
792 _ => return true,
793 };
794
795 if row_idx_a >= partition.len() || row_idx_b >= partition.len() {
797 return false;
798 }
799
800 let row_a = &partition.rows[row_idx_a];
801 let row_b = &partition.rows[row_idx_b];
802
803 for order_item in order_items {
805 let val_a = evaluate_expression_with_map(&order_item.expr, row_a, &partition.column_map).unwrap_or(SqlValue::Null);
806 let val_b = evaluate_expression_with_map(&order_item.expr, row_b, &partition.column_map).unwrap_or(SqlValue::Null);
807
808 if compare_values(&val_a, &val_b) != Ordering::Equal {
809 return false;
810 }
811 }
812
813 true
814}