Skip to main content

shape_runtime/
window_executor.rs

1//! Window function execution engine
2//!
3//! Executes SQL-style window functions over datasets:
4//! - Ranking functions: ROW_NUMBER, RANK, DENSE_RANK, NTILE
5//! - Navigation functions: LAG, LEAD, FIRST_VALUE, LAST_VALUE, NTH_VALUE
6//! - Aggregate functions: SUM, AVG, MIN, MAX, COUNT over window frames
7
8use crate::context::ExecutionContext;
9use shape_ast::ast::{Expr, SortDirection, WindowBound, WindowExpr, WindowFrame, WindowFunction};
10use shape_ast::error::Result;
11use shape_value::ValueWord;
12use std::collections::HashMap;
13
14/// Execute window functions over a dataset
15pub struct WindowExecutor {
16    /// Partitioned row data
17    partitions: HashMap<Vec<OrderedValue>, Vec<RowData>>,
18}
19
20/// A row of data with its original index for result placement
21struct RowData {
22    /// Original row index in the input dataset
23    original_index: usize,
24    /// Row values by field name
25    values: HashMap<String, ValueWord>,
26}
27
28/// Wrapper for ValueWord that implements Eq + Hash for partition keys
29#[derive(Clone, Debug)]
30struct OrderedValue(ValueWord);
31
32impl PartialEq for OrderedValue {
33    fn eq(&self, other: &Self) -> bool {
34        use shape_value::NanTag;
35        match (self.0.tag(), other.0.tag()) {
36            (NanTag::F64, NanTag::F64)
37            | (NanTag::I48, NanTag::I48)
38            | (NanTag::F64, NanTag::I48)
39            | (NanTag::I48, NanTag::F64) => match (self.0.as_f64(), other.0.as_f64()) {
40                (Some(a), Some(b)) => {
41                    if a.is_nan() && b.is_nan() {
42                        true
43                    } else {
44                        a == b
45                    }
46                }
47                _ => false,
48            },
49            (NanTag::Heap, NanTag::Heap) => {
50                if let (Some(a), Some(b)) = (self.0.as_str(), other.0.as_str()) {
51                    a == b
52                } else {
53                    false
54                }
55            }
56            (NanTag::Bool, NanTag::Bool) => self.0.as_bool() == other.0.as_bool(),
57            (NanTag::None, NanTag::None) => true,
58            _ => {
59                if let (Some(a), Some(b)) = (self.0.as_time(), other.0.as_time()) {
60                    a == b
61                } else {
62                    false
63                }
64            }
65        }
66    }
67}
68
69impl Eq for OrderedValue {}
70
71impl std::hash::Hash for OrderedValue {
72    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
73        use shape_value::NanTag;
74        match self.0.tag() {
75            NanTag::F64 | NanTag::I48 => {
76                state.write_u8(0);
77                if let Some(n) = self.0.as_f64() {
78                    state.write_u64(n.to_bits());
79                }
80            }
81            NanTag::Heap => {
82                if let Some(s) = self.0.as_str() {
83                    state.write_u8(1);
84                    s.hash(state);
85                } else {
86                    state.write_u8(255);
87                }
88            }
89            NanTag::Bool => {
90                state.write_u8(2);
91                if let Some(b) = self.0.as_bool() {
92                    b.hash(state);
93                }
94            }
95            NanTag::None => {
96                state.write_u8(4);
97            }
98            _ => {
99                if let Some(t) = self.0.as_time() {
100                    state.write_u8(3);
101                    t.timestamp_nanos_opt().unwrap_or(0).hash(state);
102                } else {
103                    state.write_u8(255);
104                }
105            }
106        }
107    }
108}
109
110impl WindowExecutor {
111    /// Create a new window executor
112    pub fn new() -> Self {
113        Self {
114            partitions: HashMap::new(),
115        }
116    }
117
118    /// Execute a window function over rows
119    pub fn execute(
120        &mut self,
121        rows: &[HashMap<String, ValueWord>],
122        window_expr: &WindowExpr,
123        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
124        ctx: &mut ExecutionContext,
125    ) -> Result<Vec<ValueWord>> {
126        self.partitions.clear();
127
128        // 1. Partition rows
129        self.partition_rows(rows, &window_expr.over.partition_by, evaluator, ctx)?;
130
131        // 2. Sort each partition
132        if let Some(ref order_by) = window_expr.over.order_by {
133            self.sort_partitions(order_by)?;
134        }
135
136        // 3. Apply window function
137        let mut results = vec![ValueWord::none(); rows.len()];
138
139        for partition in self.partitions.values() {
140            for (pos, row) in partition.iter().enumerate() {
141                let value = evaluate_window_function(
142                    &window_expr.function,
143                    partition,
144                    pos,
145                    &window_expr.over.frame,
146                    evaluator,
147                    ctx,
148                )?;
149                results[row.original_index] = value;
150            }
151        }
152
153        Ok(results)
154    }
155
156    fn partition_rows(
157        &mut self,
158        rows: &[HashMap<String, ValueWord>],
159        partition_by: &[Expr],
160        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
161        ctx: &mut ExecutionContext,
162    ) -> Result<()> {
163        if partition_by.is_empty() {
164            // Single partition with all rows
165            let all_rows: Vec<_> = rows
166                .iter()
167                .enumerate()
168                .map(|(idx, row)| RowData {
169                    original_index: idx,
170                    values: row.clone(),
171                })
172                .collect();
173            self.partitions.insert(vec![], all_rows);
174            return Ok(());
175        }
176
177        for (idx, row) in rows.iter().enumerate() {
178            ctx.push_scope();
179            for (key, value) in row {
180                let _ = ctx.set_variable_nb(key, value.clone());
181            }
182
183            let mut key = Vec::with_capacity(partition_by.len());
184            for expr in partition_by {
185                let value = if let Some(eval) = evaluator {
186                    eval.eval_expr(expr, ctx).unwrap_or(ValueWord::none())
187                } else {
188                    ValueWord::none()
189                };
190                key.push(OrderedValue(value));
191            }
192
193            ctx.pop_scope();
194
195            self.partitions.entry(key).or_default().push(RowData {
196                original_index: idx,
197                values: row.clone(),
198            });
199        }
200
201        Ok(())
202    }
203
204    fn sort_partitions(&mut self, order_by: &shape_ast::ast::OrderByClause) -> Result<()> {
205        for partition in self.partitions.values_mut() {
206            partition.sort_by(|a, b| {
207                for (expr, direction) in &order_by.columns {
208                    let a_val = extract_sort_value(&a.values, expr);
209                    let b_val = extract_sort_value(&b.values, expr);
210
211                    let cmp = compare_nb_values(&a_val, &b_val);
212                    let cmp = match direction {
213                        SortDirection::Ascending => cmp,
214                        SortDirection::Descending => cmp.reverse(),
215                    };
216
217                    if cmp != std::cmp::Ordering::Equal {
218                        return cmp;
219                    }
220                }
221                std::cmp::Ordering::Equal
222            });
223        }
224        Ok(())
225    }
226}
227
228impl Default for WindowExecutor {
229    fn default() -> Self {
230        Self::new()
231    }
232}
233
234/// Evaluate window function for a specific row
235fn evaluate_window_function(
236    func: &WindowFunction,
237    partition: &[RowData],
238    current_idx: usize,
239    frame: &Option<WindowFrame>,
240    evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
241    ctx: &mut ExecutionContext,
242) -> Result<ValueWord> {
243    match func {
244        WindowFunction::RowNumber => Ok(ValueWord::from_f64((current_idx + 1) as f64)),
245
246        WindowFunction::Rank => {
247            let rank = calculate_rank(partition, current_idx, false);
248            Ok(ValueWord::from_f64(rank as f64))
249        }
250
251        WindowFunction::DenseRank => {
252            let rank = calculate_rank(partition, current_idx, true);
253            Ok(ValueWord::from_f64(rank as f64))
254        }
255
256        WindowFunction::Ntile(n) => {
257            let bucket = if partition.is_empty() {
258                1
259            } else {
260                (current_idx * *n / partition.len()) + 1
261            };
262            Ok(ValueWord::from_f64(bucket as f64))
263        }
264
265        WindowFunction::Lag {
266            expr,
267            offset,
268            default,
269        } => {
270            if let Some(target_idx) = current_idx.checked_sub(*offset) {
271                if target_idx < partition.len() {
272                    return eval_expr_at(expr, &partition[target_idx], evaluator, ctx);
273                }
274            }
275            if let Some(def) = default {
276                if let Some(eval) = evaluator {
277                    Ok(eval.eval_expr(def, ctx)?)
278                } else {
279                    Ok(ValueWord::none())
280                }
281            } else {
282                Ok(ValueWord::none())
283            }
284        }
285
286        WindowFunction::Lead {
287            expr,
288            offset,
289            default,
290        } => {
291            let target_idx = current_idx + *offset;
292            if target_idx < partition.len() {
293                return eval_expr_at(expr, &partition[target_idx], evaluator, ctx);
294            }
295            if let Some(def) = default {
296                if let Some(eval) = evaluator {
297                    Ok(eval.eval_expr(def, ctx)?)
298                } else {
299                    Ok(ValueWord::none())
300                }
301            } else {
302                Ok(ValueWord::none())
303            }
304        }
305
306        WindowFunction::FirstValue(expr) => {
307            let (start, _) = get_frame_bounds(frame, partition.len(), current_idx);
308            eval_expr_at(expr, &partition[start], evaluator, ctx)
309        }
310
311        WindowFunction::LastValue(expr) => {
312            let (_, end) = get_frame_bounds(frame, partition.len(), current_idx);
313            eval_expr_at(expr, &partition[end], evaluator, ctx)
314        }
315
316        WindowFunction::NthValue(expr, n) => {
317            let (start, end) = get_frame_bounds(frame, partition.len(), current_idx);
318            let target_idx = start + n - 1;
319            if target_idx <= end && target_idx < partition.len() {
320                eval_expr_at(expr, &partition[target_idx], evaluator, ctx)
321            } else {
322                Ok(ValueWord::none())
323            }
324        }
325
326        WindowFunction::Sum(expr)
327        | WindowFunction::Avg(expr)
328        | WindowFunction::Min(expr)
329        | WindowFunction::Max(expr) => {
330            let (start, end) = get_frame_bounds(frame, partition.len(), current_idx);
331            let mut values = Vec::new();
332
333            for i in start..=end.min(partition.len().saturating_sub(1)) {
334                let nb = eval_expr_at(expr, &partition[i], evaluator, ctx)?;
335                if let Some(n) = nb.as_f64() {
336                    values.push(n);
337                }
338            }
339
340            if values.is_empty() {
341                return Ok(ValueWord::none());
342            }
343
344            let result = match func {
345                WindowFunction::Sum(_) => values.iter().sum::<f64>(),
346                WindowFunction::Avg(_) => values.iter().sum::<f64>() / values.len() as f64,
347                WindowFunction::Min(_) => values
348                    .iter()
349                    .cloned()
350                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
351                    .unwrap_or(f64::NAN),
352                WindowFunction::Max(_) => values
353                    .iter()
354                    .cloned()
355                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
356                    .unwrap_or(f64::NAN),
357                _ => unreachable!(),
358            };
359
360            Ok(ValueWord::from_f64(result))
361        }
362
363        WindowFunction::Count(expr_opt) => {
364            let (start, end) = get_frame_bounds(frame, partition.len(), current_idx);
365
366            let count = if let Some(expr) = expr_opt {
367                (start..=end.min(partition.len().saturating_sub(1)))
368                    .filter(|&i| {
369                        eval_expr_at(expr, &partition[i], evaluator, ctx)
370                            .map(|v| !v.is_none())
371                            .unwrap_or(false)
372                    })
373                    .count()
374            } else {
375                end.min(partition.len().saturating_sub(1))
376                    .saturating_sub(start)
377                    + 1
378            };
379
380            Ok(ValueWord::from_f64(count as f64))
381        }
382    }
383}
384
385/// Evaluate expression with row context
386fn eval_expr_at(
387    expr: &Expr,
388    row: &RowData,
389    evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
390    ctx: &mut ExecutionContext,
391) -> Result<ValueWord> {
392    ctx.push_scope();
393    for (key, value) in &row.values {
394        let _ = ctx.set_variable_nb(key, value.clone());
395    }
396    let result = if let Some(eval) = evaluator {
397        Ok(eval.eval_expr(expr, ctx)?)
398    } else {
399        // Fallback: try simple identifier lookup
400        if let Expr::Identifier(name, _) = expr {
401            Ok(row.values.get(name).cloned().unwrap_or(ValueWord::none()))
402        } else {
403            Ok(ValueWord::none())
404        }
405    };
406    ctx.pop_scope();
407    result
408}
409
410/// Calculate rank within partition
411fn calculate_rank(_partition: &[RowData], current_idx: usize, dense: bool) -> usize {
412    if current_idx == 0 {
413        return 1;
414    }
415    // Simplified: each row gets sequential rank
416    // Full implementation would compare ORDER BY values
417    if dense {
418        current_idx + 1
419    } else {
420        current_idx + 1
421    }
422}
423
424/// Get frame bounds for aggregate functions
425fn get_frame_bounds(
426    frame: &Option<WindowFrame>,
427    partition_len: usize,
428    current_idx: usize,
429) -> (usize, usize) {
430    match frame {
431        Some(f) => {
432            let start = match &f.start {
433                WindowBound::UnboundedPreceding => 0,
434                WindowBound::CurrentRow => current_idx,
435                WindowBound::Preceding(n) => current_idx.saturating_sub(*n),
436                WindowBound::Following(n) => (current_idx + n).min(partition_len.saturating_sub(1)),
437                WindowBound::UnboundedFollowing => partition_len.saturating_sub(1),
438            };
439            let end = match &f.end {
440                WindowBound::UnboundedPreceding => 0,
441                WindowBound::CurrentRow => current_idx,
442                WindowBound::Preceding(n) => current_idx.saturating_sub(*n),
443                WindowBound::Following(n) => (current_idx + n).min(partition_len.saturating_sub(1)),
444                WindowBound::UnboundedFollowing => partition_len.saturating_sub(1),
445            };
446            (start, end)
447        }
448        None => (0, current_idx),
449    }
450}
451
452/// Extract sort value from expression
453fn extract_sort_value(row: &HashMap<String, ValueWord>, expr: &Expr) -> ValueWord {
454    if let Expr::Identifier(name, _) = expr {
455        return row.get(name).cloned().unwrap_or(ValueWord::none());
456    }
457    ValueWord::none()
458}
459
460/// Compare two ValueWord values for sorting
461fn compare_nb_values(a: &ValueWord, b: &ValueWord) -> std::cmp::Ordering {
462    use shape_value::NanTag;
463    match (a.tag(), b.tag()) {
464        (NanTag::F64, NanTag::F64)
465        | (NanTag::I48, NanTag::I48)
466        | (NanTag::F64, NanTag::I48)
467        | (NanTag::I48, NanTag::F64) => match (a.as_f64(), b.as_f64()) {
468            (Some(an), Some(bn)) => an.partial_cmp(&bn).unwrap_or(std::cmp::Ordering::Equal),
469            _ => std::cmp::Ordering::Equal,
470        },
471        (NanTag::Heap, NanTag::Heap) => match (a.as_str(), b.as_str()) {
472            (Some(sa), Some(sb)) => sa.cmp(sb),
473            _ => std::cmp::Ordering::Equal,
474        },
475        (NanTag::Bool, NanTag::Bool) => match (a.as_bool(), b.as_bool()) {
476            (Some(ba), Some(bb)) => ba.cmp(&bb),
477            _ => std::cmp::Ordering::Equal,
478        },
479        (NanTag::None, NanTag::None) => std::cmp::Ordering::Equal,
480        (NanTag::None, _) => std::cmp::Ordering::Less,
481        (_, NanTag::None) => std::cmp::Ordering::Greater,
482        _ => match (a.as_time(), b.as_time()) {
483            (Some(ta), Some(tb)) => ta.cmp(&tb),
484            _ => std::cmp::Ordering::Equal,
485        },
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    fn make_rows(data: Vec<Vec<(&str, ValueWord)>>) -> Vec<HashMap<String, ValueWord>> {
494        data.into_iter()
495            .map(|row| row.into_iter().map(|(k, v)| (k.to_string(), v)).collect())
496            .collect()
497    }
498
499    #[test]
500    fn test_row_number_simple() {
501        let mut ctx = ExecutionContext::new_empty();
502        let mut executor = WindowExecutor::new();
503
504        let rows = make_rows(vec![
505            vec![("x", ValueWord::from_f64(1.0))],
506            vec![("x", ValueWord::from_f64(2.0))],
507            vec![("x", ValueWord::from_f64(3.0))],
508        ]);
509
510        let window_expr = WindowExpr {
511            function: WindowFunction::RowNumber,
512            over: shape_ast::ast::WindowSpec {
513                partition_by: vec![],
514                order_by: None,
515                frame: None,
516            },
517        };
518
519        let results = executor
520            .execute(&rows, &window_expr, None, &mut ctx)
521            .unwrap();
522
523        assert_eq!(results.len(), 3);
524        assert_eq!(results[0].as_f64(), Some(1.0));
525        assert_eq!(results[1].as_f64(), Some(2.0));
526        assert_eq!(results[2].as_f64(), Some(3.0));
527    }
528}