vortex_expr/pruning/
pruning_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5
6use itertools::Itertools;
7use vortex_array::stats::Stat;
8use vortex_dtype::{Field, FieldName, FieldPath, FieldPathSet};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use super::relation::Relation;
12use crate::{ExprRef, StatsCatalog, get_item, root};
13
14pub type RequiredStats = Relation<FieldPath, Stat>;
15
16// A catalog that return a stat column whenever it is required, tracking all accessed
17// stats and returning them later.
18#[derive(Default)]
19struct TrackingStatsCatalog {
20    usage: HashMap<(FieldPath, Stat), ExprRef>,
21}
22
23impl TrackingStatsCatalog {
24    /// Consume the catalog, yielding a map of field statistics that were required
25    /// for each expression.
26    fn into_usages(self) -> HashMap<(FieldPath, Stat), ExprRef> {
27        self.usage
28    }
29}
30
31// A catalog that return a stat column if it exists in the given scope.
32struct ScopeStatsCatalog<'a> {
33    any_catalog: TrackingStatsCatalog,
34    available_stats: &'a FieldPathSet,
35}
36
37impl StatsCatalog for ScopeStatsCatalog<'_> {
38    fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
39        let stat_path = field_path.clone().push(stat.name());
40
41        if self.available_stats.contains(&stat_path) {
42            self.any_catalog.stats_ref(field_path, stat)
43        } else {
44            None
45        }
46    }
47}
48
49impl StatsCatalog for TrackingStatsCatalog {
50    fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
51        let mut expr = root();
52        let name = field_path_stat_field_name(field_path, stat);
53        expr = get_item(name, expr);
54        self.usage.insert((field_path.clone(), stat), expr.clone());
55        Some(expr)
56    }
57}
58
59#[doc(hidden)]
60pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
61    field_path
62        .parts()
63        .iter()
64        .map(|f| match f {
65            Field::Name(n) => n.as_ref(),
66            Field::ElementType => todo!("element type not currently handled"),
67        })
68        .chain(iter::once(stat.name()))
69        .join("_")
70        .into()
71}
72
73/// Build a pruning expr mask, using an existing set of stats.
74/// The available stats are provided as a set of [`FieldPath`].
75///
76/// A pruning expression is one that returns `true` for all positions where the original expression
77/// cannot hold, and false if it cannot be determined from stats alone whether the positions can
78/// be pruned.
79///
80/// If the falsification logic attempts to access an unknown stat,
81/// this function will return `None`.
82pub fn checked_pruning_expr(
83    expr: &ExprRef,
84    available_stats: &FieldPathSet,
85) -> Option<(ExprRef, RequiredStats)> {
86    let mut catalog = ScopeStatsCatalog {
87        any_catalog: Default::default(),
88        available_stats,
89    };
90
91    let expr = expr.stat_falsification(&mut catalog)?;
92
93    // TODO(joe): filter access by used exprs
94    let mut relation: Relation<FieldPath, Stat> = Relation::new();
95    for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
96        relation.insert(field_path, stat)
97    }
98
99    Some((expr, relation))
100}
101
102#[cfg(test)]
103mod tests {
104    use rstest::{fixture, rstest};
105    use vortex_array::stats::Stat;
106    use vortex_dtype::{FieldName, FieldPath, FieldPathSet};
107
108    use crate::pruning::pruning_expr::HashMap;
109    use crate::pruning::{checked_pruning_expr, field_path_stat_field_name};
110    use crate::{HashSet, and, col, eq, get_item, gt, gt_eq, lit, lt, lt_eq, not_eq, or, root};
111
112    // Implement some checked pruning expressions.
113    #[fixture]
114    fn available_stats() -> FieldPathSet {
115        let field_a = FieldPath::from_name("a");
116        let field_b = FieldPath::from_name("b");
117
118        FieldPathSet::from_iter([
119            field_a.clone().push(Stat::Min.name()),
120            field_a.push(Stat::Max.name()),
121            field_b.clone().push(Stat::Min.name()),
122            field_b.push(Stat::Max.name()),
123        ])
124    }
125
126    #[rstest]
127    pub fn pruning_equals(available_stats: FieldPathSet) {
128        let name = FieldName::from("a");
129        let literal_eq = lit(42);
130        let eq_expr = eq(get_item("a", root()), literal_eq.clone());
131        let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
132        let expected_expr = or(
133            gt(
134                get_item(
135                    field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
136                    root(),
137                ),
138                literal_eq.clone(),
139            ),
140            gt(
141                literal_eq,
142                col(field_path_stat_field_name(
143                    &FieldPath::from_name(name),
144                    Stat::Max,
145                )),
146            ),
147        );
148        assert_eq!(&converted, &expected_expr);
149    }
150
151    #[rstest]
152    pub fn pruning_equals_column(available_stats: FieldPathSet) {
153        let column = FieldName::from("a");
154        let other_col = FieldName::from("b");
155        let eq_expr = eq(col(column.clone()), col(other_col.clone()));
156
157        let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
158        assert_eq!(
159            refs.map(),
160            &HashMap::from_iter([
161                (
162                    FieldPath::from_name(column.clone()),
163                    HashSet::from_iter([Stat::Min, Stat::Max])
164                ),
165                (
166                    FieldPath::from_name(other_col.clone()),
167                    HashSet::from_iter([Stat::Max, Stat::Min])
168                )
169            ])
170        );
171        let expected_expr = or(
172            gt(
173                col(field_path_stat_field_name(
174                    &FieldPath::from_name(column.clone()),
175                    Stat::Min,
176                )),
177                col(field_path_stat_field_name(
178                    &FieldPath::from_name(other_col.clone()),
179                    Stat::Max,
180                )),
181            ),
182            gt(
183                col(field_path_stat_field_name(
184                    &FieldPath::from_name(other_col),
185                    Stat::Min,
186                )),
187                col(field_path_stat_field_name(
188                    &FieldPath::from_name(column),
189                    Stat::Max,
190                )),
191            ),
192        );
193        assert_eq!(&converted, &expected_expr);
194    }
195
196    #[rstest]
197    pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
198        let column = FieldName::from("a");
199        let other_col = FieldName::from("b");
200        let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
201
202        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
203        assert_eq!(
204            refs.map(),
205            &HashMap::from_iter([
206                (
207                    FieldPath::from_name(column.clone()),
208                    HashSet::from_iter([Stat::Min, Stat::Max])
209                ),
210                (
211                    FieldPath::from_name(other_col.clone()),
212                    HashSet::from_iter([Stat::Max, Stat::Min])
213                )
214            ])
215        );
216        let expected_expr = and(
217            eq(
218                col(field_path_stat_field_name(
219                    &FieldPath::from_name(column.clone()),
220                    Stat::Min,
221                )),
222                col(field_path_stat_field_name(
223                    &FieldPath::from_name(other_col.clone()),
224                    Stat::Max,
225                )),
226            ),
227            eq(
228                col(field_path_stat_field_name(
229                    &FieldPath::from_name(column),
230                    Stat::Max,
231                )),
232                col(field_path_stat_field_name(
233                    &FieldPath::from_name(other_col),
234                    Stat::Min,
235                )),
236            ),
237        );
238
239        assert_eq!(&converted, &expected_expr);
240    }
241
242    #[rstest]
243    pub fn pruning_gt_column(available_stats: FieldPathSet) {
244        let column = FieldName::from("a");
245        let other_col = FieldName::from("b");
246        let other_expr = col(other_col.clone());
247        let not_eq_expr = gt(col(column.clone()), other_expr.clone());
248
249        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
250        assert_eq!(
251            refs.map(),
252            &HashMap::from_iter([
253                (
254                    FieldPath::from_name(column.clone()),
255                    HashSet::from_iter([Stat::Max])
256                ),
257                (
258                    FieldPath::from_name(other_col.clone()),
259                    HashSet::from_iter([Stat::Min])
260                )
261            ])
262        );
263        let expected_expr = lt_eq(
264            col(field_path_stat_field_name(
265                &FieldPath::from_name(column),
266                Stat::Max,
267            )),
268            col(field_path_stat_field_name(
269                &FieldPath::from_name(other_col),
270                Stat::Min,
271            )),
272        );
273        assert_eq!(&converted, &expected_expr);
274    }
275
276    #[rstest]
277    pub fn pruning_gt_value(available_stats: FieldPathSet) {
278        let column = FieldName::from("a");
279        let other_col = lit(42);
280        let not_eq_expr = gt(col(column.clone()), other_col.clone());
281
282        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
283        assert_eq!(
284            refs.map(),
285            &HashMap::from_iter([(
286                FieldPath::from_name(column.clone()),
287                HashSet::from_iter([Stat::Max])
288            ),])
289        );
290        let expected_expr = lt_eq(
291            col(field_path_stat_field_name(
292                &FieldPath::from_name(column),
293                Stat::Max,
294            )),
295            other_col.clone(),
296        );
297        assert_eq!(&converted, &(expected_expr));
298    }
299
300    #[rstest]
301    pub fn pruning_lt_column(available_stats: FieldPathSet) {
302        let column = FieldName::from("a");
303        let other_col = FieldName::from("b");
304        let other_expr = col(other_col.clone());
305        let not_eq_expr = lt(col(column.clone()), other_expr.clone());
306
307        let (converted, refs) = checked_pruning_expr(&not_eq_expr, &available_stats).unwrap();
308        assert_eq!(
309            refs.map(),
310            &HashMap::from_iter([
311                (
312                    FieldPath::from_name(column.clone()),
313                    HashSet::from_iter([Stat::Min])
314                ),
315                (
316                    FieldPath::from_name(other_col.clone()),
317                    HashSet::from_iter([Stat::Max])
318                )
319            ])
320        );
321        let expected_expr = gt_eq(
322            col(field_path_stat_field_name(
323                &FieldPath::from_name(column),
324                Stat::Min,
325            )),
326            col(field_path_stat_field_name(
327                &FieldPath::from_name(other_col),
328                Stat::Max,
329            )),
330        );
331        assert_eq!(&converted, &expected_expr);
332    }
333
334    #[rstest]
335    pub fn pruning_lt_value(available_stats: FieldPathSet) {
336        // expression   => a < 42
337        // pruning expr => a.min >= 42
338        let expr = lt(col("a"), lit(42));
339
340        let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
341        assert_eq!(
342            refs.map(),
343            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
344        );
345        assert_eq!(&converted, &gt_eq(col("a_min"), lit(42)));
346    }
347
348    #[rstest]
349    fn pruning_identity(available_stats: FieldPathSet) {
350        let expr = or(lt(col("a").clone(), lit(10)), gt(col("a").clone(), lit(50)));
351
352        let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
353
354        let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
355        assert_eq!(&predicate.to_string(), &expected_expr.to_string());
356    }
357    #[rstest]
358    pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
359        // Test case: a > 10 AND a < 50
360        let column = FieldName::from("a");
361        let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
362        let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
363
364        // Expected: a_max <= 10 OR a_min >= 50
365        assert_eq!(
366            &predicate,
367            &or(
368                lt_eq(col(FieldName::from("a_max")), lit(10)),
369                gt_eq(col(FieldName::from("a_min")), lit(50)),
370            ),
371        );
372    }
373
374    #[rstest]
375    fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
376        // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression:
377        // x > (y > z)
378        // The x column is a Boolean-valued column. The y and z columns are numeric. True > False.
379        // Suppose we had a Vortex zone whose min/max statistics for each column were:
380        // x: [True, True]
381        // y: [1, 2]
382        // z: [0, 2]
383        // The pruning predicate will convert the aforementioned expression into:
384        // x_max <= (y_min > z_min)
385        // If we evaluate that pruning expression on our zone we get:
386        // x_max <= (y_min > z_min)
387        // x_max <= (1     > 0    )
388        // x_max <= True
389        // True <= True
390        // True
391        // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate:
392        // > a true value means the chunk can be pruned.
393        // But, the following record lies within the above intervals and *passes* the filter expression! We
394        // cannot prune this zone because we need this record!
395        // {x: True, y: 1, z: 2}
396        // x > (y > z)
397        // True > (1 > 2)
398        // True > False
399        // True
400        let expr = gt_eq(col("x"), gt(col("y"), col("z")));
401        assert!(checked_pruning_expr(&expr, &available_stats).is_none());
402        // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)`
403    }
404
405    #[fixture]
406    fn available_stats_with_nans() -> FieldPathSet {
407        let float_col = FieldPath::from_name("float_col");
408        let int_col = FieldPath::from_name("int_col");
409
410        FieldPathSet::from_iter([
411            // Float columns will have a NaNCount.
412            float_col.clone().push(Stat::Min.name()),
413            float_col.clone().push(Stat::Max.name()),
414            float_col.push(Stat::NaNCount.name()),
415            // int columns will not have a NanCount serialized into the layout
416            int_col.clone().push(Stat::Min.name()),
417            int_col.push(Stat::Max.name()),
418        ])
419    }
420
421    #[rstest]
422    fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
423        let expr = gt_eq(col("float_col"), lit(f32::NAN));
424        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
425        assert_eq!(
426            &converted,
427            &and(
428                and(
429                    eq(col("float_col_nan_count"), lit(0u64)),
430                    // NaNCount of NaN is 1
431                    eq(lit(1u64), lit(0u64)),
432                ),
433                // This is the standard conversion of the >= operator. Comparing NAN to a max
434                // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited
435                // by the previous check for nan_count anyway.
436                lt(col("float_col_max"), lit(f32::NAN)),
437            )
438        );
439
440        // One half of the expression requires NAN count check, the other half does not.
441        let expr = and(
442            gt(col("float_col"), lit(10f32)),
443            lt(col("int_col"), lit(10)),
444        );
445
446        let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
447
448        assert_eq!(
449            &converted,
450            &or(
451                // NaNCount check is enforced for the float column
452                and(
453                    and(
454                        eq(col("float_col_nan_count"), lit(0u64)),
455                        // NanCount of a non-NaN float literal is 0
456                        eq(lit(0u64), lit(0u64)),
457                    ),
458                    // We want the opposite: we can prune IF either one is false.
459                    lt_eq(col("float_col_max"), lit(10f32)),
460                ),
461                // NanCount check is skipped for the int column
462                gt_eq(col("int_col_min"), lit(10)),
463            )
464        )
465    }
466}