sql_cli/sql/
window_context.rs

1//! Window function context for managing partitioned and ordered data views
2//!
3//! This module provides the WindowContext helper class that enables window functions
4//! like LAG, LEAD, ROW_NUMBER, etc. by managing partitions and ordering.
5
6use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::parser::ast::{
14    FrameBound, FrameUnit, OrderByColumn, SortDirection, WindowFrame, WindowSpec,
15};
16
17/// Key for identifying a partition (combination of partition column values)
18/// We use String representation for now since DataValue doesn't impl Ord
19#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
20struct PartitionKey(String);
21
22impl PartitionKey {
23    /// Create a partition key from data values
24    fn from_values(values: Vec<DataValue>) -> Self {
25        // Create a unique string representation
26        let key_parts: Vec<String> = values
27            .iter()
28            .map(|v| match v {
29                DataValue::String(s) => format!("S:{}", s),
30                DataValue::InternedString(s) => format!("S:{}", s),
31                DataValue::Integer(i) => format!("I:{}", i),
32                DataValue::Float(f) => format!("F:{}", f),
33                DataValue::Boolean(b) => format!("B:{}", b),
34                DataValue::DateTime(dt) => format!("D:{}", dt),
35                DataValue::Null => "N".to_string(),
36            })
37            .collect();
38        let key = key_parts.join("|");
39        PartitionKey(key)
40    }
41}
42
43/// An ordered partition containing row indices
44#[derive(Debug, Clone)]
45pub struct OrderedPartition {
46    /// Original row indices from DataView, in sorted order
47    rows: Vec<usize>,
48
49    /// Quick lookup: row_index -> position in partition
50    row_positions: HashMap<usize, usize>,
51}
52
53impl OrderedPartition {
54    /// Create a new ordered partition from row indices
55    fn new(mut rows: Vec<usize>) -> Self {
56        // Build position lookup
57        let row_positions: HashMap<usize, usize> = rows
58            .iter()
59            .enumerate()
60            .map(|(pos, &row_idx)| (row_idx, pos))
61            .collect();
62
63        Self {
64            rows,
65            row_positions,
66        }
67    }
68
69    /// Navigate to offset from current position
70    pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
71        let current_pos = self.row_positions.get(&current_row)?;
72        let target_pos = (*current_pos as i32) + offset;
73
74        if target_pos >= 0 && target_pos < self.rows.len() as i32 {
75            Some(self.rows[target_pos as usize])
76        } else {
77            None
78        }
79    }
80
81    /// Get position of row in this partition (0-based)
82    pub fn get_position(&self, row_index: usize) -> Option<usize> {
83        self.row_positions.get(&row_index).copied()
84    }
85
86    /// Get the first row index in this partition
87    pub fn first_row(&self) -> Option<usize> {
88        self.rows.first().copied()
89    }
90
91    /// Get the last row index in this partition
92    pub fn last_row(&self) -> Option<usize> {
93        self.rows.last().copied()
94    }
95}
96
97/// Context for evaluating window functions
98pub struct WindowContext {
99    /// Source data view
100    source: Arc<DataView>,
101
102    /// Partitions with their ordered rows
103    partitions: BTreeMap<PartitionKey, OrderedPartition>,
104
105    /// Mapping from row index to its partition key
106    row_to_partition: HashMap<usize, PartitionKey>,
107
108    /// Window specification
109    spec: WindowSpec,
110}
111
112impl WindowContext {
113    /// Create a new window context with partitioning and ordering
114    pub fn new(
115        view: Arc<DataView>,
116        partition_by: Vec<String>,
117        order_by: Vec<OrderByColumn>,
118    ) -> Result<Self> {
119        Self::new_with_spec(
120            view,
121            WindowSpec {
122                partition_by,
123                order_by,
124                frame: None,
125            },
126        )
127    }
128
129    /// Create a new window context with a full window specification
130    pub fn new_with_spec(view: Arc<DataView>, spec: WindowSpec) -> Result<Self> {
131        let partition_by = spec.partition_by.clone();
132        let order_by = spec.order_by.clone();
133
134        // If no partition columns, treat entire view as single partition
135        if partition_by.is_empty() {
136            let single_partition = Self::create_single_partition(&view, &order_by)?;
137            let partition_key = PartitionKey::from_values(vec![]);
138
139            // Build row-to-partition mapping
140            let mut row_to_partition = HashMap::new();
141            for &row_idx in &single_partition.rows {
142                row_to_partition.insert(row_idx, partition_key.clone());
143            }
144
145            let mut partitions = BTreeMap::new();
146            partitions.insert(partition_key, single_partition);
147
148            return Ok(Self {
149                source: view,
150                partitions,
151                row_to_partition,
152                spec,
153            });
154        }
155
156        // Create partitions based on partition_by columns
157        let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
158        let mut row_to_partition = HashMap::new();
159
160        // Get column indices for partition columns
161        let source_table = view.source();
162        let partition_col_indices: Vec<usize> = partition_by
163            .iter()
164            .map(|col| {
165                source_table
166                    .get_column_index(col)
167                    .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
168            })
169            .collect::<Result<Vec<_>>>()?;
170
171        // Group rows by partition key
172        for row_idx in view.get_visible_rows() {
173            // Build partition key from row values
174            let mut key_values = Vec::new();
175            for &col_idx in &partition_col_indices {
176                let value = source_table
177                    .get_value(row_idx, col_idx)
178                    .ok_or_else(|| anyhow!("Failed to get value for partition"))?
179                    .clone();
180                key_values.push(value);
181            }
182            let key = PartitionKey::from_values(key_values);
183
184            // Add row to partition
185            partition_map.entry(key.clone()).or_default().push(row_idx);
186            row_to_partition.insert(row_idx, key);
187        }
188
189        // Sort each partition according to ORDER BY
190        let mut partitions = BTreeMap::new();
191        for (key, mut rows) in partition_map {
192            // Sort rows within partition
193            if !order_by.is_empty() {
194                Self::sort_rows(&mut rows, source_table, &order_by)?;
195            }
196
197            partitions.insert(key, OrderedPartition::new(rows));
198        }
199
200        Ok(Self {
201            source: view,
202            partitions,
203            row_to_partition,
204            spec,
205        })
206    }
207
208    /// Create a single partition from the entire view
209    fn create_single_partition(
210        view: &DataView,
211        order_by: &[OrderByColumn],
212    ) -> Result<OrderedPartition> {
213        let mut rows: Vec<usize> = view.get_visible_rows();
214
215        if !order_by.is_empty() {
216            Self::sort_rows(&mut rows, view.source(), order_by)?;
217        }
218
219        Ok(OrderedPartition::new(rows))
220    }
221
222    /// Sort row indices according to ORDER BY specification
223    fn sort_rows(
224        rows: &mut Vec<usize>,
225        table: &DataTable,
226        order_by: &[OrderByColumn],
227    ) -> Result<()> {
228        // Get column indices for ORDER BY columns
229        let sort_cols: Vec<(usize, bool)> = order_by
230            .iter()
231            .map(|col| {
232                let idx = table
233                    .get_column_index(&col.column)
234                    .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", col.column))?;
235                let ascending = matches!(col.direction, SortDirection::Asc);
236                Ok((idx, ascending))
237            })
238            .collect::<Result<Vec<_>>>()?;
239
240        // Sort rows based on column values
241        rows.sort_by(|&a, &b| {
242            for &(col_idx, ascending) in &sort_cols {
243                let val_a = table.get_value(a, col_idx);
244                let val_b = table.get_value(b, col_idx);
245
246                match (val_a, val_b) {
247                    (None, None) => continue,
248                    (None, Some(_)) => {
249                        return if ascending {
250                            std::cmp::Ordering::Less
251                        } else {
252                            std::cmp::Ordering::Greater
253                        }
254                    }
255                    (Some(_), None) => {
256                        return if ascending {
257                            std::cmp::Ordering::Greater
258                        } else {
259                            std::cmp::Ordering::Less
260                        }
261                    }
262                    (Some(v_a), Some(v_b)) => {
263                        // DataValue only implements PartialOrd, not Ord
264                        let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
265                        if ord != std::cmp::Ordering::Equal {
266                            return if ascending { ord } else { ord.reverse() };
267                        }
268                    }
269                }
270            }
271            std::cmp::Ordering::Equal
272        });
273
274        Ok(())
275    }
276
277    /// Get value at offset from current row (for LAG/LEAD)
278    pub fn get_offset_value(
279        &self,
280        current_row: usize,
281        offset: i32,
282        column: &str,
283    ) -> Option<DataValue> {
284        // Find which partition this row belongs to
285        let partition_key = self.row_to_partition.get(&current_row)?;
286        let partition = self.partitions.get(partition_key)?;
287
288        // Navigate to target row
289        let target_row = partition.get_row_at_offset(current_row, offset)?;
290
291        // Get column value from target row
292        let source_table = self.source.source();
293        let col_idx = source_table.get_column_index(column)?;
294        source_table.get_value(target_row, col_idx).cloned()
295    }
296
297    /// Get row number within partition (1-based)
298    pub fn get_row_number(&self, row_index: usize) -> usize {
299        if let Some(partition_key) = self.row_to_partition.get(&row_index) {
300            if let Some(partition) = self.partitions.get(partition_key) {
301                if let Some(position) = partition.get_position(row_index) {
302                    return position + 1; // Convert to 1-based
303                }
304            }
305        }
306        0 // Should not happen for valid row
307    }
308
309    /// Get first value in frame
310    pub fn get_frame_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
311        let frame_rows = self.get_frame_rows(row_index);
312        if frame_rows.is_empty() {
313            return Some(DataValue::Null);
314        }
315
316        let source_table = self.source.source();
317        let col_idx = source_table.get_column_index(column)?;
318
319        // Get the first row in the frame
320        let first_row = frame_rows[0];
321        source_table.get_value(first_row, col_idx).cloned()
322    }
323
324    /// Get last value in frame
325    pub fn get_frame_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
326        let frame_rows = self.get_frame_rows(row_index);
327        if frame_rows.is_empty() {
328            return Some(DataValue::Null);
329        }
330
331        let source_table = self.source.source();
332        let col_idx = source_table.get_column_index(column)?;
333
334        // Get the last row in the frame
335        let last_row = frame_rows[frame_rows.len() - 1];
336        source_table.get_value(last_row, col_idx).cloned()
337    }
338
339    /// Get first value in partition
340    pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
341        let partition_key = self.row_to_partition.get(&row_index)?;
342        let partition = self.partitions.get(partition_key)?;
343        let first_row = partition.first_row()?;
344
345        let source_table = self.source.source();
346        let col_idx = source_table.get_column_index(column)?;
347        source_table.get_value(first_row, col_idx).cloned()
348    }
349
350    /// Get last value in partition
351    pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
352        let partition_key = self.row_to_partition.get(&row_index)?;
353        let partition = self.partitions.get(partition_key)?;
354        let last_row = partition.last_row()?;
355
356        let source_table = self.source.source();
357        let col_idx = source_table.get_column_index(column)?;
358        source_table.get_value(last_row, col_idx).cloned()
359    }
360
361    /// Get the number of partitions
362    pub fn partition_count(&self) -> usize {
363        self.partitions.len()
364    }
365
366    /// Check if context has partitions (vs single window)
367    pub fn has_partitions(&self) -> bool {
368        !self.spec.partition_by.is_empty()
369    }
370
371    /// Check if context has a window frame specification
372    pub fn has_frame(&self) -> bool {
373        self.spec.frame.is_some()
374    }
375
376    /// Get the source DataView
377    pub fn source(&self) -> &DataTable {
378        self.source.source()
379    }
380
381    /// Get row indices within the window frame for a given row
382    pub fn get_frame_rows(&self, row_index: usize) -> Vec<usize> {
383        // Find which partition this row belongs to
384        let partition_key = match self.row_to_partition.get(&row_index) {
385            Some(key) => key,
386            None => return vec![],
387        };
388
389        let partition = match self.partitions.get(partition_key) {
390            Some(p) => p,
391            None => return vec![],
392        };
393
394        // Get current row's position in partition
395        let current_pos = match partition.get_position(row_index) {
396            Some(pos) => pos as i64,
397            None => return vec![],
398        };
399
400        // If no frame specified, return entire partition (default behavior)
401        let frame = match &self.spec.frame {
402            Some(f) => f,
403            None => return partition.rows.clone(),
404        };
405
406        // Calculate frame bounds
407        let (start_pos, end_pos) = match frame.unit {
408            FrameUnit::Rows => {
409                // ROWS frame - based on physical row positions
410                let start =
411                    self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
412                let end = match &frame.end {
413                    Some(bound) => {
414                        self.calculate_frame_position(bound, current_pos, partition.rows.len())
415                    }
416                    None => current_pos, // Default to CURRENT ROW
417                };
418                (start, end)
419            }
420            FrameUnit::Range => {
421                // RANGE frame - based on ORDER BY values (not yet fully implemented)
422                // For now, treat like ROWS
423                let start =
424                    self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
425                let end = match &frame.end {
426                    Some(bound) => {
427                        self.calculate_frame_position(bound, current_pos, partition.rows.len())
428                    }
429                    None => current_pos,
430                };
431                (start, end)
432            }
433        };
434
435        // Collect rows within frame bounds
436        let mut frame_rows = Vec::new();
437        for i in start_pos..=end_pos {
438            if i >= 0 && (i as usize) < partition.rows.len() {
439                frame_rows.push(partition.rows[i as usize]);
440            }
441        }
442
443        frame_rows
444    }
445
446    /// Calculate absolute position from frame bound
447    fn calculate_frame_position(
448        &self,
449        bound: &FrameBound,
450        current_pos: i64,
451        partition_size: usize,
452    ) -> i64 {
453        match bound {
454            FrameBound::UnboundedPreceding => 0,
455            FrameBound::UnboundedFollowing => partition_size as i64 - 1,
456            FrameBound::CurrentRow => current_pos,
457            FrameBound::Preceding(n) => current_pos - n,
458            FrameBound::Following(n) => current_pos + n,
459        }
460    }
461
462    /// Calculate sum of a column within the window frame for the given row
463    pub fn get_frame_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
464        let frame_rows = self.get_frame_rows(row_index);
465        if frame_rows.is_empty() {
466            return Some(DataValue::Null);
467        }
468
469        let source_table = self.source.source();
470        let col_idx = source_table.get_column_index(column)?;
471
472        let mut sum = 0.0;
473        let mut has_float = false;
474        let mut has_value = false;
475
476        // Sum all values in the frame
477        for &row_idx in &frame_rows {
478            if let Some(value) = source_table.get_value(row_idx, col_idx) {
479                match value {
480                    DataValue::Integer(i) => {
481                        sum += *i as f64;
482                        has_value = true;
483                    }
484                    DataValue::Float(f) => {
485                        sum += f;
486                        has_float = true;
487                        has_value = true;
488                    }
489                    DataValue::Null => {
490                        // Skip NULL values
491                    }
492                    _ => {
493                        // Non-numeric values - return NULL
494                        return Some(DataValue::Null);
495                    }
496                }
497            }
498        }
499
500        if !has_value {
501            return Some(DataValue::Null);
502        }
503
504        // Return as integer if all values were integers and sum is whole
505        if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
506            Some(DataValue::Integer(sum as i64))
507        } else {
508            Some(DataValue::Float(sum))
509        }
510    }
511
512    /// Calculate count within the window frame
513    pub fn get_frame_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
514        let frame_rows = self.get_frame_rows(row_index);
515        if frame_rows.is_empty() {
516            return Some(DataValue::Integer(0));
517        }
518
519        if let Some(col_name) = column {
520            // COUNT(column) - count non-null values in frame
521            let source_table = self.source.source();
522            let col_idx = source_table.get_column_index(col_name)?;
523
524            let count = frame_rows
525                .iter()
526                .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
527                .filter(|v| !matches!(v, DataValue::Null))
528                .count();
529
530            Some(DataValue::Integer(count as i64))
531        } else {
532            // COUNT(*) - count all rows in frame
533            Some(DataValue::Integer(frame_rows.len() as i64))
534        }
535    }
536
537    /// Calculate average of a column within the window frame
538    pub fn get_frame_avg(&self, row_index: usize, column: &str) -> Option<DataValue> {
539        let frame_rows = self.get_frame_rows(row_index);
540        if frame_rows.is_empty() {
541            return Some(DataValue::Null);
542        }
543
544        let source_table = self.source.source();
545        let col_idx = source_table.get_column_index(column)?;
546
547        let mut sum = 0.0;
548        let mut count = 0;
549
550        // Sum all non-null values in the frame
551        for &row_idx in &frame_rows {
552            if let Some(value) = source_table.get_value(row_idx, col_idx) {
553                match value {
554                    DataValue::Integer(i) => {
555                        sum += *i as f64;
556                        count += 1;
557                    }
558                    DataValue::Float(f) => {
559                        sum += f;
560                        count += 1;
561                    }
562                    DataValue::Null => {
563                        // Skip NULL values
564                    }
565                    _ => {
566                        // Non-numeric values - return NULL
567                        return Some(DataValue::Null);
568                    }
569                }
570            }
571        }
572
573        if count == 0 {
574            return Some(DataValue::Null);
575        }
576
577        Some(DataValue::Float(sum / count as f64))
578    }
579
580    /// Calculate standard deviation within the window frame (sample stddev)
581    pub fn get_frame_stddev(&self, row_index: usize, column: &str) -> Option<DataValue> {
582        let variance = self.get_frame_variance(row_index, column)?;
583        match variance {
584            DataValue::Float(v) => Some(DataValue::Float(v.sqrt())),
585            DataValue::Null => Some(DataValue::Null),
586            _ => Some(DataValue::Null),
587        }
588    }
589
590    /// Calculate variance within the window frame (sample variance with n-1)
591    pub fn get_frame_variance(&self, row_index: usize, column: &str) -> Option<DataValue> {
592        let frame_rows = self.get_frame_rows(row_index);
593        if frame_rows.is_empty() {
594            return Some(DataValue::Null);
595        }
596
597        let source_table = self.source.source();
598        let col_idx = source_table.get_column_index(column)?;
599
600        let mut values = Vec::new();
601
602        // Collect all non-null values in the frame
603        for &row_idx in &frame_rows {
604            if let Some(value) = source_table.get_value(row_idx, col_idx) {
605                match value {
606                    DataValue::Integer(i) => values.push(*i as f64),
607                    DataValue::Float(f) => values.push(*f),
608                    DataValue::Null => {
609                        // Skip NULL values
610                    }
611                    _ => {
612                        // Non-numeric values - return NULL
613                        return Some(DataValue::Null);
614                    }
615                }
616            }
617        }
618
619        if values.is_empty() {
620            return Some(DataValue::Null);
621        }
622
623        if values.len() == 1 {
624            // Variance of single value is 0
625            return Some(DataValue::Float(0.0));
626        }
627
628        // Calculate mean
629        let mean = values.iter().sum::<f64>() / values.len() as f64;
630
631        // Calculate sample variance (n-1 denominator)
632        let variance =
633            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
634
635        Some(DataValue::Float(variance))
636    }
637
638    /// Calculate sum of a column over the partition containing the given row
639    pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
640        let partition_key = self.row_to_partition.get(&row_index)?;
641        let partition = self.partitions.get(partition_key)?;
642        let source_table = self.source.source();
643        let col_idx = source_table.get_column_index(column)?;
644
645        let mut sum = 0.0;
646        let mut has_float = false;
647        let mut has_value = false;
648
649        // Sum all values in the partition
650        for &row_idx in &partition.rows {
651            if let Some(value) = source_table.get_value(row_idx, col_idx) {
652                match value {
653                    DataValue::Integer(i) => {
654                        sum += *i as f64;
655                        has_value = true;
656                    }
657                    DataValue::Float(f) => {
658                        sum += f;
659                        has_float = true;
660                        has_value = true;
661                    }
662                    DataValue::Null => {
663                        // Skip NULL values
664                    }
665                    _ => {
666                        // Non-numeric values - return NULL
667                        return Some(DataValue::Null);
668                    }
669                }
670            }
671        }
672
673        if !has_value {
674            return Some(DataValue::Null);
675        }
676
677        // Return as integer if all values were integers and sum is whole
678        if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
679            Some(DataValue::Integer(sum as i64))
680        } else {
681            Some(DataValue::Float(sum))
682        }
683    }
684
685    /// Calculate count of non-null values in a column over the partition
686    pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
687        let partition_key = self.row_to_partition.get(&row_index)?;
688        let partition = self.partitions.get(partition_key)?;
689
690        if let Some(col_name) = column {
691            // COUNT(column) - count non-null values
692            let source_table = self.source.source();
693            let col_idx = source_table.get_column_index(col_name)?;
694
695            let count = partition
696                .rows
697                .iter()
698                .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
699                .filter(|v| !matches!(v, DataValue::Null))
700                .count();
701
702            Some(DataValue::Integer(count as i64))
703        } else {
704            // COUNT(*) - count all rows in partition
705            Some(DataValue::Integer(partition.rows.len() as i64))
706        }
707    }
708}