1#![allow(dead_code)]
3
4use std::fmt::Display;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::aliases::hash_map::HashMap;
9use vortex_array::aliases::hash_set::HashSet;
10use vortex_array::stats::Stat;
11use vortex_array::{Array, ArrayRef};
12use vortex_dtype::{FieldName, Nullability};
13use vortex_error::{VortexExpect as _, VortexResult};
14use vortex_scalar::Scalar;
15
16use crate::{
17 BinaryExpr, ExprRef, GetItem, Identity, Literal, Not, Operator, VortexExprExt, and, eq,
18 get_item, gt, ident, lit, not, or,
19};
20
21#[derive(Debug, Clone)]
22pub struct Relation<K, V> {
23 map: HashMap<K, HashSet<V>>,
24}
25
26impl<K: Display, V: Display> Display for Relation<K, V> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(
29 f,
30 "{}",
31 self.map.iter().format_with(",", |(k, v), fmt| {
32 fmt(&format_args!("{k}: {{{}}}", v.iter().format(",")))
33 })
34 )
35 }
36}
37
38impl<K: Hash + Eq, V: Hash + Eq> Default for Relation<K, V> {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl<K: Hash + Eq, V: Hash + Eq> Relation<K, V> {
45 pub fn new() -> Self {
46 Relation {
47 map: HashMap::new(),
48 }
49 }
50
51 pub fn union(mut iter: impl Iterator<Item = Relation<K, V>>) -> Relation<K, V> {
52 if let Some(mut x) = iter.next() {
53 for y in iter {
54 x.extend(y)
55 }
56 x
57 } else {
58 Relation::new()
59 }
60 }
61
62 pub fn extend(&mut self, other: Relation<K, V>) {
63 for (l, rs) in other.map.into_iter() {
64 self.map.entry(l).or_default().extend(rs.into_iter())
65 }
66 }
67
68 pub fn insert(&mut self, k: K, v: V) {
69 self.map.entry(k).or_default().insert(v);
70 }
71
72 pub fn into_map(self) -> HashMap<K, HashSet<V>> {
73 self.map
74 }
75}
76
77#[derive(Debug, Clone)]
78pub struct PruningPredicate {
79 expr: ExprRef,
80 required_stats: Relation<FieldOrIdentity, Stat>,
81}
82
83impl Display for PruningPredicate {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(
86 f,
87 "PruningPredicate({}, {{{}}})",
88 self.expr, self.required_stats
89 )
90 }
91}
92
93impl PruningPredicate {
94 pub fn try_new(original_expr: &ExprRef) -> Option<Self> {
95 let (expr, required_stats) = convert_to_pruning_expression(original_expr);
96 if let Some(lexp) = expr.as_any().downcast_ref::<Literal>() {
97 if lexp
99 .value()
100 .as_bool_opt()
101 .and_then(|b| b.value())
102 .map(|b| !b)
103 .unwrap_or(false)
104 {
105 None
106 } else {
107 Some(Self {
108 expr,
109 required_stats,
110 })
111 }
112 } else {
113 Some(Self {
114 expr,
115 required_stats,
116 })
117 }
118 }
119
120 pub fn expr(&self) -> &ExprRef {
121 &self.expr
122 }
123
124 pub fn required_stats(&self) -> &HashMap<FieldOrIdentity, HashSet<Stat>> {
125 &self.required_stats.map
126 }
127
128 pub fn evaluate(&self, metadata: &dyn Array) -> VortexResult<Option<ArrayRef>> {
134 let known_stats = HashSet::from_iter(
135 metadata
136 .as_struct_typed()
137 .vortex_expect("metadata must be struct array")
138 .names()
139 .iter()
140 .map(|x| x.to_string()),
141 );
142 let required_stats = self
143 .required_stats()
144 .iter()
145 .flat_map(|(key, value)| value.iter().map(|stat| key.stat_field_name_string(*stat)))
146 .collect::<HashSet<_>>();
147 let missing_stats = required_stats.difference(&known_stats).collect::<Vec<_>>();
148
149 if !missing_stats.is_empty() {
150 return Ok(None);
151 }
152
153 Ok(Some(self.expr.evaluate(metadata)?))
154 }
155}
156
157fn not_prunable() -> PruningPredicateStats {
158 (
159 lit(Scalar::bool(false, Nullability::NonNullable)),
160 Relation::new(),
161 )
162}
163
164fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats {
167 if let Some(nexp) = expr.as_any().downcast_ref::<Not>() {
168 if let Some(get_item) = nexp.child().as_any().downcast_ref::<GetItem>() {
169 if get_item.child().as_any().is::<Identity>() {
170 return convert_access_reference(expr, true);
171 }
172 }
173 }
174
175 if let Some(get_item) = expr.as_any().downcast_ref::<GetItem>() {
176 if get_item.child().as_any().is::<Identity>() {
177 return convert_access_reference(expr, false);
178 }
179 }
180
181 if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
182 if bexp.op() == Operator::Or || bexp.op() == Operator::And {
183 let (rewritten_left, mut refs_lhs) = convert_to_pruning_expression(bexp.lhs());
184 let (rewritten_right, refs_rhs) = convert_to_pruning_expression(bexp.rhs());
185 refs_lhs.extend(refs_rhs);
186 return (
187 BinaryExpr::new_expr(rewritten_left, bexp.op(), rewritten_right),
188 refs_lhs,
189 );
190 }
191
192 if let Some(get_item) = bexp.lhs().as_any().downcast_ref::<GetItem>() {
193 if get_item.child().as_any().is::<Identity>() {
194 return PruningPredicateRewriter::rewrite_binary_op(
195 FieldOrIdentity::Field(get_item.field().clone()),
196 bexp.op(),
197 bexp.rhs(),
198 );
199 }
200 };
201
202 if let Some(get_item) = bexp.rhs().as_any().downcast_ref::<GetItem>() {
203 if get_item.child().as_any().is::<Identity>() {
204 return PruningPredicateRewriter::rewrite_binary_op(
205 FieldOrIdentity::Field(get_item.field().clone()),
206 bexp.op().swap(),
207 bexp.lhs(),
208 );
209 }
210 }
211
212 if bexp.lhs().as_any().is::<Identity>() {
213 return PruningPredicateRewriter::rewrite_binary_op(
214 FieldOrIdentity::Identity,
215 bexp.op(),
216 bexp.rhs(),
217 );
218 };
219
220 if bexp.rhs().as_any().is::<Identity>() {
221 return PruningPredicateRewriter::rewrite_binary_op(
222 FieldOrIdentity::Identity,
223 bexp.op().swap(),
224 bexp.lhs(),
225 );
226 };
227 }
228
229 not_prunable()
230}
231
232fn convert_access_reference(expr: &ExprRef, invert: bool) -> PruningPredicateStats {
233 let mut refs = Relation::new();
234 let Some(min_expr) = replace_get_item_with_stat(expr, Stat::Min, &mut refs) else {
235 return not_prunable();
236 };
237 let Some(max_expr) = replace_get_item_with_stat(expr, Stat::Max, &mut refs) else {
238 return not_prunable();
239 };
240
241 let expr = if invert {
242 and(min_expr, max_expr)
243 } else {
244 not(or(min_expr, max_expr))
245 };
246
247 (expr, refs)
248}
249
250struct PruningPredicateRewriter<'a> {
251 access: FieldOrIdentity,
252 operator: Operator,
253 other_exp: &'a ExprRef,
254 stats_to_fetch: Relation<FieldOrIdentity, Stat>,
255}
256
257type PruningPredicateStats = (ExprRef, Relation<FieldOrIdentity, Stat>);
258
259impl<'a> PruningPredicateRewriter<'a> {
260 pub fn try_new(
261 access: FieldOrIdentity,
262 operator: Operator,
263 other_exp: &'a ExprRef,
264 ) -> Option<Self> {
265 if let FieldOrIdentity::Field(field) = &access {
268 if other_exp.references().contains(field) {
269 return None;
270 }
271 };
272
273 Some(Self {
274 access,
275 operator,
276 other_exp,
277 stats_to_fetch: Relation::new(),
278 })
279 }
280
281 pub fn rewrite_binary_op(
282 access: FieldOrIdentity,
283 operator: Operator,
284 other_exp: &'a ExprRef,
285 ) -> PruningPredicateStats {
286 Self::try_new(access, operator, other_exp)
287 .and_then(Self::rewrite)
288 .unwrap_or_else(not_prunable)
289 }
290
291 fn add_stat_reference(&mut self, stat: Stat) -> FieldName {
292 let new_field = self.access.stat_field_name(stat);
293 self.stats_to_fetch.insert(self.access.clone(), stat);
294 new_field
295 }
296
297 fn rewrite_other_exp(&mut self, stat: Stat) -> ExprRef {
298 replace_get_item_with_stat(self.other_exp, stat, &mut self.stats_to_fetch)
299 .unwrap_or_else(|| self.other_exp.clone())
300 }
301
302 fn rewrite(mut self) -> Option<PruningPredicateStats> {
303 let expr: Option<ExprRef> = match self.operator {
304 Operator::Eq => {
305 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
306 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
307 let replaced_max = self.rewrite_other_exp(Stat::Max);
308 let replaced_min = self.rewrite_other_exp(Stat::Min);
309
310 Some(or(gt(min_col, replaced_max), gt(replaced_min, max_col)))
311 }
312 Operator::NotEq => {
313 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
314 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
315 let replaced_max = self.rewrite_other_exp(Stat::Max);
316 let replaced_min = self.rewrite_other_exp(Stat::Min);
317
318 Some(and(eq(min_col, replaced_max), eq(max_col, replaced_min)))
319 }
320 Operator::Gt | Operator::Gte => {
321 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
322 let replaced_min = self.rewrite_other_exp(Stat::Min);
323
324 Some(BinaryExpr::new_expr(
325 max_col,
326 self.operator
327 .inverse()
328 .vortex_expect("inverse of gt & gt_eq defined"),
329 replaced_min,
330 ))
331 }
332 Operator::Lt | Operator::Lte => {
333 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
334 let replaced_max = self.rewrite_other_exp(Stat::Max);
335
336 Some(BinaryExpr::new_expr(
337 min_col,
338 self.operator
339 .inverse()
340 .vortex_expect("inverse of lt & lte defined"),
341 replaced_max,
342 ))
343 }
344 _ => None,
345 };
346 expr.map(|e| (e, self.stats_to_fetch))
347 }
348}
349
350fn replace_get_item_with_stat(
351 expr: &ExprRef,
352 stat: Stat,
353 stats_to_fetch: &mut Relation<FieldOrIdentity, Stat>,
354) -> Option<ExprRef> {
355 if let Some(get_i) = expr.as_any().downcast_ref::<GetItem>() {
356 if get_i.child().as_any().is::<Identity>() {
357 let new_field = stat_field_name(get_i.field(), stat);
358 stats_to_fetch.insert(FieldOrIdentity::Field(get_i.field().clone()), stat);
359 return Some(get_item(new_field, ident()));
360 }
361 }
362
363 if let Some(not_expr) = expr.as_any().downcast_ref::<Not>() {
364 let rewritten = replace_get_item_with_stat(not_expr.child(), stat, stats_to_fetch)?;
365 return Some(not(rewritten));
366 }
367
368 if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
369 let rewritten_lhs = replace_get_item_with_stat(bexp.lhs(), stat, stats_to_fetch);
370 let rewritten_rhs = replace_get_item_with_stat(bexp.rhs(), stat, stats_to_fetch);
371 if rewritten_lhs.is_none() && rewritten_rhs.is_none() {
372 return None;
373 }
374
375 let lhs = rewritten_lhs.unwrap_or_else(|| bexp.lhs().clone());
376 let rhs = rewritten_rhs.unwrap_or_else(|| bexp.rhs().clone());
377
378 return Some(BinaryExpr::new_expr(lhs, bexp.op(), rhs));
379 }
380
381 None
382}
383
384#[derive(Debug, Clone, Hash, PartialEq, Eq)]
385pub enum FieldOrIdentity {
386 Field(FieldName),
387 Identity,
388}
389
390pub(crate) fn stat_field_name(field: &FieldName, stat: Stat) -> FieldName {
391 FieldName::from(stat_field_name_string(field, stat))
392}
393
394pub(crate) fn stat_field_name_string(field: &FieldName, stat: Stat) -> String {
395 format!("{field}_{stat}")
396}
397
398impl FieldOrIdentity {
399 pub(crate) fn stat_field_name(&self, stat: Stat) -> FieldName {
400 FieldName::from(self.stat_field_name_string(stat))
401 }
402
403 pub(crate) fn stat_field_name_string(&self, stat: Stat) -> String {
404 match self {
405 FieldOrIdentity::Field(field) => stat_field_name_string(field, stat),
406 FieldOrIdentity::Identity => stat.to_string(),
407 }
408 }
409}
410
411impl Display for FieldOrIdentity {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 match self {
414 FieldOrIdentity::Field(field) => write!(f, "{}", field),
415 FieldOrIdentity::Identity => write!(f, "$[]"),
416 }
417 }
418}
419
420impl<T> From<T> for FieldOrIdentity
421where
422 FieldName: From<T>,
423{
424 fn from(value: T) -> Self {
425 FieldOrIdentity::Field(FieldName::from(value))
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use vortex_array::aliases::hash_map::HashMap;
432 use vortex_array::aliases::hash_set::HashSet;
433 use vortex_array::stats::Stat;
434 use vortex_dtype::FieldName;
435
436 use crate::pruning::{
437 FieldOrIdentity, PruningPredicate, convert_to_pruning_expression, stat_field_name,
438 };
439 use crate::{
440 and, eq, get_item, get_item_scope, gt, gt_eq, ident, lit, lt, lt_eq, not, not_eq, or,
441 };
442
443 #[test]
444 pub fn pruning_equals() {
445 let name = FieldName::from("a");
446 let literal_eq = lit(42);
447 let eq_expr = eq(get_item("a", ident()), literal_eq.clone());
448 let (converted, refs) = convert_to_pruning_expression(&eq_expr);
449 assert_eq!(
450 refs.into_map(),
451 HashMap::from_iter([(
452 FieldOrIdentity::Field(name.clone()),
453 HashSet::from_iter([Stat::Min, Stat::Max])
454 )])
455 );
456 let expected_expr = or(
457 gt(
458 get_item(stat_field_name(&name, Stat::Min), ident()),
459 literal_eq.clone(),
460 ),
461 gt(
462 literal_eq,
463 get_item_scope(stat_field_name(&name, Stat::Max)),
464 ),
465 );
466 assert_eq!(&converted, &expected_expr);
467 }
468
469 #[test]
470 pub fn pruning_equals_column() {
471 let column = FieldName::from("a");
472 let other_col = FieldName::from("b");
473 let eq_expr = eq(
474 get_item_scope(column.clone()),
475 get_item_scope(other_col.clone()),
476 );
477
478 let (converted, refs) = convert_to_pruning_expression(&eq_expr);
479 assert_eq!(
480 refs.into_map(),
481 HashMap::from_iter([
482 (
483 FieldOrIdentity::Field(column.clone()),
484 HashSet::from_iter([Stat::Min, Stat::Max])
485 ),
486 (
487 FieldOrIdentity::Field(other_col.clone()),
488 HashSet::from_iter([Stat::Max, Stat::Min])
489 )
490 ])
491 );
492 let expected_expr = or(
493 gt(
494 get_item_scope(stat_field_name(&column, Stat::Min)),
495 get_item_scope(stat_field_name(&other_col, Stat::Max)),
496 ),
497 gt(
498 get_item_scope(stat_field_name(&other_col, Stat::Min)),
499 get_item_scope(stat_field_name(&column, Stat::Max)),
500 ),
501 );
502 assert_eq!(&converted, &expected_expr);
503 }
504
505 #[test]
506 pub fn pruning_not_equals_column() {
507 let column = FieldName::from("a");
508 let other_col = FieldName::from("b");
509 let not_eq_expr = not_eq(
510 get_item_scope(column.clone()),
511 get_item_scope(other_col.clone()),
512 );
513
514 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
515 assert_eq!(
516 refs.into_map(),
517 HashMap::from_iter([
518 (
519 FieldOrIdentity::Field(column.clone()),
520 HashSet::from_iter([Stat::Min, Stat::Max])
521 ),
522 (
523 FieldOrIdentity::Field(other_col.clone()),
524 HashSet::from_iter([Stat::Max, Stat::Min])
525 )
526 ])
527 );
528 let expected_expr = and(
529 eq(
530 get_item_scope(stat_field_name(&column, Stat::Min)),
531 get_item_scope(stat_field_name(&other_col, Stat::Max)),
532 ),
533 eq(
534 get_item_scope(stat_field_name(&column, Stat::Max)),
535 get_item_scope(stat_field_name(&other_col, Stat::Min)),
536 ),
537 );
538
539 assert_eq!(&converted, &expected_expr);
540 }
541
542 #[test]
543 pub fn pruning_gt_column() {
544 let column = FieldName::from("a");
545 let other_col = FieldName::from("b");
546 let other_expr = get_item_scope(other_col.clone());
547 let not_eq_expr = gt(get_item_scope(column.clone()), other_expr.clone());
548
549 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
550 assert_eq!(
551 refs.into_map(),
552 HashMap::from_iter([
553 (
554 FieldOrIdentity::Field(column.clone()),
555 HashSet::from_iter([Stat::Max])
556 ),
557 (
558 FieldOrIdentity::Field(other_col.clone()),
559 HashSet::from_iter([Stat::Min])
560 )
561 ])
562 );
563 let expected_expr = lt_eq(
564 get_item_scope(stat_field_name(&column, Stat::Max)),
565 get_item_scope(stat_field_name(&other_col, Stat::Min)),
566 );
567 assert_eq!(&converted, &expected_expr);
568 }
569
570 #[test]
571 pub fn pruning_gt_value() {
572 let column = FieldName::from("a");
573 let other_col = lit(42);
574 let not_eq_expr = gt(get_item_scope(column.clone()), other_col.clone());
575
576 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
577 assert_eq!(
578 refs.into_map(),
579 HashMap::from_iter([(
580 FieldOrIdentity::Field(column.clone()),
581 HashSet::from_iter([Stat::Max])
582 ),])
583 );
584 let expected_expr = lt_eq(
585 get_item_scope(stat_field_name(&column, Stat::Max)),
586 other_col.clone(),
587 );
588 assert_eq!(&converted, &expected_expr);
589 }
590
591 #[test]
592 pub fn pruning_lt_column() {
593 let column = FieldName::from("a");
594 let other_col = FieldName::from("b");
595 let other_expr = get_item_scope(other_col.clone());
596 let not_eq_expr = lt(get_item_scope(column.clone()), other_expr.clone());
597
598 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
599 assert_eq!(
600 refs.into_map(),
601 HashMap::from_iter([
602 (
603 FieldOrIdentity::Field(column.clone()),
604 HashSet::from_iter([Stat::Min])
605 ),
606 (
607 FieldOrIdentity::Field(other_col.clone()),
608 HashSet::from_iter([Stat::Max])
609 )
610 ])
611 );
612 let expected_expr = gt_eq(
613 get_item_scope(stat_field_name(&column, Stat::Min)),
614 get_item_scope(stat_field_name(&other_col, Stat::Max)),
615 );
616 assert_eq!(&converted, &expected_expr);
617 }
618
619 #[test]
620 pub fn pruning_lt_value() {
621 let column = FieldName::from("a");
622 let other_col = lit(42);
623 let not_eq_expr = lt(get_item_scope(column.clone()), other_col.clone());
624
625 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
626 assert_eq!(
627 refs.into_map(),
628 HashMap::from_iter([(
629 FieldOrIdentity::Field(column.clone()),
630 HashSet::from_iter([Stat::Min])
631 )])
632 );
633 let expected_expr = gt_eq(
634 get_item_scope(stat_field_name(&column, Stat::Min)),
635 other_col.clone(),
636 );
637 assert_eq!(&converted, &expected_expr);
638 }
639
640 #[test]
641 fn unprojectable_expr() {
642 let or_expr = not(lt(get_item_scope("a"), get_item_scope("b")));
643 assert!(PruningPredicate::try_new(&or_expr).is_none());
644 }
645
646 #[test]
647 fn display_pruning_predicate() {
648 let column = FieldName::from("a");
649 let other_col = lit(42);
650 let not_eq_expr = lt(get_item_scope(column), other_col);
651
652 assert_eq!(
653 PruningPredicate::try_new(¬_eq_expr).unwrap().to_string(),
654 "PruningPredicate(($.a_min >= 42_i32), {a: {min}})"
655 );
656 }
657
658 #[test]
659 fn or_required_stats_from_both_arms() {
660 let item = get_item_scope(FieldName::from("a"));
661 let expr = or(lt(item.clone(), lit(10)), gt(item, lit(50)));
662
663 let expected = HashMap::from([(
664 FieldOrIdentity::from("a"),
665 HashSet::from([Stat::Min, Stat::Max]),
666 )]);
667
668 assert_eq!(
669 PruningPredicate::try_new(&expr).unwrap().required_stats(),
670 &expected
671 );
672 }
673
674 #[test]
675 fn and_required_stats_from_both_arms() {
676 let item = get_item_scope(FieldName::from("a"));
677 let expr = and(gt(item.clone(), lit(50)), lt(item, lit(10)));
678
679 let expected = HashMap::from([(
680 FieldOrIdentity::from("a"),
681 HashSet::from([Stat::Min, Stat::Max]),
682 )]);
683
684 assert_eq!(
685 PruningPredicate::try_new(&expr).unwrap().required_stats(),
686 &expected
687 );
688 }
689
690 #[test]
691 fn pruning_identity() {
692 let expr = ident();
693 let expr = or(lt(expr.clone(), lit(10)), gt(expr.clone(), lit(50)));
694
695 let expected = HashMap::from([(
696 FieldOrIdentity::Identity,
697 HashSet::from([Stat::Min, Stat::Max]),
698 )]);
699
700 let predicate = PruningPredicate::try_new(&expr).unwrap();
701 assert_eq!(predicate.required_stats(), &expected);
702
703 let expected_expr = or(
704 gt_eq(get_item_scope(FieldName::from("min")), lit(10)),
705 lt_eq(get_item_scope(FieldName::from("max")), lit(50)),
706 );
707 assert_eq!(predicate.expr(), &expected_expr)
708 }
709}