vortex_array/expr/pruning/
pruning_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cell::RefCell;
5use std::iter;
6
7use itertools::Itertools;
8use vortex_dtype::Field;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldPath;
11use vortex_dtype::FieldPathSet;
12use vortex_utils::aliases::hash_map::HashMap;
13
14use super::relation::Relation;
15use crate::expr::Expression;
16use crate::expr::StatsCatalog;
17use crate::expr::exprs::get_item::get_item;
18use crate::expr::exprs::root::root;
19use crate::expr::stats::Stat;
20
21pub type RequiredStats = Relation<FieldPath, Stat>;
22
23// A catalog that return a stat column whenever it is required, tracking all accessed
24// stats and returning them later.
25#[derive(Default)]
26struct TrackingStatsCatalog {
27    usage: RefCell<HashMap<(FieldPath, Stat), Expression>>,
28}
29
30impl TrackingStatsCatalog {
31    /// Consume the catalog, yielding a map of field statistics that were required
32    /// for each expression.
33    fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> {
34        self.usage.into_inner()
35    }
36}
37
38// A catalog that return a stat column if it exists in the given scope.
39struct ScopeStatsCatalog<'a> {
40    any_catalog: TrackingStatsCatalog,
41    available_stats: &'a FieldPathSet,
42}
43
44impl StatsCatalog for ScopeStatsCatalog<'_> {
45    fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
46        let stat_path = field_path.clone().push(stat.name());
47
48        if self.available_stats.contains(&stat_path) {
49            self.any_catalog.stats_ref(field_path, stat)
50        } else {
51            None
52        }
53    }
54}
55
56impl StatsCatalog for TrackingStatsCatalog {
57    fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
58        let mut expr = root();
59        let name = field_path_stat_field_name(field_path, stat);
60        expr = get_item(name, expr);
61        self.usage
62            .borrow_mut()
63            .insert((field_path.clone(), stat), expr.clone());
64        Some(expr)
65    }
66}
67
68#[doc(hidden)]
69pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
70    field_path
71        .parts()
72        .iter()
73        .map(|f| match f {
74            Field::Name(n) => n.as_ref(),
75            Field::ElementType => todo!("element type not currently handled"),
76        })
77        .chain(iter::once(stat.name()))
78        .join("_")
79        .into()
80}
81
82/// Build a pruning expr mask, using an existing set of stats.
83/// The available stats are provided as a set of [`FieldPath`].
84///
85/// A pruning expression is one that returns `true` for all positions where the original expression
86/// cannot hold, and false if it cannot be determined from stats alone whether the positions can
87/// be pruned.
88///
89/// If the falsification logic attempts to access an unknown stat,
90/// this function will return `None`.
91pub fn checked_pruning_expr(
92    expr: &Expression,
93    available_stats: &FieldPathSet,
94) -> Option<(Expression, RequiredStats)> {
95    let catalog = ScopeStatsCatalog {
96        any_catalog: Default::default(),
97        available_stats,
98    };
99
100    let expr = expr.stat_falsification(&catalog)?;
101
102    // TODO(joe): filter access by used exprs
103    let mut relation: Relation<FieldPath, Stat> = Relation::new();
104    for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
105        relation.insert(field_path, stat)
106    }
107
108    Some((expr, relation))
109}
110
111#[cfg(test)]
112mod tests {
113    use rstest::fixture;
114    use rstest::rstest;
115    use vortex_dtype::DType;
116    use vortex_dtype::FieldName;
117    use vortex_dtype::FieldNames;
118    use vortex_dtype::FieldPath;
119    use vortex_dtype::FieldPathSet;
120    use vortex_dtype::Nullability;
121    use vortex_dtype::StructFields;
122    use vortex_utils::aliases::hash_set::HashSet;
123
124    use super::HashMap;
125    use crate::compute::BetweenOptions;
126    use crate::compute::StrictComparison;
127    use crate::expr::exprs::between::between;
128    use crate::expr::exprs::binary::and;
129    use crate::expr::exprs::binary::eq;
130    use crate::expr::exprs::binary::gt;
131    use crate::expr::exprs::binary::gt_eq;
132    use crate::expr::exprs::binary::lt;
133    use crate::expr::exprs::binary::lt_eq;
134    use crate::expr::exprs::binary::not_eq;
135    use crate::expr::exprs::binary::or;
136    use crate::expr::exprs::cast::cast;
137    use crate::expr::exprs::get_item::col;
138    use crate::expr::exprs::get_item::get_item;
139    use crate::expr::exprs::literal::lit;
140    use crate::expr::exprs::root::root;
141    use crate::expr::pruning::checked_pruning_expr;
142    use crate::expr::pruning::field_path_stat_field_name;
143    use crate::expr::stats::Stat;
144
145    // Implement some checked pruning expressions.
146    #[fixture]
147    fn available_stats() -> FieldPathSet {
148        let field_a = FieldPath::from_name("a");
149        let field_b = FieldPath::from_name("b");
150
151        FieldPathSet::from_iter([
152            field_a.clone().push(Stat::Min.name()),
153            field_a.push(Stat::Max.name()),
154            field_b.clone().push(Stat::Min.name()),
155            field_b.push(Stat::Max.name()),
156        ])
157    }
158
159    #[rstest]
160    pub fn pruning_equals(available_stats: FieldPathSet) {
161        let name = FieldName::from("a");
162        let literal_eq = lit(42);
163        let eq_expr = eq(get_item("a", root()), literal_eq.clone());
164        let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
165        let expected_expr = or(
166            gt(
167                get_item(
168                    field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
169                    root(),
170                ),
171                literal_eq.clone(),
172            ),
173            gt(
174                literal_eq,
175                col(field_path_stat_field_name(
176                    &FieldPath::from_name(name),
177                    Stat::Max,
178                )),
179            ),
180        );
181        assert_eq!(&converted, &expected_expr);
182    }
183
184    #[rstest]
185    pub fn pruning_equals_column(available_stats: FieldPathSet) {
186        let column = FieldName::from("a");
187        let other_col = FieldName::from("b");
188        let eq_expr = eq(col(column.clone()), col(other_col.clone()));
189
190        let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
191        assert_eq!(
192            refs.map(),
193            &HashMap::from_iter([
194                (
195                    FieldPath::from_name(column.clone()),
196                    HashSet::from_iter([Stat::Min, Stat::Max])
197                ),
198                (
199                    FieldPath::from_name(other_col.clone()),
200                    HashSet::from_iter([Stat::Max, Stat::Min])
201                )
202            ])
203        );
204        let expected_expr = or(
205            gt(
206                col(field_path_stat_field_name(
207                    &FieldPath::from_name(column.clone()),
208                    Stat::Min,
209                )),
210                col(field_path_stat_field_name(
211                    &FieldPath::from_name(other_col.clone()),
212                    Stat::Max,
213                )),
214            ),
215            gt(
216                col(field_path_stat_field_name(
217                    &FieldPath::from_name(other_col),
218                    Stat::Min,
219                )),
220                col(field_path_stat_field_name(
221                    &FieldPath::from_name(column),
222                    Stat::Max,
223                )),
224            ),
225        );
226        assert_eq!(&converted, &expected_expr);
227    }
228
229    #[rstest]
230    pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
231        let column = FieldName::from("a");
232        let other_col = FieldName::from("b");
233        let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
234
235        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
236        assert_eq!(
237            refs.map(),
238            &HashMap::from_iter([
239                (
240                    FieldPath::from_name(column.clone()),
241                    HashSet::from_iter([Stat::Min, Stat::Max])
242                ),
243                (
244                    FieldPath::from_name(other_col.clone()),
245                    HashSet::from_iter([Stat::Max, Stat::Min])
246                )
247            ])
248        );
249        let expected_expr = and(
250            eq(
251                col(field_path_stat_field_name(
252                    &FieldPath::from_name(column.clone()),
253                    Stat::Min,
254                )),
255                col(field_path_stat_field_name(
256                    &FieldPath::from_name(other_col.clone()),
257                    Stat::Max,
258                )),
259            ),
260            eq(
261                col(field_path_stat_field_name(
262                    &FieldPath::from_name(column),
263                    Stat::Max,
264                )),
265                col(field_path_stat_field_name(
266                    &FieldPath::from_name(other_col),
267                    Stat::Min,
268                )),
269            ),
270        );
271
272        assert_eq!(&converted, &expected_expr);
273    }
274
275    #[rstest]
276    pub fn pruning_gt_column(available_stats: FieldPathSet) {
277        let column = FieldName::from("a");
278        let other_col = FieldName::from("b");
279        let other_expr = col(other_col.clone());
280        let not_eq_expr = gt(col(column.clone()), other_expr);
281
282        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
283        assert_eq!(
284            refs.map(),
285            &HashMap::from_iter([
286                (
287                    FieldPath::from_name(column.clone()),
288                    HashSet::from_iter([Stat::Max])
289                ),
290                (
291                    FieldPath::from_name(other_col.clone()),
292                    HashSet::from_iter([Stat::Min])
293                )
294            ])
295        );
296        let expected_expr = lt_eq(
297            col(field_path_stat_field_name(
298                &FieldPath::from_name(column),
299                Stat::Max,
300            )),
301            col(field_path_stat_field_name(
302                &FieldPath::from_name(other_col),
303                Stat::Min,
304            )),
305        );
306        assert_eq!(&converted, &expected_expr);
307    }
308
309    #[rstest]
310    pub fn pruning_gt_value(available_stats: FieldPathSet) {
311        let column = FieldName::from("a");
312        let other_col = lit(42);
313        let not_eq_expr = gt(col(column.clone()), other_col.clone());
314
315        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
316        assert_eq!(
317            refs.map(),
318            &HashMap::from_iter([(
319                FieldPath::from_name(column.clone()),
320                HashSet::from_iter([Stat::Max])
321            ),])
322        );
323        let expected_expr = lt_eq(
324            col(field_path_stat_field_name(
325                &FieldPath::from_name(column),
326                Stat::Max,
327            )),
328            other_col,
329        );
330        assert_eq!(&converted, &(expected_expr));
331    }
332
333    #[rstest]
334    pub fn pruning_lt_column(available_stats: FieldPathSet) {
335        let column = FieldName::from("a");
336        let other_col = FieldName::from("b");
337        let other_expr = col(other_col.clone());
338        let not_eq_expr = lt(col(column.clone()), other_expr);
339
340        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
341        assert_eq!(
342            refs.map(),
343            &HashMap::from_iter([
344                (
345                    FieldPath::from_name(column.clone()),
346                    HashSet::from_iter([Stat::Min])
347                ),
348                (
349                    FieldPath::from_name(other_col.clone()),
350                    HashSet::from_iter([Stat::Max])
351                )
352            ])
353        );
354        let expected_expr = gt_eq(
355            col(field_path_stat_field_name(
356                &FieldPath::from_name(column),
357                Stat::Min,
358            )),
359            col(field_path_stat_field_name(
360                &FieldPath::from_name(other_col),
361                Stat::Max,
362            )),
363        );
364        assert_eq!(&converted, &expected_expr);
365    }
366
367    #[rstest]
368    pub fn pruning_lt_value(available_stats: FieldPathSet) {
369        // expression   => a < 42
370        // pruning expr => a.min >= 42
371        let expr = lt(col("a"), lit(42));
372
373        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
374        assert_eq!(
375            refs.map(),
376            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
377        );
378        assert_eq!(&converted, &gt_eq(col("a_min"), lit(42)));
379    }
380
381    #[rstest]
382    fn pruning_identity(available_stats: FieldPathSet) {
383        let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50)));
384
385        let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
386
387        let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
388        assert_eq!(&predicate.to_string(), &expected_expr.to_string());
389    }
390    #[rstest]
391    pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
392        // Test case: a > 10 AND a < 50
393        let column = FieldName::from("a");
394        let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
395        let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
396
397        // Expected: a_max <= 10 OR a_min >= 50
398        assert_eq!(
399            &predicate,
400            &or(
401                lt_eq(col(FieldName::from("a_max")), lit(10)),
402                gt_eq(col(FieldName::from("a_min")), lit(50)),
403            ),
404        );
405    }
406
407    #[rstest]
408    fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
409        // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression:
410        // x > (y > z)
411        // The x column is a Boolean-valued column. The y and z columns are numeric. True > False.
412        // Suppose we had a Vortex zone whose min/max statistics for each column were:
413        // x: [True, True]
414        // y: [1, 2]
415        // z: [0, 2]
416        // The pruning predicate will convert the aforementioned expression into:
417        // x_max <= (y_min > z_min)
418        // If we evaluate that pruning expression on our zone we get:
419        // x_max <= (y_min > z_min)
420        // x_max <= (1     > 0    )
421        // x_max <= True
422        // True <= True
423        // True
424        // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate:
425        // > a true value means the chunk can be pruned.
426        // But, the following record lies within the above intervals and *passes* the filter expression! We
427        // cannot prune this zone because we need this record!
428        // {x: True, y: 1, z: 2}
429        // x > (y > z)
430        // True > (1 > 2)
431        // True > False
432        // True
433        let expr = gt_eq(col("x"), gt(col("y"), col("z")));
434        assert!(checked_pruning_expr(&expr, &available_stats).is_none());
435        // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)`
436    }
437
438    #[fixture]
439    fn available_stats_with_nans() -> FieldPathSet {
440        let float_col = FieldPath::from_name("float_col");
441        let int_col = FieldPath::from_name("int_col");
442
443        FieldPathSet::from_iter([
444            // Float columns will have a NaNCount.
445            float_col.clone().push(Stat::Min.name()),
446            float_col.clone().push(Stat::Max.name()),
447            float_col.push(Stat::NaNCount.name()),
448            // int columns will not have a NanCount serialized into the layout
449            int_col.clone().push(Stat::Min.name()),
450            int_col.push(Stat::Max.name()),
451        ])
452    }
453
454    #[rstest]
455    fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
456        let expr = gt_eq(col("float_col"), lit(f32::NAN));
457        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
458        assert_eq!(
459            &converted,
460            &and(
461                and(
462                    eq(col("float_col_nan_count"), lit(0u64)),
463                    // NaNCount of NaN is 1
464                    eq(lit(1u64), lit(0u64)),
465                ),
466                // This is the standard conversion of the >= operator. Comparing NAN to a max
467                // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited
468                // by the previous check for nan_count anyway.
469                lt(col("float_col_max"), lit(f32::NAN)),
470            )
471        );
472
473        // One half of the expression requires NAN count check, the other half does not.
474        let expr = and(
475            gt(col("float_col"), lit(10f32)),
476            lt(col("int_col"), lit(10)),
477        );
478
479        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
480
481        assert_eq!(
482            &converted,
483            &or(
484                // NaNCount check is enforced for the float column
485                and(
486                    and(
487                        eq(col("float_col_nan_count"), lit(0u64)),
488                        // NanCount of a non-NaN float literal is 0
489                        eq(lit(0u64), lit(0u64)),
490                    ),
491                    // We want the opposite: we can prune IF either one is false.
492                    lt_eq(col("float_col_max"), lit(10f32)),
493                ),
494                // NanCount check is skipped for the int column
495                gt_eq(col("int_col_min"), lit(10)),
496            )
497        )
498    }
499
500    #[rstest]
501    fn pruning_between(available_stats: FieldPathSet) {
502        let expr = between(
503            col("a"),
504            lit(10),
505            lit(50),
506            BetweenOptions {
507                lower_strict: StrictComparison::NonStrict,
508                upper_strict: StrictComparison::NonStrict,
509            },
510        );
511        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
512        assert_eq!(
513            refs.map(),
514            &HashMap::from_iter([(
515                FieldPath::from_name("a"),
516                HashSet::from_iter([Stat::Min, Stat::Max])
517            )])
518        );
519        assert_eq!(
520            &converted,
521            &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
522        );
523    }
524
525    #[rstest]
526    fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
527        // This test verifies that cast properly forwards analysis methods to
528        // enable pruning.
529        let struct_dtype = DType::Struct(
530            StructFields::new(
531                FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
532                vec![
533                    DType::Utf8(Nullability::Nullable),
534                    DType::Utf8(Nullability::Nullable),
535                ],
536            ),
537            Nullability::NonNullable,
538        );
539        let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
540        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
541        assert_eq!(
542            refs.map(),
543            &HashMap::from_iter([(
544                FieldPath::from_name("a"),
545                HashSet::from_iter([Stat::Min, Stat::Max])
546            )])
547        );
548        assert_eq!(
549            &converted,
550            &or(
551                gt(col("a_min"), lit("value")),
552                gt(lit("value"), col("a_max"))
553            )
554        );
555    }
556}