sql_cli/sql/
window_context.rs

1//! Window function context for managing partitioned and ordered data views
2//!
3//! This module provides the WindowContext helper class that enables window functions
4//! like LAG, LEAD, ROW_NUMBER, etc. by managing partitions and ordering.
5
6use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::recursive_parser::{OrderByColumn, SortDirection};
14
15/// Key for identifying a partition (combination of partition column values)
16/// We use String representation for now since DataValue doesn't impl Ord
17#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
18struct PartitionKey(String);
19
20impl PartitionKey {
21    /// Create a partition key from data values
22    fn from_values(values: Vec<DataValue>) -> Self {
23        // Create a unique string representation
24        let key_parts: Vec<String> = values
25            .iter()
26            .map(|v| match v {
27                DataValue::String(s) => format!("S:{}", s),
28                DataValue::InternedString(s) => format!("S:{}", s),
29                DataValue::Integer(i) => format!("I:{}", i),
30                DataValue::Float(f) => format!("F:{}", f),
31                DataValue::Boolean(b) => format!("B:{}", b),
32                DataValue::DateTime(dt) => format!("D:{}", dt),
33                DataValue::Null => "N".to_string(),
34            })
35            .collect();
36        let key = key_parts.join("|");
37        PartitionKey(key)
38    }
39}
40
41/// An ordered partition containing row indices
42#[derive(Debug, Clone)]
43pub struct OrderedPartition {
44    /// Original row indices from DataView, in sorted order
45    rows: Vec<usize>,
46
47    /// Quick lookup: row_index -> position in partition
48    row_positions: HashMap<usize, usize>,
49}
50
51impl OrderedPartition {
52    /// Create a new ordered partition from row indices
53    fn new(mut rows: Vec<usize>) -> Self {
54        // Build position lookup
55        let row_positions: HashMap<usize, usize> = rows
56            .iter()
57            .enumerate()
58            .map(|(pos, &row_idx)| (row_idx, pos))
59            .collect();
60
61        Self {
62            rows,
63            row_positions,
64        }
65    }
66
67    /// Navigate to offset from current position
68    pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
69        let current_pos = self.row_positions.get(&current_row)?;
70        let target_pos = (*current_pos as i32) + offset;
71
72        if target_pos >= 0 && target_pos < self.rows.len() as i32 {
73            Some(self.rows[target_pos as usize])
74        } else {
75            None
76        }
77    }
78
79    /// Get position of row in this partition (0-based)
80    pub fn get_position(&self, row_index: usize) -> Option<usize> {
81        self.row_positions.get(&row_index).copied()
82    }
83
84    /// Get the first row index in this partition
85    pub fn first_row(&self) -> Option<usize> {
86        self.rows.first().copied()
87    }
88
89    /// Get the last row index in this partition
90    pub fn last_row(&self) -> Option<usize> {
91        self.rows.last().copied()
92    }
93}
94
95/// Window specification defining partitioning and ordering
96#[derive(Debug, Clone)]
97pub struct WindowSpec {
98    pub partition_by: Vec<String>,
99    pub order_by: Vec<OrderByColumn>,
100}
101
102/// Context for evaluating window functions
103pub struct WindowContext {
104    /// Source data view
105    source: Arc<DataView>,
106
107    /// Partitions with their ordered rows
108    partitions: BTreeMap<PartitionKey, OrderedPartition>,
109
110    /// Mapping from row index to its partition key
111    row_to_partition: HashMap<usize, PartitionKey>,
112
113    /// Window specification
114    spec: WindowSpec,
115}
116
117impl WindowContext {
118    /// Create a new window context with partitioning and ordering
119    pub fn new(
120        view: Arc<DataView>,
121        partition_by: Vec<String>,
122        order_by: Vec<OrderByColumn>,
123    ) -> Result<Self> {
124        let spec = WindowSpec {
125            partition_by: partition_by.clone(),
126            order_by: order_by.clone(),
127        };
128
129        // If no partition columns, treat entire view as single partition
130        if partition_by.is_empty() {
131            let single_partition = Self::create_single_partition(&view, &order_by)?;
132            let partition_key = PartitionKey::from_values(vec![]);
133
134            // Build row-to-partition mapping
135            let mut row_to_partition = HashMap::new();
136            for &row_idx in &single_partition.rows {
137                row_to_partition.insert(row_idx, partition_key.clone());
138            }
139
140            let mut partitions = BTreeMap::new();
141            partitions.insert(partition_key, single_partition);
142
143            return Ok(Self {
144                source: view,
145                partitions,
146                row_to_partition,
147                spec,
148            });
149        }
150
151        // Create partitions based on partition_by columns
152        let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
153        let mut row_to_partition = HashMap::new();
154
155        // Get column indices for partition columns
156        let source_table = view.source();
157        let partition_col_indices: Vec<usize> = partition_by
158            .iter()
159            .map(|col| {
160                source_table
161                    .get_column_index(col)
162                    .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
163            })
164            .collect::<Result<Vec<_>>>()?;
165
166        // Group rows by partition key
167        for row_idx in view.get_visible_rows() {
168            // Build partition key from row values
169            let mut key_values = Vec::new();
170            for &col_idx in &partition_col_indices {
171                let value = source_table
172                    .get_value(row_idx, col_idx)
173                    .ok_or_else(|| anyhow!("Failed to get value for partition"))?
174                    .clone();
175                key_values.push(value);
176            }
177            let key = PartitionKey::from_values(key_values);
178
179            // Add row to partition
180            partition_map.entry(key.clone()).or_default().push(row_idx);
181            row_to_partition.insert(row_idx, key);
182        }
183
184        // Sort each partition according to ORDER BY
185        let mut partitions = BTreeMap::new();
186        for (key, mut rows) in partition_map {
187            // Sort rows within partition
188            if !order_by.is_empty() {
189                Self::sort_rows(&mut rows, source_table, &order_by)?;
190            }
191
192            partitions.insert(key, OrderedPartition::new(rows));
193        }
194
195        Ok(Self {
196            source: view,
197            partitions,
198            row_to_partition,
199            spec,
200        })
201    }
202
203    /// Create a single partition from the entire view
204    fn create_single_partition(
205        view: &DataView,
206        order_by: &[OrderByColumn],
207    ) -> Result<OrderedPartition> {
208        let mut rows: Vec<usize> = view.get_visible_rows();
209
210        if !order_by.is_empty() {
211            Self::sort_rows(&mut rows, view.source(), order_by)?;
212        }
213
214        Ok(OrderedPartition::new(rows))
215    }
216
217    /// Sort row indices according to ORDER BY specification
218    fn sort_rows(
219        rows: &mut Vec<usize>,
220        table: &DataTable,
221        order_by: &[OrderByColumn],
222    ) -> Result<()> {
223        // Get column indices for ORDER BY columns
224        let sort_cols: Vec<(usize, bool)> = order_by
225            .iter()
226            .map(|col| {
227                let idx = table
228                    .get_column_index(&col.column)
229                    .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", col.column))?;
230                let ascending = matches!(col.direction, SortDirection::Asc);
231                Ok((idx, ascending))
232            })
233            .collect::<Result<Vec<_>>>()?;
234
235        // Sort rows based on column values
236        rows.sort_by(|&a, &b| {
237            for &(col_idx, ascending) in &sort_cols {
238                let val_a = table.get_value(a, col_idx);
239                let val_b = table.get_value(b, col_idx);
240
241                match (val_a, val_b) {
242                    (None, None) => continue,
243                    (None, Some(_)) => {
244                        return if ascending {
245                            std::cmp::Ordering::Less
246                        } else {
247                            std::cmp::Ordering::Greater
248                        }
249                    }
250                    (Some(_), None) => {
251                        return if ascending {
252                            std::cmp::Ordering::Greater
253                        } else {
254                            std::cmp::Ordering::Less
255                        }
256                    }
257                    (Some(v_a), Some(v_b)) => {
258                        // DataValue only implements PartialOrd, not Ord
259                        let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
260                        if ord != std::cmp::Ordering::Equal {
261                            return if ascending { ord } else { ord.reverse() };
262                        }
263                    }
264                }
265            }
266            std::cmp::Ordering::Equal
267        });
268
269        Ok(())
270    }
271
272    /// Get value at offset from current row (for LAG/LEAD)
273    pub fn get_offset_value(
274        &self,
275        current_row: usize,
276        offset: i32,
277        column: &str,
278    ) -> Option<DataValue> {
279        // Find which partition this row belongs to
280        let partition_key = self.row_to_partition.get(&current_row)?;
281        let partition = self.partitions.get(partition_key)?;
282
283        // Navigate to target row
284        let target_row = partition.get_row_at_offset(current_row, offset)?;
285
286        // Get column value from target row
287        let source_table = self.source.source();
288        let col_idx = source_table.get_column_index(column)?;
289        source_table.get_value(target_row, col_idx).cloned()
290    }
291
292    /// Get row number within partition (1-based)
293    pub fn get_row_number(&self, row_index: usize) -> usize {
294        if let Some(partition_key) = self.row_to_partition.get(&row_index) {
295            if let Some(partition) = self.partitions.get(partition_key) {
296                if let Some(position) = partition.get_position(row_index) {
297                    return position + 1; // Convert to 1-based
298                }
299            }
300        }
301        0 // Should not happen for valid row
302    }
303
304    /// Get first value in partition
305    pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
306        let partition_key = self.row_to_partition.get(&row_index)?;
307        let partition = self.partitions.get(partition_key)?;
308        let first_row = partition.first_row()?;
309
310        let source_table = self.source.source();
311        let col_idx = source_table.get_column_index(column)?;
312        source_table.get_value(first_row, col_idx).cloned()
313    }
314
315    /// Get last value in partition
316    pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
317        let partition_key = self.row_to_partition.get(&row_index)?;
318        let partition = self.partitions.get(partition_key)?;
319        let last_row = partition.last_row()?;
320
321        let source_table = self.source.source();
322        let col_idx = source_table.get_column_index(column)?;
323        source_table.get_value(last_row, col_idx).cloned()
324    }
325
326    /// Get the number of partitions
327    pub fn partition_count(&self) -> usize {
328        self.partitions.len()
329    }
330
331    /// Check if context has partitions (vs single window)
332    pub fn has_partitions(&self) -> bool {
333        !self.spec.partition_by.is_empty()
334    }
335
336    /// Calculate sum of a column over the partition containing the given row
337    pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
338        let partition_key = self.row_to_partition.get(&row_index)?;
339        let partition = self.partitions.get(partition_key)?;
340        let source_table = self.source.source();
341        let col_idx = source_table.get_column_index(column)?;
342
343        let mut sum = 0.0;
344        let mut has_float = false;
345        let mut has_value = false;
346
347        // Sum all values in the partition
348        for &row_idx in &partition.rows {
349            if let Some(value) = source_table.get_value(row_idx, col_idx) {
350                match value {
351                    DataValue::Integer(i) => {
352                        sum += *i as f64;
353                        has_value = true;
354                    }
355                    DataValue::Float(f) => {
356                        sum += f;
357                        has_float = true;
358                        has_value = true;
359                    }
360                    DataValue::Null => {
361                        // Skip NULL values
362                    }
363                    _ => {
364                        // Non-numeric values - return NULL
365                        return Some(DataValue::Null);
366                    }
367                }
368            }
369        }
370
371        if !has_value {
372            return Some(DataValue::Null);
373        }
374
375        // Return as integer if all values were integers and sum is whole
376        if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
377            Some(DataValue::Integer(sum as i64))
378        } else {
379            Some(DataValue::Float(sum))
380        }
381    }
382
383    /// Calculate count of non-null values in a column over the partition
384    pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
385        let partition_key = self.row_to_partition.get(&row_index)?;
386        let partition = self.partitions.get(partition_key)?;
387
388        if let Some(col_name) = column {
389            // COUNT(column) - count non-null values
390            let source_table = self.source.source();
391            let col_idx = source_table.get_column_index(col_name)?;
392
393            let count = partition
394                .rows
395                .iter()
396                .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
397                .filter(|v| !matches!(v, DataValue::Null))
398                .count();
399
400            Some(DataValue::Integer(count as i64))
401        } else {
402            // COUNT(*) - count all rows in partition
403            Some(DataValue::Integer(partition.rows.len() as i64))
404        }
405    }
406}