1use crate::context::ExecutionContext;
9use shape_ast::ast::{Expr, SortDirection, WindowBound, WindowExpr, WindowFrame, WindowFunction};
10use shape_ast::error::Result;
11use shape_value::ValueWord;
12use std::collections::HashMap;
13
14pub struct WindowExecutor {
16 partitions: HashMap<Vec<OrderedValue>, Vec<RowData>>,
18}
19
20struct RowData {
22 original_index: usize,
24 values: HashMap<String, ValueWord>,
26}
27
28#[derive(Clone, Debug)]
30struct OrderedValue(ValueWord);
31
32impl PartialEq for OrderedValue {
33 fn eq(&self, other: &Self) -> bool {
34 use shape_value::NanTag;
35 match (self.0.tag(), other.0.tag()) {
36 (NanTag::F64, NanTag::F64)
37 | (NanTag::I48, NanTag::I48)
38 | (NanTag::F64, NanTag::I48)
39 | (NanTag::I48, NanTag::F64) => match (self.0.as_f64(), other.0.as_f64()) {
40 (Some(a), Some(b)) => {
41 if a.is_nan() && b.is_nan() {
42 true
43 } else {
44 a == b
45 }
46 }
47 _ => false,
48 },
49 (NanTag::Heap, NanTag::Heap) => {
50 if let (Some(a), Some(b)) = (self.0.as_str(), other.0.as_str()) {
51 a == b
52 } else {
53 false
54 }
55 }
56 (NanTag::Bool, NanTag::Bool) => self.0.as_bool() == other.0.as_bool(),
57 (NanTag::None, NanTag::None) => true,
58 _ => {
59 if let (Some(a), Some(b)) = (self.0.as_time(), other.0.as_time()) {
60 a == b
61 } else {
62 false
63 }
64 }
65 }
66 }
67}
68
69impl Eq for OrderedValue {}
70
71impl std::hash::Hash for OrderedValue {
72 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
73 use shape_value::NanTag;
74 match self.0.tag() {
75 NanTag::F64 | NanTag::I48 => {
76 state.write_u8(0);
77 if let Some(n) = self.0.as_f64() {
78 state.write_u64(n.to_bits());
79 }
80 }
81 NanTag::Heap => {
82 if let Some(s) = self.0.as_str() {
83 state.write_u8(1);
84 s.hash(state);
85 } else {
86 state.write_u8(255);
87 }
88 }
89 NanTag::Bool => {
90 state.write_u8(2);
91 if let Some(b) = self.0.as_bool() {
92 b.hash(state);
93 }
94 }
95 NanTag::None => {
96 state.write_u8(4);
97 }
98 _ => {
99 if let Some(t) = self.0.as_time() {
100 state.write_u8(3);
101 t.timestamp_nanos_opt().unwrap_or(0).hash(state);
102 } else {
103 state.write_u8(255);
104 }
105 }
106 }
107 }
108}
109
110impl WindowExecutor {
111 pub fn new() -> Self {
113 Self {
114 partitions: HashMap::new(),
115 }
116 }
117
118 pub fn execute(
120 &mut self,
121 rows: &[HashMap<String, ValueWord>],
122 window_expr: &WindowExpr,
123 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
124 ctx: &mut ExecutionContext,
125 ) -> Result<Vec<ValueWord>> {
126 self.partitions.clear();
127
128 self.partition_rows(rows, &window_expr.over.partition_by, evaluator, ctx)?;
130
131 if let Some(ref order_by) = window_expr.over.order_by {
133 self.sort_partitions(order_by)?;
134 }
135
136 let mut results = vec![ValueWord::none(); rows.len()];
138
139 for partition in self.partitions.values() {
140 for (pos, row) in partition.iter().enumerate() {
141 let value = evaluate_window_function(
142 &window_expr.function,
143 partition,
144 pos,
145 &window_expr.over.frame,
146 evaluator,
147 ctx,
148 )?;
149 results[row.original_index] = value;
150 }
151 }
152
153 Ok(results)
154 }
155
156 fn partition_rows(
157 &mut self,
158 rows: &[HashMap<String, ValueWord>],
159 partition_by: &[Expr],
160 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
161 ctx: &mut ExecutionContext,
162 ) -> Result<()> {
163 if partition_by.is_empty() {
164 let all_rows: Vec<_> = rows
166 .iter()
167 .enumerate()
168 .map(|(idx, row)| RowData {
169 original_index: idx,
170 values: row.clone(),
171 })
172 .collect();
173 self.partitions.insert(vec![], all_rows);
174 return Ok(());
175 }
176
177 for (idx, row) in rows.iter().enumerate() {
178 ctx.push_scope();
179 for (key, value) in row {
180 let _ = ctx.set_variable_nb(key, value.clone());
181 }
182
183 let mut key = Vec::with_capacity(partition_by.len());
184 for expr in partition_by {
185 let value = if let Some(eval) = evaluator {
186 eval.eval_expr(expr, ctx).unwrap_or(ValueWord::none())
187 } else {
188 ValueWord::none()
189 };
190 key.push(OrderedValue(value));
191 }
192
193 ctx.pop_scope();
194
195 self.partitions.entry(key).or_default().push(RowData {
196 original_index: idx,
197 values: row.clone(),
198 });
199 }
200
201 Ok(())
202 }
203
204 fn sort_partitions(&mut self, order_by: &shape_ast::ast::OrderByClause) -> Result<()> {
205 for partition in self.partitions.values_mut() {
206 partition.sort_by(|a, b| {
207 for (expr, direction) in &order_by.columns {
208 let a_val = extract_sort_value(&a.values, expr);
209 let b_val = extract_sort_value(&b.values, expr);
210
211 let cmp = compare_nb_values(&a_val, &b_val);
212 let cmp = match direction {
213 SortDirection::Ascending => cmp,
214 SortDirection::Descending => cmp.reverse(),
215 };
216
217 if cmp != std::cmp::Ordering::Equal {
218 return cmp;
219 }
220 }
221 std::cmp::Ordering::Equal
222 });
223 }
224 Ok(())
225 }
226}
227
228impl Default for WindowExecutor {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234fn evaluate_window_function(
236 func: &WindowFunction,
237 partition: &[RowData],
238 current_idx: usize,
239 frame: &Option<WindowFrame>,
240 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
241 ctx: &mut ExecutionContext,
242) -> Result<ValueWord> {
243 match func {
244 WindowFunction::RowNumber => Ok(ValueWord::from_f64((current_idx + 1) as f64)),
245
246 WindowFunction::Rank => {
247 let rank = calculate_rank(partition, current_idx, false);
248 Ok(ValueWord::from_f64(rank as f64))
249 }
250
251 WindowFunction::DenseRank => {
252 let rank = calculate_rank(partition, current_idx, true);
253 Ok(ValueWord::from_f64(rank as f64))
254 }
255
256 WindowFunction::Ntile(n) => {
257 let bucket = if partition.is_empty() {
258 1
259 } else {
260 (current_idx * *n / partition.len()) + 1
261 };
262 Ok(ValueWord::from_f64(bucket as f64))
263 }
264
265 WindowFunction::Lag {
266 expr,
267 offset,
268 default,
269 } => {
270 if let Some(target_idx) = current_idx.checked_sub(*offset) {
271 if target_idx < partition.len() {
272 return eval_expr_at(expr, &partition[target_idx], evaluator, ctx);
273 }
274 }
275 if let Some(def) = default {
276 if let Some(eval) = evaluator {
277 Ok(eval.eval_expr(def, ctx)?)
278 } else {
279 Ok(ValueWord::none())
280 }
281 } else {
282 Ok(ValueWord::none())
283 }
284 }
285
286 WindowFunction::Lead {
287 expr,
288 offset,
289 default,
290 } => {
291 let target_idx = current_idx + *offset;
292 if target_idx < partition.len() {
293 return eval_expr_at(expr, &partition[target_idx], evaluator, ctx);
294 }
295 if let Some(def) = default {
296 if let Some(eval) = evaluator {
297 Ok(eval.eval_expr(def, ctx)?)
298 } else {
299 Ok(ValueWord::none())
300 }
301 } else {
302 Ok(ValueWord::none())
303 }
304 }
305
306 WindowFunction::FirstValue(expr) => {
307 let (start, _) = get_frame_bounds(frame, partition.len(), current_idx);
308 eval_expr_at(expr, &partition[start], evaluator, ctx)
309 }
310
311 WindowFunction::LastValue(expr) => {
312 let (_, end) = get_frame_bounds(frame, partition.len(), current_idx);
313 eval_expr_at(expr, &partition[end], evaluator, ctx)
314 }
315
316 WindowFunction::NthValue(expr, n) => {
317 let (start, end) = get_frame_bounds(frame, partition.len(), current_idx);
318 let target_idx = start + n - 1;
319 if target_idx <= end && target_idx < partition.len() {
320 eval_expr_at(expr, &partition[target_idx], evaluator, ctx)
321 } else {
322 Ok(ValueWord::none())
323 }
324 }
325
326 WindowFunction::Sum(expr)
327 | WindowFunction::Avg(expr)
328 | WindowFunction::Min(expr)
329 | WindowFunction::Max(expr) => {
330 let (start, end) = get_frame_bounds(frame, partition.len(), current_idx);
331 let mut values = Vec::new();
332
333 for i in start..=end.min(partition.len().saturating_sub(1)) {
334 let nb = eval_expr_at(expr, &partition[i], evaluator, ctx)?;
335 if let Some(n) = nb.as_f64() {
336 values.push(n);
337 }
338 }
339
340 if values.is_empty() {
341 return Ok(ValueWord::none());
342 }
343
344 let result = match func {
345 WindowFunction::Sum(_) => values.iter().sum::<f64>(),
346 WindowFunction::Avg(_) => values.iter().sum::<f64>() / values.len() as f64,
347 WindowFunction::Min(_) => values
348 .iter()
349 .cloned()
350 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
351 .unwrap_or(f64::NAN),
352 WindowFunction::Max(_) => values
353 .iter()
354 .cloned()
355 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
356 .unwrap_or(f64::NAN),
357 _ => unreachable!(),
358 };
359
360 Ok(ValueWord::from_f64(result))
361 }
362
363 WindowFunction::Count(expr_opt) => {
364 let (start, end) = get_frame_bounds(frame, partition.len(), current_idx);
365
366 let count = if let Some(expr) = expr_opt {
367 (start..=end.min(partition.len().saturating_sub(1)))
368 .filter(|&i| {
369 eval_expr_at(expr, &partition[i], evaluator, ctx)
370 .map(|v| !v.is_none())
371 .unwrap_or(false)
372 })
373 .count()
374 } else {
375 end.min(partition.len().saturating_sub(1))
376 .saturating_sub(start)
377 + 1
378 };
379
380 Ok(ValueWord::from_f64(count as f64))
381 }
382 }
383}
384
385fn eval_expr_at(
387 expr: &Expr,
388 row: &RowData,
389 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
390 ctx: &mut ExecutionContext,
391) -> Result<ValueWord> {
392 ctx.push_scope();
393 for (key, value) in &row.values {
394 let _ = ctx.set_variable_nb(key, value.clone());
395 }
396 let result = if let Some(eval) = evaluator {
397 Ok(eval.eval_expr(expr, ctx)?)
398 } else {
399 if let Expr::Identifier(name, _) = expr {
401 Ok(row.values.get(name).cloned().unwrap_or(ValueWord::none()))
402 } else {
403 Ok(ValueWord::none())
404 }
405 };
406 ctx.pop_scope();
407 result
408}
409
410fn calculate_rank(_partition: &[RowData], current_idx: usize, dense: bool) -> usize {
412 if current_idx == 0 {
413 return 1;
414 }
415 if dense {
418 current_idx + 1
419 } else {
420 current_idx + 1
421 }
422}
423
424fn get_frame_bounds(
426 frame: &Option<WindowFrame>,
427 partition_len: usize,
428 current_idx: usize,
429) -> (usize, usize) {
430 match frame {
431 Some(f) => {
432 let start = match &f.start {
433 WindowBound::UnboundedPreceding => 0,
434 WindowBound::CurrentRow => current_idx,
435 WindowBound::Preceding(n) => current_idx.saturating_sub(*n),
436 WindowBound::Following(n) => (current_idx + n).min(partition_len.saturating_sub(1)),
437 WindowBound::UnboundedFollowing => partition_len.saturating_sub(1),
438 };
439 let end = match &f.end {
440 WindowBound::UnboundedPreceding => 0,
441 WindowBound::CurrentRow => current_idx,
442 WindowBound::Preceding(n) => current_idx.saturating_sub(*n),
443 WindowBound::Following(n) => (current_idx + n).min(partition_len.saturating_sub(1)),
444 WindowBound::UnboundedFollowing => partition_len.saturating_sub(1),
445 };
446 (start, end)
447 }
448 None => (0, current_idx),
449 }
450}
451
452fn extract_sort_value(row: &HashMap<String, ValueWord>, expr: &Expr) -> ValueWord {
454 if let Expr::Identifier(name, _) = expr {
455 return row.get(name).cloned().unwrap_or(ValueWord::none());
456 }
457 ValueWord::none()
458}
459
460fn compare_nb_values(a: &ValueWord, b: &ValueWord) -> std::cmp::Ordering {
462 use shape_value::NanTag;
463 match (a.tag(), b.tag()) {
464 (NanTag::F64, NanTag::F64)
465 | (NanTag::I48, NanTag::I48)
466 | (NanTag::F64, NanTag::I48)
467 | (NanTag::I48, NanTag::F64) => match (a.as_f64(), b.as_f64()) {
468 (Some(an), Some(bn)) => an.partial_cmp(&bn).unwrap_or(std::cmp::Ordering::Equal),
469 _ => std::cmp::Ordering::Equal,
470 },
471 (NanTag::Heap, NanTag::Heap) => match (a.as_str(), b.as_str()) {
472 (Some(sa), Some(sb)) => sa.cmp(sb),
473 _ => std::cmp::Ordering::Equal,
474 },
475 (NanTag::Bool, NanTag::Bool) => match (a.as_bool(), b.as_bool()) {
476 (Some(ba), Some(bb)) => ba.cmp(&bb),
477 _ => std::cmp::Ordering::Equal,
478 },
479 (NanTag::None, NanTag::None) => std::cmp::Ordering::Equal,
480 (NanTag::None, _) => std::cmp::Ordering::Less,
481 (_, NanTag::None) => std::cmp::Ordering::Greater,
482 _ => match (a.as_time(), b.as_time()) {
483 (Some(ta), Some(tb)) => ta.cmp(&tb),
484 _ => std::cmp::Ordering::Equal,
485 },
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 fn make_rows(data: Vec<Vec<(&str, ValueWord)>>) -> Vec<HashMap<String, ValueWord>> {
494 data.into_iter()
495 .map(|row| row.into_iter().map(|(k, v)| (k.to_string(), v)).collect())
496 .collect()
497 }
498
499 #[test]
500 fn test_row_number_simple() {
501 let mut ctx = ExecutionContext::new_empty();
502 let mut executor = WindowExecutor::new();
503
504 let rows = make_rows(vec![
505 vec![("x", ValueWord::from_f64(1.0))],
506 vec![("x", ValueWord::from_f64(2.0))],
507 vec![("x", ValueWord::from_f64(3.0))],
508 ]);
509
510 let window_expr = WindowExpr {
511 function: WindowFunction::RowNumber,
512 over: shape_ast::ast::WindowSpec {
513 partition_by: vec![],
514 order_by: None,
515 frame: None,
516 },
517 };
518
519 let results = executor
520 .execute(&rows, &window_expr, None, &mut ctx)
521 .unwrap();
522
523 assert_eq!(results.len(), 3);
524 assert_eq!(results[0].as_f64(), Some(1.0));
525 assert_eq!(results[1].as_f64(), Some(2.0));
526 assert_eq!(results[2].as_f64(), Some(3.0));
527 }
528}