sql_cli/sql/
window_context.rs

1//! Window function context for managing partitioned and ordered data views
2//!
3//! This module provides the WindowContext helper class that enables window functions
4//! like LAG, LEAD, ROW_NUMBER, etc. by managing partitions and ordering.
5
6use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::parser::ast::{
14    FrameBound, FrameUnit, OrderByItem, SortDirection, SqlExpression, WindowSpec,
15};
16
17/// Key for identifying a partition (combination of partition column values)
18/// We use String representation for now since DataValue doesn't impl Ord
19#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
20struct PartitionKey(String);
21
22impl PartitionKey {
23    /// Create a partition key from data values
24    fn from_values(values: Vec<DataValue>) -> Self {
25        // Create a unique string representation
26        let key_parts: Vec<String> = values
27            .iter()
28            .map(|v| match v {
29                DataValue::String(s) => format!("S:{}", s),
30                DataValue::InternedString(s) => format!("S:{}", s),
31                DataValue::Integer(i) => format!("I:{}", i),
32                DataValue::Float(f) => format!("F:{}", f),
33                DataValue::Boolean(b) => format!("B:{}", b),
34                DataValue::DateTime(dt) => format!("D:{}", dt),
35                DataValue::Null => "N".to_string(),
36            })
37            .collect();
38        let key = key_parts.join("|");
39        PartitionKey(key)
40    }
41}
42
43/// An ordered partition containing row indices
44#[derive(Debug, Clone)]
45pub struct OrderedPartition {
46    /// Original row indices from DataView, in sorted order
47    rows: Vec<usize>,
48
49    /// Quick lookup: row_index -> position in partition
50    row_positions: HashMap<usize, usize>,
51}
52
53impl OrderedPartition {
54    /// Create a new ordered partition from row indices
55    fn new(rows: Vec<usize>) -> Self {
56        // Build position lookup
57        let row_positions: HashMap<usize, usize> = rows
58            .iter()
59            .enumerate()
60            .map(|(pos, &row_idx)| (row_idx, pos))
61            .collect();
62
63        Self {
64            rows,
65            row_positions,
66        }
67    }
68
69    /// Navigate to offset from current position
70    pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
71        let current_pos = self.row_positions.get(&current_row)?;
72        let target_pos = (*current_pos as i32) + offset;
73
74        if target_pos >= 0 && target_pos < self.rows.len() as i32 {
75            Some(self.rows[target_pos as usize])
76        } else {
77            None
78        }
79    }
80
81    /// Get position of row in this partition (0-based)
82    pub fn get_position(&self, row_index: usize) -> Option<usize> {
83        self.row_positions.get(&row_index).copied()
84    }
85
86    /// Get the first row index in this partition
87    pub fn first_row(&self) -> Option<usize> {
88        self.rows.first().copied()
89    }
90
91    /// Get the last row index in this partition
92    pub fn last_row(&self) -> Option<usize> {
93        self.rows.last().copied()
94    }
95}
96
97/// Context for evaluating window functions
98pub struct WindowContext {
99    /// Source data view
100    source: Arc<DataView>,
101
102    /// Partitions with their ordered rows
103    partitions: BTreeMap<PartitionKey, OrderedPartition>,
104
105    /// Mapping from row index to its partition key
106    row_to_partition: HashMap<usize, PartitionKey>,
107
108    /// Window specification
109    spec: WindowSpec,
110}
111
112impl WindowContext {
113    /// Create a new window context with partitioning and ordering
114    pub fn new(
115        view: Arc<DataView>,
116        partition_by: Vec<String>,
117        order_by: Vec<OrderByItem>,
118    ) -> Result<Self> {
119        Self::new_with_spec(
120            view,
121            WindowSpec {
122                partition_by,
123                order_by,
124                frame: None,
125            },
126        )
127    }
128
129    /// Create a new window context with a full window specification
130    pub fn new_with_spec(view: Arc<DataView>, spec: WindowSpec) -> Result<Self> {
131        let partition_by = spec.partition_by.clone();
132        let order_by = spec.order_by.clone();
133
134        // If no partition columns, treat entire view as single partition
135        if partition_by.is_empty() {
136            let single_partition = Self::create_single_partition(&view, &order_by)?;
137            let partition_key = PartitionKey::from_values(vec![]);
138
139            // Build row-to-partition mapping
140            let mut row_to_partition = HashMap::new();
141            for &row_idx in &single_partition.rows {
142                row_to_partition.insert(row_idx, partition_key.clone());
143            }
144
145            let mut partitions = BTreeMap::new();
146            partitions.insert(partition_key, single_partition);
147
148            return Ok(Self {
149                source: view,
150                partitions,
151                row_to_partition,
152                spec,
153            });
154        }
155
156        // Create partitions based on partition_by columns
157        let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
158        let mut row_to_partition = HashMap::new();
159
160        // Get column indices for partition columns
161        let source_table = view.source();
162        let partition_col_indices: Vec<usize> = partition_by
163            .iter()
164            .map(|col| {
165                source_table
166                    .get_column_index(col)
167                    .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
168            })
169            .collect::<Result<Vec<_>>>()?;
170
171        // Group rows by partition key
172        for row_idx in view.get_visible_rows() {
173            // Build partition key from row values
174            let mut key_values = Vec::new();
175            for &col_idx in &partition_col_indices {
176                let value = source_table
177                    .get_value(row_idx, col_idx)
178                    .ok_or_else(|| anyhow!("Failed to get value for partition"))?
179                    .clone();
180                key_values.push(value);
181            }
182            let key = PartitionKey::from_values(key_values);
183
184            // Add row to partition
185            partition_map.entry(key.clone()).or_default().push(row_idx);
186            row_to_partition.insert(row_idx, key);
187        }
188
189        // Sort each partition according to ORDER BY
190        let mut partitions = BTreeMap::new();
191        for (key, mut rows) in partition_map {
192            // Sort rows within partition
193            if !order_by.is_empty() {
194                Self::sort_rows(&mut rows, source_table, &order_by)?;
195            }
196
197            partitions.insert(key, OrderedPartition::new(rows));
198        }
199
200        Ok(Self {
201            source: view,
202            partitions,
203            row_to_partition,
204            spec,
205        })
206    }
207
208    /// Create a single partition from the entire view
209    fn create_single_partition(
210        view: &DataView,
211        order_by: &[OrderByItem],
212    ) -> Result<OrderedPartition> {
213        let mut rows: Vec<usize> = view.get_visible_rows();
214
215        if !order_by.is_empty() {
216            Self::sort_rows(&mut rows, view.source(), order_by)?;
217        }
218
219        Ok(OrderedPartition::new(rows))
220    }
221
222    /// Sort row indices according to ORDER BY specification
223    fn sort_rows(rows: &mut Vec<usize>, table: &DataTable, order_by: &[OrderByItem]) -> Result<()> {
224        // Get column indices for ORDER BY columns
225        let sort_cols: Vec<(usize, bool)> = order_by
226            .iter()
227            .map(|col| {
228                // Extract column name from expression (currently only supports simple columns)
229                let column_name = match &col.expr {
230                    SqlExpression::Column(col_ref) => &col_ref.name,
231                    _ => {
232                        return Err(anyhow!("Window function ORDER BY only supports simple columns, not expressions"));
233                    }
234                };
235                let idx = table
236                    .get_column_index(column_name)
237                    .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", column_name))?;
238                let ascending = matches!(col.direction, SortDirection::Asc);
239                Ok((idx, ascending))
240            })
241            .collect::<Result<Vec<_>>>()?;
242
243        // Sort rows based on column values
244        rows.sort_by(|&a, &b| {
245            for &(col_idx, ascending) in &sort_cols {
246                let val_a = table.get_value(a, col_idx);
247                let val_b = table.get_value(b, col_idx);
248
249                match (val_a, val_b) {
250                    (None, None) => continue,
251                    (None, Some(_)) => {
252                        return if ascending {
253                            std::cmp::Ordering::Less
254                        } else {
255                            std::cmp::Ordering::Greater
256                        }
257                    }
258                    (Some(_), None) => {
259                        return if ascending {
260                            std::cmp::Ordering::Greater
261                        } else {
262                            std::cmp::Ordering::Less
263                        }
264                    }
265                    (Some(v_a), Some(v_b)) => {
266                        // DataValue only implements PartialOrd, not Ord
267                        let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
268                        if ord != std::cmp::Ordering::Equal {
269                            return if ascending { ord } else { ord.reverse() };
270                        }
271                    }
272                }
273            }
274            std::cmp::Ordering::Equal
275        });
276
277        Ok(())
278    }
279
280    /// Get value at offset from current row (for LAG/LEAD)
281    pub fn get_offset_value(
282        &self,
283        current_row: usize,
284        offset: i32,
285        column: &str,
286    ) -> Option<DataValue> {
287        // Find which partition this row belongs to
288        let partition_key = self.row_to_partition.get(&current_row)?;
289        let partition = self.partitions.get(partition_key)?;
290
291        // Navigate to target row
292        let target_row = partition.get_row_at_offset(current_row, offset)?;
293
294        // Get column value from target row
295        let source_table = self.source.source();
296        let col_idx = source_table.get_column_index(column)?;
297        source_table.get_value(target_row, col_idx).cloned()
298    }
299
300    /// Get row number within partition (1-based)
301    pub fn get_row_number(&self, row_index: usize) -> usize {
302        if let Some(partition_key) = self.row_to_partition.get(&row_index) {
303            if let Some(partition) = self.partitions.get(partition_key) {
304                if let Some(position) = partition.get_position(row_index) {
305                    return position + 1; // Convert to 1-based
306                }
307            }
308        }
309        0 // Should not happen for valid row
310    }
311
312    /// Get first value in frame
313    pub fn get_frame_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
314        let frame_rows = self.get_frame_rows(row_index);
315        if frame_rows.is_empty() {
316            return Some(DataValue::Null);
317        }
318
319        let source_table = self.source.source();
320        let col_idx = source_table.get_column_index(column)?;
321
322        // Get the first row in the frame
323        let first_row = frame_rows[0];
324        source_table.get_value(first_row, col_idx).cloned()
325    }
326
327    /// Get last value in frame
328    pub fn get_frame_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
329        let frame_rows = self.get_frame_rows(row_index);
330        if frame_rows.is_empty() {
331            return Some(DataValue::Null);
332        }
333
334        let source_table = self.source.source();
335        let col_idx = source_table.get_column_index(column)?;
336
337        // Get the last row in the frame
338        let last_row = frame_rows[frame_rows.len() - 1];
339        source_table.get_value(last_row, col_idx).cloned()
340    }
341
342    /// Get first value in partition
343    pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
344        let partition_key = self.row_to_partition.get(&row_index)?;
345        let partition = self.partitions.get(partition_key)?;
346        let first_row = partition.first_row()?;
347
348        let source_table = self.source.source();
349        let col_idx = source_table.get_column_index(column)?;
350        source_table.get_value(first_row, col_idx).cloned()
351    }
352
353    /// Get last value in partition
354    pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
355        let partition_key = self.row_to_partition.get(&row_index)?;
356        let partition = self.partitions.get(partition_key)?;
357        let last_row = partition.last_row()?;
358
359        let source_table = self.source.source();
360        let col_idx = source_table.get_column_index(column)?;
361        source_table.get_value(last_row, col_idx).cloned()
362    }
363
364    /// Get the number of partitions
365    pub fn partition_count(&self) -> usize {
366        self.partitions.len()
367    }
368
369    /// Check if context has partitions (vs single window)
370    pub fn has_partitions(&self) -> bool {
371        !self.spec.partition_by.is_empty()
372    }
373
374    /// Check if context has a window frame specification
375    pub fn has_frame(&self) -> bool {
376        self.spec.frame.is_some()
377    }
378
379    /// Get the source DataView
380    pub fn source(&self) -> &DataTable {
381        self.source.source()
382    }
383
384    /// Get row indices within the window frame for a given row
385    pub fn get_frame_rows(&self, row_index: usize) -> Vec<usize> {
386        // Find which partition this row belongs to
387        let partition_key = match self.row_to_partition.get(&row_index) {
388            Some(key) => key,
389            None => return vec![],
390        };
391
392        let partition = match self.partitions.get(partition_key) {
393            Some(p) => p,
394            None => return vec![],
395        };
396
397        // Get current row's position in partition
398        let current_pos = match partition.get_position(row_index) {
399            Some(pos) => pos as i64,
400            None => return vec![],
401        };
402
403        // If no frame specified, return entire partition (default behavior)
404        let frame = match &self.spec.frame {
405            Some(f) => f,
406            None => return partition.rows.clone(),
407        };
408
409        // Calculate frame bounds
410        let (start_pos, end_pos) = match frame.unit {
411            FrameUnit::Rows => {
412                // ROWS frame - based on physical row positions
413                let start =
414                    self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
415                let end = match &frame.end {
416                    Some(bound) => {
417                        self.calculate_frame_position(bound, current_pos, partition.rows.len())
418                    }
419                    None => current_pos, // Default to CURRENT ROW
420                };
421                (start, end)
422            }
423            FrameUnit::Range => {
424                // RANGE frame - based on ORDER BY values (not yet fully implemented)
425                // For now, treat like ROWS
426                let start =
427                    self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
428                let end = match &frame.end {
429                    Some(bound) => {
430                        self.calculate_frame_position(bound, current_pos, partition.rows.len())
431                    }
432                    None => current_pos,
433                };
434                (start, end)
435            }
436        };
437
438        // Collect rows within frame bounds
439        let mut frame_rows = Vec::new();
440        for i in start_pos..=end_pos {
441            if i >= 0 && (i as usize) < partition.rows.len() {
442                frame_rows.push(partition.rows[i as usize]);
443            }
444        }
445
446        frame_rows
447    }
448
449    /// Calculate absolute position from frame bound
450    fn calculate_frame_position(
451        &self,
452        bound: &FrameBound,
453        current_pos: i64,
454        partition_size: usize,
455    ) -> i64 {
456        match bound {
457            FrameBound::UnboundedPreceding => 0,
458            FrameBound::UnboundedFollowing => partition_size as i64 - 1,
459            FrameBound::CurrentRow => current_pos,
460            FrameBound::Preceding(n) => current_pos - n,
461            FrameBound::Following(n) => current_pos + n,
462        }
463    }
464
465    /// Calculate sum of a column within the window frame for the given row
466    pub fn get_frame_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
467        let frame_rows = self.get_frame_rows(row_index);
468        if frame_rows.is_empty() {
469            return Some(DataValue::Null);
470        }
471
472        let source_table = self.source.source();
473        let col_idx = source_table.get_column_index(column)?;
474
475        let mut sum = 0.0;
476        let mut has_float = false;
477        let mut has_value = false;
478
479        // Sum all values in the frame
480        for &row_idx in &frame_rows {
481            if let Some(value) = source_table.get_value(row_idx, col_idx) {
482                match value {
483                    DataValue::Integer(i) => {
484                        sum += *i as f64;
485                        has_value = true;
486                    }
487                    DataValue::Float(f) => {
488                        sum += f;
489                        has_float = true;
490                        has_value = true;
491                    }
492                    DataValue::Null => {
493                        // Skip NULL values
494                    }
495                    _ => {
496                        // Non-numeric values - return NULL
497                        return Some(DataValue::Null);
498                    }
499                }
500            }
501        }
502
503        if !has_value {
504            return Some(DataValue::Null);
505        }
506
507        // Return as integer if all values were integers and sum is whole
508        if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
509            Some(DataValue::Integer(sum as i64))
510        } else {
511            Some(DataValue::Float(sum))
512        }
513    }
514
515    /// Calculate count within the window frame
516    pub fn get_frame_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
517        let frame_rows = self.get_frame_rows(row_index);
518        if frame_rows.is_empty() {
519            return Some(DataValue::Integer(0));
520        }
521
522        if let Some(col_name) = column {
523            // COUNT(column) - count non-null values in frame
524            let source_table = self.source.source();
525            let col_idx = source_table.get_column_index(col_name)?;
526
527            let count = frame_rows
528                .iter()
529                .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
530                .filter(|v| !matches!(v, DataValue::Null))
531                .count();
532
533            Some(DataValue::Integer(count as i64))
534        } else {
535            // COUNT(*) - count all rows in frame
536            Some(DataValue::Integer(frame_rows.len() as i64))
537        }
538    }
539
540    /// Calculate average of a column within the window frame
541    pub fn get_frame_avg(&self, row_index: usize, column: &str) -> Option<DataValue> {
542        let frame_rows = self.get_frame_rows(row_index);
543        if frame_rows.is_empty() {
544            return Some(DataValue::Null);
545        }
546
547        let source_table = self.source.source();
548        let col_idx = source_table.get_column_index(column)?;
549
550        let mut sum = 0.0;
551        let mut count = 0;
552
553        // Sum all non-null values in the frame
554        for &row_idx in &frame_rows {
555            if let Some(value) = source_table.get_value(row_idx, col_idx) {
556                match value {
557                    DataValue::Integer(i) => {
558                        sum += *i as f64;
559                        count += 1;
560                    }
561                    DataValue::Float(f) => {
562                        sum += f;
563                        count += 1;
564                    }
565                    DataValue::Null => {
566                        // Skip NULL values
567                    }
568                    _ => {
569                        // Non-numeric values - return NULL
570                        return Some(DataValue::Null);
571                    }
572                }
573            }
574        }
575
576        if count == 0 {
577            return Some(DataValue::Null);
578        }
579
580        Some(DataValue::Float(sum / count as f64))
581    }
582
583    /// Calculate standard deviation within the window frame (sample stddev)
584    pub fn get_frame_stddev(&self, row_index: usize, column: &str) -> Option<DataValue> {
585        let variance = self.get_frame_variance(row_index, column)?;
586        match variance {
587            DataValue::Float(v) => Some(DataValue::Float(v.sqrt())),
588            DataValue::Null => Some(DataValue::Null),
589            _ => Some(DataValue::Null),
590        }
591    }
592
593    /// Calculate variance within the window frame (sample variance with n-1)
594    pub fn get_frame_variance(&self, row_index: usize, column: &str) -> Option<DataValue> {
595        let frame_rows = self.get_frame_rows(row_index);
596        if frame_rows.is_empty() {
597            return Some(DataValue::Null);
598        }
599
600        let source_table = self.source.source();
601        let col_idx = source_table.get_column_index(column)?;
602
603        let mut values = Vec::new();
604
605        // Collect all non-null values in the frame
606        for &row_idx in &frame_rows {
607            if let Some(value) = source_table.get_value(row_idx, col_idx) {
608                match value {
609                    DataValue::Integer(i) => values.push(*i as f64),
610                    DataValue::Float(f) => values.push(*f),
611                    DataValue::Null => {
612                        // Skip NULL values
613                    }
614                    _ => {
615                        // Non-numeric values - return NULL
616                        return Some(DataValue::Null);
617                    }
618                }
619            }
620        }
621
622        if values.is_empty() {
623            return Some(DataValue::Null);
624        }
625
626        if values.len() == 1 {
627            // Variance of single value is 0
628            return Some(DataValue::Float(0.0));
629        }
630
631        // Calculate mean
632        let mean = values.iter().sum::<f64>() / values.len() as f64;
633
634        // Calculate sample variance (n-1 denominator)
635        let variance =
636            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
637
638        Some(DataValue::Float(variance))
639    }
640
641    /// Calculate sum of a column over the partition containing the given row
642    pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
643        let partition_key = self.row_to_partition.get(&row_index)?;
644        let partition = self.partitions.get(partition_key)?;
645        let source_table = self.source.source();
646        let col_idx = source_table.get_column_index(column)?;
647
648        let mut sum = 0.0;
649        let mut has_float = false;
650        let mut has_value = false;
651
652        // Sum all values in the partition
653        for &row_idx in &partition.rows {
654            if let Some(value) = source_table.get_value(row_idx, col_idx) {
655                match value {
656                    DataValue::Integer(i) => {
657                        sum += *i as f64;
658                        has_value = true;
659                    }
660                    DataValue::Float(f) => {
661                        sum += f;
662                        has_float = true;
663                        has_value = true;
664                    }
665                    DataValue::Null => {
666                        // Skip NULL values
667                    }
668                    _ => {
669                        // Non-numeric values - return NULL
670                        return Some(DataValue::Null);
671                    }
672                }
673            }
674        }
675
676        if !has_value {
677            return Some(DataValue::Null);
678        }
679
680        // Return as integer if all values were integers and sum is whole
681        if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
682            Some(DataValue::Integer(sum as i64))
683        } else {
684            Some(DataValue::Float(sum))
685        }
686    }
687
688    /// Calculate count of non-null values in a column over the partition
689    pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
690        let partition_key = self.row_to_partition.get(&row_index)?;
691        let partition = self.partitions.get(partition_key)?;
692
693        if let Some(col_name) = column {
694            // COUNT(column) - count non-null values
695            let source_table = self.source.source();
696            let col_idx = source_table.get_column_index(col_name)?;
697
698            let count = partition
699                .rows
700                .iter()
701                .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
702                .filter(|v| !matches!(v, DataValue::Null))
703                .count();
704
705            Some(DataValue::Integer(count as i64))
706        } else {
707            // COUNT(*) - count all rows in partition
708            Some(DataValue::Integer(partition.rows.len() as i64))
709        }
710    }
711}