sql_cli/sql/
window_context.rs

1//! Window function context for managing partitioned and ordered data views
2//!
3//! This module provides the WindowContext helper class that enables window functions
4//! like LAG, LEAD, ROW_NUMBER, etc. by managing partitions and ordering.
5
6use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::parser::ast::{FrameBound, FrameUnit, OrderByColumn, SortDirection, WindowSpec};
14
15/// Key for identifying a partition (combination of partition column values)
16/// We use String representation for now since DataValue doesn't impl Ord
17#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
18struct PartitionKey(String);
19
20impl PartitionKey {
21    /// Create a partition key from data values
22    fn from_values(values: Vec<DataValue>) -> Self {
23        // Create a unique string representation
24        let key_parts: Vec<String> = values
25            .iter()
26            .map(|v| match v {
27                DataValue::String(s) => format!("S:{}", s),
28                DataValue::InternedString(s) => format!("S:{}", s),
29                DataValue::Integer(i) => format!("I:{}", i),
30                DataValue::Float(f) => format!("F:{}", f),
31                DataValue::Boolean(b) => format!("B:{}", b),
32                DataValue::DateTime(dt) => format!("D:{}", dt),
33                DataValue::Null => "N".to_string(),
34            })
35            .collect();
36        let key = key_parts.join("|");
37        PartitionKey(key)
38    }
39}
40
41/// An ordered partition containing row indices
42#[derive(Debug, Clone)]
43pub struct OrderedPartition {
44    /// Original row indices from DataView, in sorted order
45    rows: Vec<usize>,
46
47    /// Quick lookup: row_index -> position in partition
48    row_positions: HashMap<usize, usize>,
49}
50
51impl OrderedPartition {
52    /// Create a new ordered partition from row indices
53    fn new(mut rows: Vec<usize>) -> Self {
54        // Build position lookup
55        let row_positions: HashMap<usize, usize> = rows
56            .iter()
57            .enumerate()
58            .map(|(pos, &row_idx)| (row_idx, pos))
59            .collect();
60
61        Self {
62            rows,
63            row_positions,
64        }
65    }
66
67    /// Navigate to offset from current position
68    pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
69        let current_pos = self.row_positions.get(&current_row)?;
70        let target_pos = (*current_pos as i32) + offset;
71
72        if target_pos >= 0 && target_pos < self.rows.len() as i32 {
73            Some(self.rows[target_pos as usize])
74        } else {
75            None
76        }
77    }
78
79    /// Get position of row in this partition (0-based)
80    pub fn get_position(&self, row_index: usize) -> Option<usize> {
81        self.row_positions.get(&row_index).copied()
82    }
83
84    /// Get the first row index in this partition
85    pub fn first_row(&self) -> Option<usize> {
86        self.rows.first().copied()
87    }
88
89    /// Get the last row index in this partition
90    pub fn last_row(&self) -> Option<usize> {
91        self.rows.last().copied()
92    }
93}
94
95/// Context for evaluating window functions
96pub struct WindowContext {
97    /// Source data view
98    source: Arc<DataView>,
99
100    /// Partitions with their ordered rows
101    partitions: BTreeMap<PartitionKey, OrderedPartition>,
102
103    /// Mapping from row index to its partition key
104    row_to_partition: HashMap<usize, PartitionKey>,
105
106    /// Window specification
107    spec: WindowSpec,
108}
109
110impl WindowContext {
111    /// Create a new window context with partitioning and ordering
112    pub fn new(
113        view: Arc<DataView>,
114        partition_by: Vec<String>,
115        order_by: Vec<OrderByColumn>,
116    ) -> Result<Self> {
117        Self::new_with_spec(
118            view,
119            WindowSpec {
120                partition_by,
121                order_by,
122                frame: None,
123            },
124        )
125    }
126
127    /// Create a new window context with a full window specification
128    pub fn new_with_spec(view: Arc<DataView>, spec: WindowSpec) -> Result<Self> {
129        let partition_by = spec.partition_by.clone();
130        let order_by = spec.order_by.clone();
131
132        // If no partition columns, treat entire view as single partition
133        if partition_by.is_empty() {
134            let single_partition = Self::create_single_partition(&view, &order_by)?;
135            let partition_key = PartitionKey::from_values(vec![]);
136
137            // Build row-to-partition mapping
138            let mut row_to_partition = HashMap::new();
139            for &row_idx in &single_partition.rows {
140                row_to_partition.insert(row_idx, partition_key.clone());
141            }
142
143            let mut partitions = BTreeMap::new();
144            partitions.insert(partition_key, single_partition);
145
146            return Ok(Self {
147                source: view,
148                partitions,
149                row_to_partition,
150                spec,
151            });
152        }
153
154        // Create partitions based on partition_by columns
155        let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
156        let mut row_to_partition = HashMap::new();
157
158        // Get column indices for partition columns
159        let source_table = view.source();
160        let partition_col_indices: Vec<usize> = partition_by
161            .iter()
162            .map(|col| {
163                source_table
164                    .get_column_index(col)
165                    .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
166            })
167            .collect::<Result<Vec<_>>>()?;
168
169        // Group rows by partition key
170        for row_idx in view.get_visible_rows() {
171            // Build partition key from row values
172            let mut key_values = Vec::new();
173            for &col_idx in &partition_col_indices {
174                let value = source_table
175                    .get_value(row_idx, col_idx)
176                    .ok_or_else(|| anyhow!("Failed to get value for partition"))?
177                    .clone();
178                key_values.push(value);
179            }
180            let key = PartitionKey::from_values(key_values);
181
182            // Add row to partition
183            partition_map.entry(key.clone()).or_default().push(row_idx);
184            row_to_partition.insert(row_idx, key);
185        }
186
187        // Sort each partition according to ORDER BY
188        let mut partitions = BTreeMap::new();
189        for (key, mut rows) in partition_map {
190            // Sort rows within partition
191            if !order_by.is_empty() {
192                Self::sort_rows(&mut rows, source_table, &order_by)?;
193            }
194
195            partitions.insert(key, OrderedPartition::new(rows));
196        }
197
198        Ok(Self {
199            source: view,
200            partitions,
201            row_to_partition,
202            spec,
203        })
204    }
205
206    /// Create a single partition from the entire view
207    fn create_single_partition(
208        view: &DataView,
209        order_by: &[OrderByColumn],
210    ) -> Result<OrderedPartition> {
211        let mut rows: Vec<usize> = view.get_visible_rows();
212
213        if !order_by.is_empty() {
214            Self::sort_rows(&mut rows, view.source(), order_by)?;
215        }
216
217        Ok(OrderedPartition::new(rows))
218    }
219
220    /// Sort row indices according to ORDER BY specification
221    fn sort_rows(
222        rows: &mut Vec<usize>,
223        table: &DataTable,
224        order_by: &[OrderByColumn],
225    ) -> Result<()> {
226        // Get column indices for ORDER BY columns
227        let sort_cols: Vec<(usize, bool)> = order_by
228            .iter()
229            .map(|col| {
230                let idx = table
231                    .get_column_index(&col.column)
232                    .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", col.column))?;
233                let ascending = matches!(col.direction, SortDirection::Asc);
234                Ok((idx, ascending))
235            })
236            .collect::<Result<Vec<_>>>()?;
237
238        // Sort rows based on column values
239        rows.sort_by(|&a, &b| {
240            for &(col_idx, ascending) in &sort_cols {
241                let val_a = table.get_value(a, col_idx);
242                let val_b = table.get_value(b, col_idx);
243
244                match (val_a, val_b) {
245                    (None, None) => continue,
246                    (None, Some(_)) => {
247                        return if ascending {
248                            std::cmp::Ordering::Less
249                        } else {
250                            std::cmp::Ordering::Greater
251                        }
252                    }
253                    (Some(_), None) => {
254                        return if ascending {
255                            std::cmp::Ordering::Greater
256                        } else {
257                            std::cmp::Ordering::Less
258                        }
259                    }
260                    (Some(v_a), Some(v_b)) => {
261                        // DataValue only implements PartialOrd, not Ord
262                        let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
263                        if ord != std::cmp::Ordering::Equal {
264                            return if ascending { ord } else { ord.reverse() };
265                        }
266                    }
267                }
268            }
269            std::cmp::Ordering::Equal
270        });
271
272        Ok(())
273    }
274
275    /// Get value at offset from current row (for LAG/LEAD)
276    pub fn get_offset_value(
277        &self,
278        current_row: usize,
279        offset: i32,
280        column: &str,
281    ) -> Option<DataValue> {
282        // Find which partition this row belongs to
283        let partition_key = self.row_to_partition.get(&current_row)?;
284        let partition = self.partitions.get(partition_key)?;
285
286        // Navigate to target row
287        let target_row = partition.get_row_at_offset(current_row, offset)?;
288
289        // Get column value from target row
290        let source_table = self.source.source();
291        let col_idx = source_table.get_column_index(column)?;
292        source_table.get_value(target_row, col_idx).cloned()
293    }
294
295    /// Get row number within partition (1-based)
296    pub fn get_row_number(&self, row_index: usize) -> usize {
297        if let Some(partition_key) = self.row_to_partition.get(&row_index) {
298            if let Some(partition) = self.partitions.get(partition_key) {
299                if let Some(position) = partition.get_position(row_index) {
300                    return position + 1; // Convert to 1-based
301                }
302            }
303        }
304        0 // Should not happen for valid row
305    }
306
307    /// Get first value in frame
308    pub fn get_frame_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
309        let frame_rows = self.get_frame_rows(row_index);
310        if frame_rows.is_empty() {
311            return Some(DataValue::Null);
312        }
313
314        let source_table = self.source.source();
315        let col_idx = source_table.get_column_index(column)?;
316
317        // Get the first row in the frame
318        let first_row = frame_rows[0];
319        source_table.get_value(first_row, col_idx).cloned()
320    }
321
322    /// Get last value in frame
323    pub fn get_frame_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
324        let frame_rows = self.get_frame_rows(row_index);
325        if frame_rows.is_empty() {
326            return Some(DataValue::Null);
327        }
328
329        let source_table = self.source.source();
330        let col_idx = source_table.get_column_index(column)?;
331
332        // Get the last row in the frame
333        let last_row = frame_rows[frame_rows.len() - 1];
334        source_table.get_value(last_row, col_idx).cloned()
335    }
336
337    /// Get first value in partition
338    pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
339        let partition_key = self.row_to_partition.get(&row_index)?;
340        let partition = self.partitions.get(partition_key)?;
341        let first_row = partition.first_row()?;
342
343        let source_table = self.source.source();
344        let col_idx = source_table.get_column_index(column)?;
345        source_table.get_value(first_row, col_idx).cloned()
346    }
347
348    /// Get last value in partition
349    pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
350        let partition_key = self.row_to_partition.get(&row_index)?;
351        let partition = self.partitions.get(partition_key)?;
352        let last_row = partition.last_row()?;
353
354        let source_table = self.source.source();
355        let col_idx = source_table.get_column_index(column)?;
356        source_table.get_value(last_row, col_idx).cloned()
357    }
358
359    /// Get the number of partitions
360    pub fn partition_count(&self) -> usize {
361        self.partitions.len()
362    }
363
364    /// Check if context has partitions (vs single window)
365    pub fn has_partitions(&self) -> bool {
366        !self.spec.partition_by.is_empty()
367    }
368
369    /// Check if context has a window frame specification
370    pub fn has_frame(&self) -> bool {
371        self.spec.frame.is_some()
372    }
373
374    /// Get the source DataView
375    pub fn source(&self) -> &DataTable {
376        self.source.source()
377    }
378
379    /// Get row indices within the window frame for a given row
380    pub fn get_frame_rows(&self, row_index: usize) -> Vec<usize> {
381        // Find which partition this row belongs to
382        let partition_key = match self.row_to_partition.get(&row_index) {
383            Some(key) => key,
384            None => return vec![],
385        };
386
387        let partition = match self.partitions.get(partition_key) {
388            Some(p) => p,
389            None => return vec![],
390        };
391
392        // Get current row's position in partition
393        let current_pos = match partition.get_position(row_index) {
394            Some(pos) => pos as i64,
395            None => return vec![],
396        };
397
398        // If no frame specified, return entire partition (default behavior)
399        let frame = match &self.spec.frame {
400            Some(f) => f,
401            None => return partition.rows.clone(),
402        };
403
404        // Calculate frame bounds
405        let (start_pos, end_pos) = match frame.unit {
406            FrameUnit::Rows => {
407                // ROWS frame - based on physical row positions
408                let start =
409                    self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
410                let end = match &frame.end {
411                    Some(bound) => {
412                        self.calculate_frame_position(bound, current_pos, partition.rows.len())
413                    }
414                    None => current_pos, // Default to CURRENT ROW
415                };
416                (start, end)
417            }
418            FrameUnit::Range => {
419                // RANGE frame - based on ORDER BY values (not yet fully implemented)
420                // For now, treat like ROWS
421                let start =
422                    self.calculate_frame_position(&frame.start, current_pos, partition.rows.len());
423                let end = match &frame.end {
424                    Some(bound) => {
425                        self.calculate_frame_position(bound, current_pos, partition.rows.len())
426                    }
427                    None => current_pos,
428                };
429                (start, end)
430            }
431        };
432
433        // Collect rows within frame bounds
434        let mut frame_rows = Vec::new();
435        for i in start_pos..=end_pos {
436            if i >= 0 && (i as usize) < partition.rows.len() {
437                frame_rows.push(partition.rows[i as usize]);
438            }
439        }
440
441        frame_rows
442    }
443
444    /// Calculate absolute position from frame bound
445    fn calculate_frame_position(
446        &self,
447        bound: &FrameBound,
448        current_pos: i64,
449        partition_size: usize,
450    ) -> i64 {
451        match bound {
452            FrameBound::UnboundedPreceding => 0,
453            FrameBound::UnboundedFollowing => partition_size as i64 - 1,
454            FrameBound::CurrentRow => current_pos,
455            FrameBound::Preceding(n) => current_pos - n,
456            FrameBound::Following(n) => current_pos + n,
457        }
458    }
459
460    /// Calculate sum of a column within the window frame for the given row
461    pub fn get_frame_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
462        let frame_rows = self.get_frame_rows(row_index);
463        if frame_rows.is_empty() {
464            return Some(DataValue::Null);
465        }
466
467        let source_table = self.source.source();
468        let col_idx = source_table.get_column_index(column)?;
469
470        let mut sum = 0.0;
471        let mut has_float = false;
472        let mut has_value = false;
473
474        // Sum all values in the frame
475        for &row_idx in &frame_rows {
476            if let Some(value) = source_table.get_value(row_idx, col_idx) {
477                match value {
478                    DataValue::Integer(i) => {
479                        sum += *i as f64;
480                        has_value = true;
481                    }
482                    DataValue::Float(f) => {
483                        sum += f;
484                        has_float = true;
485                        has_value = true;
486                    }
487                    DataValue::Null => {
488                        // Skip NULL values
489                    }
490                    _ => {
491                        // Non-numeric values - return NULL
492                        return Some(DataValue::Null);
493                    }
494                }
495            }
496        }
497
498        if !has_value {
499            return Some(DataValue::Null);
500        }
501
502        // Return as integer if all values were integers and sum is whole
503        if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
504            Some(DataValue::Integer(sum as i64))
505        } else {
506            Some(DataValue::Float(sum))
507        }
508    }
509
510    /// Calculate count within the window frame
511    pub fn get_frame_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
512        let frame_rows = self.get_frame_rows(row_index);
513        if frame_rows.is_empty() {
514            return Some(DataValue::Integer(0));
515        }
516
517        if let Some(col_name) = column {
518            // COUNT(column) - count non-null values in frame
519            let source_table = self.source.source();
520            let col_idx = source_table.get_column_index(col_name)?;
521
522            let count = frame_rows
523                .iter()
524                .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
525                .filter(|v| !matches!(v, DataValue::Null))
526                .count();
527
528            Some(DataValue::Integer(count as i64))
529        } else {
530            // COUNT(*) - count all rows in frame
531            Some(DataValue::Integer(frame_rows.len() as i64))
532        }
533    }
534
535    /// Calculate average of a column within the window frame
536    pub fn get_frame_avg(&self, row_index: usize, column: &str) -> Option<DataValue> {
537        let frame_rows = self.get_frame_rows(row_index);
538        if frame_rows.is_empty() {
539            return Some(DataValue::Null);
540        }
541
542        let source_table = self.source.source();
543        let col_idx = source_table.get_column_index(column)?;
544
545        let mut sum = 0.0;
546        let mut count = 0;
547
548        // Sum all non-null values in the frame
549        for &row_idx in &frame_rows {
550            if let Some(value) = source_table.get_value(row_idx, col_idx) {
551                match value {
552                    DataValue::Integer(i) => {
553                        sum += *i as f64;
554                        count += 1;
555                    }
556                    DataValue::Float(f) => {
557                        sum += f;
558                        count += 1;
559                    }
560                    DataValue::Null => {
561                        // Skip NULL values
562                    }
563                    _ => {
564                        // Non-numeric values - return NULL
565                        return Some(DataValue::Null);
566                    }
567                }
568            }
569        }
570
571        if count == 0 {
572            return Some(DataValue::Null);
573        }
574
575        Some(DataValue::Float(sum / count as f64))
576    }
577
578    /// Calculate standard deviation within the window frame (sample stddev)
579    pub fn get_frame_stddev(&self, row_index: usize, column: &str) -> Option<DataValue> {
580        let variance = self.get_frame_variance(row_index, column)?;
581        match variance {
582            DataValue::Float(v) => Some(DataValue::Float(v.sqrt())),
583            DataValue::Null => Some(DataValue::Null),
584            _ => Some(DataValue::Null),
585        }
586    }
587
588    /// Calculate variance within the window frame (sample variance with n-1)
589    pub fn get_frame_variance(&self, row_index: usize, column: &str) -> Option<DataValue> {
590        let frame_rows = self.get_frame_rows(row_index);
591        if frame_rows.is_empty() {
592            return Some(DataValue::Null);
593        }
594
595        let source_table = self.source.source();
596        let col_idx = source_table.get_column_index(column)?;
597
598        let mut values = Vec::new();
599
600        // Collect all non-null values in the frame
601        for &row_idx in &frame_rows {
602            if let Some(value) = source_table.get_value(row_idx, col_idx) {
603                match value {
604                    DataValue::Integer(i) => values.push(*i as f64),
605                    DataValue::Float(f) => values.push(*f),
606                    DataValue::Null => {
607                        // Skip NULL values
608                    }
609                    _ => {
610                        // Non-numeric values - return NULL
611                        return Some(DataValue::Null);
612                    }
613                }
614            }
615        }
616
617        if values.is_empty() {
618            return Some(DataValue::Null);
619        }
620
621        if values.len() == 1 {
622            // Variance of single value is 0
623            return Some(DataValue::Float(0.0));
624        }
625
626        // Calculate mean
627        let mean = values.iter().sum::<f64>() / values.len() as f64;
628
629        // Calculate sample variance (n-1 denominator)
630        let variance =
631            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
632
633        Some(DataValue::Float(variance))
634    }
635
636    /// Calculate sum of a column over the partition containing the given row
637    pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
638        let partition_key = self.row_to_partition.get(&row_index)?;
639        let partition = self.partitions.get(partition_key)?;
640        let source_table = self.source.source();
641        let col_idx = source_table.get_column_index(column)?;
642
643        let mut sum = 0.0;
644        let mut has_float = false;
645        let mut has_value = false;
646
647        // Sum all values in the partition
648        for &row_idx in &partition.rows {
649            if let Some(value) = source_table.get_value(row_idx, col_idx) {
650                match value {
651                    DataValue::Integer(i) => {
652                        sum += *i as f64;
653                        has_value = true;
654                    }
655                    DataValue::Float(f) => {
656                        sum += f;
657                        has_float = true;
658                        has_value = true;
659                    }
660                    DataValue::Null => {
661                        // Skip NULL values
662                    }
663                    _ => {
664                        // Non-numeric values - return NULL
665                        return Some(DataValue::Null);
666                    }
667                }
668            }
669        }
670
671        if !has_value {
672            return Some(DataValue::Null);
673        }
674
675        // Return as integer if all values were integers and sum is whole
676        if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
677            Some(DataValue::Integer(sum as i64))
678        } else {
679            Some(DataValue::Float(sum))
680        }
681    }
682
683    /// Calculate count of non-null values in a column over the partition
684    pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
685        let partition_key = self.row_to_partition.get(&row_index)?;
686        let partition = self.partitions.get(partition_key)?;
687
688        if let Some(col_name) = column {
689            // COUNT(column) - count non-null values
690            let source_table = self.source.source();
691            let col_idx = source_table.get_column_index(col_name)?;
692
693            let count = partition
694                .rows
695                .iter()
696                .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
697                .filter(|v| !matches!(v, DataValue::Null))
698                .count();
699
700            Some(DataValue::Integer(count as i64))
701        } else {
702            // COUNT(*) - count all rows in partition
703            Some(DataValue::Integer(partition.rows.len() as i64))
704        }
705    }
706}