vortex_expr/
pruning.rs

1// This code doesn't have usage outside of tests yet, remove once usage is added
2#![allow(dead_code)]
3
4use std::fmt::Display;
5use std::hash::Hash;
6
7use itertools::Itertools;
8use vortex_array::aliases::hash_map::HashMap;
9use vortex_array::aliases::hash_set::HashSet;
10use vortex_array::stats::Stat;
11use vortex_array::{Array, ArrayRef};
12use vortex_dtype::{FieldName, Nullability};
13use vortex_error::{VortexExpect as _, VortexResult};
14use vortex_scalar::Scalar;
15
16use crate::{
17    BinaryExpr, ExprRef, GetItem, Identity, Literal, Not, Operator, VortexExprExt, and, eq,
18    get_item, gt, ident, lit, not, or,
19};
20
21#[derive(Debug, Clone)]
22pub struct Relation<K, V> {
23    map: HashMap<K, HashSet<V>>,
24}
25
26impl<K: Display, V: Display> Display for Relation<K, V> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(
29            f,
30            "{}",
31            self.map.iter().format_with(",", |(k, v), fmt| {
32                fmt(&format_args!("{k}: {{{}}}", v.iter().format(",")))
33            })
34        )
35    }
36}
37
38impl<K: Hash + Eq, V: Hash + Eq> Default for Relation<K, V> {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl<K: Hash + Eq, V: Hash + Eq> Relation<K, V> {
45    pub fn new() -> Self {
46        Relation {
47            map: HashMap::new(),
48        }
49    }
50
51    pub fn union(mut iter: impl Iterator<Item = Relation<K, V>>) -> Relation<K, V> {
52        if let Some(mut x) = iter.next() {
53            for y in iter {
54                x.extend(y)
55            }
56            x
57        } else {
58            Relation::new()
59        }
60    }
61
62    pub fn extend(&mut self, other: Relation<K, V>) {
63        for (l, rs) in other.map.into_iter() {
64            self.map.entry(l).or_default().extend(rs.into_iter())
65        }
66    }
67
68    pub fn insert(&mut self, k: K, v: V) {
69        self.map.entry(k).or_default().insert(v);
70    }
71
72    pub fn into_map(self) -> HashMap<K, HashSet<V>> {
73        self.map
74    }
75}
76
77#[derive(Debug, Clone)]
78pub struct PruningPredicate {
79    expr: ExprRef,
80    required_stats: Relation<FieldOrIdentity, Stat>,
81}
82
83impl Display for PruningPredicate {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(
86            f,
87            "PruningPredicate({}, {{{}}})",
88            self.expr, self.required_stats
89        )
90    }
91}
92
93impl PruningPredicate {
94    pub fn try_new(original_expr: &ExprRef) -> Option<Self> {
95        let (expr, required_stats) = convert_to_pruning_expression(original_expr);
96        if let Some(lexp) = expr.as_any().downcast_ref::<Literal>() {
97            // Is the expression constant false, i.e. prune nothing
98            if lexp
99                .value()
100                .as_bool_opt()
101                .and_then(|b| b.value())
102                .map(|b| !b)
103                .unwrap_or(false)
104            {
105                None
106            } else {
107                Some(Self {
108                    expr,
109                    required_stats,
110                })
111            }
112        } else {
113            Some(Self {
114                expr,
115                required_stats,
116            })
117        }
118    }
119
120    pub fn expr(&self) -> &ExprRef {
121        &self.expr
122    }
123
124    pub fn required_stats(&self) -> &HashMap<FieldOrIdentity, HashSet<Stat>> {
125        &self.required_stats.map
126    }
127
128    /// Evaluate this predicate against a per-chunk statistics table.
129    ///
130    /// Returns Ok(None) if any of the required statistics are not present in metadata.
131    /// If it returns Ok(Some(array)), the array is a boolean array with the same length as the
132    /// metadata, and a true value means the chunk _can_ be pruned.
133    pub fn evaluate(&self, metadata: &dyn Array) -> VortexResult<Option<ArrayRef>> {
134        let known_stats = HashSet::from_iter(
135            metadata
136                .as_struct_typed()
137                .vortex_expect("metadata must be struct array")
138                .names()
139                .iter()
140                .map(|x| x.to_string()),
141        );
142        let required_stats = self
143            .required_stats()
144            .iter()
145            .flat_map(|(key, value)| value.iter().map(|stat| key.stat_field_name_string(*stat)))
146            .collect::<HashSet<_>>();
147        let missing_stats = required_stats.difference(&known_stats).collect::<Vec<_>>();
148
149        if !missing_stats.is_empty() {
150            return Ok(None);
151        }
152
153        Ok(Some(self.expr.evaluate(metadata)?))
154    }
155}
156
157fn not_prunable() -> PruningPredicateStats {
158    (
159        lit(Scalar::bool(false, Nullability::NonNullable)),
160        Relation::new(),
161    )
162}
163
164// Anything that can't be translated has to be represented as
165// boolean true expression, i.e. the value might be in that chunk
166fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats {
167    if let Some(nexp) = expr.as_any().downcast_ref::<Not>() {
168        if let Some(get_item) = nexp.child().as_any().downcast_ref::<GetItem>() {
169            if get_item.child().as_any().is::<Identity>() {
170                return convert_access_reference(expr, true);
171            }
172        }
173    }
174
175    if let Some(get_item) = expr.as_any().downcast_ref::<GetItem>() {
176        if get_item.child().as_any().is::<Identity>() {
177            return convert_access_reference(expr, false);
178        }
179    }
180
181    if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
182        if bexp.op() == Operator::Or || bexp.op() == Operator::And {
183            let (rewritten_left, mut refs_lhs) = convert_to_pruning_expression(bexp.lhs());
184            let (rewritten_right, refs_rhs) = convert_to_pruning_expression(bexp.rhs());
185            refs_lhs.extend(refs_rhs);
186            return (
187                BinaryExpr::new_expr(rewritten_left, bexp.op(), rewritten_right),
188                refs_lhs,
189            );
190        }
191
192        if let Some(get_item) = bexp.lhs().as_any().downcast_ref::<GetItem>() {
193            if get_item.child().as_any().is::<Identity>() {
194                return PruningPredicateRewriter::rewrite_binary_op(
195                    FieldOrIdentity::Field(get_item.field().clone()),
196                    bexp.op(),
197                    bexp.rhs(),
198                );
199            }
200        };
201
202        if let Some(get_item) = bexp.rhs().as_any().downcast_ref::<GetItem>() {
203            if get_item.child().as_any().is::<Identity>() {
204                return PruningPredicateRewriter::rewrite_binary_op(
205                    FieldOrIdentity::Field(get_item.field().clone()),
206                    bexp.op().swap(),
207                    bexp.lhs(),
208                );
209            }
210        }
211
212        if bexp.lhs().as_any().is::<Identity>() {
213            return PruningPredicateRewriter::rewrite_binary_op(
214                FieldOrIdentity::Identity,
215                bexp.op(),
216                bexp.rhs(),
217            );
218        };
219
220        if bexp.rhs().as_any().is::<Identity>() {
221            return PruningPredicateRewriter::rewrite_binary_op(
222                FieldOrIdentity::Identity,
223                bexp.op().swap(),
224                bexp.lhs(),
225            );
226        };
227    }
228
229    not_prunable()
230}
231
232fn convert_access_reference(expr: &ExprRef, invert: bool) -> PruningPredicateStats {
233    let mut refs = Relation::new();
234    let Some(min_expr) = replace_get_item_with_stat(expr, Stat::Min, &mut refs) else {
235        return not_prunable();
236    };
237    let Some(max_expr) = replace_get_item_with_stat(expr, Stat::Max, &mut refs) else {
238        return not_prunable();
239    };
240
241    let expr = if invert {
242        and(min_expr, max_expr)
243    } else {
244        not(or(min_expr, max_expr))
245    };
246
247    (expr, refs)
248}
249
250struct PruningPredicateRewriter<'a> {
251    access: FieldOrIdentity,
252    operator: Operator,
253    other_exp: &'a ExprRef,
254    stats_to_fetch: Relation<FieldOrIdentity, Stat>,
255}
256
257type PruningPredicateStats = (ExprRef, Relation<FieldOrIdentity, Stat>);
258
259impl<'a> PruningPredicateRewriter<'a> {
260    pub fn try_new(
261        access: FieldOrIdentity,
262        operator: Operator,
263        other_exp: &'a ExprRef,
264    ) -> Option<Self> {
265        // TODO(robert): Simplify expression to guarantee that each column is not compared to itself
266        //  For majority of cases self column references are likely not prunable
267        if let FieldOrIdentity::Field(field) = &access {
268            if other_exp.references().contains(field) {
269                return None;
270            }
271        };
272
273        Some(Self {
274            access,
275            operator,
276            other_exp,
277            stats_to_fetch: Relation::new(),
278        })
279    }
280
281    pub fn rewrite_binary_op(
282        access: FieldOrIdentity,
283        operator: Operator,
284        other_exp: &'a ExprRef,
285    ) -> PruningPredicateStats {
286        Self::try_new(access, operator, other_exp)
287            .and_then(Self::rewrite)
288            .unwrap_or_else(not_prunable)
289    }
290
291    fn add_stat_reference(&mut self, stat: Stat) -> FieldName {
292        let new_field = self.access.stat_field_name(stat);
293        self.stats_to_fetch.insert(self.access.clone(), stat);
294        new_field
295    }
296
297    fn rewrite_other_exp(&mut self, stat: Stat) -> ExprRef {
298        replace_get_item_with_stat(self.other_exp, stat, &mut self.stats_to_fetch)
299            .unwrap_or_else(|| self.other_exp.clone())
300    }
301
302    fn rewrite(mut self) -> Option<PruningPredicateStats> {
303        let expr: Option<ExprRef> = match self.operator {
304            Operator::Eq => {
305                let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
306                let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
307                let replaced_max = self.rewrite_other_exp(Stat::Max);
308                let replaced_min = self.rewrite_other_exp(Stat::Min);
309
310                Some(or(gt(min_col, replaced_max), gt(replaced_min, max_col)))
311            }
312            Operator::NotEq => {
313                let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
314                let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
315                let replaced_max = self.rewrite_other_exp(Stat::Max);
316                let replaced_min = self.rewrite_other_exp(Stat::Min);
317
318                Some(and(eq(min_col, replaced_max), eq(max_col, replaced_min)))
319            }
320            Operator::Gt | Operator::Gte => {
321                let max_col = get_item(self.add_stat_reference(Stat::Max), ident());
322                let replaced_min = self.rewrite_other_exp(Stat::Min);
323
324                Some(BinaryExpr::new_expr(
325                    max_col,
326                    self.operator
327                        .inverse()
328                        .vortex_expect("inverse of gt & gt_eq defined"),
329                    replaced_min,
330                ))
331            }
332            Operator::Lt | Operator::Lte => {
333                let min_col = get_item(self.add_stat_reference(Stat::Min), ident());
334                let replaced_max = self.rewrite_other_exp(Stat::Max);
335
336                Some(BinaryExpr::new_expr(
337                    min_col,
338                    self.operator
339                        .inverse()
340                        .vortex_expect("inverse of lt & lte defined"),
341                    replaced_max,
342                ))
343            }
344            _ => None,
345        };
346        expr.map(|e| (e, self.stats_to_fetch))
347    }
348}
349
350fn replace_get_item_with_stat(
351    expr: &ExprRef,
352    stat: Stat,
353    stats_to_fetch: &mut Relation<FieldOrIdentity, Stat>,
354) -> Option<ExprRef> {
355    if let Some(get_i) = expr.as_any().downcast_ref::<GetItem>() {
356        if get_i.child().as_any().is::<Identity>() {
357            let new_field = stat_field_name(get_i.field(), stat);
358            stats_to_fetch.insert(FieldOrIdentity::Field(get_i.field().clone()), stat);
359            return Some(get_item(new_field, ident()));
360        }
361    }
362
363    if let Some(not_expr) = expr.as_any().downcast_ref::<Not>() {
364        let rewritten = replace_get_item_with_stat(not_expr.child(), stat, stats_to_fetch)?;
365        return Some(not(rewritten));
366    }
367
368    if let Some(bexp) = expr.as_any().downcast_ref::<BinaryExpr>() {
369        let rewritten_lhs = replace_get_item_with_stat(bexp.lhs(), stat, stats_to_fetch);
370        let rewritten_rhs = replace_get_item_with_stat(bexp.rhs(), stat, stats_to_fetch);
371        if rewritten_lhs.is_none() && rewritten_rhs.is_none() {
372            return None;
373        }
374
375        let lhs = rewritten_lhs.unwrap_or_else(|| bexp.lhs().clone());
376        let rhs = rewritten_rhs.unwrap_or_else(|| bexp.rhs().clone());
377
378        return Some(BinaryExpr::new_expr(lhs, bexp.op(), rhs));
379    }
380
381    None
382}
383
384#[derive(Debug, Clone, Hash, PartialEq, Eq)]
385pub enum FieldOrIdentity {
386    Field(FieldName),
387    Identity,
388}
389
390pub(crate) fn stat_field_name(field: &FieldName, stat: Stat) -> FieldName {
391    FieldName::from(stat_field_name_string(field, stat))
392}
393
394pub(crate) fn stat_field_name_string(field: &FieldName, stat: Stat) -> String {
395    format!("{field}_{stat}")
396}
397
398impl FieldOrIdentity {
399    pub(crate) fn stat_field_name(&self, stat: Stat) -> FieldName {
400        FieldName::from(self.stat_field_name_string(stat))
401    }
402
403    pub(crate) fn stat_field_name_string(&self, stat: Stat) -> String {
404        match self {
405            FieldOrIdentity::Field(field) => stat_field_name_string(field, stat),
406            FieldOrIdentity::Identity => stat.to_string(),
407        }
408    }
409}
410
411impl Display for FieldOrIdentity {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        match self {
414            FieldOrIdentity::Field(field) => write!(f, "{}", field),
415            FieldOrIdentity::Identity => write!(f, "$[]"),
416        }
417    }
418}
419
420impl<T> From<T> for FieldOrIdentity
421where
422    FieldName: From<T>,
423{
424    fn from(value: T) -> Self {
425        FieldOrIdentity::Field(FieldName::from(value))
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use vortex_array::aliases::hash_map::HashMap;
432    use vortex_array::aliases::hash_set::HashSet;
433    use vortex_array::stats::Stat;
434    use vortex_dtype::FieldName;
435
436    use crate::pruning::{
437        FieldOrIdentity, PruningPredicate, convert_to_pruning_expression, stat_field_name,
438    };
439    use crate::{
440        and, eq, get_item, get_item_scope, gt, gt_eq, ident, lit, lt, lt_eq, not, not_eq, or,
441    };
442
443    #[test]
444    pub fn pruning_equals() {
445        let name = FieldName::from("a");
446        let literal_eq = lit(42);
447        let eq_expr = eq(get_item("a", ident()), literal_eq.clone());
448        let (converted, refs) = convert_to_pruning_expression(&eq_expr);
449        assert_eq!(
450            refs.into_map(),
451            HashMap::from_iter([(
452                FieldOrIdentity::Field(name.clone()),
453                HashSet::from_iter([Stat::Min, Stat::Max])
454            )])
455        );
456        let expected_expr = or(
457            gt(
458                get_item(stat_field_name(&name, Stat::Min), ident()),
459                literal_eq.clone(),
460            ),
461            gt(
462                literal_eq,
463                get_item_scope(stat_field_name(&name, Stat::Max)),
464            ),
465        );
466        assert_eq!(&converted, &expected_expr);
467    }
468
469    #[test]
470    pub fn pruning_equals_column() {
471        let column = FieldName::from("a");
472        let other_col = FieldName::from("b");
473        let eq_expr = eq(
474            get_item_scope(column.clone()),
475            get_item_scope(other_col.clone()),
476        );
477
478        let (converted, refs) = convert_to_pruning_expression(&eq_expr);
479        assert_eq!(
480            refs.into_map(),
481            HashMap::from_iter([
482                (
483                    FieldOrIdentity::Field(column.clone()),
484                    HashSet::from_iter([Stat::Min, Stat::Max])
485                ),
486                (
487                    FieldOrIdentity::Field(other_col.clone()),
488                    HashSet::from_iter([Stat::Max, Stat::Min])
489                )
490            ])
491        );
492        let expected_expr = or(
493            gt(
494                get_item_scope(stat_field_name(&column, Stat::Min)),
495                get_item_scope(stat_field_name(&other_col, Stat::Max)),
496            ),
497            gt(
498                get_item_scope(stat_field_name(&other_col, Stat::Min)),
499                get_item_scope(stat_field_name(&column, Stat::Max)),
500            ),
501        );
502        assert_eq!(&converted, &expected_expr);
503    }
504
505    #[test]
506    pub fn pruning_not_equals_column() {
507        let column = FieldName::from("a");
508        let other_col = FieldName::from("b");
509        let not_eq_expr = not_eq(
510            get_item_scope(column.clone()),
511            get_item_scope(other_col.clone()),
512        );
513
514        let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
515        assert_eq!(
516            refs.into_map(),
517            HashMap::from_iter([
518                (
519                    FieldOrIdentity::Field(column.clone()),
520                    HashSet::from_iter([Stat::Min, Stat::Max])
521                ),
522                (
523                    FieldOrIdentity::Field(other_col.clone()),
524                    HashSet::from_iter([Stat::Max, Stat::Min])
525                )
526            ])
527        );
528        let expected_expr = and(
529            eq(
530                get_item_scope(stat_field_name(&column, Stat::Min)),
531                get_item_scope(stat_field_name(&other_col, Stat::Max)),
532            ),
533            eq(
534                get_item_scope(stat_field_name(&column, Stat::Max)),
535                get_item_scope(stat_field_name(&other_col, Stat::Min)),
536            ),
537        );
538
539        assert_eq!(&converted, &expected_expr);
540    }
541
542    #[test]
543    pub fn pruning_gt_column() {
544        let column = FieldName::from("a");
545        let other_col = FieldName::from("b");
546        let other_expr = get_item_scope(other_col.clone());
547        let not_eq_expr = gt(get_item_scope(column.clone()), other_expr.clone());
548
549        let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
550        assert_eq!(
551            refs.into_map(),
552            HashMap::from_iter([
553                (
554                    FieldOrIdentity::Field(column.clone()),
555                    HashSet::from_iter([Stat::Max])
556                ),
557                (
558                    FieldOrIdentity::Field(other_col.clone()),
559                    HashSet::from_iter([Stat::Min])
560                )
561            ])
562        );
563        let expected_expr = lt_eq(
564            get_item_scope(stat_field_name(&column, Stat::Max)),
565            get_item_scope(stat_field_name(&other_col, Stat::Min)),
566        );
567        assert_eq!(&converted, &expected_expr);
568    }
569
570    #[test]
571    pub fn pruning_gt_value() {
572        let column = FieldName::from("a");
573        let other_col = lit(42);
574        let not_eq_expr = gt(get_item_scope(column.clone()), other_col.clone());
575
576        let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
577        assert_eq!(
578            refs.into_map(),
579            HashMap::from_iter([(
580                FieldOrIdentity::Field(column.clone()),
581                HashSet::from_iter([Stat::Max])
582            ),])
583        );
584        let expected_expr = lt_eq(
585            get_item_scope(stat_field_name(&column, Stat::Max)),
586            other_col.clone(),
587        );
588        assert_eq!(&converted, &expected_expr);
589    }
590
591    #[test]
592    pub fn pruning_lt_column() {
593        let column = FieldName::from("a");
594        let other_col = FieldName::from("b");
595        let other_expr = get_item_scope(other_col.clone());
596        let not_eq_expr = lt(get_item_scope(column.clone()), other_expr.clone());
597
598        let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
599        assert_eq!(
600            refs.into_map(),
601            HashMap::from_iter([
602                (
603                    FieldOrIdentity::Field(column.clone()),
604                    HashSet::from_iter([Stat::Min])
605                ),
606                (
607                    FieldOrIdentity::Field(other_col.clone()),
608                    HashSet::from_iter([Stat::Max])
609                )
610            ])
611        );
612        let expected_expr = gt_eq(
613            get_item_scope(stat_field_name(&column, Stat::Min)),
614            get_item_scope(stat_field_name(&other_col, Stat::Max)),
615        );
616        assert_eq!(&converted, &expected_expr);
617    }
618
619    #[test]
620    pub fn pruning_lt_value() {
621        let column = FieldName::from("a");
622        let other_col = lit(42);
623        let not_eq_expr = lt(get_item_scope(column.clone()), other_col.clone());
624
625        let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
626        assert_eq!(
627            refs.into_map(),
628            HashMap::from_iter([(
629                FieldOrIdentity::Field(column.clone()),
630                HashSet::from_iter([Stat::Min])
631            )])
632        );
633        let expected_expr = gt_eq(
634            get_item_scope(stat_field_name(&column, Stat::Min)),
635            other_col.clone(),
636        );
637        assert_eq!(&converted, &expected_expr);
638    }
639
640    #[test]
641    fn unprojectable_expr() {
642        let or_expr = not(lt(get_item_scope("a"), get_item_scope("b")));
643        assert!(PruningPredicate::try_new(&or_expr).is_none());
644    }
645
646    #[test]
647    fn display_pruning_predicate() {
648        let column = FieldName::from("a");
649        let other_col = lit(42);
650        let not_eq_expr = lt(get_item_scope(column), other_col);
651
652        assert_eq!(
653            PruningPredicate::try_new(&not_eq_expr).unwrap().to_string(),
654            "PruningPredicate(($.a_min >= 42_i32), {a: {min}})"
655        );
656    }
657
658    #[test]
659    fn or_required_stats_from_both_arms() {
660        let item = get_item_scope(FieldName::from("a"));
661        let expr = or(lt(item.clone(), lit(10)), gt(item, lit(50)));
662
663        let expected = HashMap::from([(
664            FieldOrIdentity::from("a"),
665            HashSet::from([Stat::Min, Stat::Max]),
666        )]);
667
668        assert_eq!(
669            PruningPredicate::try_new(&expr).unwrap().required_stats(),
670            &expected
671        );
672    }
673
674    #[test]
675    fn and_required_stats_from_both_arms() {
676        let item = get_item_scope(FieldName::from("a"));
677        let expr = and(gt(item.clone(), lit(50)), lt(item, lit(10)));
678
679        let expected = HashMap::from([(
680            FieldOrIdentity::from("a"),
681            HashSet::from([Stat::Min, Stat::Max]),
682        )]);
683
684        assert_eq!(
685            PruningPredicate::try_new(&expr).unwrap().required_stats(),
686            &expected
687        );
688    }
689
690    #[test]
691    fn pruning_identity() {
692        let expr = ident();
693        let expr = or(lt(expr.clone(), lit(10)), gt(expr.clone(), lit(50)));
694
695        let expected = HashMap::from([(
696            FieldOrIdentity::Identity,
697            HashSet::from([Stat::Min, Stat::Max]),
698        )]);
699
700        let predicate = PruningPredicate::try_new(&expr).unwrap();
701        assert_eq!(predicate.required_stats(), &expected);
702
703        let expected_expr = or(
704            gt_eq(get_item_scope(FieldName::from("min")), lit(10)),
705            lt_eq(get_item_scope(FieldName::from("max")), lit(50)),
706        );
707        assert_eq!(predicate.expr(), &expected_expr)
708    }
709}