1#![allow(dead_code)]
3
4use std::fmt::Display;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::aliases::hash_map::HashMap;
9use vortex_array::aliases::hash_set::HashSet;
10use vortex_array::stats::Stat;
11use vortex_array::{Array, ArrayRef};
12use vortex_dtype::{FieldName, Nullability};
13use vortex_error::{VortexExpect, VortexResult};
14use vortex_scalar::Scalar;
15
16use crate::between::Between;
17use crate::{
18 BinaryExpr, ExprRef, GetItem, Identity, Literal, Not, Operator, VortexExprExt, and, eq,
19 get_item, gt, ident, lit, not, or,
20};
21
22#[derive(Debug, Clone)]
23pub struct Relation<K, V> {
24 map: HashMap<K, HashSet<V>>,
25}
26
27impl<K: Display, V: Display> Display for Relation<K, V> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(
30 f,
31 "{}",
32 self.map.iter().format_with(",", |(k, v), fmt| {
33 fmt(&format_args!("{k}: {{{}}}", v.iter().format(",")))
34 })
35 )
36 }
37}
38
39impl<K: Hash + Eq, V: Hash + Eq> Default for Relation<K, V> {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl<K: Hash + Eq, V: Hash + Eq> Relation<K, V> {
46 pub fn new() -> Self {
47 Relation {
48 map: HashMap::new(),
49 }
50 }
51
52 pub fn union(mut iter: impl Iterator<Item = Relation<K, V>>) -> Relation<K, V> {
53 if let Some(mut x) = iter.next() {
54 for y in iter {
55 x.extend(y)
56 }
57 x
58 } else {
59 Relation::new()
60 }
61 }
62
63 pub fn extend(&mut self, other: Relation<K, V>) {
64 for (l, rs) in other.map.into_iter() {
65 self.map.entry(l).or_default().extend(rs.into_iter())
66 }
67 }
68
69 pub fn insert(&mut self, k: K, v: V) {
70 self.map.entry(k).or_default().insert(v);
71 }
72
73 pub fn into_map(self) -> HashMap<K, HashSet<V>> {
74 self.map
75 }
76}
77
78#[derive(Debug, Clone)]
79pub struct PruningPredicate {
80 expr: ExprRef,
81 required_stats: Relation<FieldOrIdentity, Stat>,
82}
83
84impl Display for PruningPredicate {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(
87 f,
88 "PruningPredicate({}, {{{}}})",
89 self.expr, self.required_stats
90 )
91 }
92}
93
94impl PruningPredicate {
95 pub fn try_new(original_expr: &ExprRef) -> Option<Self> {
96 let (expr, required_stats) = convert_to_pruning_expression(original_expr);
97 if let Some(lexp) = expr.as_any().downcast_ref::<Literal>() {
98 if lexp
100 .value()
101 .as_bool_opt()
102 .and_then(|b| b.value())
103 .map(|b| !b)
104 .unwrap_or(false)
105 {
106 None
107 } else {
108 Some(Self {
109 expr,
110 required_stats,
111 })
112 }
113 } else {
114 Some(Self {
115 expr,
116 required_stats,
117 })
118 }
119 }
120
121 pub fn expr(&self) -> &ExprRef {
122 &self.expr
123 }
124
125 pub fn required_stats(&self) -> &HashMap<FieldOrIdentity, HashSet<Stat>> {
126 &self.required_stats.map
127 }
128
129 pub fn evaluate(&self, metadata: &dyn Array) -> VortexResult<Option<ArrayRef>> {
135 let known_stats = HashSet::from_iter(
136 metadata
137 .as_struct_typed()
138 .vortex_expect("metadata must be struct array")
139 .names()
140 .iter()
141 .map(|x| x.to_string()),
142 );
143 let required_stats = self
144 .required_stats()
145 .iter()
146 .flat_map(|(key, value)| value.iter().map(|stat| key.stat_field_name_string(*stat)))
147 .collect::<HashSet<_>>();
148 let missing_stats = required_stats.difference(&known_stats).collect::<Vec<_>>();
149
150 if !missing_stats.is_empty() {
151 return Ok(None);
152 }
153
154 Ok(Some(self.expr.evaluate(metadata)?))
155 }
156}
157
158fn not_prunable() -> PruningPredicateStats {
159 (
160 lit(Scalar::bool(false, Nullability::NonNullable)),
161 Relation::new(),
162 )
163}
164
165fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats {
168 if let Some(nexp) = expr.as_any().downcast_ref::<Not>() {
169 if let Some(get_item) = nexp.child().as_any().downcast_ref::<GetItem>() {
170 if get_item.child().as_any().is::<Identity>() {
171 return convert_access_reference(expr, true);
172 }
173 }
174 }
175
176 if let Some(get_item) = expr.as_any().downcast_ref::<GetItem>() {
177 if get_item.child().as_any().is::<Identity>() {
178 return convert_access_reference(expr, false);
179 }
180 }
181
182 if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
183 if bexp.op() == Operator::Or || bexp.op() == Operator::And {
184 let (rewritten_left, mut refs_lhs) = convert_to_pruning_expression(bexp.lhs());
185 let (rewritten_right, refs_rhs) = convert_to_pruning_expression(bexp.rhs());
186 refs_lhs.extend(refs_rhs);
187 let flipped_op = bexp
188 .op()
189 .logical_inverse()
190 .vortex_expect("Can not be any other operator than and / or");
191 return (
192 BinaryExpr::new_expr(rewritten_left, flipped_op, rewritten_right),
193 refs_lhs,
194 );
195 }
196
197 if let Some(get_item) = bexp.lhs().as_any().downcast_ref::<GetItem>() {
198 if get_item.child().as_any().is::<Identity>() {
199 return PruningPredicateRewriter::rewrite_binary_op(
200 FieldOrIdentity::Field(get_item.field().clone()),
201 bexp.op(),
202 bexp.rhs(),
203 );
204 }
205 };
206
207 if let Some(get_item) = bexp.rhs().as_any().downcast_ref::<GetItem>() {
208 if get_item.child().as_any().is::<Identity>() {
209 return PruningPredicateRewriter::rewrite_binary_op(
210 FieldOrIdentity::Field(get_item.field().clone()),
211 bexp.op().swap(),
212 bexp.lhs(),
213 );
214 }
215 }
216
217 if bexp.lhs().as_any().is::<Identity>() {
218 return PruningPredicateRewriter::rewrite_binary_op(
219 FieldOrIdentity::Identity,
220 bexp.op(),
221 bexp.rhs(),
222 );
223 };
224
225 if bexp.rhs().as_any().is::<Identity>() {
226 return PruningPredicateRewriter::rewrite_binary_op(
227 FieldOrIdentity::Identity,
228 bexp.op().swap(),
229 bexp.lhs(),
230 );
231 };
232 }
233
234 if let Some(between_expr) = expr.as_any().downcast_ref::<Between>() {
235 return convert_to_pruning_expression(&between_expr.to_binary_expr());
236 }
237
238 not_prunable()
239}
240
241fn convert_access_reference(expr: &ExprRef, invert: bool) -> PruningPredicateStats {
242 let mut refs = Relation::new();
243 let Some(min_expr) = replace_get_item_with_stat(expr, Stat::Min, &mut refs) else {
244 return not_prunable();
245 };
246 let Some(max_expr) = replace_get_item_with_stat(expr, Stat::Max, &mut refs) else {
247 return not_prunable();
248 };
249
250 let expr = if invert {
251 and(min_expr, max_expr)
252 } else {
253 not(or(min_expr, max_expr))
254 };
255
256 (expr, refs)
257}
258
259struct PruningPredicateRewriter<'a> {
260 access: FieldOrIdentity,
261 operator: Operator,
262 other_exp: &'a ExprRef,
263 stats_to_fetch: Relation<FieldOrIdentity, Stat>,
264}
265
266type PruningPredicateStats = (ExprRef, Relation<FieldOrIdentity, Stat>);
267
268impl<'a> PruningPredicateRewriter<'a> {
269 pub fn try_new(
270 access: FieldOrIdentity,
271 operator: Operator,
272 other_exp: &'a ExprRef,
273 ) -> Option<Self> {
274 if let FieldOrIdentity::Field(field) = &access {
277 if other_exp.references().contains(field) {
278 return None;
279 }
280 };
281
282 Some(Self {
283 access,
284 operator,
285 other_exp,
286 stats_to_fetch: Relation::new(),
287 })
288 }
289
290 pub fn rewrite_binary_op(
291 access: FieldOrIdentity,
292 operator: Operator,
293 other_exp: &'a ExprRef,
294 ) -> PruningPredicateStats {
295 Self::try_new(access, operator, other_exp)
296 .and_then(Self::rewrite)
297 .unwrap_or_else(not_prunable)
298 }
299
300 fn add_stat_reference(&mut self, stat: Stat) -> FieldName {
301 let new_field = self.access.stat_field_name(stat);
302 self.stats_to_fetch.insert(self.access.clone(), stat);
303 new_field
304 }
305
306 fn rewrite_other_exp(&mut self, stat: Stat) -> ExprRef {
307 replace_get_item_with_stat(self.other_exp, stat, &mut self.stats_to_fetch)
308 .unwrap_or_else(|| self.other_exp.clone())
309 }
310
311 fn rewrite(mut self) -> Option<PruningPredicateStats> {
312 let expr: Option<ExprRef> = match self.operator {
313 Operator::Eq => {
314 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
315 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
316 let replaced_max = self.rewrite_other_exp(Stat::Max);
317 let replaced_min = self.rewrite_other_exp(Stat::Min);
318
319 Some(or(gt(min_col, replaced_max), gt(replaced_min, max_col)))
320 }
321 Operator::NotEq => {
322 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
323 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
324 let replaced_max = self.rewrite_other_exp(Stat::Max);
325 let replaced_min = self.rewrite_other_exp(Stat::Min);
326
327 Some(and(eq(min_col, replaced_max), eq(max_col, replaced_min)))
328 }
329 Operator::Gt | Operator::Gte => {
330 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
331 let replaced_min = self.rewrite_other_exp(Stat::Min);
332
333 Some(BinaryExpr::new_expr(
334 max_col,
335 self.operator
336 .inverse()
337 .vortex_expect("inverse of gt & gt_eq defined"),
338 replaced_min,
339 ))
340 }
341 Operator::Lt | Operator::Lte => {
342 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
343 let replaced_max = self.rewrite_other_exp(Stat::Max);
344
345 Some(BinaryExpr::new_expr(
346 min_col,
347 self.operator
348 .inverse()
349 .vortex_expect("inverse of lt & lte defined"),
350 replaced_max,
351 ))
352 }
353 _ => None,
354 };
355 expr.map(|e| (e, self.stats_to_fetch))
356 }
357}
358
359fn replace_get_item_with_stat(
360 expr: &ExprRef,
361 stat: Stat,
362 stats_to_fetch: &mut Relation<FieldOrIdentity, Stat>,
363) -> Option<ExprRef> {
364 if let Some(get_i) = expr.as_any().downcast_ref::<GetItem>() {
365 if get_i.child().as_any().is::<Identity>() {
366 let new_field = stat_field_name(get_i.field(), stat);
367 stats_to_fetch.insert(FieldOrIdentity::Field(get_i.field().clone()), stat);
368 return Some(get_item(new_field, ident()));
369 }
370 }
371
372 if let Some(not_expr) = expr.as_any().downcast_ref::<Not>() {
373 let rewritten = replace_get_item_with_stat(not_expr.child(), stat, stats_to_fetch)?;
374 return Some(not(rewritten));
375 }
376
377 if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
378 let rewritten_lhs = replace_get_item_with_stat(bexp.lhs(), stat, stats_to_fetch);
379 let rewritten_rhs = replace_get_item_with_stat(bexp.rhs(), stat, stats_to_fetch);
380 if rewritten_lhs.is_none() && rewritten_rhs.is_none() {
381 return None;
382 }
383
384 let lhs = rewritten_lhs.unwrap_or_else(|| bexp.lhs().clone());
385 let rhs = rewritten_rhs.unwrap_or_else(|| bexp.rhs().clone());
386
387 return Some(BinaryExpr::new_expr(lhs, bexp.op(), rhs));
388 }
389
390 None
391}
392
393#[derive(Debug, Clone, Hash, PartialEq, Eq)]
394pub enum FieldOrIdentity {
395 Field(FieldName),
396 Identity,
397}
398
399pub(crate) fn stat_field_name(field: &FieldName, stat: Stat) -> FieldName {
400 FieldName::from(stat_field_name_string(field, stat))
401}
402
403pub(crate) fn stat_field_name_string(field: &FieldName, stat: Stat) -> String {
404 format!("{field}_{stat}")
405}
406
407impl FieldOrIdentity {
408 pub(crate) fn stat_field_name(&self, stat: Stat) -> FieldName {
409 FieldName::from(self.stat_field_name_string(stat))
410 }
411
412 pub(crate) fn stat_field_name_string(&self, stat: Stat) -> String {
413 match self {
414 FieldOrIdentity::Field(field) => stat_field_name_string(field, stat),
415 FieldOrIdentity::Identity => stat.to_string(),
416 }
417 }
418}
419
420impl Display for FieldOrIdentity {
421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422 match self {
423 FieldOrIdentity::Field(field) => write!(f, "{}", field),
424 FieldOrIdentity::Identity => write!(f, "$[]"),
425 }
426 }
427}
428
429impl<T> From<T> for FieldOrIdentity
430where
431 FieldName: From<T>,
432{
433 fn from(value: T) -> Self {
434 FieldOrIdentity::Field(FieldName::from(value))
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use vortex_array::aliases::hash_map::HashMap;
441 use vortex_array::aliases::hash_set::HashSet;
442 use vortex_array::stats::Stat;
443 use vortex_dtype::FieldName;
444
445 use crate::pruning::{
446 FieldOrIdentity, PruningPredicate, convert_to_pruning_expression, stat_field_name,
447 };
448 use crate::{
449 and, eq, get_item, get_item_scope, gt, gt_eq, ident, lit, lt, lt_eq, not, not_eq, or,
450 };
451
452 #[test]
453 pub fn pruning_equals() {
454 let name = FieldName::from("a");
455 let literal_eq = lit(42);
456 let eq_expr = eq(get_item("a", ident()), literal_eq.clone());
457 let (converted, refs) = convert_to_pruning_expression(&eq_expr);
458 assert_eq!(
459 refs.into_map(),
460 HashMap::from_iter([(
461 FieldOrIdentity::Field(name.clone()),
462 HashSet::from_iter([Stat::Min, Stat::Max])
463 )])
464 );
465 let expected_expr = or(
466 gt(
467 get_item(stat_field_name(&name, Stat::Min), ident()),
468 literal_eq.clone(),
469 ),
470 gt(
471 literal_eq,
472 get_item_scope(stat_field_name(&name, Stat::Max)),
473 ),
474 );
475 assert_eq!(&converted, &expected_expr);
476 }
477
478 #[test]
479 pub fn pruning_equals_column() {
480 let column = FieldName::from("a");
481 let other_col = FieldName::from("b");
482 let eq_expr = eq(
483 get_item_scope(column.clone()),
484 get_item_scope(other_col.clone()),
485 );
486
487 let (converted, refs) = convert_to_pruning_expression(&eq_expr);
488 assert_eq!(
489 refs.into_map(),
490 HashMap::from_iter([
491 (
492 FieldOrIdentity::Field(column.clone()),
493 HashSet::from_iter([Stat::Min, Stat::Max])
494 ),
495 (
496 FieldOrIdentity::Field(other_col.clone()),
497 HashSet::from_iter([Stat::Max, Stat::Min])
498 )
499 ])
500 );
501 let expected_expr = or(
502 gt(
503 get_item_scope(stat_field_name(&column, Stat::Min)),
504 get_item_scope(stat_field_name(&other_col, Stat::Max)),
505 ),
506 gt(
507 get_item_scope(stat_field_name(&other_col, Stat::Min)),
508 get_item_scope(stat_field_name(&column, Stat::Max)),
509 ),
510 );
511 assert_eq!(&converted, &expected_expr);
512 }
513
514 #[test]
515 pub fn pruning_not_equals_column() {
516 let column = FieldName::from("a");
517 let other_col = FieldName::from("b");
518 let not_eq_expr = not_eq(
519 get_item_scope(column.clone()),
520 get_item_scope(other_col.clone()),
521 );
522
523 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
524 assert_eq!(
525 refs.into_map(),
526 HashMap::from_iter([
527 (
528 FieldOrIdentity::Field(column.clone()),
529 HashSet::from_iter([Stat::Min, Stat::Max])
530 ),
531 (
532 FieldOrIdentity::Field(other_col.clone()),
533 HashSet::from_iter([Stat::Max, Stat::Min])
534 )
535 ])
536 );
537 let expected_expr = and(
538 eq(
539 get_item_scope(stat_field_name(&column, Stat::Min)),
540 get_item_scope(stat_field_name(&other_col, Stat::Max)),
541 ),
542 eq(
543 get_item_scope(stat_field_name(&column, Stat::Max)),
544 get_item_scope(stat_field_name(&other_col, Stat::Min)),
545 ),
546 );
547
548 assert_eq!(&converted, &expected_expr);
549 }
550
551 #[test]
552 pub fn pruning_gt_column() {
553 let column = FieldName::from("a");
554 let other_col = FieldName::from("b");
555 let other_expr = get_item_scope(other_col.clone());
556 let not_eq_expr = gt(get_item_scope(column.clone()), other_expr.clone());
557
558 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
559 assert_eq!(
560 refs.into_map(),
561 HashMap::from_iter([
562 (
563 FieldOrIdentity::Field(column.clone()),
564 HashSet::from_iter([Stat::Max])
565 ),
566 (
567 FieldOrIdentity::Field(other_col.clone()),
568 HashSet::from_iter([Stat::Min])
569 )
570 ])
571 );
572 let expected_expr = lt_eq(
573 get_item_scope(stat_field_name(&column, Stat::Max)),
574 get_item_scope(stat_field_name(&other_col, Stat::Min)),
575 );
576 assert_eq!(&converted, &expected_expr);
577 }
578
579 #[test]
580 pub fn pruning_gt_value() {
581 let column = FieldName::from("a");
582 let other_col = lit(42);
583 let not_eq_expr = gt(get_item_scope(column.clone()), other_col.clone());
584
585 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
586 assert_eq!(
587 refs.into_map(),
588 HashMap::from_iter([(
589 FieldOrIdentity::Field(column.clone()),
590 HashSet::from_iter([Stat::Max])
591 ),])
592 );
593 let expected_expr = lt_eq(
594 get_item_scope(stat_field_name(&column, Stat::Max)),
595 other_col.clone(),
596 );
597 assert_eq!(&converted, &expected_expr);
598 }
599
600 #[test]
601 pub fn pruning_lt_column() {
602 let column = FieldName::from("a");
603 let other_col = FieldName::from("b");
604 let other_expr = get_item_scope(other_col.clone());
605 let not_eq_expr = lt(get_item_scope(column.clone()), other_expr.clone());
606
607 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
608 assert_eq!(
609 refs.into_map(),
610 HashMap::from_iter([
611 (
612 FieldOrIdentity::Field(column.clone()),
613 HashSet::from_iter([Stat::Min])
614 ),
615 (
616 FieldOrIdentity::Field(other_col.clone()),
617 HashSet::from_iter([Stat::Max])
618 )
619 ])
620 );
621 let expected_expr = gt_eq(
622 get_item_scope(stat_field_name(&column, Stat::Min)),
623 get_item_scope(stat_field_name(&other_col, Stat::Max)),
624 );
625 assert_eq!(&converted, &expected_expr);
626 }
627
628 #[test]
629 pub fn pruning_lt_value() {
630 let column = FieldName::from("a");
631 let other_col = lit(42);
632 let not_eq_expr = lt(get_item_scope(column.clone()), other_col.clone());
633
634 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
635 assert_eq!(
636 refs.into_map(),
637 HashMap::from_iter([(
638 FieldOrIdentity::Field(column.clone()),
639 HashSet::from_iter([Stat::Min])
640 )])
641 );
642 let expected_expr = gt_eq(
643 get_item_scope(stat_field_name(&column, Stat::Min)),
644 other_col.clone(),
645 );
646 assert_eq!(&converted, &expected_expr);
647 }
648
649 #[test]
650 fn unprojectable_expr() {
651 let or_expr = not(lt(get_item_scope("a"), get_item_scope("b")));
652 assert!(PruningPredicate::try_new(&or_expr).is_none());
653 }
654
655 #[test]
656 fn display_pruning_predicate() {
657 let column = FieldName::from("a");
658 let other_col = lit(42);
659 let not_eq_expr = lt(get_item_scope(column), other_col);
660
661 assert_eq!(
662 PruningPredicate::try_new(¬_eq_expr).unwrap().to_string(),
663 "PruningPredicate(($.a_min >= 42_i32), {a: {min}})"
664 );
665 }
666
667 #[test]
668 fn or_required_stats_from_both_arms() {
669 let item = get_item_scope(FieldName::from("a"));
670 let expr = or(lt(item.clone(), lit(10)), gt(item, lit(50)));
671
672 let expected = HashMap::from([(
673 FieldOrIdentity::from("a"),
674 HashSet::from([Stat::Min, Stat::Max]),
675 )]);
676
677 assert_eq!(
678 PruningPredicate::try_new(&expr).unwrap().required_stats(),
679 &expected
680 );
681 }
682
683 #[test]
684 fn and_required_stats_from_both_arms() {
685 let item = get_item_scope(FieldName::from("a"));
686 let expr = and(gt(item.clone(), lit(50)), lt(item, lit(10)));
687
688 let expected = HashMap::from([(
689 FieldOrIdentity::from("a"),
690 HashSet::from([Stat::Min, Stat::Max]),
691 )]);
692
693 assert_eq!(
694 PruningPredicate::try_new(&expr).unwrap().required_stats(),
695 &expected
696 );
697 }
698
699 #[test]
700 fn pruning_identity() {
701 let expr = ident();
702 let expr = or(lt(expr.clone(), lit(10)), gt(expr.clone(), lit(50)));
703
704 let expected = HashMap::from([(
705 FieldOrIdentity::Identity,
706 HashSet::from([Stat::Min, Stat::Max]),
707 )]);
708
709 let predicate = PruningPredicate::try_new(&expr).unwrap();
710 assert_eq!(predicate.required_stats(), &expected);
711
712 let expected_expr = and(
713 gt_eq(get_item_scope(FieldName::from("min")), lit(10)),
714 lt_eq(get_item_scope(FieldName::from("max")), lit(50)),
715 );
716 assert_eq!(predicate.expr(), &expected_expr)
717 }
718 #[test]
719 pub fn pruning_and_or_operators() {
720 let column = FieldName::from("a");
722 let and_expr = and(
723 gt(get_item_scope(column.clone()), lit(10)),
724 lt(get_item_scope(column), lit(50)),
725 );
726 let pruned = PruningPredicate::try_new(&and_expr).unwrap();
727
728 assert_eq!(
730 pruned.expr(),
731 &or(
732 lt_eq(get_item_scope(FieldName::from("a_max")), lit(10)),
733 gt_eq(get_item_scope(FieldName::from("a_min")), lit(50))
734 ),
735 );
736 }
737}