1#![allow(dead_code)]
3
4use std::fmt::Display;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::aliases::hash_map::HashMap;
9use vortex_array::aliases::hash_set::HashSet;
10use vortex_array::stats::Stat;
11use vortex_array::{Array, ArrayRef};
12use vortex_dtype::{FieldName, Nullability};
13use vortex_error::{VortexExpect, VortexResult};
14use vortex_scalar::Scalar;
15
16use crate::between::Between;
17use crate::{
18 BinaryExpr, ExprRef, GetItem, Identity, Literal, Not, Operator, VortexExprExt, and, eq,
19 get_item, gt, ident, lit, not, or,
20};
21
22#[derive(Debug, Clone)]
23pub struct Relation<K, V> {
24 map: HashMap<K, HashSet<V>>,
25}
26
27impl<K: Display, V: Display> Display for Relation<K, V> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(
30 f,
31 "{}",
32 self.map.iter().format_with(",", |(k, v), fmt| {
33 fmt(&format_args!("{k}: {{{}}}", v.iter().format(",")))
34 })
35 )
36 }
37}
38
39impl<K: Hash + Eq, V: Hash + Eq> Default for Relation<K, V> {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl<K: Hash + Eq, V: Hash + Eq> Relation<K, V> {
46 pub fn new() -> Self {
47 Relation {
48 map: HashMap::new(),
49 }
50 }
51
52 pub fn union(mut iter: impl Iterator<Item = Relation<K, V>>) -> Relation<K, V> {
53 if let Some(mut x) = iter.next() {
54 for y in iter {
55 x.extend(y)
56 }
57 x
58 } else {
59 Relation::new()
60 }
61 }
62
63 pub fn extend(&mut self, other: Relation<K, V>) {
64 for (l, rs) in other.map.into_iter() {
65 self.map.entry(l).or_default().extend(rs.into_iter())
66 }
67 }
68
69 pub fn insert(&mut self, k: K, v: V) {
70 self.map.entry(k).or_default().insert(v);
71 }
72
73 pub fn into_map(self) -> HashMap<K, HashSet<V>> {
74 self.map
75 }
76}
77
78#[derive(Debug, Clone)]
79pub struct PruningPredicate {
80 expr: ExprRef,
81 required_stats: Relation<FieldOrIdentity, Stat>,
82}
83
84impl Display for PruningPredicate {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(
87 f,
88 "PruningPredicate({}, {{{}}})",
89 self.expr, self.required_stats
90 )
91 }
92}
93
94impl PruningPredicate {
95 pub fn try_new(original_expr: &ExprRef) -> Option<Self> {
96 let (expr, required_stats) = convert_to_pruning_expression(original_expr);
97 if let Some(lexp) = expr.as_any().downcast_ref::<Literal>() {
98 if lexp
100 .value()
101 .as_bool_opt()
102 .and_then(|b| b.value())
103 .map(|b| !b)
104 .unwrap_or(false)
105 {
106 None
107 } else {
108 Some(Self {
109 expr,
110 required_stats,
111 })
112 }
113 } else {
114 Some(Self {
115 expr,
116 required_stats,
117 })
118 }
119 }
120
121 pub fn expr(&self) -> &ExprRef {
122 &self.expr
123 }
124
125 pub fn required_stats(&self) -> &HashMap<FieldOrIdentity, HashSet<Stat>> {
126 &self.required_stats.map
127 }
128
129 pub fn evaluate(&self, metadata: &dyn Array) -> VortexResult<Option<ArrayRef>> {
135 let known_stats = HashSet::from_iter(
136 metadata
137 .dtype()
138 .as_struct()
139 .vortex_expect("metadata must be struct array")
140 .names()
141 .iter()
142 .map(|x| x.to_string()),
143 );
144 let required_stats = self
145 .required_stats()
146 .iter()
147 .flat_map(|(key, value)| value.iter().map(|stat| key.stat_field_name_string(*stat)))
148 .collect::<HashSet<_>>();
149 let missing_stats = required_stats.difference(&known_stats).collect::<Vec<_>>();
150
151 if !missing_stats.is_empty() {
152 return Ok(None);
153 }
154
155 Ok(Some(self.expr.evaluate(metadata)?))
156 }
157}
158
159fn not_prunable() -> PruningPredicateStats {
160 (
161 lit(Scalar::bool(false, Nullability::NonNullable)),
162 Relation::new(),
163 )
164}
165
166fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats {
169 if let Some(nexp) = expr.as_any().downcast_ref::<Not>() {
170 if let Some(get_item) = nexp.child().as_any().downcast_ref::<GetItem>() {
171 if get_item.child().as_any().is::<Identity>() {
172 return convert_access_reference(expr, true);
173 }
174 }
175 }
176
177 if let Some(get_item) = expr.as_any().downcast_ref::<GetItem>() {
178 if get_item.child().as_any().is::<Identity>() {
179 return convert_access_reference(expr, false);
180 }
181 }
182
183 if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
184 if bexp.op() == Operator::Or || bexp.op() == Operator::And {
185 let (rewritten_left, mut refs_lhs) = convert_to_pruning_expression(bexp.lhs());
186 let (rewritten_right, refs_rhs) = convert_to_pruning_expression(bexp.rhs());
187 refs_lhs.extend(refs_rhs);
188 let flipped_op = bexp
189 .op()
190 .logical_inverse()
191 .vortex_expect("Can not be any other operator than and / or");
192 return (
193 BinaryExpr::new_expr(rewritten_left, flipped_op, rewritten_right),
194 refs_lhs,
195 );
196 }
197
198 if let Some(get_item) = bexp.lhs().as_any().downcast_ref::<GetItem>() {
199 if get_item.child().as_any().is::<Identity>() {
200 return PruningPredicateRewriter::rewrite_binary_op(
201 FieldOrIdentity::Field(get_item.field().clone()),
202 bexp.op(),
203 bexp.rhs(),
204 );
205 }
206 };
207
208 if let Some(get_item) = bexp.rhs().as_any().downcast_ref::<GetItem>() {
209 if get_item.child().as_any().is::<Identity>() {
210 return PruningPredicateRewriter::rewrite_binary_op(
211 FieldOrIdentity::Field(get_item.field().clone()),
212 bexp.op().swap(),
213 bexp.lhs(),
214 );
215 }
216 }
217
218 if bexp.lhs().as_any().is::<Identity>() {
219 return PruningPredicateRewriter::rewrite_binary_op(
220 FieldOrIdentity::Identity,
221 bexp.op(),
222 bexp.rhs(),
223 );
224 };
225
226 if bexp.rhs().as_any().is::<Identity>() {
227 return PruningPredicateRewriter::rewrite_binary_op(
228 FieldOrIdentity::Identity,
229 bexp.op().swap(),
230 bexp.lhs(),
231 );
232 };
233 }
234
235 if let Some(between_expr) = expr.as_any().downcast_ref::<Between>() {
236 return convert_to_pruning_expression(&between_expr.to_binary_expr());
237 }
238
239 not_prunable()
240}
241
242fn convert_access_reference(expr: &ExprRef, invert: bool) -> PruningPredicateStats {
243 let mut refs = Relation::new();
244 let Some(min_expr) = replace_get_item_with_stat(expr, Stat::Min, &mut refs) else {
245 return not_prunable();
246 };
247 let Some(max_expr) = replace_get_item_with_stat(expr, Stat::Max, &mut refs) else {
248 return not_prunable();
249 };
250
251 let expr = if invert {
252 and(min_expr, max_expr)
253 } else {
254 not(or(min_expr, max_expr))
255 };
256
257 (expr, refs)
258}
259
260struct PruningPredicateRewriter<'a> {
261 access: FieldOrIdentity,
262 operator: Operator,
263 other_exp: &'a ExprRef,
264 stats_to_fetch: Relation<FieldOrIdentity, Stat>,
265}
266
267type PruningPredicateStats = (ExprRef, Relation<FieldOrIdentity, Stat>);
268
269impl<'a> PruningPredicateRewriter<'a> {
270 pub fn try_new(
271 access: FieldOrIdentity,
272 operator: Operator,
273 other_exp: &'a ExprRef,
274 ) -> Option<Self> {
275 if let FieldOrIdentity::Field(field) = &access {
278 if other_exp.references().contains(field) {
279 return None;
280 }
281 };
282
283 Some(Self {
284 access,
285 operator,
286 other_exp,
287 stats_to_fetch: Relation::new(),
288 })
289 }
290
291 pub fn rewrite_binary_op(
292 access: FieldOrIdentity,
293 operator: Operator,
294 other_exp: &'a ExprRef,
295 ) -> PruningPredicateStats {
296 Self::try_new(access, operator, other_exp)
297 .and_then(Self::rewrite)
298 .unwrap_or_else(not_prunable)
299 }
300
301 fn add_stat_reference(&mut self, stat: Stat) -> FieldName {
302 let new_field = self.access.stat_field_name(stat);
303 self.stats_to_fetch.insert(self.access.clone(), stat);
304 new_field
305 }
306
307 fn rewrite_other_exp(&mut self, stat: Stat) -> ExprRef {
308 replace_get_item_with_stat(self.other_exp, stat, &mut self.stats_to_fetch)
309 .unwrap_or_else(|| self.other_exp.clone())
310 }
311
312 fn rewrite(mut self) -> Option<PruningPredicateStats> {
313 let expr: Option<ExprRef> = match self.operator {
314 Operator::Eq => {
315 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
316 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
317 let replaced_max = self.rewrite_other_exp(Stat::Max);
318 let replaced_min = self.rewrite_other_exp(Stat::Min);
319
320 Some(or(gt(min_col, replaced_max), gt(replaced_min, max_col)))
321 }
322 Operator::NotEq => {
323 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
324 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
325 let replaced_max = self.rewrite_other_exp(Stat::Max);
326 let replaced_min = self.rewrite_other_exp(Stat::Min);
327
328 Some(and(eq(min_col, replaced_max), eq(max_col, replaced_min)))
329 }
330 Operator::Gt | Operator::Gte => {
331 let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
332 let replaced_min = self.rewrite_other_exp(Stat::Min);
333
334 Some(BinaryExpr::new_expr(
335 max_col,
336 self.operator
337 .inverse()
338 .vortex_expect("inverse of gt & gt_eq defined"),
339 replaced_min,
340 ))
341 }
342 Operator::Lt | Operator::Lte => {
343 let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
344 let replaced_max = self.rewrite_other_exp(Stat::Max);
345
346 Some(BinaryExpr::new_expr(
347 min_col,
348 self.operator
349 .inverse()
350 .vortex_expect("inverse of lt & lte defined"),
351 replaced_max,
352 ))
353 }
354 _ => None,
355 };
356 expr.map(|e| (e, self.stats_to_fetch))
357 }
358}
359
360fn replace_get_item_with_stat(
361 expr: &ExprRef,
362 stat: Stat,
363 stats_to_fetch: &mut Relation<FieldOrIdentity, Stat>,
364) -> Option<ExprRef> {
365 if let Some(get_i) = expr.as_any().downcast_ref::<GetItem>() {
366 if get_i.child().as_any().is::<Identity>() {
367 let new_field = stat_field_name(get_i.field(), stat);
368 stats_to_fetch.insert(FieldOrIdentity::Field(get_i.field().clone()), stat);
369 return Some(get_item(new_field, ident()));
370 }
371 }
372
373 if let Some(not_expr) = expr.as_any().downcast_ref::<Not>() {
374 let rewritten = replace_get_item_with_stat(not_expr.child(), stat, stats_to_fetch)?;
375 return Some(not(rewritten));
376 }
377
378 if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
379 let rewritten_lhs = replace_get_item_with_stat(bexp.lhs(), stat, stats_to_fetch);
380 let rewritten_rhs = replace_get_item_with_stat(bexp.rhs(), stat, stats_to_fetch);
381 if rewritten_lhs.is_none() && rewritten_rhs.is_none() {
382 return None;
383 }
384
385 let lhs = rewritten_lhs.unwrap_or_else(|| bexp.lhs().clone());
386 let rhs = rewritten_rhs.unwrap_or_else(|| bexp.rhs().clone());
387
388 return Some(BinaryExpr::new_expr(lhs, bexp.op(), rhs));
389 }
390
391 None
392}
393
394#[derive(Debug, Clone, Hash, PartialEq, Eq)]
395pub enum FieldOrIdentity {
396 Field(FieldName),
397 Identity,
398}
399
400pub(crate) fn stat_field_name(field: &FieldName, stat: Stat) -> FieldName {
401 FieldName::from(stat_field_name_string(field, stat))
402}
403
404pub(crate) fn stat_field_name_string(field: &FieldName, stat: Stat) -> String {
405 format!("{field}_{stat}")
406}
407
408impl FieldOrIdentity {
409 pub(crate) fn stat_field_name(&self, stat: Stat) -> FieldName {
410 FieldName::from(self.stat_field_name_string(stat))
411 }
412
413 pub(crate) fn stat_field_name_string(&self, stat: Stat) -> String {
414 match self {
415 FieldOrIdentity::Field(field) => stat_field_name_string(field, stat),
416 FieldOrIdentity::Identity => stat.to_string(),
417 }
418 }
419}
420
421impl Display for FieldOrIdentity {
422 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423 match self {
424 FieldOrIdentity::Field(field) => write!(f, "{}", field),
425 FieldOrIdentity::Identity => write!(f, "$[]"),
426 }
427 }
428}
429
430impl<T> From<T> for FieldOrIdentity
431where
432 FieldName: From<T>,
433{
434 fn from(value: T) -> Self {
435 FieldOrIdentity::Field(FieldName::from(value))
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use vortex_array::aliases::hash_map::HashMap;
442 use vortex_array::aliases::hash_set::HashSet;
443 use vortex_array::stats::Stat;
444 use vortex_dtype::FieldName;
445
446 use crate::pruning::{
447 FieldOrIdentity, PruningPredicate, convert_to_pruning_expression, stat_field_name,
448 };
449 use crate::{
450 and, eq, get_item, get_item_scope, gt, gt_eq, ident, lit, lt, lt_eq, not, not_eq, or,
451 };
452
453 #[test]
454 pub fn pruning_equals() {
455 let name = FieldName::from("a");
456 let literal_eq = lit(42);
457 let eq_expr = eq(get_item("a", ident()), literal_eq.clone());
458 let (converted, refs) = convert_to_pruning_expression(&eq_expr);
459 assert_eq!(
460 refs.into_map(),
461 HashMap::from_iter([(
462 FieldOrIdentity::Field(name.clone()),
463 HashSet::from_iter([Stat::Min, Stat::Max])
464 )])
465 );
466 let expected_expr = or(
467 gt(
468 get_item(stat_field_name(&name, Stat::Min), ident()),
469 literal_eq.clone(),
470 ),
471 gt(
472 literal_eq,
473 get_item_scope(stat_field_name(&name, Stat::Max)),
474 ),
475 );
476 assert_eq!(&converted, &expected_expr);
477 }
478
479 #[test]
480 pub fn pruning_equals_column() {
481 let column = FieldName::from("a");
482 let other_col = FieldName::from("b");
483 let eq_expr = eq(
484 get_item_scope(column.clone()),
485 get_item_scope(other_col.clone()),
486 );
487
488 let (converted, refs) = convert_to_pruning_expression(&eq_expr);
489 assert_eq!(
490 refs.into_map(),
491 HashMap::from_iter([
492 (
493 FieldOrIdentity::Field(column.clone()),
494 HashSet::from_iter([Stat::Min, Stat::Max])
495 ),
496 (
497 FieldOrIdentity::Field(other_col.clone()),
498 HashSet::from_iter([Stat::Max, Stat::Min])
499 )
500 ])
501 );
502 let expected_expr = or(
503 gt(
504 get_item_scope(stat_field_name(&column, Stat::Min)),
505 get_item_scope(stat_field_name(&other_col, Stat::Max)),
506 ),
507 gt(
508 get_item_scope(stat_field_name(&other_col, Stat::Min)),
509 get_item_scope(stat_field_name(&column, Stat::Max)),
510 ),
511 );
512 assert_eq!(&converted, &expected_expr);
513 }
514
515 #[test]
516 pub fn pruning_not_equals_column() {
517 let column = FieldName::from("a");
518 let other_col = FieldName::from("b");
519 let not_eq_expr = not_eq(
520 get_item_scope(column.clone()),
521 get_item_scope(other_col.clone()),
522 );
523
524 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
525 assert_eq!(
526 refs.into_map(),
527 HashMap::from_iter([
528 (
529 FieldOrIdentity::Field(column.clone()),
530 HashSet::from_iter([Stat::Min, Stat::Max])
531 ),
532 (
533 FieldOrIdentity::Field(other_col.clone()),
534 HashSet::from_iter([Stat::Max, Stat::Min])
535 )
536 ])
537 );
538 let expected_expr = and(
539 eq(
540 get_item_scope(stat_field_name(&column, Stat::Min)),
541 get_item_scope(stat_field_name(&other_col, Stat::Max)),
542 ),
543 eq(
544 get_item_scope(stat_field_name(&column, Stat::Max)),
545 get_item_scope(stat_field_name(&other_col, Stat::Min)),
546 ),
547 );
548
549 assert_eq!(&converted, &expected_expr);
550 }
551
552 #[test]
553 pub fn pruning_gt_column() {
554 let column = FieldName::from("a");
555 let other_col = FieldName::from("b");
556 let other_expr = get_item_scope(other_col.clone());
557 let not_eq_expr = gt(get_item_scope(column.clone()), other_expr.clone());
558
559 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
560 assert_eq!(
561 refs.into_map(),
562 HashMap::from_iter([
563 (
564 FieldOrIdentity::Field(column.clone()),
565 HashSet::from_iter([Stat::Max])
566 ),
567 (
568 FieldOrIdentity::Field(other_col.clone()),
569 HashSet::from_iter([Stat::Min])
570 )
571 ])
572 );
573 let expected_expr = lt_eq(
574 get_item_scope(stat_field_name(&column, Stat::Max)),
575 get_item_scope(stat_field_name(&other_col, Stat::Min)),
576 );
577 assert_eq!(&converted, &expected_expr);
578 }
579
580 #[test]
581 pub fn pruning_gt_value() {
582 let column = FieldName::from("a");
583 let other_col = lit(42);
584 let not_eq_expr = gt(get_item_scope(column.clone()), other_col.clone());
585
586 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
587 assert_eq!(
588 refs.into_map(),
589 HashMap::from_iter([(
590 FieldOrIdentity::Field(column.clone()),
591 HashSet::from_iter([Stat::Max])
592 ),])
593 );
594 let expected_expr = lt_eq(
595 get_item_scope(stat_field_name(&column, Stat::Max)),
596 other_col.clone(),
597 );
598 assert_eq!(&converted, &expected_expr);
599 }
600
601 #[test]
602 pub fn pruning_lt_column() {
603 let column = FieldName::from("a");
604 let other_col = FieldName::from("b");
605 let other_expr = get_item_scope(other_col.clone());
606 let not_eq_expr = lt(get_item_scope(column.clone()), other_expr.clone());
607
608 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
609 assert_eq!(
610 refs.into_map(),
611 HashMap::from_iter([
612 (
613 FieldOrIdentity::Field(column.clone()),
614 HashSet::from_iter([Stat::Min])
615 ),
616 (
617 FieldOrIdentity::Field(other_col.clone()),
618 HashSet::from_iter([Stat::Max])
619 )
620 ])
621 );
622 let expected_expr = gt_eq(
623 get_item_scope(stat_field_name(&column, Stat::Min)),
624 get_item_scope(stat_field_name(&other_col, Stat::Max)),
625 );
626 assert_eq!(&converted, &expected_expr);
627 }
628
629 #[test]
630 pub fn pruning_lt_value() {
631 let column = FieldName::from("a");
632 let other_col = lit(42);
633 let not_eq_expr = lt(get_item_scope(column.clone()), other_col.clone());
634
635 let (converted, refs) = convert_to_pruning_expression(¬_eq_expr);
636 assert_eq!(
637 refs.into_map(),
638 HashMap::from_iter([(
639 FieldOrIdentity::Field(column.clone()),
640 HashSet::from_iter([Stat::Min])
641 )])
642 );
643 let expected_expr = gt_eq(
644 get_item_scope(stat_field_name(&column, Stat::Min)),
645 other_col.clone(),
646 );
647 assert_eq!(&converted, &expected_expr);
648 }
649
650 #[test]
651 fn unprojectable_expr() {
652 let or_expr = not(lt(get_item_scope("a"), get_item_scope("b")));
653 assert!(PruningPredicate::try_new(&or_expr).is_none());
654 }
655
656 #[test]
657 fn display_pruning_predicate() {
658 let column = FieldName::from("a");
659 let other_col = lit(42);
660 let not_eq_expr = lt(get_item_scope(column), other_col);
661
662 assert_eq!(
663 PruningPredicate::try_new(¬_eq_expr).unwrap().to_string(),
664 "PruningPredicate(($.a_min >= 42i32), {a: {min}})"
665 );
666 }
667
668 #[test]
669 fn or_required_stats_from_both_arms() {
670 let item = get_item_scope(FieldName::from("a"));
671 let expr = or(lt(item.clone(), lit(10)), gt(item, lit(50)));
672
673 let expected = HashMap::from([(
674 FieldOrIdentity::from("a"),
675 HashSet::from([Stat::Min, Stat::Max]),
676 )]);
677
678 assert_eq!(
679 PruningPredicate::try_new(&expr).unwrap().required_stats(),
680 &expected
681 );
682 }
683
684 #[test]
685 fn and_required_stats_from_both_arms() {
686 let item = get_item_scope(FieldName::from("a"));
687 let expr = and(gt(item.clone(), lit(50)), lt(item, lit(10)));
688
689 let expected = HashMap::from([(
690 FieldOrIdentity::from("a"),
691 HashSet::from([Stat::Min, Stat::Max]),
692 )]);
693
694 assert_eq!(
695 PruningPredicate::try_new(&expr).unwrap().required_stats(),
696 &expected
697 );
698 }
699
700 #[test]
701 fn pruning_identity() {
702 let expr = ident();
703 let expr = or(lt(expr.clone(), lit(10)), gt(expr.clone(), lit(50)));
704
705 let expected = HashMap::from([(
706 FieldOrIdentity::Identity,
707 HashSet::from([Stat::Min, Stat::Max]),
708 )]);
709
710 let predicate = PruningPredicate::try_new(&expr).unwrap();
711 assert_eq!(predicate.required_stats(), &expected);
712
713 let expected_expr = and(
714 gt_eq(get_item_scope(FieldName::from("min")), lit(10)),
715 lt_eq(get_item_scope(FieldName::from("max")), lit(50)),
716 );
717 assert_eq!(predicate.expr(), &expected_expr)
718 }
719 #[test]
720 pub fn pruning_and_or_operators() {
721 let column = FieldName::from("a");
723 let and_expr = and(
724 gt(get_item_scope(column.clone()), lit(10)),
725 lt(get_item_scope(column), lit(50)),
726 );
727 let pruned = PruningPredicate::try_new(&and_expr).unwrap();
728
729 assert_eq!(
731 pruned.expr(),
732 &or(
733 lt_eq(get_item_scope(FieldName::from("a_max")), lit(10)),
734 gt_eq(get_item_scope(FieldName::from("a_min")), lit(50))
735 ),
736 );
737 }
738}