vortex_array/expr/pruning/
pruning_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5
6use itertools::Itertools;
7use vortex_dtype::{Field, FieldName, FieldPath, FieldPathSet};
8use vortex_utils::aliases::hash_map::HashMap;
9
10use super::relation::Relation;
11use crate::expr::exprs::get_item::get_item;
12use crate::expr::exprs::root::root;
13use crate::expr::{Expression, StatsCatalog};
14use crate::stats::Stat;
15
16pub type RequiredStats = Relation<FieldPath, Stat>;
17
18// A catalog that return a stat column whenever it is required, tracking all accessed
19// stats and returning them later.
20#[derive(Default)]
21struct TrackingStatsCatalog {
22    usage: HashMap<(FieldPath, Stat), Expression>,
23}
24
25impl TrackingStatsCatalog {
26    /// Consume the catalog, yielding a map of field statistics that were required
27    /// for each expression.
28    fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> {
29        self.usage
30    }
31}
32
33// A catalog that return a stat column if it exists in the given scope.
34struct ScopeStatsCatalog<'a> {
35    any_catalog: TrackingStatsCatalog,
36    available_stats: &'a FieldPathSet,
37}
38
39impl StatsCatalog for ScopeStatsCatalog<'_> {
40    fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
41        let stat_path = field_path.clone().push(stat.name());
42
43        if self.available_stats.contains(&stat_path) {
44            self.any_catalog.stats_ref(field_path, stat)
45        } else {
46            None
47        }
48    }
49}
50
51impl StatsCatalog for TrackingStatsCatalog {
52    fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
53        let mut expr = root();
54        let name = field_path_stat_field_name(field_path, stat);
55        expr = get_item(name, expr);
56        self.usage.insert((field_path.clone(), stat), expr.clone());
57        Some(expr)
58    }
59}
60
61#[doc(hidden)]
62pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
63    field_path
64        .parts()
65        .iter()
66        .map(|f| match f {
67            Field::Name(n) => n.as_ref(),
68            Field::ElementType => todo!("element type not currently handled"),
69        })
70        .chain(iter::once(stat.name()))
71        .join("_")
72        .into()
73}
74
75/// Build a pruning expr mask, using an existing set of stats.
76/// The available stats are provided as a set of [`FieldPath`].
77///
78/// A pruning expression is one that returns `true` for all positions where the original expression
79/// cannot hold, and false if it cannot be determined from stats alone whether the positions can
80/// be pruned.
81///
82/// If the falsification logic attempts to access an unknown stat,
83/// this function will return `None`.
84pub fn checked_pruning_expr(
85    expr: &Expression,
86    available_stats: &FieldPathSet,
87) -> Option<(Expression, RequiredStats)> {
88    let mut catalog = ScopeStatsCatalog {
89        any_catalog: Default::default(),
90        available_stats,
91    };
92
93    let expr = expr.stat_falsification(&mut catalog)?;
94
95    // TODO(joe): filter access by used exprs
96    let mut relation: Relation<FieldPath, Stat> = Relation::new();
97    for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
98        relation.insert(field_path, stat)
99    }
100
101    Some((expr, relation))
102}
103
104#[cfg(test)]
105mod tests {
106    use rstest::{fixture, rstest};
107    use vortex_dtype::{
108        DType, FieldName, FieldNames, FieldPath, FieldPathSet, Nullability, StructFields,
109    };
110    use vortex_utils::aliases::hash_set::HashSet;
111
112    use super::HashMap;
113    use crate::compute::{BetweenOptions, StrictComparison};
114    use crate::expr::exprs::between::between;
115    use crate::expr::exprs::binary::{and, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
116    use crate::expr::exprs::cast::cast;
117    use crate::expr::exprs::get_item::{col, get_item};
118    use crate::expr::exprs::literal::lit;
119    use crate::expr::exprs::root::root;
120    use crate::expr::pruning::{checked_pruning_expr, field_path_stat_field_name};
121    use crate::stats::Stat;
122
123    // Implement some checked pruning expressions.
124    #[fixture]
125    fn available_stats() -> FieldPathSet {
126        let field_a = FieldPath::from_name("a");
127        let field_b = FieldPath::from_name("b");
128
129        FieldPathSet::from_iter([
130            field_a.clone().push(Stat::Min.name()),
131            field_a.push(Stat::Max.name()),
132            field_b.clone().push(Stat::Min.name()),
133            field_b.push(Stat::Max.name()),
134        ])
135    }
136
137    #[rstest]
138    pub fn pruning_equals(available_stats: FieldPathSet) {
139        let name = FieldName::from("a");
140        let literal_eq = lit(42);
141        let eq_expr = eq(get_item("a", root()), literal_eq.clone());
142        let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
143        let expected_expr = or(
144            gt(
145                get_item(
146                    field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
147                    root(),
148                ),
149                literal_eq.clone(),
150            ),
151            gt(
152                literal_eq,
153                col(field_path_stat_field_name(
154                    &FieldPath::from_name(name),
155                    Stat::Max,
156                )),
157            ),
158        );
159        assert_eq!(&converted, &expected_expr);
160    }
161
162    #[rstest]
163    pub fn pruning_equals_column(available_stats: FieldPathSet) {
164        let column = FieldName::from("a");
165        let other_col = FieldName::from("b");
166        let eq_expr = eq(col(column.clone()), col(other_col.clone()));
167
168        let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
169        assert_eq!(
170            refs.map(),
171            &HashMap::from_iter([
172                (
173                    FieldPath::from_name(column.clone()),
174                    HashSet::from_iter([Stat::Min, Stat::Max])
175                ),
176                (
177                    FieldPath::from_name(other_col.clone()),
178                    HashSet::from_iter([Stat::Max, Stat::Min])
179                )
180            ])
181        );
182        let expected_expr = or(
183            gt(
184                col(field_path_stat_field_name(
185                    &FieldPath::from_name(column.clone()),
186                    Stat::Min,
187                )),
188                col(field_path_stat_field_name(
189                    &FieldPath::from_name(other_col.clone()),
190                    Stat::Max,
191                )),
192            ),
193            gt(
194                col(field_path_stat_field_name(
195                    &FieldPath::from_name(other_col),
196                    Stat::Min,
197                )),
198                col(field_path_stat_field_name(
199                    &FieldPath::from_name(column),
200                    Stat::Max,
201                )),
202            ),
203        );
204        assert_eq!(&converted, &expected_expr);
205    }
206
207    #[rstest]
208    pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
209        let column = FieldName::from("a");
210        let other_col = FieldName::from("b");
211        let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
212
213        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
214        assert_eq!(
215            refs.map(),
216            &HashMap::from_iter([
217                (
218                    FieldPath::from_name(column.clone()),
219                    HashSet::from_iter([Stat::Min, Stat::Max])
220                ),
221                (
222                    FieldPath::from_name(other_col.clone()),
223                    HashSet::from_iter([Stat::Max, Stat::Min])
224                )
225            ])
226        );
227        let expected_expr = and(
228            eq(
229                col(field_path_stat_field_name(
230                    &FieldPath::from_name(column.clone()),
231                    Stat::Min,
232                )),
233                col(field_path_stat_field_name(
234                    &FieldPath::from_name(other_col.clone()),
235                    Stat::Max,
236                )),
237            ),
238            eq(
239                col(field_path_stat_field_name(
240                    &FieldPath::from_name(column),
241                    Stat::Max,
242                )),
243                col(field_path_stat_field_name(
244                    &FieldPath::from_name(other_col),
245                    Stat::Min,
246                )),
247            ),
248        );
249
250        assert_eq!(&converted, &expected_expr);
251    }
252
253    #[rstest]
254    pub fn pruning_gt_column(available_stats: FieldPathSet) {
255        let column = FieldName::from("a");
256        let other_col = FieldName::from("b");
257        let other_expr = col(other_col.clone());
258        let not_eq_expr = gt(col(column.clone()), other_expr);
259
260        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
261        assert_eq!(
262            refs.map(),
263            &HashMap::from_iter([
264                (
265                    FieldPath::from_name(column.clone()),
266                    HashSet::from_iter([Stat::Max])
267                ),
268                (
269                    FieldPath::from_name(other_col.clone()),
270                    HashSet::from_iter([Stat::Min])
271                )
272            ])
273        );
274        let expected_expr = lt_eq(
275            col(field_path_stat_field_name(
276                &FieldPath::from_name(column),
277                Stat::Max,
278            )),
279            col(field_path_stat_field_name(
280                &FieldPath::from_name(other_col),
281                Stat::Min,
282            )),
283        );
284        assert_eq!(&converted, &expected_expr);
285    }
286
287    #[rstest]
288    pub fn pruning_gt_value(available_stats: FieldPathSet) {
289        let column = FieldName::from("a");
290        let other_col = lit(42);
291        let not_eq_expr = gt(col(column.clone()), other_col.clone());
292
293        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
294        assert_eq!(
295            refs.map(),
296            &HashMap::from_iter([(
297                FieldPath::from_name(column.clone()),
298                HashSet::from_iter([Stat::Max])
299            ),])
300        );
301        let expected_expr = lt_eq(
302            col(field_path_stat_field_name(
303                &FieldPath::from_name(column),
304                Stat::Max,
305            )),
306            other_col,
307        );
308        assert_eq!(&converted, &(expected_expr));
309    }
310
311    #[rstest]
312    pub fn pruning_lt_column(available_stats: FieldPathSet) {
313        let column = FieldName::from("a");
314        let other_col = FieldName::from("b");
315        let other_expr = col(other_col.clone());
316        let not_eq_expr = lt(col(column.clone()), other_expr);
317
318        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
319        assert_eq!(
320            refs.map(),
321            &HashMap::from_iter([
322                (
323                    FieldPath::from_name(column.clone()),
324                    HashSet::from_iter([Stat::Min])
325                ),
326                (
327                    FieldPath::from_name(other_col.clone()),
328                    HashSet::from_iter([Stat::Max])
329                )
330            ])
331        );
332        let expected_expr = gt_eq(
333            col(field_path_stat_field_name(
334                &FieldPath::from_name(column),
335                Stat::Min,
336            )),
337            col(field_path_stat_field_name(
338                &FieldPath::from_name(other_col),
339                Stat::Max,
340            )),
341        );
342        assert_eq!(&converted, &expected_expr);
343    }
344
345    #[rstest]
346    pub fn pruning_lt_value(available_stats: FieldPathSet) {
347        // expression   => a < 42
348        // pruning expr => a.min >= 42
349        let expr = lt(col("a"), lit(42));
350
351        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
352        assert_eq!(
353            refs.map(),
354            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
355        );
356        assert_eq!(&converted, &gt_eq(col("a_min"), lit(42)));
357    }
358
359    #[rstest]
360    fn pruning_identity(available_stats: FieldPathSet) {
361        let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50)));
362
363        let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
364
365        let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
366        assert_eq!(&predicate.to_string(), &expected_expr.to_string());
367    }
368    #[rstest]
369    pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
370        // Test case: a > 10 AND a < 50
371        let column = FieldName::from("a");
372        let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
373        let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
374
375        // Expected: a_max <= 10 OR a_min >= 50
376        assert_eq!(
377            &predicate,
378            &or(
379                lt_eq(col(FieldName::from("a_max")), lit(10)),
380                gt_eq(col(FieldName::from("a_min")), lit(50)),
381            ),
382        );
383    }
384
385    #[rstest]
386    fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
387        // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression:
388        // x > (y > z)
389        // The x column is a Boolean-valued column. The y and z columns are numeric. True > False.
390        // Suppose we had a Vortex zone whose min/max statistics for each column were:
391        // x: [True, True]
392        // y: [1, 2]
393        // z: [0, 2]
394        // The pruning predicate will convert the aforementioned expression into:
395        // x_max <= (y_min > z_min)
396        // If we evaluate that pruning expression on our zone we get:
397        // x_max <= (y_min > z_min)
398        // x_max <= (1     > 0    )
399        // x_max <= True
400        // True <= True
401        // True
402        // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate:
403        // > a true value means the chunk can be pruned.
404        // But, the following record lies within the above intervals and *passes* the filter expression! We
405        // cannot prune this zone because we need this record!
406        // {x: True, y: 1, z: 2}
407        // x > (y > z)
408        // True > (1 > 2)
409        // True > False
410        // True
411        let expr = gt_eq(col("x"), gt(col("y"), col("z")));
412        assert!(checked_pruning_expr(&expr, &available_stats).is_none());
413        // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)`
414    }
415
416    #[fixture]
417    fn available_stats_with_nans() -> FieldPathSet {
418        let float_col = FieldPath::from_name("float_col");
419        let int_col = FieldPath::from_name("int_col");
420
421        FieldPathSet::from_iter([
422            // Float columns will have a NaNCount.
423            float_col.clone().push(Stat::Min.name()),
424            float_col.clone().push(Stat::Max.name()),
425            float_col.push(Stat::NaNCount.name()),
426            // int columns will not have a NanCount serialized into the layout
427            int_col.clone().push(Stat::Min.name()),
428            int_col.push(Stat::Max.name()),
429        ])
430    }
431
432    #[rstest]
433    fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
434        let expr = gt_eq(col("float_col"), lit(f32::NAN));
435        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
436        assert_eq!(
437            &converted,
438            &and(
439                and(
440                    eq(col("float_col_nan_count"), lit(0u64)),
441                    // NaNCount of NaN is 1
442                    eq(lit(1u64), lit(0u64)),
443                ),
444                // This is the standard conversion of the >= operator. Comparing NAN to a max
445                // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited
446                // by the previous check for nan_count anyway.
447                lt(col("float_col_max"), lit(f32::NAN)),
448            )
449        );
450
451        // One half of the expression requires NAN count check, the other half does not.
452        let expr = and(
453            gt(col("float_col"), lit(10f32)),
454            lt(col("int_col"), lit(10)),
455        );
456
457        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
458
459        assert_eq!(
460            &converted,
461            &or(
462                // NaNCount check is enforced for the float column
463                and(
464                    and(
465                        eq(col("float_col_nan_count"), lit(0u64)),
466                        // NanCount of a non-NaN float literal is 0
467                        eq(lit(0u64), lit(0u64)),
468                    ),
469                    // We want the opposite: we can prune IF either one is false.
470                    lt_eq(col("float_col_max"), lit(10f32)),
471                ),
472                // NanCount check is skipped for the int column
473                gt_eq(col("int_col_min"), lit(10)),
474            )
475        )
476    }
477
478    #[rstest]
479    fn pruning_between(available_stats: FieldPathSet) {
480        let expr = between(
481            col("a"),
482            lit(10),
483            lit(50),
484            BetweenOptions {
485                lower_strict: StrictComparison::NonStrict,
486                upper_strict: StrictComparison::NonStrict,
487            },
488        );
489        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
490        assert_eq!(
491            refs.map(),
492            &HashMap::from_iter([(
493                FieldPath::from_name("a"),
494                HashSet::from_iter([Stat::Min, Stat::Max])
495            )])
496        );
497        assert_eq!(
498            &converted,
499            &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
500        );
501    }
502
503    #[rstest]
504    fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
505        // This test verifies that cast properly forwards analysis methods to
506        // enable pruning.
507        let struct_dtype = DType::Struct(
508            StructFields::new(
509                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
510                vec![
511                    DType::Utf8(Nullability::Nullable),
512                    DType::Utf8(Nullability::Nullable),
513                ],
514            ),
515            Nullability::NonNullable,
516        );
517        let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
518        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
519        assert_eq!(
520            refs.map(),
521            &HashMap::from_iter([(
522                FieldPath::from_name("a"),
523                HashSet::from_iter([Stat::Min, Stat::Max])
524            )])
525        );
526        assert_eq!(
527            &converted,
528            &or(
529                gt(col("a_min"), lit("value")),
530                gt(lit("value"), col("a_max"))
531            )
532        );
533    }
534}