1use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::parser::ast::{
14 FrameBound, FrameUnit, OrderByItem, SortDirection, SqlExpression, WindowSpec,
15};
16
17#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
20struct PartitionKey(String);
21
22impl PartitionKey {
23 fn from_values(values: Vec<DataValue>) -> Self {
25 let key_parts: Vec<String> = values
27 .iter()
28 .map(|v| match v {
29 DataValue::String(s) => format!("S:{}", s),
30 DataValue::InternedString(s) => format!("S:{}", s),
31 DataValue::Integer(i) => format!("I:{}", i),
32 DataValue::Float(f) => format!("F:{}", f),
33 DataValue::Boolean(b) => format!("B:{}", b),
34 DataValue::DateTime(dt) => format!("D:{}", dt),
35 DataValue::Null => "N".to_string(),
36 })
37 .collect();
38 let key = key_parts.join("|");
39 PartitionKey(key)
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct OrderedPartition {
46 rows: Vec<usize>,
48
49 row_positions: HashMap<usize, usize>,
51}
52
53impl OrderedPartition {
54 fn new(rows: Vec<usize>) -> Self {
56 let row_positions: HashMap<usize, usize> = rows
58 .iter()
59 .enumerate()
60 .map(|(pos, &row_idx)| (row_idx, pos))
61 .collect();
62
63 Self {
64 rows,
65 row_positions,
66 }
67 }
68
69 pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
71 let current_pos = self.row_positions.get(¤t_row)?;
72 let target_pos = (*current_pos as i32) + offset;
73
74 if target_pos >= 0 && target_pos < self.rows.len() as i32 {
75 Some(self.rows[target_pos as usize])
76 } else {
77 None
78 }
79 }
80
81 pub fn get_position(&self, row_index: usize) -> Option<usize> {
83 self.row_positions.get(&row_index).copied()
84 }
85
86 pub fn first_row(&self) -> Option<usize> {
88 self.rows.first().copied()
89 }
90
91 pub fn last_row(&self) -> Option<usize> {
93 self.rows.last().copied()
94 }
95}
96
97pub struct WindowContext {
99 source: Arc<DataView>,
101
102 partitions: BTreeMap<PartitionKey, OrderedPartition>,
104
105 row_to_partition: HashMap<usize, PartitionKey>,
107
108 spec: WindowSpec,
110}
111
112impl WindowContext {
113 pub fn new(
115 view: Arc<DataView>,
116 partition_by: Vec<String>,
117 order_by: Vec<OrderByItem>,
118 ) -> Result<Self> {
119 Self::new_with_spec(
120 view,
121 WindowSpec {
122 partition_by,
123 order_by,
124 frame: None,
125 },
126 )
127 }
128
129 pub fn new_with_spec(view: Arc<DataView>, spec: WindowSpec) -> Result<Self> {
131 let partition_by = spec.partition_by.clone();
132 let order_by = spec.order_by.clone();
133
134 if partition_by.is_empty() {
136 let single_partition = Self::create_single_partition(&view, &order_by)?;
137 let partition_key = PartitionKey::from_values(vec![]);
138
139 let mut row_to_partition = HashMap::new();
141 for &row_idx in &single_partition.rows {
142 row_to_partition.insert(row_idx, partition_key.clone());
143 }
144
145 let mut partitions = BTreeMap::new();
146 partitions.insert(partition_key, single_partition);
147
148 return Ok(Self {
149 source: view,
150 partitions,
151 row_to_partition,
152 spec,
153 });
154 }
155
156 let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
158 let mut row_to_partition = HashMap::new();
159
160 let source_table = view.source();
162 let partition_col_indices: Vec<usize> = partition_by
163 .iter()
164 .map(|col| {
165 source_table
166 .get_column_index(col)
167 .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
168 })
169 .collect::<Result<Vec<_>>>()?;
170
171 for row_idx in view.get_visible_rows() {
173 let mut key_values = Vec::new();
175 for &col_idx in &partition_col_indices {
176 let value = source_table
177 .get_value(row_idx, col_idx)
178 .ok_or_else(|| anyhow!("Failed to get value for partition"))?
179 .clone();
180 key_values.push(value);
181 }
182 let key = PartitionKey::from_values(key_values);
183
184 partition_map.entry(key.clone()).or_default().push(row_idx);
186 row_to_partition.insert(row_idx, key);
187 }
188
189 let mut partitions = BTreeMap::new();
191 for (key, mut rows) in partition_map {
192 if !order_by.is_empty() {
194 Self::sort_rows(&mut rows, source_table, &order_by)?;
195 }
196
197 partitions.insert(key, OrderedPartition::new(rows));
198 }
199
200 Ok(Self {
201 source: view,
202 partitions,
203 row_to_partition,
204 spec,
205 })
206 }
207
208 fn create_single_partition(
210 view: &DataView,
211 order_by: &[OrderByItem],
212 ) -> Result<OrderedPartition> {
213 let mut rows: Vec<usize> = view.get_visible_rows();
214
215 if !order_by.is_empty() {
216 Self::sort_rows(&mut rows, view.source(), order_by)?;
217 }
218
219 Ok(OrderedPartition::new(rows))
220 }
221
222 fn sort_rows(rows: &mut Vec<usize>, table: &DataTable, order_by: &[OrderByItem]) -> Result<()> {
224 let sort_cols: Vec<(usize, bool)> = order_by
226 .iter()
227 .map(|col| {
228 let column_name = match &col.expr {
230 SqlExpression::Column(col_ref) => &col_ref.name,
231 _ => {
232 return Err(anyhow!("Window function ORDER BY only supports simple columns, not expressions"));
233 }
234 };
235 let idx = table
236 .get_column_index(column_name)
237 .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", column_name))?;
238 let ascending = matches!(col.direction, SortDirection::Asc);
239 Ok((idx, ascending))
240 })
241 .collect::<Result<Vec<_>>>()?;
242
243 rows.sort_by(|&a, &b| {
245 for &(col_idx, ascending) in &sort_cols {
246 let val_a = table.get_value(a, col_idx);
247 let val_b = table.get_value(b, col_idx);
248
249 match (val_a, val_b) {
250 (None, None) => continue,
251 (None, Some(_)) => {
252 return if ascending {
253 std::cmp::Ordering::Less
254 } else {
255 std::cmp::Ordering::Greater
256 }
257 }
258 (Some(_), None) => {
259 return if ascending {
260 std::cmp::Ordering::Greater
261 } else {
262 std::cmp::Ordering::Less
263 }
264 }
265 (Some(v_a), Some(v_b)) => {
266 let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
268 if ord != std::cmp::Ordering::Equal {
269 return if ascending { ord } else { ord.reverse() };
270 }
271 }
272 }
273 }
274 std::cmp::Ordering::Equal
275 });
276
277 Ok(())
278 }
279
280 pub fn get_offset_value(
282 &self,
283 current_row: usize,
284 offset: i32,
285 column: &str,
286 ) -> Option<DataValue> {
287 let partition_key = self.row_to_partition.get(¤t_row)?;
289 let partition = self.partitions.get(partition_key)?;
290
291 let target_row = partition.get_row_at_offset(current_row, offset)?;
293
294 let source_table = self.source.source();
296 let col_idx = source_table.get_column_index(column)?;
297 source_table.get_value(target_row, col_idx).cloned()
298 }
299
300 pub fn get_row_number(&self, row_index: usize) -> usize {
302 if let Some(partition_key) = self.row_to_partition.get(&row_index) {
303 if let Some(partition) = self.partitions.get(partition_key) {
304 if let Some(position) = partition.get_position(row_index) {
305 return position + 1; }
307 }
308 }
309 0 }
311
312 pub fn get_frame_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
314 let frame_rows = self.get_frame_rows(row_index);
315 if frame_rows.is_empty() {
316 return Some(DataValue::Null);
317 }
318
319 let source_table = self.source.source();
320 let col_idx = source_table.get_column_index(column)?;
321
322 let first_row = frame_rows[0];
324 source_table.get_value(first_row, col_idx).cloned()
325 }
326
327 pub fn get_frame_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
329 let frame_rows = self.get_frame_rows(row_index);
330 if frame_rows.is_empty() {
331 return Some(DataValue::Null);
332 }
333
334 let source_table = self.source.source();
335 let col_idx = source_table.get_column_index(column)?;
336
337 let last_row = frame_rows[frame_rows.len() - 1];
339 source_table.get_value(last_row, col_idx).cloned()
340 }
341
342 pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
344 let partition_key = self.row_to_partition.get(&row_index)?;
345 let partition = self.partitions.get(partition_key)?;
346 let first_row = partition.first_row()?;
347
348 let source_table = self.source.source();
349 let col_idx = source_table.get_column_index(column)?;
350 source_table.get_value(first_row, col_idx).cloned()
351 }
352
353 pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
355 let partition_key = self.row_to_partition.get(&row_index)?;
356 let partition = self.partitions.get(partition_key)?;
357 let last_row = partition.last_row()?;
358
359 let source_table = self.source.source();
360 let col_idx = source_table.get_column_index(column)?;
361 source_table.get_value(last_row, col_idx).cloned()
362 }
363
364 pub fn partition_count(&self) -> usize {
366 self.partitions.len()
367 }
368
369 pub fn has_partitions(&self) -> bool {
371 !self.spec.partition_by.is_empty()
372 }
373
374 pub fn has_frame(&self) -> bool {
376 self.spec.frame.is_some()
377 }
378
379 pub fn source(&self) -> &DataTable {
381 self.source.source()
382 }
383
384 pub fn get_frame_rows(&self, row_index: usize) -> Vec<usize> {
386 let partition_key = match self.row_to_partition.get(&row_index) {
388 Some(key) => key,
389 None => return vec![],
390 };
391
392 let partition = match self.partitions.get(partition_key) {
393 Some(p) => p,
394 None => return vec![],
395 };
396
397 let current_pos = match partition.get_position(row_index) {
399 Some(pos) => pos as i64,
400 None => return vec![],
401 };
402
403 let frame = match &self.spec.frame {
405 Some(f) => f,
406 None => return partition.rows.clone(),
407 };
408
409 let (start_pos, end_pos) = match frame.unit {
411 FrameUnit::Rows => {
412 let start =
414 self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
415 let end = match &frame.end {
416 Some(bound) => {
417 self.calculate_frame_position(bound, current_pos, partition.rows.len())
418 }
419 None => current_pos, };
421 (start, end)
422 }
423 FrameUnit::Range => {
424 let start =
427 self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
428 let end = match &frame.end {
429 Some(bound) => {
430 self.calculate_frame_position(bound, current_pos, partition.rows.len())
431 }
432 None => current_pos,
433 };
434 (start, end)
435 }
436 };
437
438 let mut frame_rows = Vec::new();
440 for i in start_pos..=end_pos {
441 if i >= 0 && (i as usize) < partition.rows.len() {
442 frame_rows.push(partition.rows[i as usize]);
443 }
444 }
445
446 frame_rows
447 }
448
449 fn calculate_frame_position(
451 &self,
452 bound: &FrameBound,
453 current_pos: i64,
454 partition_size: usize,
455 ) -> i64 {
456 match bound {
457 FrameBound::UnboundedPreceding => 0,
458 FrameBound::UnboundedFollowing => partition_size as i64 - 1,
459 FrameBound::CurrentRow => current_pos,
460 FrameBound::Preceding(n) => current_pos - n,
461 FrameBound::Following(n) => current_pos + n,
462 }
463 }
464
465 pub fn get_frame_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
467 let frame_rows = self.get_frame_rows(row_index);
468 if frame_rows.is_empty() {
469 return Some(DataValue::Null);
470 }
471
472 let source_table = self.source.source();
473 let col_idx = source_table.get_column_index(column)?;
474
475 let mut sum = 0.0;
476 let mut has_float = false;
477 let mut has_value = false;
478
479 for &row_idx in &frame_rows {
481 if let Some(value) = source_table.get_value(row_idx, col_idx) {
482 match value {
483 DataValue::Integer(i) => {
484 sum += *i as f64;
485 has_value = true;
486 }
487 DataValue::Float(f) => {
488 sum += f;
489 has_float = true;
490 has_value = true;
491 }
492 DataValue::Null => {
493 }
495 _ => {
496 return Some(DataValue::Null);
498 }
499 }
500 }
501 }
502
503 if !has_value {
504 return Some(DataValue::Null);
505 }
506
507 if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
509 Some(DataValue::Integer(sum as i64))
510 } else {
511 Some(DataValue::Float(sum))
512 }
513 }
514
515 pub fn get_frame_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
517 let frame_rows = self.get_frame_rows(row_index);
518 if frame_rows.is_empty() {
519 return Some(DataValue::Integer(0));
520 }
521
522 if let Some(col_name) = column {
523 let source_table = self.source.source();
525 let col_idx = source_table.get_column_index(col_name)?;
526
527 let count = frame_rows
528 .iter()
529 .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
530 .filter(|v| !matches!(v, DataValue::Null))
531 .count();
532
533 Some(DataValue::Integer(count as i64))
534 } else {
535 Some(DataValue::Integer(frame_rows.len() as i64))
537 }
538 }
539
540 pub fn get_frame_avg(&self, row_index: usize, column: &str) -> Option<DataValue> {
542 let frame_rows = self.get_frame_rows(row_index);
543 if frame_rows.is_empty() {
544 return Some(DataValue::Null);
545 }
546
547 let source_table = self.source.source();
548 let col_idx = source_table.get_column_index(column)?;
549
550 let mut sum = 0.0;
551 let mut count = 0;
552
553 for &row_idx in &frame_rows {
555 if let Some(value) = source_table.get_value(row_idx, col_idx) {
556 match value {
557 DataValue::Integer(i) => {
558 sum += *i as f64;
559 count += 1;
560 }
561 DataValue::Float(f) => {
562 sum += f;
563 count += 1;
564 }
565 DataValue::Null => {
566 }
568 _ => {
569 return Some(DataValue::Null);
571 }
572 }
573 }
574 }
575
576 if count == 0 {
577 return Some(DataValue::Null);
578 }
579
580 Some(DataValue::Float(sum / count as f64))
581 }
582
583 pub fn get_frame_stddev(&self, row_index: usize, column: &str) -> Option<DataValue> {
585 let variance = self.get_frame_variance(row_index, column)?;
586 match variance {
587 DataValue::Float(v) => Some(DataValue::Float(v.sqrt())),
588 DataValue::Null => Some(DataValue::Null),
589 _ => Some(DataValue::Null),
590 }
591 }
592
593 pub fn get_frame_variance(&self, row_index: usize, column: &str) -> Option<DataValue> {
595 let frame_rows = self.get_frame_rows(row_index);
596 if frame_rows.is_empty() {
597 return Some(DataValue::Null);
598 }
599
600 let source_table = self.source.source();
601 let col_idx = source_table.get_column_index(column)?;
602
603 let mut values = Vec::new();
604
605 for &row_idx in &frame_rows {
607 if let Some(value) = source_table.get_value(row_idx, col_idx) {
608 match value {
609 DataValue::Integer(i) => values.push(*i as f64),
610 DataValue::Float(f) => values.push(*f),
611 DataValue::Null => {
612 }
614 _ => {
615 return Some(DataValue::Null);
617 }
618 }
619 }
620 }
621
622 if values.is_empty() {
623 return Some(DataValue::Null);
624 }
625
626 if values.len() == 1 {
627 return Some(DataValue::Float(0.0));
629 }
630
631 let mean = values.iter().sum::<f64>() / values.len() as f64;
633
634 let variance =
636 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
637
638 Some(DataValue::Float(variance))
639 }
640
641 pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
643 let partition_key = self.row_to_partition.get(&row_index)?;
644 let partition = self.partitions.get(partition_key)?;
645 let source_table = self.source.source();
646 let col_idx = source_table.get_column_index(column)?;
647
648 let mut sum = 0.0;
649 let mut has_float = false;
650 let mut has_value = false;
651
652 for &row_idx in &partition.rows {
654 if let Some(value) = source_table.get_value(row_idx, col_idx) {
655 match value {
656 DataValue::Integer(i) => {
657 sum += *i as f64;
658 has_value = true;
659 }
660 DataValue::Float(f) => {
661 sum += f;
662 has_float = true;
663 has_value = true;
664 }
665 DataValue::Null => {
666 }
668 _ => {
669 return Some(DataValue::Null);
671 }
672 }
673 }
674 }
675
676 if !has_value {
677 return Some(DataValue::Null);
678 }
679
680 if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
682 Some(DataValue::Integer(sum as i64))
683 } else {
684 Some(DataValue::Float(sum))
685 }
686 }
687
688 pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
690 let partition_key = self.row_to_partition.get(&row_index)?;
691 let partition = self.partitions.get(partition_key)?;
692
693 if let Some(col_name) = column {
694 let source_table = self.source.source();
696 let col_idx = source_table.get_column_index(col_name)?;
697
698 let count = partition
699 .rows
700 .iter()
701 .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
702 .filter(|v| !matches!(v, DataValue::Null))
703 .count();
704
705 Some(DataValue::Integer(count as i64))
706 } else {
707 Some(DataValue::Integer(partition.rows.len() as i64))
709 }
710 }
711}