Skip to main content

vortex_array/expr/pruning/
pruning_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cell::RefCell;
5use std::iter;
6
7use itertools::Itertools;
8use vortex_utils::aliases::hash_map::HashMap;
9
10use super::relation::Relation;
11use crate::dtype::Field;
12use crate::dtype::FieldName;
13use crate::dtype::FieldPath;
14use crate::dtype::FieldPathSet;
15use crate::expr::Expression;
16use crate::expr::StatsCatalog;
17use crate::expr::get_item;
18use crate::expr::root;
19use crate::expr::stats::Stat;
20
21pub type RequiredStats = Relation<FieldPath, Stat>;
22
23// A catalog that return a stat column whenever it is required, tracking all accessed
24// stats and returning them later.
25#[derive(Default)]
26pub(crate) struct TrackingStatsCatalog {
27    usage: RefCell<HashMap<(FieldPath, Stat), Expression>>,
28}
29
30impl TrackingStatsCatalog {
31    /// Consume the catalog, yielding a map of field statistics that were required
32    /// for each expression.
33    fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> {
34        self.usage.into_inner()
35    }
36}
37
38// A catalog that return a stat column if it exists in the given scope.
39struct ScopeStatsCatalog<'a> {
40    inner: TrackingStatsCatalog,
41    available_stats: &'a FieldPathSet,
42}
43
44impl StatsCatalog for ScopeStatsCatalog<'_> {
45    fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
46        let stat_path = field_path.clone().push(stat.name());
47
48        if self.available_stats.contains(&stat_path) {
49            self.inner.stats_ref(field_path, stat)
50        } else {
51            None
52        }
53    }
54}
55
56impl StatsCatalog for TrackingStatsCatalog {
57    fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
58        let mut expr = root();
59        let name = field_path_stat_field_name(field_path, stat);
60        expr = get_item(name, expr);
61        self.usage
62            .borrow_mut()
63            .insert((field_path.clone(), stat), expr.clone());
64        Some(expr)
65    }
66}
67
68#[doc(hidden)]
69pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
70    field_path
71        .parts()
72        .iter()
73        .map(|f| match f {
74            Field::Name(n) => n.as_ref(),
75            Field::ElementType => todo!("element type not currently handled"),
76        })
77        .chain(iter::once(stat.name()))
78        .join("_")
79        .into()
80}
81
82/// Build a pruning expr mask, using an existing set of stats.
83/// The available stats are provided as a set of [`FieldPath`].
84///
85/// A pruning expression is one that returns `true` for all positions where the original expression
86/// cannot hold, and false if it cannot be determined from stats alone whether the positions can
87/// be pruned.
88///
89/// Some rewrites, such as `is_not_null(...)`, emit
90/// [`row_count`][crate::scalar_fn::internal::row_count] placeholders. The evaluation layer must
91/// replace those placeholders with the row count for its current scope before
92/// executing the returned expression.
93///
94/// If the falsification logic attempts to access an unknown stat,
95/// this function will return `None`.
96pub fn checked_pruning_expr(
97    expr: &Expression,
98    available_stats: &FieldPathSet,
99) -> Option<(Expression, RequiredStats)> {
100    let catalog = ScopeStatsCatalog {
101        inner: Default::default(),
102        available_stats,
103    };
104
105    let expr = expr.stat_falsification(&catalog)?;
106
107    // TODO(joe): filter access by used exprs
108    let mut relation: Relation<FieldPath, Stat> = Relation::new();
109    for ((field_path, stat), _) in catalog.inner.into_usages() {
110        relation.insert(field_path, stat)
111    }
112
113    Some((expr, relation))
114}
115
116#[cfg(test)]
117mod tests {
118    use rstest::fixture;
119    use rstest::rstest;
120    use vortex_utils::aliases::hash_set::HashSet;
121
122    use super::HashMap;
123    use crate::dtype::DType;
124    use crate::dtype::FieldName;
125    use crate::dtype::FieldNames;
126    use crate::dtype::FieldPath;
127    use crate::dtype::FieldPathSet;
128    use crate::dtype::Nullability;
129    use crate::dtype::StructFields;
130    use crate::expr::and;
131    use crate::expr::between;
132    use crate::expr::cast;
133    use crate::expr::col;
134    use crate::expr::eq;
135    use crate::expr::get_item;
136    use crate::expr::gt;
137    use crate::expr::gt_eq;
138    use crate::expr::lit;
139    use crate::expr::lt;
140    use crate::expr::lt_eq;
141    use crate::expr::not_eq;
142    use crate::expr::or;
143    use crate::expr::pruning::checked_pruning_expr;
144    use crate::expr::pruning::field_path_stat_field_name;
145    use crate::expr::root;
146    use crate::expr::stats::Stat;
147    use crate::scalar_fn::fns::between::BetweenOptions;
148    use crate::scalar_fn::fns::between::StrictComparison;
149
150    // Implement some checked pruning expressions.
151    #[fixture]
152    fn available_stats() -> FieldPathSet {
153        let field_a = FieldPath::from_name("a");
154        let field_b = FieldPath::from_name("b");
155
156        FieldPathSet::from_iter([
157            field_a.clone().push(Stat::Min.name()),
158            field_a.push(Stat::Max.name()),
159            field_b.clone().push(Stat::Min.name()),
160            field_b.push(Stat::Max.name()),
161        ])
162    }
163
164    #[rstest]
165    pub fn pruning_equals(available_stats: FieldPathSet) {
166        let name = FieldName::from("a");
167        let literal_eq = lit(42);
168        let eq_expr = eq(get_item("a", root()), literal_eq.clone());
169        let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
170        let expected_expr = or(
171            gt(
172                get_item(
173                    field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
174                    root(),
175                ),
176                literal_eq.clone(),
177            ),
178            gt(
179                literal_eq,
180                col(field_path_stat_field_name(
181                    &FieldPath::from_name(name),
182                    Stat::Max,
183                )),
184            ),
185        );
186        assert_eq!(&converted, &expected_expr);
187    }
188
189    #[rstest]
190    pub fn pruning_equals_column(available_stats: FieldPathSet) {
191        let column = FieldName::from("a");
192        let other_col = FieldName::from("b");
193        let eq_expr = eq(col(column.clone()), col(other_col.clone()));
194
195        let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
196        assert_eq!(
197            refs.map(),
198            &HashMap::from_iter([
199                (
200                    FieldPath::from_name(column.clone()),
201                    HashSet::from_iter([Stat::Min, Stat::Max])
202                ),
203                (
204                    FieldPath::from_name(other_col.clone()),
205                    HashSet::from_iter([Stat::Max, Stat::Min])
206                )
207            ])
208        );
209        let expected_expr = or(
210            gt(
211                col(field_path_stat_field_name(
212                    &FieldPath::from_name(column.clone()),
213                    Stat::Min,
214                )),
215                col(field_path_stat_field_name(
216                    &FieldPath::from_name(other_col.clone()),
217                    Stat::Max,
218                )),
219            ),
220            gt(
221                col(field_path_stat_field_name(
222                    &FieldPath::from_name(other_col),
223                    Stat::Min,
224                )),
225                col(field_path_stat_field_name(
226                    &FieldPath::from_name(column),
227                    Stat::Max,
228                )),
229            ),
230        );
231        assert_eq!(&converted, &expected_expr);
232    }
233
234    #[rstest]
235    pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
236        let column = FieldName::from("a");
237        let other_col = FieldName::from("b");
238        let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
239
240        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
241        assert_eq!(
242            refs.map(),
243            &HashMap::from_iter([
244                (
245                    FieldPath::from_name(column.clone()),
246                    HashSet::from_iter([Stat::Min, Stat::Max])
247                ),
248                (
249                    FieldPath::from_name(other_col.clone()),
250                    HashSet::from_iter([Stat::Max, Stat::Min])
251                )
252            ])
253        );
254        let expected_expr = and(
255            eq(
256                col(field_path_stat_field_name(
257                    &FieldPath::from_name(column.clone()),
258                    Stat::Min,
259                )),
260                col(field_path_stat_field_name(
261                    &FieldPath::from_name(other_col.clone()),
262                    Stat::Max,
263                )),
264            ),
265            eq(
266                col(field_path_stat_field_name(
267                    &FieldPath::from_name(column),
268                    Stat::Max,
269                )),
270                col(field_path_stat_field_name(
271                    &FieldPath::from_name(other_col),
272                    Stat::Min,
273                )),
274            ),
275        );
276
277        assert_eq!(&converted, &expected_expr);
278    }
279
280    #[rstest]
281    pub fn pruning_gt_column(available_stats: FieldPathSet) {
282        let column = FieldName::from("a");
283        let other_col = FieldName::from("b");
284        let other_expr = col(other_col.clone());
285        let not_eq_expr = gt(col(column.clone()), other_expr);
286
287        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
288        assert_eq!(
289            refs.map(),
290            &HashMap::from_iter([
291                (
292                    FieldPath::from_name(column.clone()),
293                    HashSet::from_iter([Stat::Max])
294                ),
295                (
296                    FieldPath::from_name(other_col.clone()),
297                    HashSet::from_iter([Stat::Min])
298                )
299            ])
300        );
301        let expected_expr = lt_eq(
302            col(field_path_stat_field_name(
303                &FieldPath::from_name(column),
304                Stat::Max,
305            )),
306            col(field_path_stat_field_name(
307                &FieldPath::from_name(other_col),
308                Stat::Min,
309            )),
310        );
311        assert_eq!(&converted, &expected_expr);
312    }
313
314    #[rstest]
315    pub fn pruning_gt_value(available_stats: FieldPathSet) {
316        let column = FieldName::from("a");
317        let other_col = lit(42);
318        let not_eq_expr = gt(col(column.clone()), other_col.clone());
319
320        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
321        assert_eq!(
322            refs.map(),
323            &HashMap::from_iter([(
324                FieldPath::from_name(column.clone()),
325                HashSet::from_iter([Stat::Max])
326            ),])
327        );
328        let expected_expr = lt_eq(
329            col(field_path_stat_field_name(
330                &FieldPath::from_name(column),
331                Stat::Max,
332            )),
333            other_col,
334        );
335        assert_eq!(&converted, &(expected_expr));
336    }
337
338    #[rstest]
339    pub fn pruning_lt_column(available_stats: FieldPathSet) {
340        let column = FieldName::from("a");
341        let other_col = FieldName::from("b");
342        let other_expr = col(other_col.clone());
343        let not_eq_expr = lt(col(column.clone()), other_expr);
344
345        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
346        assert_eq!(
347            refs.map(),
348            &HashMap::from_iter([
349                (
350                    FieldPath::from_name(column.clone()),
351                    HashSet::from_iter([Stat::Min])
352                ),
353                (
354                    FieldPath::from_name(other_col.clone()),
355                    HashSet::from_iter([Stat::Max])
356                )
357            ])
358        );
359        let expected_expr = gt_eq(
360            col(field_path_stat_field_name(
361                &FieldPath::from_name(column),
362                Stat::Min,
363            )),
364            col(field_path_stat_field_name(
365                &FieldPath::from_name(other_col),
366                Stat::Max,
367            )),
368        );
369        assert_eq!(&converted, &expected_expr);
370    }
371
372    #[rstest]
373    pub fn pruning_lt_value(available_stats: FieldPathSet) {
374        // expression   => a < 42
375        // pruning expr => a.min >= 42
376        let expr = lt(col("a"), lit(42));
377
378        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
379        assert_eq!(
380            refs.map(),
381            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
382        );
383        assert_eq!(&converted, &gt_eq(col("a_min"), lit(42)));
384    }
385
386    #[rstest]
387    fn pruning_identity(available_stats: FieldPathSet) {
388        let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50)));
389
390        let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
391
392        let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
393        assert_eq!(&predicate.to_string(), &expected_expr.to_string());
394    }
395    #[rstest]
396    pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
397        // Test case: a > 10 AND a < 50
398        let column = FieldName::from("a");
399        let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
400        let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
401
402        // Expected: a_max <= 10 OR a_min >= 50
403        assert_eq!(
404            &predicate,
405            &or(
406                lt_eq(col(FieldName::from("a_max")), lit(10)),
407                gt_eq(col(FieldName::from("a_min")), lit(50)),
408            ),
409        );
410    }
411
412    #[rstest]
413    fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
414        // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression:
415        // x > (y > z)
416        // The x column is a Boolean-valued column. The y and z columns are numeric. True > False.
417        // Suppose we had a Vortex zone whose min/max statistics for each column were:
418        // x: [True, True]
419        // y: [1, 2]
420        // z: [0, 2]
421        // The pruning predicate will convert the aforementioned expression into:
422        // x_max <= (y_min > z_min)
423        // If we evaluate that pruning expression on our zone we get:
424        // x_max <= (y_min > z_min)
425        // x_max <= (1     > 0    )
426        // x_max <= True
427        // True <= True
428        // True
429        // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate:
430        // > a true value means the chunk can be pruned.
431        // But, the following record lies within the above intervals and *passes* the filter expression! We
432        // cannot prune this zone because we need this record!
433        // {x: True, y: 1, z: 2}
434        // x > (y > z)
435        // True > (1 > 2)
436        // True > False
437        // True
438        let expr = gt_eq(col("x"), gt(col("y"), col("z")));
439        assert!(checked_pruning_expr(&expr, &available_stats).is_none());
440        // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)`
441    }
442
443    #[fixture]
444    fn available_stats_with_nans() -> FieldPathSet {
445        let float_col = FieldPath::from_name("float_col");
446        let int_col = FieldPath::from_name("int_col");
447
448        FieldPathSet::from_iter([
449            // Float columns will have a NaNCount.
450            float_col.clone().push(Stat::Min.name()),
451            float_col.clone().push(Stat::Max.name()),
452            float_col.push(Stat::NaNCount.name()),
453            // int columns will not have a NanCount serialized into the layout
454            int_col.clone().push(Stat::Min.name()),
455            int_col.push(Stat::Max.name()),
456        ])
457    }
458
459    #[rstest]
460    fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
461        let expr = gt_eq(col("float_col"), lit(f32::NAN));
462        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
463        assert_eq!(
464            &converted,
465            &and(
466                and(
467                    eq(col("float_col_nan_count"), lit(0u64)),
468                    // NaNCount of NaN is 1
469                    eq(lit(1u64), lit(0u64)),
470                ),
471                // This is the standard conversion of the >= operator. Comparing NAN to a max
472                // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited
473                // by the previous check for nan_count anyway.
474                lt(col("float_col_max"), lit(f32::NAN)),
475            )
476        );
477
478        // One half of the expression requires NAN count check, the other half does not.
479        let expr = and(
480            gt(col("float_col"), lit(10f32)),
481            lt(col("int_col"), lit(10)),
482        );
483
484        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
485
486        assert_eq!(
487            &converted,
488            &or(
489                // NaNCount check is enforced for the float column
490                and(
491                    and(
492                        eq(col("float_col_nan_count"), lit(0u64)),
493                        // NanCount of a non-NaN float literal is 0
494                        eq(lit(0u64), lit(0u64)),
495                    ),
496                    // We want the opposite: we can prune IF either one is false.
497                    lt_eq(col("float_col_max"), lit(10f32)),
498                ),
499                // NanCount check is skipped for the int column
500                gt_eq(col("int_col_min"), lit(10)),
501            )
502        )
503    }
504
505    #[rstest]
506    fn pruning_between(available_stats: FieldPathSet) {
507        let expr = between(
508            col("a"),
509            lit(10),
510            lit(50),
511            BetweenOptions {
512                lower_strict: StrictComparison::NonStrict,
513                upper_strict: StrictComparison::NonStrict,
514            },
515        );
516        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
517        assert_eq!(
518            refs.map(),
519            &HashMap::from_iter([(
520                FieldPath::from_name("a"),
521                HashSet::from_iter([Stat::Min, Stat::Max])
522            )])
523        );
524        assert_eq!(
525            &converted,
526            &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
527        );
528    }
529
530    #[rstest]
531    fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
532        // This test verifies that cast properly forwards analysis methods to
533        // enable pruning.
534        let struct_dtype = DType::Struct(
535            StructFields::new(
536                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
537                vec![
538                    DType::Utf8(Nullability::Nullable),
539                    DType::Utf8(Nullability::Nullable),
540                ],
541            ),
542            Nullability::NonNullable,
543        );
544        let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
545        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
546        assert_eq!(
547            refs.map(),
548            &HashMap::from_iter([(
549                FieldPath::from_name("a"),
550                HashSet::from_iter([Stat::Min, Stat::Max])
551            )])
552        );
553        assert_eq!(
554            &converted,
555            &or(
556                gt(col("a_min"), lit("value")),
557                gt(lit("value"), col("a_max"))
558            )
559        );
560    }
561}