vortex_expr/pruning/
pruning_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5
6use itertools::Itertools;
7use vortex_array::stats::Stat;
8use vortex_dtype::{Field, FieldName, FieldPath, FieldPathSet};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use super::relation::Relation;
12use crate::{ExprRef, StatsCatalog, get_item, root};
13
14pub type RequiredStats = Relation<FieldPath, Stat>;
15
16// A catalog that return a stat column whenever it is required, tracking all accessed
17// stats and returning them later.
18#[derive(Default)]
19struct TrackingStatsCatalog {
20    usage: HashMap<(FieldPath, Stat), ExprRef>,
21}
22
23impl TrackingStatsCatalog {
24    /// Consume the catalog, yielding a map of field statistics that were required
25    /// for each expression.
26    fn into_usages(self) -> HashMap<(FieldPath, Stat), ExprRef> {
27        self.usage
28    }
29}
30
31// A catalog that return a stat column if it exists in the given scope.
32struct ScopeStatsCatalog<'a> {
33    any_catalog: TrackingStatsCatalog,
34    available_stats: &'a FieldPathSet,
35}
36
37impl StatsCatalog for ScopeStatsCatalog<'_> {
38    fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
39        let stat_path = field_path.clone().push(stat.name());
40
41        if self.available_stats.contains(&stat_path) {
42            self.any_catalog.stats_ref(field_path, stat)
43        } else {
44            None
45        }
46    }
47}
48
49impl StatsCatalog for TrackingStatsCatalog {
50    fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
51        let mut expr = root();
52        let name = field_path_stat_field_name(field_path, stat);
53        expr = get_item(name, expr);
54        self.usage.insert((field_path.clone(), stat), expr.clone());
55        Some(expr)
56    }
57}
58
59#[doc(hidden)]
60pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
61    field_path
62        .parts()
63        .iter()
64        .map(|f| match f {
65            Field::Name(n) => n.as_ref(),
66            Field::ElementType => todo!("element type not currently handled"),
67        })
68        .chain(iter::once(stat.name()))
69        .join("_")
70        .into()
71}
72
73/// Build a pruning expr mask, using an existing set of stats.
74/// The available stats are provided as a set of [`FieldPath`].
75///
76/// A pruning expression is one that returns `true` for all positions where the original expression
77/// cannot hold, and false if it cannot be determined from stats alone whether the positions can
78/// be pruned.
79///
80/// If the falsification logic attempts to access an unknown stat,
81/// this function will return `None`.
82pub fn checked_pruning_expr(
83    expr: &ExprRef,
84    available_stats: &FieldPathSet,
85) -> Option<(ExprRef, RequiredStats)> {
86    let mut catalog = ScopeStatsCatalog {
87        any_catalog: Default::default(),
88        available_stats,
89    };
90
91    let expr = expr.stat_falsification(&mut catalog)?;
92
93    // TODO(joe): filter access by used exprs
94    let mut relation: Relation<FieldPath, Stat> = Relation::new();
95    for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
96        relation.insert(field_path, stat)
97    }
98
99    Some((expr, relation))
100}
101
102#[cfg(test)]
103mod tests {
104    use rstest::{fixture, rstest};
105    use vortex_array::compute::{BetweenOptions, StrictComparison};
106    use vortex_array::stats::Stat;
107    use vortex_dtype::{FieldName, FieldPath, FieldPathSet};
108
109    use crate::pruning::pruning_expr::HashMap;
110    use crate::pruning::{checked_pruning_expr, field_path_stat_field_name};
111    use crate::{
112        HashSet, and, between, col, eq, get_item, gt, gt_eq, lit, lt, lt_eq, not_eq, or, root,
113    };
114
115    // Implement some checked pruning expressions.
116    #[fixture]
117    fn available_stats() -> FieldPathSet {
118        let field_a = FieldPath::from_name("a");
119        let field_b = FieldPath::from_name("b");
120
121        FieldPathSet::from_iter([
122            field_a.clone().push(Stat::Min.name()),
123            field_a.push(Stat::Max.name()),
124            field_b.clone().push(Stat::Min.name()),
125            field_b.push(Stat::Max.name()),
126        ])
127    }
128
129    #[rstest]
130    pub fn pruning_equals(available_stats: FieldPathSet) {
131        let name = FieldName::from("a");
132        let literal_eq = lit(42);
133        let eq_expr = eq(get_item("a", root()), literal_eq.clone());
134        let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
135        let expected_expr = or(
136            gt(
137                get_item(
138                    field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
139                    root(),
140                ),
141                literal_eq.clone(),
142            ),
143            gt(
144                literal_eq,
145                col(field_path_stat_field_name(
146                    &FieldPath::from_name(name),
147                    Stat::Max,
148                )),
149            ),
150        );
151        assert_eq!(&converted, &expected_expr);
152    }
153
154    #[rstest]
155    pub fn pruning_equals_column(available_stats: FieldPathSet) {
156        let column = FieldName::from("a");
157        let other_col = FieldName::from("b");
158        let eq_expr = eq(col(column.clone()), col(other_col.clone()));
159
160        let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
161        assert_eq!(
162            refs.map(),
163            &HashMap::from_iter([
164                (
165                    FieldPath::from_name(column.clone()),
166                    HashSet::from_iter([Stat::Min, Stat::Max])
167                ),
168                (
169                    FieldPath::from_name(other_col.clone()),
170                    HashSet::from_iter([Stat::Max, Stat::Min])
171                )
172            ])
173        );
174        let expected_expr = or(
175            gt(
176                col(field_path_stat_field_name(
177                    &FieldPath::from_name(column.clone()),
178                    Stat::Min,
179                )),
180                col(field_path_stat_field_name(
181                    &FieldPath::from_name(other_col.clone()),
182                    Stat::Max,
183                )),
184            ),
185            gt(
186                col(field_path_stat_field_name(
187                    &FieldPath::from_name(other_col),
188                    Stat::Min,
189                )),
190                col(field_path_stat_field_name(
191                    &FieldPath::from_name(column),
192                    Stat::Max,
193                )),
194            ),
195        );
196        assert_eq!(&converted, &expected_expr);
197    }
198
199    #[rstest]
200    pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
201        let column = FieldName::from("a");
202        let other_col = FieldName::from("b");
203        let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
204
205        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
206        assert_eq!(
207            refs.map(),
208            &HashMap::from_iter([
209                (
210                    FieldPath::from_name(column.clone()),
211                    HashSet::from_iter([Stat::Min, Stat::Max])
212                ),
213                (
214                    FieldPath::from_name(other_col.clone()),
215                    HashSet::from_iter([Stat::Max, Stat::Min])
216                )
217            ])
218        );
219        let expected_expr = and(
220            eq(
221                col(field_path_stat_field_name(
222                    &FieldPath::from_name(column.clone()),
223                    Stat::Min,
224                )),
225                col(field_path_stat_field_name(
226                    &FieldPath::from_name(other_col.clone()),
227                    Stat::Max,
228                )),
229            ),
230            eq(
231                col(field_path_stat_field_name(
232                    &FieldPath::from_name(column),
233                    Stat::Max,
234                )),
235                col(field_path_stat_field_name(
236                    &FieldPath::from_name(other_col),
237                    Stat::Min,
238                )),
239            ),
240        );
241
242        assert_eq!(&converted, &expected_expr);
243    }
244
245    #[rstest]
246    pub fn pruning_gt_column(available_stats: FieldPathSet) {
247        let column = FieldName::from("a");
248        let other_col = FieldName::from("b");
249        let other_expr = col(other_col.clone());
250        let not_eq_expr = gt(col(column.clone()), other_expr.clone());
251
252        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
253        assert_eq!(
254            refs.map(),
255            &HashMap::from_iter([
256                (
257                    FieldPath::from_name(column.clone()),
258                    HashSet::from_iter([Stat::Max])
259                ),
260                (
261                    FieldPath::from_name(other_col.clone()),
262                    HashSet::from_iter([Stat::Min])
263                )
264            ])
265        );
266        let expected_expr = lt_eq(
267            col(field_path_stat_field_name(
268                &FieldPath::from_name(column),
269                Stat::Max,
270            )),
271            col(field_path_stat_field_name(
272                &FieldPath::from_name(other_col),
273                Stat::Min,
274            )),
275        );
276        assert_eq!(&converted, &expected_expr);
277    }
278
279    #[rstest]
280    pub fn pruning_gt_value(available_stats: FieldPathSet) {
281        let column = FieldName::from("a");
282        let other_col = lit(42);
283        let not_eq_expr = gt(col(column.clone()), other_col.clone());
284
285        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
286        assert_eq!(
287            refs.map(),
288            &HashMap::from_iter([(
289                FieldPath::from_name(column.clone()),
290                HashSet::from_iter([Stat::Max])
291            ),])
292        );
293        let expected_expr = lt_eq(
294            col(field_path_stat_field_name(
295                &FieldPath::from_name(column),
296                Stat::Max,
297            )),
298            other_col.clone(),
299        );
300        assert_eq!(&converted, &(expected_expr));
301    }
302
303    #[rstest]
304    pub fn pruning_lt_column(available_stats: FieldPathSet) {
305        let column = FieldName::from("a");
306        let other_col = FieldName::from("b");
307        let other_expr = col(other_col.clone());
308        let not_eq_expr = lt(col(column.clone()), other_expr.clone());
309
310        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
311        assert_eq!(
312            refs.map(),
313            &HashMap::from_iter([
314                (
315                    FieldPath::from_name(column.clone()),
316                    HashSet::from_iter([Stat::Min])
317                ),
318                (
319                    FieldPath::from_name(other_col.clone()),
320                    HashSet::from_iter([Stat::Max])
321                )
322            ])
323        );
324        let expected_expr = gt_eq(
325            col(field_path_stat_field_name(
326                &FieldPath::from_name(column),
327                Stat::Min,
328            )),
329            col(field_path_stat_field_name(
330                &FieldPath::from_name(other_col),
331                Stat::Max,
332            )),
333        );
334        assert_eq!(&converted, &expected_expr);
335    }
336
337    #[rstest]
338    pub fn pruning_lt_value(available_stats: FieldPathSet) {
339        // expression   => a < 42
340        // pruning expr => a.min >= 42
341        let expr = lt(col("a"), lit(42));
342
343        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
344        assert_eq!(
345            refs.map(),
346            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
347        );
348        assert_eq!(&converted, &gt_eq(col("a_min"), lit(42)));
349    }
350
351    #[rstest]
352    fn pruning_identity(available_stats: FieldPathSet) {
353        let expr = or(lt(col("a").clone(), lit(10)), gt(col("a").clone(), lit(50)));
354
355        let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
356
357        let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
358        assert_eq!(&predicate.to_string(), &expected_expr.to_string());
359    }
360    #[rstest]
361    pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
362        // Test case: a > 10 AND a < 50
363        let column = FieldName::from("a");
364        let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
365        let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
366
367        // Expected: a_max <= 10 OR a_min >= 50
368        assert_eq!(
369            &predicate,
370            &or(
371                lt_eq(col(FieldName::from("a_max")), lit(10)),
372                gt_eq(col(FieldName::from("a_min")), lit(50)),
373            ),
374        );
375    }
376
377    #[rstest]
378    fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
379        // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression:
380        // x > (y > z)
381        // The x column is a Boolean-valued column. The y and z columns are numeric. True > False.
382        // Suppose we had a Vortex zone whose min/max statistics for each column were:
383        // x: [True, True]
384        // y: [1, 2]
385        // z: [0, 2]
386        // The pruning predicate will convert the aforementioned expression into:
387        // x_max <= (y_min > z_min)
388        // If we evaluate that pruning expression on our zone we get:
389        // x_max <= (y_min > z_min)
390        // x_max <= (1     > 0    )
391        // x_max <= True
392        // True <= True
393        // True
394        // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate:
395        // > a true value means the chunk can be pruned.
396        // But, the following record lies within the above intervals and *passes* the filter expression! We
397        // cannot prune this zone because we need this record!
398        // {x: True, y: 1, z: 2}
399        // x > (y > z)
400        // True > (1 > 2)
401        // True > False
402        // True
403        let expr = gt_eq(col("x"), gt(col("y"), col("z")));
404        assert!(checked_pruning_expr(&expr, &available_stats).is_none());
405        // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)`
406    }
407
408    #[fixture]
409    fn available_stats_with_nans() -> FieldPathSet {
410        let float_col = FieldPath::from_name("float_col");
411        let int_col = FieldPath::from_name("int_col");
412
413        FieldPathSet::from_iter([
414            // Float columns will have a NaNCount.
415            float_col.clone().push(Stat::Min.name()),
416            float_col.clone().push(Stat::Max.name()),
417            float_col.push(Stat::NaNCount.name()),
418            // int columns will not have a NanCount serialized into the layout
419            int_col.clone().push(Stat::Min.name()),
420            int_col.push(Stat::Max.name()),
421        ])
422    }
423
424    #[rstest]
425    fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
426        let expr = gt_eq(col("float_col"), lit(f32::NAN));
427        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
428        assert_eq!(
429            &converted,
430            &and(
431                and(
432                    eq(col("float_col_nan_count"), lit(0u64)),
433                    // NaNCount of NaN is 1
434                    eq(lit(1u64), lit(0u64)),
435                ),
436                // This is the standard conversion of the >= operator. Comparing NAN to a max
437                // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited
438                // by the previous check for nan_count anyway.
439                lt(col("float_col_max"), lit(f32::NAN)),
440            )
441        );
442
443        // One half of the expression requires NAN count check, the other half does not.
444        let expr = and(
445            gt(col("float_col"), lit(10f32)),
446            lt(col("int_col"), lit(10)),
447        );
448
449        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
450
451        assert_eq!(
452            &converted,
453            &or(
454                // NaNCount check is enforced for the float column
455                and(
456                    and(
457                        eq(col("float_col_nan_count"), lit(0u64)),
458                        // NanCount of a non-NaN float literal is 0
459                        eq(lit(0u64), lit(0u64)),
460                    ),
461                    // We want the opposite: we can prune IF either one is false.
462                    lt_eq(col("float_col_max"), lit(10f32)),
463                ),
464                // NanCount check is skipped for the int column
465                gt_eq(col("int_col_min"), lit(10)),
466            )
467        )
468    }
469
470    #[rstest]
471    fn pruning_between(available_stats: FieldPathSet) {
472        let expr = between(
473            col("a"),
474            lit(10),
475            lit(50),
476            BetweenOptions {
477                lower_strict: StrictComparison::NonStrict,
478                upper_strict: StrictComparison::NonStrict,
479            },
480        );
481        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
482        assert_eq!(
483            refs.map(),
484            &HashMap::from_iter([(
485                FieldPath::from_name("a"),
486                HashSet::from_iter([Stat::Min, Stat::Max])
487            )])
488        );
489        assert_eq!(
490            &converted,
491            &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
492        );
493    }
494}