1use std::iter;
5
6use itertools::Itertools;
7use vortex_dtype::{Field, FieldName, FieldPath, FieldPathSet};
8use vortex_utils::aliases::hash_map::HashMap;
9
10use super::relation::Relation;
11use crate::expr::exprs::get_item::get_item;
12use crate::expr::exprs::root::root;
13use crate::expr::{Expression, StatsCatalog};
14use crate::stats::Stat;
15
16pub type RequiredStats = Relation<FieldPath, Stat>;
17
18#[derive(Default)]
21struct TrackingStatsCatalog {
22 usage: HashMap<(FieldPath, Stat), Expression>,
23}
24
25impl TrackingStatsCatalog {
26 fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> {
29 self.usage
30 }
31}
32
33struct ScopeStatsCatalog<'a> {
35 any_catalog: TrackingStatsCatalog,
36 available_stats: &'a FieldPathSet,
37}
38
39impl StatsCatalog for ScopeStatsCatalog<'_> {
40 fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
41 let stat_path = field_path.clone().push(stat.name());
42
43 if self.available_stats.contains(&stat_path) {
44 self.any_catalog.stats_ref(field_path, stat)
45 } else {
46 None
47 }
48 }
49}
50
51impl StatsCatalog for TrackingStatsCatalog {
52 fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
53 let mut expr = root();
54 let name = field_path_stat_field_name(field_path, stat);
55 expr = get_item(name, expr);
56 self.usage.insert((field_path.clone(), stat), expr.clone());
57 Some(expr)
58 }
59}
60
61#[doc(hidden)]
62pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
63 field_path
64 .parts()
65 .iter()
66 .map(|f| match f {
67 Field::Name(n) => n.as_ref(),
68 Field::ElementType => todo!("element type not currently handled"),
69 })
70 .chain(iter::once(stat.name()))
71 .join("_")
72 .into()
73}
74
75pub fn checked_pruning_expr(
85 expr: &Expression,
86 available_stats: &FieldPathSet,
87) -> Option<(Expression, RequiredStats)> {
88 let mut catalog = ScopeStatsCatalog {
89 any_catalog: Default::default(),
90 available_stats,
91 };
92
93 let expr = expr.stat_falsification(&mut catalog)?;
94
95 let mut relation: Relation<FieldPath, Stat> = Relation::new();
97 for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
98 relation.insert(field_path, stat)
99 }
100
101 Some((expr, relation))
102}
103
104#[cfg(test)]
105mod tests {
106 use rstest::{fixture, rstest};
107 use vortex_dtype::{
108 DType, FieldName, FieldNames, FieldPath, FieldPathSet, Nullability, StructFields,
109 };
110 use vortex_utils::aliases::hash_set::HashSet;
111
112 use super::HashMap;
113 use crate::compute::{BetweenOptions, StrictComparison};
114 use crate::expr::exprs::between::between;
115 use crate::expr::exprs::binary::{and, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
116 use crate::expr::exprs::cast::cast;
117 use crate::expr::exprs::get_item::{col, get_item};
118 use crate::expr::exprs::literal::lit;
119 use crate::expr::exprs::root::root;
120 use crate::expr::pruning::{checked_pruning_expr, field_path_stat_field_name};
121 use crate::stats::Stat;
122
123 #[fixture]
125 fn available_stats() -> FieldPathSet {
126 let field_a = FieldPath::from_name("a");
127 let field_b = FieldPath::from_name("b");
128
129 FieldPathSet::from_iter([
130 field_a.clone().push(Stat::Min.name()),
131 field_a.push(Stat::Max.name()),
132 field_b.clone().push(Stat::Min.name()),
133 field_b.push(Stat::Max.name()),
134 ])
135 }
136
137 #[rstest]
138 pub fn pruning_equals(available_stats: FieldPathSet) {
139 let name = FieldName::from("a");
140 let literal_eq = lit(42);
141 let eq_expr = eq(get_item("a", root()), literal_eq.clone());
142 let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
143 let expected_expr = or(
144 gt(
145 get_item(
146 field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
147 root(),
148 ),
149 literal_eq.clone(),
150 ),
151 gt(
152 literal_eq,
153 col(field_path_stat_field_name(
154 &FieldPath::from_name(name),
155 Stat::Max,
156 )),
157 ),
158 );
159 assert_eq!(&converted, &expected_expr);
160 }
161
162 #[rstest]
163 pub fn pruning_equals_column(available_stats: FieldPathSet) {
164 let column = FieldName::from("a");
165 let other_col = FieldName::from("b");
166 let eq_expr = eq(col(column.clone()), col(other_col.clone()));
167
168 let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
169 assert_eq!(
170 refs.map(),
171 &HashMap::from_iter([
172 (
173 FieldPath::from_name(column.clone()),
174 HashSet::from_iter([Stat::Min, Stat::Max])
175 ),
176 (
177 FieldPath::from_name(other_col.clone()),
178 HashSet::from_iter([Stat::Max, Stat::Min])
179 )
180 ])
181 );
182 let expected_expr = or(
183 gt(
184 col(field_path_stat_field_name(
185 &FieldPath::from_name(column.clone()),
186 Stat::Min,
187 )),
188 col(field_path_stat_field_name(
189 &FieldPath::from_name(other_col.clone()),
190 Stat::Max,
191 )),
192 ),
193 gt(
194 col(field_path_stat_field_name(
195 &FieldPath::from_name(other_col),
196 Stat::Min,
197 )),
198 col(field_path_stat_field_name(
199 &FieldPath::from_name(column),
200 Stat::Max,
201 )),
202 ),
203 );
204 assert_eq!(&converted, &expected_expr);
205 }
206
207 #[rstest]
208 pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
209 let column = FieldName::from("a");
210 let other_col = FieldName::from("b");
211 let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
212
213 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
214 assert_eq!(
215 refs.map(),
216 &HashMap::from_iter([
217 (
218 FieldPath::from_name(column.clone()),
219 HashSet::from_iter([Stat::Min, Stat::Max])
220 ),
221 (
222 FieldPath::from_name(other_col.clone()),
223 HashSet::from_iter([Stat::Max, Stat::Min])
224 )
225 ])
226 );
227 let expected_expr = and(
228 eq(
229 col(field_path_stat_field_name(
230 &FieldPath::from_name(column.clone()),
231 Stat::Min,
232 )),
233 col(field_path_stat_field_name(
234 &FieldPath::from_name(other_col.clone()),
235 Stat::Max,
236 )),
237 ),
238 eq(
239 col(field_path_stat_field_name(
240 &FieldPath::from_name(column),
241 Stat::Max,
242 )),
243 col(field_path_stat_field_name(
244 &FieldPath::from_name(other_col),
245 Stat::Min,
246 )),
247 ),
248 );
249
250 assert_eq!(&converted, &expected_expr);
251 }
252
253 #[rstest]
254 pub fn pruning_gt_column(available_stats: FieldPathSet) {
255 let column = FieldName::from("a");
256 let other_col = FieldName::from("b");
257 let other_expr = col(other_col.clone());
258 let not_eq_expr = gt(col(column.clone()), other_expr);
259
260 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
261 assert_eq!(
262 refs.map(),
263 &HashMap::from_iter([
264 (
265 FieldPath::from_name(column.clone()),
266 HashSet::from_iter([Stat::Max])
267 ),
268 (
269 FieldPath::from_name(other_col.clone()),
270 HashSet::from_iter([Stat::Min])
271 )
272 ])
273 );
274 let expected_expr = lt_eq(
275 col(field_path_stat_field_name(
276 &FieldPath::from_name(column),
277 Stat::Max,
278 )),
279 col(field_path_stat_field_name(
280 &FieldPath::from_name(other_col),
281 Stat::Min,
282 )),
283 );
284 assert_eq!(&converted, &expected_expr);
285 }
286
287 #[rstest]
288 pub fn pruning_gt_value(available_stats: FieldPathSet) {
289 let column = FieldName::from("a");
290 let other_col = lit(42);
291 let not_eq_expr = gt(col(column.clone()), other_col.clone());
292
293 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
294 assert_eq!(
295 refs.map(),
296 &HashMap::from_iter([(
297 FieldPath::from_name(column.clone()),
298 HashSet::from_iter([Stat::Max])
299 ),])
300 );
301 let expected_expr = lt_eq(
302 col(field_path_stat_field_name(
303 &FieldPath::from_name(column),
304 Stat::Max,
305 )),
306 other_col,
307 );
308 assert_eq!(&converted, &(expected_expr));
309 }
310
311 #[rstest]
312 pub fn pruning_lt_column(available_stats: FieldPathSet) {
313 let column = FieldName::from("a");
314 let other_col = FieldName::from("b");
315 let other_expr = col(other_col.clone());
316 let not_eq_expr = lt(col(column.clone()), other_expr);
317
318 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
319 assert_eq!(
320 refs.map(),
321 &HashMap::from_iter([
322 (
323 FieldPath::from_name(column.clone()),
324 HashSet::from_iter([Stat::Min])
325 ),
326 (
327 FieldPath::from_name(other_col.clone()),
328 HashSet::from_iter([Stat::Max])
329 )
330 ])
331 );
332 let expected_expr = gt_eq(
333 col(field_path_stat_field_name(
334 &FieldPath::from_name(column),
335 Stat::Min,
336 )),
337 col(field_path_stat_field_name(
338 &FieldPath::from_name(other_col),
339 Stat::Max,
340 )),
341 );
342 assert_eq!(&converted, &expected_expr);
343 }
344
345 #[rstest]
346 pub fn pruning_lt_value(available_stats: FieldPathSet) {
347 let expr = lt(col("a"), lit(42));
350
351 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
352 assert_eq!(
353 refs.map(),
354 &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
355 );
356 assert_eq!(&converted, >_eq(col("a_min"), lit(42)));
357 }
358
359 #[rstest]
360 fn pruning_identity(available_stats: FieldPathSet) {
361 let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50)));
362
363 let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
364
365 let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
366 assert_eq!(&predicate.to_string(), &expected_expr.to_string());
367 }
368 #[rstest]
369 pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
370 let column = FieldName::from("a");
372 let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
373 let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
374
375 assert_eq!(
377 &predicate,
378 &or(
379 lt_eq(col(FieldName::from("a_max")), lit(10)),
380 gt_eq(col(FieldName::from("a_min")), lit(50)),
381 ),
382 );
383 }
384
385 #[rstest]
386 fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
387 let expr = gt_eq(col("x"), gt(col("y"), col("z")));
412 assert!(checked_pruning_expr(&expr, &available_stats).is_none());
413 }
415
416 #[fixture]
417 fn available_stats_with_nans() -> FieldPathSet {
418 let float_col = FieldPath::from_name("float_col");
419 let int_col = FieldPath::from_name("int_col");
420
421 FieldPathSet::from_iter([
422 float_col.clone().push(Stat::Min.name()),
424 float_col.clone().push(Stat::Max.name()),
425 float_col.push(Stat::NaNCount.name()),
426 int_col.clone().push(Stat::Min.name()),
428 int_col.push(Stat::Max.name()),
429 ])
430 }
431
432 #[rstest]
433 fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
434 let expr = gt_eq(col("float_col"), lit(f32::NAN));
435 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
436 assert_eq!(
437 &converted,
438 &and(
439 and(
440 eq(col("float_col_nan_count"), lit(0u64)),
441 eq(lit(1u64), lit(0u64)),
443 ),
444 lt(col("float_col_max"), lit(f32::NAN)),
448 )
449 );
450
451 let expr = and(
453 gt(col("float_col"), lit(10f32)),
454 lt(col("int_col"), lit(10)),
455 );
456
457 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
458
459 assert_eq!(
460 &converted,
461 &or(
462 and(
464 and(
465 eq(col("float_col_nan_count"), lit(0u64)),
466 eq(lit(0u64), lit(0u64)),
468 ),
469 lt_eq(col("float_col_max"), lit(10f32)),
471 ),
472 gt_eq(col("int_col_min"), lit(10)),
474 )
475 )
476 }
477
478 #[rstest]
479 fn pruning_between(available_stats: FieldPathSet) {
480 let expr = between(
481 col("a"),
482 lit(10),
483 lit(50),
484 BetweenOptions {
485 lower_strict: StrictComparison::NonStrict,
486 upper_strict: StrictComparison::NonStrict,
487 },
488 );
489 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
490 assert_eq!(
491 refs.map(),
492 &HashMap::from_iter([(
493 FieldPath::from_name("a"),
494 HashSet::from_iter([Stat::Min, Stat::Max])
495 )])
496 );
497 assert_eq!(
498 &converted,
499 &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
500 );
501 }
502
503 #[rstest]
504 fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
505 let struct_dtype = DType::Struct(
508 StructFields::new(
509 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
510 vec![
511 DType::Utf8(Nullability::Nullable),
512 DType::Utf8(Nullability::Nullable),
513 ],
514 ),
515 Nullability::NonNullable,
516 );
517 let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
518 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
519 assert_eq!(
520 refs.map(),
521 &HashMap::from_iter([(
522 FieldPath::from_name("a"),
523 HashSet::from_iter([Stat::Min, Stat::Max])
524 )])
525 );
526 assert_eq!(
527 &converted,
528 &or(
529 gt(col("a_min"), lit("value")),
530 gt(lit("value"), col("a_max"))
531 )
532 );
533 }
534}