1use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::parser::ast::{
14 FrameBound, FrameUnit, OrderByColumn, SortDirection, WindowFrame, WindowSpec,
15};
16
17#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
20struct PartitionKey(String);
21
22impl PartitionKey {
23 fn from_values(values: Vec<DataValue>) -> Self {
25 let key_parts: Vec<String> = values
27 .iter()
28 .map(|v| match v {
29 DataValue::String(s) => format!("S:{}", s),
30 DataValue::InternedString(s) => format!("S:{}", s),
31 DataValue::Integer(i) => format!("I:{}", i),
32 DataValue::Float(f) => format!("F:{}", f),
33 DataValue::Boolean(b) => format!("B:{}", b),
34 DataValue::DateTime(dt) => format!("D:{}", dt),
35 DataValue::Null => "N".to_string(),
36 })
37 .collect();
38 let key = key_parts.join("|");
39 PartitionKey(key)
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct OrderedPartition {
46 rows: Vec<usize>,
48
49 row_positions: HashMap<usize, usize>,
51}
52
53impl OrderedPartition {
54 fn new(mut rows: Vec<usize>) -> Self {
56 let row_positions: HashMap<usize, usize> = rows
58 .iter()
59 .enumerate()
60 .map(|(pos, &row_idx)| (row_idx, pos))
61 .collect();
62
63 Self {
64 rows,
65 row_positions,
66 }
67 }
68
69 pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
71 let current_pos = self.row_positions.get(¤t_row)?;
72 let target_pos = (*current_pos as i32) + offset;
73
74 if target_pos >= 0 && target_pos < self.rows.len() as i32 {
75 Some(self.rows[target_pos as usize])
76 } else {
77 None
78 }
79 }
80
81 pub fn get_position(&self, row_index: usize) -> Option<usize> {
83 self.row_positions.get(&row_index).copied()
84 }
85
86 pub fn first_row(&self) -> Option<usize> {
88 self.rows.first().copied()
89 }
90
91 pub fn last_row(&self) -> Option<usize> {
93 self.rows.last().copied()
94 }
95}
96
97pub struct WindowContext {
99 source: Arc<DataView>,
101
102 partitions: BTreeMap<PartitionKey, OrderedPartition>,
104
105 row_to_partition: HashMap<usize, PartitionKey>,
107
108 spec: WindowSpec,
110}
111
112impl WindowContext {
113 pub fn new(
115 view: Arc<DataView>,
116 partition_by: Vec<String>,
117 order_by: Vec<OrderByColumn>,
118 ) -> Result<Self> {
119 Self::new_with_spec(
120 view,
121 WindowSpec {
122 partition_by,
123 order_by,
124 frame: None,
125 },
126 )
127 }
128
129 pub fn new_with_spec(view: Arc<DataView>, spec: WindowSpec) -> Result<Self> {
131 let partition_by = spec.partition_by.clone();
132 let order_by = spec.order_by.clone();
133
134 if partition_by.is_empty() {
136 let single_partition = Self::create_single_partition(&view, &order_by)?;
137 let partition_key = PartitionKey::from_values(vec![]);
138
139 let mut row_to_partition = HashMap::new();
141 for &row_idx in &single_partition.rows {
142 row_to_partition.insert(row_idx, partition_key.clone());
143 }
144
145 let mut partitions = BTreeMap::new();
146 partitions.insert(partition_key, single_partition);
147
148 return Ok(Self {
149 source: view,
150 partitions,
151 row_to_partition,
152 spec,
153 });
154 }
155
156 let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
158 let mut row_to_partition = HashMap::new();
159
160 let source_table = view.source();
162 let partition_col_indices: Vec<usize> = partition_by
163 .iter()
164 .map(|col| {
165 source_table
166 .get_column_index(col)
167 .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
168 })
169 .collect::<Result<Vec<_>>>()?;
170
171 for row_idx in view.get_visible_rows() {
173 let mut key_values = Vec::new();
175 for &col_idx in &partition_col_indices {
176 let value = source_table
177 .get_value(row_idx, col_idx)
178 .ok_or_else(|| anyhow!("Failed to get value for partition"))?
179 .clone();
180 key_values.push(value);
181 }
182 let key = PartitionKey::from_values(key_values);
183
184 partition_map.entry(key.clone()).or_default().push(row_idx);
186 row_to_partition.insert(row_idx, key);
187 }
188
189 let mut partitions = BTreeMap::new();
191 for (key, mut rows) in partition_map {
192 if !order_by.is_empty() {
194 Self::sort_rows(&mut rows, source_table, &order_by)?;
195 }
196
197 partitions.insert(key, OrderedPartition::new(rows));
198 }
199
200 Ok(Self {
201 source: view,
202 partitions,
203 row_to_partition,
204 spec,
205 })
206 }
207
208 fn create_single_partition(
210 view: &DataView,
211 order_by: &[OrderByColumn],
212 ) -> Result<OrderedPartition> {
213 let mut rows: Vec<usize> = view.get_visible_rows();
214
215 if !order_by.is_empty() {
216 Self::sort_rows(&mut rows, view.source(), order_by)?;
217 }
218
219 Ok(OrderedPartition::new(rows))
220 }
221
222 fn sort_rows(
224 rows: &mut Vec<usize>,
225 table: &DataTable,
226 order_by: &[OrderByColumn],
227 ) -> Result<()> {
228 let sort_cols: Vec<(usize, bool)> = order_by
230 .iter()
231 .map(|col| {
232 let idx = table
233 .get_column_index(&col.column)
234 .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", col.column))?;
235 let ascending = matches!(col.direction, SortDirection::Asc);
236 Ok((idx, ascending))
237 })
238 .collect::<Result<Vec<_>>>()?;
239
240 rows.sort_by(|&a, &b| {
242 for &(col_idx, ascending) in &sort_cols {
243 let val_a = table.get_value(a, col_idx);
244 let val_b = table.get_value(b, col_idx);
245
246 match (val_a, val_b) {
247 (None, None) => continue,
248 (None, Some(_)) => {
249 return if ascending {
250 std::cmp::Ordering::Less
251 } else {
252 std::cmp::Ordering::Greater
253 }
254 }
255 (Some(_), None) => {
256 return if ascending {
257 std::cmp::Ordering::Greater
258 } else {
259 std::cmp::Ordering::Less
260 }
261 }
262 (Some(v_a), Some(v_b)) => {
263 let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
265 if ord != std::cmp::Ordering::Equal {
266 return if ascending { ord } else { ord.reverse() };
267 }
268 }
269 }
270 }
271 std::cmp::Ordering::Equal
272 });
273
274 Ok(())
275 }
276
277 pub fn get_offset_value(
279 &self,
280 current_row: usize,
281 offset: i32,
282 column: &str,
283 ) -> Option<DataValue> {
284 let partition_key = self.row_to_partition.get(¤t_row)?;
286 let partition = self.partitions.get(partition_key)?;
287
288 let target_row = partition.get_row_at_offset(current_row, offset)?;
290
291 let source_table = self.source.source();
293 let col_idx = source_table.get_column_index(column)?;
294 source_table.get_value(target_row, col_idx).cloned()
295 }
296
297 pub fn get_row_number(&self, row_index: usize) -> usize {
299 if let Some(partition_key) = self.row_to_partition.get(&row_index) {
300 if let Some(partition) = self.partitions.get(partition_key) {
301 if let Some(position) = partition.get_position(row_index) {
302 return position + 1; }
304 }
305 }
306 0 }
308
309 pub fn get_frame_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
311 let frame_rows = self.get_frame_rows(row_index);
312 if frame_rows.is_empty() {
313 return Some(DataValue::Null);
314 }
315
316 let source_table = self.source.source();
317 let col_idx = source_table.get_column_index(column)?;
318
319 let first_row = frame_rows[0];
321 source_table.get_value(first_row, col_idx).cloned()
322 }
323
324 pub fn get_frame_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
326 let frame_rows = self.get_frame_rows(row_index);
327 if frame_rows.is_empty() {
328 return Some(DataValue::Null);
329 }
330
331 let source_table = self.source.source();
332 let col_idx = source_table.get_column_index(column)?;
333
334 let last_row = frame_rows[frame_rows.len() - 1];
336 source_table.get_value(last_row, col_idx).cloned()
337 }
338
339 pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
341 let partition_key = self.row_to_partition.get(&row_index)?;
342 let partition = self.partitions.get(partition_key)?;
343 let first_row = partition.first_row()?;
344
345 let source_table = self.source.source();
346 let col_idx = source_table.get_column_index(column)?;
347 source_table.get_value(first_row, col_idx).cloned()
348 }
349
350 pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
352 let partition_key = self.row_to_partition.get(&row_index)?;
353 let partition = self.partitions.get(partition_key)?;
354 let last_row = partition.last_row()?;
355
356 let source_table = self.source.source();
357 let col_idx = source_table.get_column_index(column)?;
358 source_table.get_value(last_row, col_idx).cloned()
359 }
360
361 pub fn partition_count(&self) -> usize {
363 self.partitions.len()
364 }
365
366 pub fn has_partitions(&self) -> bool {
368 !self.spec.partition_by.is_empty()
369 }
370
371 pub fn has_frame(&self) -> bool {
373 self.spec.frame.is_some()
374 }
375
376 pub fn source(&self) -> &DataTable {
378 self.source.source()
379 }
380
381 pub fn get_frame_rows(&self, row_index: usize) -> Vec<usize> {
383 let partition_key = match self.row_to_partition.get(&row_index) {
385 Some(key) => key,
386 None => return vec![],
387 };
388
389 let partition = match self.partitions.get(partition_key) {
390 Some(p) => p,
391 None => return vec![],
392 };
393
394 let current_pos = match partition.get_position(row_index) {
396 Some(pos) => pos as i64,
397 None => return vec![],
398 };
399
400 let frame = match &self.spec.frame {
402 Some(f) => f,
403 None => return partition.rows.clone(),
404 };
405
406 let (start_pos, end_pos) = match frame.unit {
408 FrameUnit::Rows => {
409 let start =
411 self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
412 let end = match &frame.end {
413 Some(bound) => {
414 self.calculate_frame_position(bound, current_pos, partition.rows.len())
415 }
416 None => current_pos, };
418 (start, end)
419 }
420 FrameUnit::Range => {
421 let start =
424 self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
425 let end = match &frame.end {
426 Some(bound) => {
427 self.calculate_frame_position(bound, current_pos, partition.rows.len())
428 }
429 None => current_pos,
430 };
431 (start, end)
432 }
433 };
434
435 let mut frame_rows = Vec::new();
437 for i in start_pos..=end_pos {
438 if i >= 0 && (i as usize) < partition.rows.len() {
439 frame_rows.push(partition.rows[i as usize]);
440 }
441 }
442
443 frame_rows
444 }
445
446 fn calculate_frame_position(
448 &self,
449 bound: &FrameBound,
450 current_pos: i64,
451 partition_size: usize,
452 ) -> i64 {
453 match bound {
454 FrameBound::UnboundedPreceding => 0,
455 FrameBound::UnboundedFollowing => partition_size as i64 - 1,
456 FrameBound::CurrentRow => current_pos,
457 FrameBound::Preceding(n) => current_pos - n,
458 FrameBound::Following(n) => current_pos + n,
459 }
460 }
461
462 pub fn get_frame_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
464 let frame_rows = self.get_frame_rows(row_index);
465 if frame_rows.is_empty() {
466 return Some(DataValue::Null);
467 }
468
469 let source_table = self.source.source();
470 let col_idx = source_table.get_column_index(column)?;
471
472 let mut sum = 0.0;
473 let mut has_float = false;
474 let mut has_value = false;
475
476 for &row_idx in &frame_rows {
478 if let Some(value) = source_table.get_value(row_idx, col_idx) {
479 match value {
480 DataValue::Integer(i) => {
481 sum += *i as f64;
482 has_value = true;
483 }
484 DataValue::Float(f) => {
485 sum += f;
486 has_float = true;
487 has_value = true;
488 }
489 DataValue::Null => {
490 }
492 _ => {
493 return Some(DataValue::Null);
495 }
496 }
497 }
498 }
499
500 if !has_value {
501 return Some(DataValue::Null);
502 }
503
504 if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
506 Some(DataValue::Integer(sum as i64))
507 } else {
508 Some(DataValue::Float(sum))
509 }
510 }
511
512 pub fn get_frame_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
514 let frame_rows = self.get_frame_rows(row_index);
515 if frame_rows.is_empty() {
516 return Some(DataValue::Integer(0));
517 }
518
519 if let Some(col_name) = column {
520 let source_table = self.source.source();
522 let col_idx = source_table.get_column_index(col_name)?;
523
524 let count = frame_rows
525 .iter()
526 .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
527 .filter(|v| !matches!(v, DataValue::Null))
528 .count();
529
530 Some(DataValue::Integer(count as i64))
531 } else {
532 Some(DataValue::Integer(frame_rows.len() as i64))
534 }
535 }
536
537 pub fn get_frame_avg(&self, row_index: usize, column: &str) -> Option<DataValue> {
539 let frame_rows = self.get_frame_rows(row_index);
540 if frame_rows.is_empty() {
541 return Some(DataValue::Null);
542 }
543
544 let source_table = self.source.source();
545 let col_idx = source_table.get_column_index(column)?;
546
547 let mut sum = 0.0;
548 let mut count = 0;
549
550 for &row_idx in &frame_rows {
552 if let Some(value) = source_table.get_value(row_idx, col_idx) {
553 match value {
554 DataValue::Integer(i) => {
555 sum += *i as f64;
556 count += 1;
557 }
558 DataValue::Float(f) => {
559 sum += f;
560 count += 1;
561 }
562 DataValue::Null => {
563 }
565 _ => {
566 return Some(DataValue::Null);
568 }
569 }
570 }
571 }
572
573 if count == 0 {
574 return Some(DataValue::Null);
575 }
576
577 Some(DataValue::Float(sum / count as f64))
578 }
579
580 pub fn get_frame_stddev(&self, row_index: usize, column: &str) -> Option<DataValue> {
582 let variance = self.get_frame_variance(row_index, column)?;
583 match variance {
584 DataValue::Float(v) => Some(DataValue::Float(v.sqrt())),
585 DataValue::Null => Some(DataValue::Null),
586 _ => Some(DataValue::Null),
587 }
588 }
589
590 pub fn get_frame_variance(&self, row_index: usize, column: &str) -> Option<DataValue> {
592 let frame_rows = self.get_frame_rows(row_index);
593 if frame_rows.is_empty() {
594 return Some(DataValue::Null);
595 }
596
597 let source_table = self.source.source();
598 let col_idx = source_table.get_column_index(column)?;
599
600 let mut values = Vec::new();
601
602 for &row_idx in &frame_rows {
604 if let Some(value) = source_table.get_value(row_idx, col_idx) {
605 match value {
606 DataValue::Integer(i) => values.push(*i as f64),
607 DataValue::Float(f) => values.push(*f),
608 DataValue::Null => {
609 }
611 _ => {
612 return Some(DataValue::Null);
614 }
615 }
616 }
617 }
618
619 if values.is_empty() {
620 return Some(DataValue::Null);
621 }
622
623 if values.len() == 1 {
624 return Some(DataValue::Float(0.0));
626 }
627
628 let mean = values.iter().sum::<f64>() / values.len() as f64;
630
631 let variance =
633 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
634
635 Some(DataValue::Float(variance))
636 }
637
638 pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
640 let partition_key = self.row_to_partition.get(&row_index)?;
641 let partition = self.partitions.get(partition_key)?;
642 let source_table = self.source.source();
643 let col_idx = source_table.get_column_index(column)?;
644
645 let mut sum = 0.0;
646 let mut has_float = false;
647 let mut has_value = false;
648
649 for &row_idx in &partition.rows {
651 if let Some(value) = source_table.get_value(row_idx, col_idx) {
652 match value {
653 DataValue::Integer(i) => {
654 sum += *i as f64;
655 has_value = true;
656 }
657 DataValue::Float(f) => {
658 sum += f;
659 has_float = true;
660 has_value = true;
661 }
662 DataValue::Null => {
663 }
665 _ => {
666 return Some(DataValue::Null);
668 }
669 }
670 }
671 }
672
673 if !has_value {
674 return Some(DataValue::Null);
675 }
676
677 if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
679 Some(DataValue::Integer(sum as i64))
680 } else {
681 Some(DataValue::Float(sum))
682 }
683 }
684
685 pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
687 let partition_key = self.row_to_partition.get(&row_index)?;
688 let partition = self.partitions.get(partition_key)?;
689
690 if let Some(col_name) = column {
691 let source_table = self.source.source();
693 let col_idx = source_table.get_column_index(col_name)?;
694
695 let count = partition
696 .rows
697 .iter()
698 .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
699 .filter(|v| !matches!(v, DataValue::Null))
700 .count();
701
702 Some(DataValue::Integer(count as i64))
703 } else {
704 Some(DataValue::Integer(partition.rows.len() as i64))
706 }
707 }
708}