1use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::parser::ast::{FrameBound, FrameUnit, OrderByColumn, SortDirection, WindowSpec};
14
15#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
18struct PartitionKey(String);
19
20impl PartitionKey {
21 fn from_values(values: Vec<DataValue>) -> Self {
23 let key_parts: Vec<String> = values
25 .iter()
26 .map(|v| match v {
27 DataValue::String(s) => format!("S:{}", s),
28 DataValue::InternedString(s) => format!("S:{}", s),
29 DataValue::Integer(i) => format!("I:{}", i),
30 DataValue::Float(f) => format!("F:{}", f),
31 DataValue::Boolean(b) => format!("B:{}", b),
32 DataValue::DateTime(dt) => format!("D:{}", dt),
33 DataValue::Null => "N".to_string(),
34 })
35 .collect();
36 let key = key_parts.join("|");
37 PartitionKey(key)
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct OrderedPartition {
44 rows: Vec<usize>,
46
47 row_positions: HashMap<usize, usize>,
49}
50
51impl OrderedPartition {
52 fn new(mut rows: Vec<usize>) -> Self {
54 let row_positions: HashMap<usize, usize> = rows
56 .iter()
57 .enumerate()
58 .map(|(pos, &row_idx)| (row_idx, pos))
59 .collect();
60
61 Self {
62 rows,
63 row_positions,
64 }
65 }
66
67 pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
69 let current_pos = self.row_positions.get(¤t_row)?;
70 let target_pos = (*current_pos as i32) + offset;
71
72 if target_pos >= 0 && target_pos < self.rows.len() as i32 {
73 Some(self.rows[target_pos as usize])
74 } else {
75 None
76 }
77 }
78
79 pub fn get_position(&self, row_index: usize) -> Option<usize> {
81 self.row_positions.get(&row_index).copied()
82 }
83
84 pub fn first_row(&self) -> Option<usize> {
86 self.rows.first().copied()
87 }
88
89 pub fn last_row(&self) -> Option<usize> {
91 self.rows.last().copied()
92 }
93}
94
95pub struct WindowContext {
97 source: Arc<DataView>,
99
100 partitions: BTreeMap<PartitionKey, OrderedPartition>,
102
103 row_to_partition: HashMap<usize, PartitionKey>,
105
106 spec: WindowSpec,
108}
109
110impl WindowContext {
111 pub fn new(
113 view: Arc<DataView>,
114 partition_by: Vec<String>,
115 order_by: Vec<OrderByColumn>,
116 ) -> Result<Self> {
117 Self::new_with_spec(
118 view,
119 WindowSpec {
120 partition_by,
121 order_by,
122 frame: None,
123 },
124 )
125 }
126
127 pub fn new_with_spec(view: Arc<DataView>, spec: WindowSpec) -> Result<Self> {
129 let partition_by = spec.partition_by.clone();
130 let order_by = spec.order_by.clone();
131
132 if partition_by.is_empty() {
134 let single_partition = Self::create_single_partition(&view, &order_by)?;
135 let partition_key = PartitionKey::from_values(vec![]);
136
137 let mut row_to_partition = HashMap::new();
139 for &row_idx in &single_partition.rows {
140 row_to_partition.insert(row_idx, partition_key.clone());
141 }
142
143 let mut partitions = BTreeMap::new();
144 partitions.insert(partition_key, single_partition);
145
146 return Ok(Self {
147 source: view,
148 partitions,
149 row_to_partition,
150 spec,
151 });
152 }
153
154 let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
156 let mut row_to_partition = HashMap::new();
157
158 let source_table = view.source();
160 let partition_col_indices: Vec<usize> = partition_by
161 .iter()
162 .map(|col| {
163 source_table
164 .get_column_index(col)
165 .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
166 })
167 .collect::<Result<Vec<_>>>()?;
168
169 for row_idx in view.get_visible_rows() {
171 let mut key_values = Vec::new();
173 for &col_idx in &partition_col_indices {
174 let value = source_table
175 .get_value(row_idx, col_idx)
176 .ok_or_else(|| anyhow!("Failed to get value for partition"))?
177 .clone();
178 key_values.push(value);
179 }
180 let key = PartitionKey::from_values(key_values);
181
182 partition_map.entry(key.clone()).or_default().push(row_idx);
184 row_to_partition.insert(row_idx, key);
185 }
186
187 let mut partitions = BTreeMap::new();
189 for (key, mut rows) in partition_map {
190 if !order_by.is_empty() {
192 Self::sort_rows(&mut rows, source_table, &order_by)?;
193 }
194
195 partitions.insert(key, OrderedPartition::new(rows));
196 }
197
198 Ok(Self {
199 source: view,
200 partitions,
201 row_to_partition,
202 spec,
203 })
204 }
205
206 fn create_single_partition(
208 view: &DataView,
209 order_by: &[OrderByColumn],
210 ) -> Result<OrderedPartition> {
211 let mut rows: Vec<usize> = view.get_visible_rows();
212
213 if !order_by.is_empty() {
214 Self::sort_rows(&mut rows, view.source(), order_by)?;
215 }
216
217 Ok(OrderedPartition::new(rows))
218 }
219
220 fn sort_rows(
222 rows: &mut Vec<usize>,
223 table: &DataTable,
224 order_by: &[OrderByColumn],
225 ) -> Result<()> {
226 let sort_cols: Vec<(usize, bool)> = order_by
228 .iter()
229 .map(|col| {
230 let idx = table
231 .get_column_index(&col.column)
232 .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", col.column))?;
233 let ascending = matches!(col.direction, SortDirection::Asc);
234 Ok((idx, ascending))
235 })
236 .collect::<Result<Vec<_>>>()?;
237
238 rows.sort_by(|&a, &b| {
240 for &(col_idx, ascending) in &sort_cols {
241 let val_a = table.get_value(a, col_idx);
242 let val_b = table.get_value(b, col_idx);
243
244 match (val_a, val_b) {
245 (None, None) => continue,
246 (None, Some(_)) => {
247 return if ascending {
248 std::cmp::Ordering::Less
249 } else {
250 std::cmp::Ordering::Greater
251 }
252 }
253 (Some(_), None) => {
254 return if ascending {
255 std::cmp::Ordering::Greater
256 } else {
257 std::cmp::Ordering::Less
258 }
259 }
260 (Some(v_a), Some(v_b)) => {
261 let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
263 if ord != std::cmp::Ordering::Equal {
264 return if ascending { ord } else { ord.reverse() };
265 }
266 }
267 }
268 }
269 std::cmp::Ordering::Equal
270 });
271
272 Ok(())
273 }
274
275 pub fn get_offset_value(
277 &self,
278 current_row: usize,
279 offset: i32,
280 column: &str,
281 ) -> Option<DataValue> {
282 let partition_key = self.row_to_partition.get(¤t_row)?;
284 let partition = self.partitions.get(partition_key)?;
285
286 let target_row = partition.get_row_at_offset(current_row, offset)?;
288
289 let source_table = self.source.source();
291 let col_idx = source_table.get_column_index(column)?;
292 source_table.get_value(target_row, col_idx).cloned()
293 }
294
295 pub fn get_row_number(&self, row_index: usize) -> usize {
297 if let Some(partition_key) = self.row_to_partition.get(&row_index) {
298 if let Some(partition) = self.partitions.get(partition_key) {
299 if let Some(position) = partition.get_position(row_index) {
300 return position + 1; }
302 }
303 }
304 0 }
306
307 pub fn get_frame_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
309 let frame_rows = self.get_frame_rows(row_index);
310 if frame_rows.is_empty() {
311 return Some(DataValue::Null);
312 }
313
314 let source_table = self.source.source();
315 let col_idx = source_table.get_column_index(column)?;
316
317 let first_row = frame_rows[0];
319 source_table.get_value(first_row, col_idx).cloned()
320 }
321
322 pub fn get_frame_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
324 let frame_rows = self.get_frame_rows(row_index);
325 if frame_rows.is_empty() {
326 return Some(DataValue::Null);
327 }
328
329 let source_table = self.source.source();
330 let col_idx = source_table.get_column_index(column)?;
331
332 let last_row = frame_rows[frame_rows.len() - 1];
334 source_table.get_value(last_row, col_idx).cloned()
335 }
336
337 pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
339 let partition_key = self.row_to_partition.get(&row_index)?;
340 let partition = self.partitions.get(partition_key)?;
341 let first_row = partition.first_row()?;
342
343 let source_table = self.source.source();
344 let col_idx = source_table.get_column_index(column)?;
345 source_table.get_value(first_row, col_idx).cloned()
346 }
347
348 pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
350 let partition_key = self.row_to_partition.get(&row_index)?;
351 let partition = self.partitions.get(partition_key)?;
352 let last_row = partition.last_row()?;
353
354 let source_table = self.source.source();
355 let col_idx = source_table.get_column_index(column)?;
356 source_table.get_value(last_row, col_idx).cloned()
357 }
358
359 pub fn partition_count(&self) -> usize {
361 self.partitions.len()
362 }
363
364 pub fn has_partitions(&self) -> bool {
366 !self.spec.partition_by.is_empty()
367 }
368
369 pub fn has_frame(&self) -> bool {
371 self.spec.frame.is_some()
372 }
373
374 pub fn source(&self) -> &DataTable {
376 self.source.source()
377 }
378
379 pub fn get_frame_rows(&self, row_index: usize) -> Vec<usize> {
381 let partition_key = match self.row_to_partition.get(&row_index) {
383 Some(key) => key,
384 None => return vec![],
385 };
386
387 let partition = match self.partitions.get(partition_key) {
388 Some(p) => p,
389 None => return vec![],
390 };
391
392 let current_pos = match partition.get_position(row_index) {
394 Some(pos) => pos as i64,
395 None => return vec![],
396 };
397
398 let frame = match &self.spec.frame {
400 Some(f) => f,
401 None => return partition.rows.clone(),
402 };
403
404 let (start_pos, end_pos) = match frame.unit {
406 FrameUnit::Rows => {
407 let start =
409 self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
410 let end = match &frame.end {
411 Some(bound) => {
412 self.calculate_frame_position(bound, current_pos, partition.rows.len())
413 }
414 None => current_pos, };
416 (start, end)
417 }
418 FrameUnit::Range => {
419 let start =
422 self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
423 let end = match &frame.end {
424 Some(bound) => {
425 self.calculate_frame_position(bound, current_pos, partition.rows.len())
426 }
427 None => current_pos,
428 };
429 (start, end)
430 }
431 };
432
433 let mut frame_rows = Vec::new();
435 for i in start_pos..=end_pos {
436 if i >= 0 && (i as usize) < partition.rows.len() {
437 frame_rows.push(partition.rows[i as usize]);
438 }
439 }
440
441 frame_rows
442 }
443
444 fn calculate_frame_position(
446 &self,
447 bound: &FrameBound,
448 current_pos: i64,
449 partition_size: usize,
450 ) -> i64 {
451 match bound {
452 FrameBound::UnboundedPreceding => 0,
453 FrameBound::UnboundedFollowing => partition_size as i64 - 1,
454 FrameBound::CurrentRow => current_pos,
455 FrameBound::Preceding(n) => current_pos - n,
456 FrameBound::Following(n) => current_pos + n,
457 }
458 }
459
460 pub fn get_frame_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
462 let frame_rows = self.get_frame_rows(row_index);
463 if frame_rows.is_empty() {
464 return Some(DataValue::Null);
465 }
466
467 let source_table = self.source.source();
468 let col_idx = source_table.get_column_index(column)?;
469
470 let mut sum = 0.0;
471 let mut has_float = false;
472 let mut has_value = false;
473
474 for &row_idx in &frame_rows {
476 if let Some(value) = source_table.get_value(row_idx, col_idx) {
477 match value {
478 DataValue::Integer(i) => {
479 sum += *i as f64;
480 has_value = true;
481 }
482 DataValue::Float(f) => {
483 sum += f;
484 has_float = true;
485 has_value = true;
486 }
487 DataValue::Null => {
488 }
490 _ => {
491 return Some(DataValue::Null);
493 }
494 }
495 }
496 }
497
498 if !has_value {
499 return Some(DataValue::Null);
500 }
501
502 if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
504 Some(DataValue::Integer(sum as i64))
505 } else {
506 Some(DataValue::Float(sum))
507 }
508 }
509
510 pub fn get_frame_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
512 let frame_rows = self.get_frame_rows(row_index);
513 if frame_rows.is_empty() {
514 return Some(DataValue::Integer(0));
515 }
516
517 if let Some(col_name) = column {
518 let source_table = self.source.source();
520 let col_idx = source_table.get_column_index(col_name)?;
521
522 let count = frame_rows
523 .iter()
524 .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
525 .filter(|v| !matches!(v, DataValue::Null))
526 .count();
527
528 Some(DataValue::Integer(count as i64))
529 } else {
530 Some(DataValue::Integer(frame_rows.len() as i64))
532 }
533 }
534
535 pub fn get_frame_avg(&self, row_index: usize, column: &str) -> Option<DataValue> {
537 let frame_rows = self.get_frame_rows(row_index);
538 if frame_rows.is_empty() {
539 return Some(DataValue::Null);
540 }
541
542 let source_table = self.source.source();
543 let col_idx = source_table.get_column_index(column)?;
544
545 let mut sum = 0.0;
546 let mut count = 0;
547
548 for &row_idx in &frame_rows {
550 if let Some(value) = source_table.get_value(row_idx, col_idx) {
551 match value {
552 DataValue::Integer(i) => {
553 sum += *i as f64;
554 count += 1;
555 }
556 DataValue::Float(f) => {
557 sum += f;
558 count += 1;
559 }
560 DataValue::Null => {
561 }
563 _ => {
564 return Some(DataValue::Null);
566 }
567 }
568 }
569 }
570
571 if count == 0 {
572 return Some(DataValue::Null);
573 }
574
575 Some(DataValue::Float(sum / count as f64))
576 }
577
578 pub fn get_frame_stddev(&self, row_index: usize, column: &str) -> Option<DataValue> {
580 let variance = self.get_frame_variance(row_index, column)?;
581 match variance {
582 DataValue::Float(v) => Some(DataValue::Float(v.sqrt())),
583 DataValue::Null => Some(DataValue::Null),
584 _ => Some(DataValue::Null),
585 }
586 }
587
588 pub fn get_frame_variance(&self, row_index: usize, column: &str) -> Option<DataValue> {
590 let frame_rows = self.get_frame_rows(row_index);
591 if frame_rows.is_empty() {
592 return Some(DataValue::Null);
593 }
594
595 let source_table = self.source.source();
596 let col_idx = source_table.get_column_index(column)?;
597
598 let mut values = Vec::new();
599
600 for &row_idx in &frame_rows {
602 if let Some(value) = source_table.get_value(row_idx, col_idx) {
603 match value {
604 DataValue::Integer(i) => values.push(*i as f64),
605 DataValue::Float(f) => values.push(*f),
606 DataValue::Null => {
607 }
609 _ => {
610 return Some(DataValue::Null);
612 }
613 }
614 }
615 }
616
617 if values.is_empty() {
618 return Some(DataValue::Null);
619 }
620
621 if values.len() == 1 {
622 return Some(DataValue::Float(0.0));
624 }
625
626 let mean = values.iter().sum::<f64>() / values.len() as f64;
628
629 let variance =
631 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
632
633 Some(DataValue::Float(variance))
634 }
635
636 pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
638 let partition_key = self.row_to_partition.get(&row_index)?;
639 let partition = self.partitions.get(partition_key)?;
640 let source_table = self.source.source();
641 let col_idx = source_table.get_column_index(column)?;
642
643 let mut sum = 0.0;
644 let mut has_float = false;
645 let mut has_value = false;
646
647 for &row_idx in &partition.rows {
649 if let Some(value) = source_table.get_value(row_idx, col_idx) {
650 match value {
651 DataValue::Integer(i) => {
652 sum += *i as f64;
653 has_value = true;
654 }
655 DataValue::Float(f) => {
656 sum += f;
657 has_float = true;
658 has_value = true;
659 }
660 DataValue::Null => {
661 }
663 _ => {
664 return Some(DataValue::Null);
666 }
667 }
668 }
669 }
670
671 if !has_value {
672 return Some(DataValue::Null);
673 }
674
675 if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
677 Some(DataValue::Integer(sum as i64))
678 } else {
679 Some(DataValue::Float(sum))
680 }
681 }
682
683 pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
685 let partition_key = self.row_to_partition.get(&row_index)?;
686 let partition = self.partitions.get(partition_key)?;
687
688 if let Some(col_name) = column {
689 let source_table = self.source.source();
691 let col_idx = source_table.get_column_index(col_name)?;
692
693 let count = partition
694 .rows
695 .iter()
696 .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
697 .filter(|v| !matches!(v, DataValue::Null))
698 .count();
699
700 Some(DataValue::Integer(count as i64))
701 } else {
702 Some(DataValue::Integer(partition.rows.len() as i64))
704 }
705 }
706}