vortex_expr/pruning/
pruning_predicate.rs

1use std::iter;
2
3use itertools::Itertools;
4use vortex_array::stats::Stat;
5use vortex_array::{Array, ArrayRef};
6use vortex_dtype::{Field, FieldName, FieldPath};
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_utils::aliases::hash_map::HashMap;
9use vortex_utils::aliases::hash_set::HashSet;
10
11use super::relation::Relation;
12use crate::{AccessPath, ExprRef, Scope, StatsCatalog, get_item, var};
13
14#[derive(Debug, Clone)]
15pub struct PruningPredicate {
16    expr: ExprRef,
17    required_stats: Relation<AccessPath, Stat>,
18}
19
20impl PruningPredicate {
21    pub fn try_new(original_expr: &ExprRef) -> Option<Self> {
22        let (expr, required_stats) = pruning_expr(original_expr)?;
23
24        Some(Self {
25            expr,
26            required_stats,
27        })
28    }
29
30    pub fn expr(&self) -> &ExprRef {
31        &self.expr
32    }
33
34    pub fn required_stats(&self) -> &HashMap<AccessPath, HashSet<Stat>> {
35        self.required_stats.map()
36    }
37
38    /// Evaluate this predicate against a per-chunk statistics table.
39    ///
40    /// Returns Ok(None) if any of the required statistics are not present in metadata.
41    /// If it returns Ok(Some(array)), the array is a boolean array with the same length as the
42    /// metadata, and a true value means the chunk _can_ be pruned.
43    pub fn evaluate(&self, metadata: &Scope) -> VortexResult<Option<ArrayRef>> {
44        // TODO(joe): Replace this with a StatsCatalog that contains all the available stats and
45        // build an expr using them.
46        let known_stats = metadata
47            .iter()
48            .flat_map(|(access_path, stat)| {
49                stat.dtype()
50                    .as_struct()
51                    .vortex_expect("metadata must be struct array")
52                    .names()
53                    .iter()
54                    .map(|n| {
55                        (
56                            AccessPath::new(FieldPath::root(), access_path.clone()),
57                            n.clone(),
58                        )
59                    })
60            })
61            .collect::<HashSet<(AccessPath, FieldName)>>();
62        let required_stats = self
63            .required_stats()
64            .iter()
65            .flat_map(|(path, stats)| stats.iter().map(|s| (path.clone(), s.name().into())))
66            .collect::<HashSet<(AccessPath, FieldName)>>();
67
68        let missing_stats = required_stats.difference(&known_stats).collect::<Vec<_>>();
69
70        if !missing_stats.is_empty() {
71            return Ok(None);
72        }
73
74        self.expr.evaluate(metadata).map(Some)
75    }
76}
77
78#[derive(Default)]
79struct FileStatsCatalog {
80    usage: HashMap<(AccessPath, Stat), ExprRef>,
81}
82
83impl StatsCatalog for FileStatsCatalog {
84    fn stats_ref(&mut self, access_path: &AccessPath, stat: Stat) -> Option<ExprRef> {
85        let mut expr = var(access_path.identifier().clone());
86        let name = access_path_stat_field_name(access_path, stat);
87        expr = get_item(name, expr);
88        self.usage.insert((access_path.clone(), stat), expr.clone());
89        Some(expr)
90    }
91}
92
93pub fn access_path_stat_field_name(access_path: &AccessPath, stat: Stat) -> FieldName {
94    access_path
95        .field_path
96        .path()
97        .iter()
98        .map(|f| match f {
99            Field::Name(n) => n.as_ref(),
100            Field::ElementType => todo!("element type not currently handled"),
101        })
102        .chain(iter::once(stat.name()))
103        .join("_")
104        .into()
105}
106
107#[allow(clippy::type_complexity)]
108// TODO: remove (Id, FieldPath) when updating FieldPath
109pub fn pruning_expr(expr: &ExprRef) -> Option<(ExprRef, Relation<AccessPath, Stat>)> {
110    let mut catalog = FileStatsCatalog {
111        ..Default::default()
112    };
113    let expr = expr.stat_falsification(&mut catalog)?;
114
115    let mut relation: Relation<AccessPath, Stat> = Relation::new();
116    for ((field_path, stat), _) in catalog.usage.into_iter() {
117        relation.insert(field_path, stat)
118    }
119
120    Some((expr, relation))
121}
122
123#[cfg(test)]
124mod tests {
125    use vortex_array::stats::Stat;
126    use vortex_dtype::FieldName;
127
128    use crate::pruning::pruning_predicate::{HashMap, pruning_expr};
129    use crate::pruning::{PruningPredicate, access_path_stat_field_name};
130    use crate::{
131        AccessPath, HashSet, and, col, eq, get_item, get_item_scope, gt, gt_eq, lit, lt, lt_eq,
132        not_eq, or, root,
133    };
134
135    #[test]
136    pub fn pruning_equals() {
137        let name = FieldName::from("a");
138        let literal_eq = lit(42);
139        let eq_expr = eq(get_item("a", root()), literal_eq.clone());
140        let (converted, _refs) = pruning_expr(&eq_expr).unwrap();
141        let expected_expr = or(
142            gt(
143                get_item(
144                    access_path_stat_field_name(&AccessPath::root_field(name.clone()), Stat::Min),
145                    root(),
146                ),
147                literal_eq.clone(),
148            ),
149            gt(
150                literal_eq,
151                get_item_scope(access_path_stat_field_name(
152                    &AccessPath::root_field(name),
153                    Stat::Max,
154                )),
155            ),
156        );
157        assert_eq!(&converted, &expected_expr);
158    }
159
160    #[test]
161    pub fn pruning_equals_column() {
162        let column = FieldName::from("a");
163        let other_col = FieldName::from("b");
164        let eq_expr = eq(
165            get_item_scope(column.clone()),
166            get_item_scope(other_col.clone()),
167        );
168
169        let (converted, refs) = pruning_expr(&eq_expr).unwrap();
170        assert_eq!(
171            refs.map(),
172            &HashMap::from_iter([
173                (
174                    AccessPath::root_field(column.clone()),
175                    HashSet::from_iter([Stat::Min, Stat::Max])
176                ),
177                (
178                    AccessPath::root_field(other_col.clone()),
179                    HashSet::from_iter([Stat::Max, Stat::Min])
180                )
181            ])
182        );
183        let expected_expr = or(
184            gt(
185                get_item_scope(access_path_stat_field_name(
186                    &AccessPath::root_field(column.clone()),
187                    Stat::Min,
188                )),
189                get_item_scope(access_path_stat_field_name(
190                    &AccessPath::root_field(other_col.clone()),
191                    Stat::Max,
192                )),
193            ),
194            gt(
195                get_item_scope(access_path_stat_field_name(
196                    &AccessPath::root_field(other_col),
197                    Stat::Min,
198                )),
199                get_item_scope(access_path_stat_field_name(
200                    &AccessPath::root_field(column),
201                    Stat::Max,
202                )),
203            ),
204        );
205        assert_eq!(&converted, &expected_expr);
206    }
207
208    #[test]
209    pub fn pruning_not_equals_column() {
210        let column = FieldName::from("a");
211        let other_col = FieldName::from("b");
212        let not_eq_expr = not_eq(
213            get_item_scope(column.clone()),
214            get_item_scope(other_col.clone()),
215        );
216
217        let (converted, refs) = pruning_expr(&not_eq_expr).unwrap();
218        assert_eq!(
219            refs.map(),
220            &HashMap::from_iter([
221                (
222                    AccessPath::root_field(column.clone()),
223                    HashSet::from_iter([Stat::Min, Stat::Max])
224                ),
225                (
226                    AccessPath::root_field(other_col.clone()),
227                    HashSet::from_iter([Stat::Max, Stat::Min])
228                )
229            ])
230        );
231        let expected_expr = and(
232            eq(
233                get_item_scope(access_path_stat_field_name(
234                    &AccessPath::root_field(column.clone()),
235                    Stat::Min,
236                )),
237                get_item_scope(access_path_stat_field_name(
238                    &AccessPath::root_field(other_col.clone()),
239                    Stat::Max,
240                )),
241            ),
242            eq(
243                get_item_scope(access_path_stat_field_name(
244                    &AccessPath::root_field(column),
245                    Stat::Max,
246                )),
247                get_item_scope(access_path_stat_field_name(
248                    &AccessPath::root_field(other_col),
249                    Stat::Min,
250                )),
251            ),
252        );
253
254        assert_eq!(&converted, &expected_expr);
255    }
256
257    #[test]
258    pub fn pruning_gt_column() {
259        let column = FieldName::from("a");
260        let other_col = FieldName::from("b");
261        let other_expr = get_item_scope(other_col.clone());
262        let not_eq_expr = gt(get_item_scope(column.clone()), other_expr.clone());
263
264        let (converted, refs) = pruning_expr(&not_eq_expr).unwrap();
265        assert_eq!(
266            refs.map(),
267            &HashMap::from_iter([
268                (
269                    AccessPath::root_field(column.clone()),
270                    HashSet::from_iter([Stat::Max])
271                ),
272                (
273                    AccessPath::root_field(other_col.clone()),
274                    HashSet::from_iter([Stat::Min])
275                )
276            ])
277        );
278        let expected_expr = lt_eq(
279            get_item_scope(access_path_stat_field_name(
280                &AccessPath::root_field(column),
281                Stat::Max,
282            )),
283            get_item_scope(access_path_stat_field_name(
284                &AccessPath::root_field(other_col),
285                Stat::Min,
286            )),
287        );
288        assert_eq!(&converted, &expected_expr);
289    }
290
291    #[test]
292    pub fn pruning_gt_value() {
293        let column = FieldName::from("a");
294        let other_col = lit(42);
295        let not_eq_expr = gt(get_item_scope(column.clone()), other_col.clone());
296
297        let (converted, refs) = pruning_expr(&not_eq_expr).unwrap();
298        assert_eq!(
299            refs.map(),
300            &HashMap::from_iter([(
301                AccessPath::root_field(column.clone()),
302                HashSet::from_iter([Stat::Max])
303            ),])
304        );
305        let expected_expr = lt_eq(
306            get_item_scope(access_path_stat_field_name(
307                &AccessPath::root_field(column),
308                Stat::Max,
309            )),
310            other_col.clone(),
311        );
312        assert_eq!(&converted, &(expected_expr));
313    }
314
315    #[test]
316    pub fn pruning_lt_column() {
317        let column = FieldName::from("a");
318        let other_col = FieldName::from("b");
319        let other_expr = get_item_scope(other_col.clone());
320        let not_eq_expr = lt(get_item_scope(column.clone()), other_expr.clone());
321
322        let (converted, refs) = pruning_expr(&not_eq_expr).unwrap();
323        assert_eq!(
324            refs.map(),
325            &HashMap::from_iter([
326                (
327                    AccessPath::root_field(column.clone()),
328                    HashSet::from_iter([Stat::Min])
329                ),
330                (
331                    AccessPath::root_field(other_col.clone()),
332                    HashSet::from_iter([Stat::Max])
333                )
334            ])
335        );
336        let expected_expr = gt_eq(
337            get_item_scope(access_path_stat_field_name(
338                &AccessPath::root_field(column),
339                Stat::Min,
340            )),
341            get_item_scope(access_path_stat_field_name(
342                &AccessPath::root_field(other_col),
343                Stat::Max,
344            )),
345        );
346        assert_eq!(&converted, &expected_expr);
347    }
348
349    #[test]
350    pub fn pruning_lt_value() {
351        let column = FieldName::from("a");
352        let other_col = lit(42);
353        let not_eq_expr = lt(get_item_scope(column.clone()), other_col.clone());
354
355        let (converted, refs) = pruning_expr(&not_eq_expr).unwrap();
356        assert_eq!(
357            refs.map(),
358            &HashMap::from_iter([(
359                AccessPath::root_field(column.clone()),
360                HashSet::from_iter([Stat::Min])
361            )])
362        );
363        let expected_expr = gt_eq(
364            get_item_scope(access_path_stat_field_name(
365                &AccessPath::root_field(column),
366                Stat::Min,
367            )),
368            other_col.clone(),
369        );
370        assert_eq!(&converted, &expected_expr);
371    }
372
373    #[test]
374    fn pruning_identity() {
375        let expr = or(lt(root().clone(), lit(10)), gt(root().clone(), lit(50)));
376
377        let (predicate, _) = pruning_expr(&expr).unwrap();
378
379        let expected_expr = and(
380            gt_eq(get_item_scope(FieldName::from("min")), lit(10)),
381            lt_eq(get_item_scope(FieldName::from("max")), lit(50)),
382        );
383        assert_eq!(&predicate, &expected_expr)
384    }
385    #[test]
386    pub fn pruning_and_or_operators() {
387        // Test case: a > 10 AND a < 50
388        let column = FieldName::from("a");
389        let and_expr = and(
390            gt(get_item_scope(column.clone()), lit(10)),
391            lt(get_item_scope(column), lit(50)),
392        );
393        let (predicate, _) = pruning_expr(&and_expr).unwrap();
394
395        // Expected: a_max <= 10 OR a_min >= 50
396        assert_eq!(
397            &predicate,
398            &or(
399                lt_eq(get_item_scope(FieldName::from("a_max")), lit(10)),
400                gt_eq(get_item_scope(FieldName::from("a_min")), lit(50))
401            ),
402        );
403    }
404
405    #[test]
406    fn test_gt_eq_with_booleans() {
407        // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression:
408        //
409        // x > (y > z)
410        //
411        // The x column is a Boolean-valued column. The y and z columns are numeric. True > False.
412        // Suppose we had a Vortex zone whose min/max statistics for each column were:
413        //
414        // x: [True, True]
415        // y: [1, 2]
416        // z: [0, 2]
417        //
418        // The pruning predicate will convert the aforementioned expression into:
419        //
420        // x_max <= (y_min > z_min)
421        //
422        // If we evaluate that pruning expression on our zone we get:
423        //
424        // x_max <= (y_min > z_min)
425        // x_max <= (1     > 0    )
426        // x_max <= True
427        // True <= True
428        // True
429        //
430        // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate:
431        //
432        // > a true value means the chunk can be pruned.
433        //
434        // But, the following record lies within the above intervals and *passes* the filter expression! We
435        // cannot prune this zone because we need this record!
436        //
437        // {x: True, y: 1, z: 2}
438        //
439        // x > (y > z)
440        // True > (1 > 2)
441        // True > False
442        // True
443        let expr = gt_eq(col("x"), gt(col("y"), col("z")));
444        assert!(PruningPredicate::try_new(&expr).is_none());
445        // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)`
446    }
447}