1use std::cell::RefCell;
5use std::iter;
6
7use itertools::Itertools;
8use vortex_dtype::Field;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldPath;
11use vortex_dtype::FieldPathSet;
12use vortex_utils::aliases::hash_map::HashMap;
13
14use super::relation::Relation;
15use crate::expr::Expression;
16use crate::expr::StatsCatalog;
17use crate::expr::exprs::get_item::get_item;
18use crate::expr::exprs::root::root;
19use crate::expr::stats::Stat;
20
21pub type RequiredStats = Relation<FieldPath, Stat>;
22
23#[derive(Default)]
26struct TrackingStatsCatalog {
27 usage: RefCell<HashMap<(FieldPath, Stat), Expression>>,
28}
29
30impl TrackingStatsCatalog {
31 fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> {
34 self.usage.into_inner()
35 }
36}
37
38struct ScopeStatsCatalog<'a> {
40 any_catalog: TrackingStatsCatalog,
41 available_stats: &'a FieldPathSet,
42}
43
44impl StatsCatalog for ScopeStatsCatalog<'_> {
45 fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
46 let stat_path = field_path.clone().push(stat.name());
47
48 if self.available_stats.contains(&stat_path) {
49 self.any_catalog.stats_ref(field_path, stat)
50 } else {
51 None
52 }
53 }
54}
55
56impl StatsCatalog for TrackingStatsCatalog {
57 fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
58 let mut expr = root();
59 let name = field_path_stat_field_name(field_path, stat);
60 expr = get_item(name, expr);
61 self.usage
62 .borrow_mut()
63 .insert((field_path.clone(), stat), expr.clone());
64 Some(expr)
65 }
66}
67
68#[doc(hidden)]
69pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
70 field_path
71 .parts()
72 .iter()
73 .map(|f| match f {
74 Field::Name(n) => n.as_ref(),
75 Field::ElementType => todo!("element type not currently handled"),
76 })
77 .chain(iter::once(stat.name()))
78 .join("_")
79 .into()
80}
81
82pub fn checked_pruning_expr(
92 expr: &Expression,
93 available_stats: &FieldPathSet,
94) -> Option<(Expression, RequiredStats)> {
95 let catalog = ScopeStatsCatalog {
96 any_catalog: Default::default(),
97 available_stats,
98 };
99
100 let expr = expr.stat_falsification(&catalog)?;
101
102 let mut relation: Relation<FieldPath, Stat> = Relation::new();
104 for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
105 relation.insert(field_path, stat)
106 }
107
108 Some((expr, relation))
109}
110
111#[cfg(test)]
112mod tests {
113 use rstest::fixture;
114 use rstest::rstest;
115 use vortex_dtype::DType;
116 use vortex_dtype::FieldName;
117 use vortex_dtype::FieldNames;
118 use vortex_dtype::FieldPath;
119 use vortex_dtype::FieldPathSet;
120 use vortex_dtype::Nullability;
121 use vortex_dtype::StructFields;
122 use vortex_utils::aliases::hash_set::HashSet;
123
124 use super::HashMap;
125 use crate::compute::BetweenOptions;
126 use crate::compute::StrictComparison;
127 use crate::expr::exprs::between::between;
128 use crate::expr::exprs::binary::and;
129 use crate::expr::exprs::binary::eq;
130 use crate::expr::exprs::binary::gt;
131 use crate::expr::exprs::binary::gt_eq;
132 use crate::expr::exprs::binary::lt;
133 use crate::expr::exprs::binary::lt_eq;
134 use crate::expr::exprs::binary::not_eq;
135 use crate::expr::exprs::binary::or;
136 use crate::expr::exprs::cast::cast;
137 use crate::expr::exprs::get_item::col;
138 use crate::expr::exprs::get_item::get_item;
139 use crate::expr::exprs::literal::lit;
140 use crate::expr::exprs::root::root;
141 use crate::expr::pruning::checked_pruning_expr;
142 use crate::expr::pruning::field_path_stat_field_name;
143 use crate::expr::stats::Stat;
144
145 #[fixture]
147 fn available_stats() -> FieldPathSet {
148 let field_a = FieldPath::from_name("a");
149 let field_b = FieldPath::from_name("b");
150
151 FieldPathSet::from_iter([
152 field_a.clone().push(Stat::Min.name()),
153 field_a.push(Stat::Max.name()),
154 field_b.clone().push(Stat::Min.name()),
155 field_b.push(Stat::Max.name()),
156 ])
157 }
158
159 #[rstest]
160 pub fn pruning_equals(available_stats: FieldPathSet) {
161 let name = FieldName::from("a");
162 let literal_eq = lit(42);
163 let eq_expr = eq(get_item("a", root()), literal_eq.clone());
164 let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
165 let expected_expr = or(
166 gt(
167 get_item(
168 field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
169 root(),
170 ),
171 literal_eq.clone(),
172 ),
173 gt(
174 literal_eq,
175 col(field_path_stat_field_name(
176 &FieldPath::from_name(name),
177 Stat::Max,
178 )),
179 ),
180 );
181 assert_eq!(&converted, &expected_expr);
182 }
183
184 #[rstest]
185 pub fn pruning_equals_column(available_stats: FieldPathSet) {
186 let column = FieldName::from("a");
187 let other_col = FieldName::from("b");
188 let eq_expr = eq(col(column.clone()), col(other_col.clone()));
189
190 let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
191 assert_eq!(
192 refs.map(),
193 &HashMap::from_iter([
194 (
195 FieldPath::from_name(column.clone()),
196 HashSet::from_iter([Stat::Min, Stat::Max])
197 ),
198 (
199 FieldPath::from_name(other_col.clone()),
200 HashSet::from_iter([Stat::Max, Stat::Min])
201 )
202 ])
203 );
204 let expected_expr = or(
205 gt(
206 col(field_path_stat_field_name(
207 &FieldPath::from_name(column.clone()),
208 Stat::Min,
209 )),
210 col(field_path_stat_field_name(
211 &FieldPath::from_name(other_col.clone()),
212 Stat::Max,
213 )),
214 ),
215 gt(
216 col(field_path_stat_field_name(
217 &FieldPath::from_name(other_col),
218 Stat::Min,
219 )),
220 col(field_path_stat_field_name(
221 &FieldPath::from_name(column),
222 Stat::Max,
223 )),
224 ),
225 );
226 assert_eq!(&converted, &expected_expr);
227 }
228
229 #[rstest]
230 pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
231 let column = FieldName::from("a");
232 let other_col = FieldName::from("b");
233 let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
234
235 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
236 assert_eq!(
237 refs.map(),
238 &HashMap::from_iter([
239 (
240 FieldPath::from_name(column.clone()),
241 HashSet::from_iter([Stat::Min, Stat::Max])
242 ),
243 (
244 FieldPath::from_name(other_col.clone()),
245 HashSet::from_iter([Stat::Max, Stat::Min])
246 )
247 ])
248 );
249 let expected_expr = and(
250 eq(
251 col(field_path_stat_field_name(
252 &FieldPath::from_name(column.clone()),
253 Stat::Min,
254 )),
255 col(field_path_stat_field_name(
256 &FieldPath::from_name(other_col.clone()),
257 Stat::Max,
258 )),
259 ),
260 eq(
261 col(field_path_stat_field_name(
262 &FieldPath::from_name(column),
263 Stat::Max,
264 )),
265 col(field_path_stat_field_name(
266 &FieldPath::from_name(other_col),
267 Stat::Min,
268 )),
269 ),
270 );
271
272 assert_eq!(&converted, &expected_expr);
273 }
274
275 #[rstest]
276 pub fn pruning_gt_column(available_stats: FieldPathSet) {
277 let column = FieldName::from("a");
278 let other_col = FieldName::from("b");
279 let other_expr = col(other_col.clone());
280 let not_eq_expr = gt(col(column.clone()), other_expr);
281
282 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
283 assert_eq!(
284 refs.map(),
285 &HashMap::from_iter([
286 (
287 FieldPath::from_name(column.clone()),
288 HashSet::from_iter([Stat::Max])
289 ),
290 (
291 FieldPath::from_name(other_col.clone()),
292 HashSet::from_iter([Stat::Min])
293 )
294 ])
295 );
296 let expected_expr = lt_eq(
297 col(field_path_stat_field_name(
298 &FieldPath::from_name(column),
299 Stat::Max,
300 )),
301 col(field_path_stat_field_name(
302 &FieldPath::from_name(other_col),
303 Stat::Min,
304 )),
305 );
306 assert_eq!(&converted, &expected_expr);
307 }
308
309 #[rstest]
310 pub fn pruning_gt_value(available_stats: FieldPathSet) {
311 let column = FieldName::from("a");
312 let other_col = lit(42);
313 let not_eq_expr = gt(col(column.clone()), other_col.clone());
314
315 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
316 assert_eq!(
317 refs.map(),
318 &HashMap::from_iter([(
319 FieldPath::from_name(column.clone()),
320 HashSet::from_iter([Stat::Max])
321 ),])
322 );
323 let expected_expr = lt_eq(
324 col(field_path_stat_field_name(
325 &FieldPath::from_name(column),
326 Stat::Max,
327 )),
328 other_col,
329 );
330 assert_eq!(&converted, &(expected_expr));
331 }
332
333 #[rstest]
334 pub fn pruning_lt_column(available_stats: FieldPathSet) {
335 let column = FieldName::from("a");
336 let other_col = FieldName::from("b");
337 let other_expr = col(other_col.clone());
338 let not_eq_expr = lt(col(column.clone()), other_expr);
339
340 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
341 assert_eq!(
342 refs.map(),
343 &HashMap::from_iter([
344 (
345 FieldPath::from_name(column.clone()),
346 HashSet::from_iter([Stat::Min])
347 ),
348 (
349 FieldPath::from_name(other_col.clone()),
350 HashSet::from_iter([Stat::Max])
351 )
352 ])
353 );
354 let expected_expr = gt_eq(
355 col(field_path_stat_field_name(
356 &FieldPath::from_name(column),
357 Stat::Min,
358 )),
359 col(field_path_stat_field_name(
360 &FieldPath::from_name(other_col),
361 Stat::Max,
362 )),
363 );
364 assert_eq!(&converted, &expected_expr);
365 }
366
367 #[rstest]
368 pub fn pruning_lt_value(available_stats: FieldPathSet) {
369 let expr = lt(col("a"), lit(42));
372
373 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
374 assert_eq!(
375 refs.map(),
376 &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
377 );
378 assert_eq!(&converted, >_eq(col("a_min"), lit(42)));
379 }
380
381 #[rstest]
382 fn pruning_identity(available_stats: FieldPathSet) {
383 let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50)));
384
385 let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
386
387 let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
388 assert_eq!(&predicate.to_string(), &expected_expr.to_string());
389 }
390 #[rstest]
391 pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
392 let column = FieldName::from("a");
394 let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
395 let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
396
397 assert_eq!(
399 &predicate,
400 &or(
401 lt_eq(col(FieldName::from("a_max")), lit(10)),
402 gt_eq(col(FieldName::from("a_min")), lit(50)),
403 ),
404 );
405 }
406
407 #[rstest]
408 fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
409 let expr = gt_eq(col("x"), gt(col("y"), col("z")));
434 assert!(checked_pruning_expr(&expr, &available_stats).is_none());
435 }
437
438 #[fixture]
439 fn available_stats_with_nans() -> FieldPathSet {
440 let float_col = FieldPath::from_name("float_col");
441 let int_col = FieldPath::from_name("int_col");
442
443 FieldPathSet::from_iter([
444 float_col.clone().push(Stat::Min.name()),
446 float_col.clone().push(Stat::Max.name()),
447 float_col.push(Stat::NaNCount.name()),
448 int_col.clone().push(Stat::Min.name()),
450 int_col.push(Stat::Max.name()),
451 ])
452 }
453
454 #[rstest]
455 fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
456 let expr = gt_eq(col("float_col"), lit(f32::NAN));
457 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
458 assert_eq!(
459 &converted,
460 &and(
461 and(
462 eq(col("float_col_nan_count"), lit(0u64)),
463 eq(lit(1u64), lit(0u64)),
465 ),
466 lt(col("float_col_max"), lit(f32::NAN)),
470 )
471 );
472
473 let expr = and(
475 gt(col("float_col"), lit(10f32)),
476 lt(col("int_col"), lit(10)),
477 );
478
479 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
480
481 assert_eq!(
482 &converted,
483 &or(
484 and(
486 and(
487 eq(col("float_col_nan_count"), lit(0u64)),
488 eq(lit(0u64), lit(0u64)),
490 ),
491 lt_eq(col("float_col_max"), lit(10f32)),
493 ),
494 gt_eq(col("int_col_min"), lit(10)),
496 )
497 )
498 }
499
500 #[rstest]
501 fn pruning_between(available_stats: FieldPathSet) {
502 let expr = between(
503 col("a"),
504 lit(10),
505 lit(50),
506 BetweenOptions {
507 lower_strict: StrictComparison::NonStrict,
508 upper_strict: StrictComparison::NonStrict,
509 },
510 );
511 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
512 assert_eq!(
513 refs.map(),
514 &HashMap::from_iter([(
515 FieldPath::from_name("a"),
516 HashSet::from_iter([Stat::Min, Stat::Max])
517 )])
518 );
519 assert_eq!(
520 &converted,
521 &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
522 );
523 }
524
525 #[rstest]
526 fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
527 let struct_dtype = DType::Struct(
530 StructFields::new(
531 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
532 vec![
533 DType::Utf8(Nullability::Nullable),
534 DType::Utf8(Nullability::Nullable),
535 ],
536 ),
537 Nullability::NonNullable,
538 );
539 let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
540 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
541 assert_eq!(
542 refs.map(),
543 &HashMap::from_iter([(
544 FieldPath::from_name("a"),
545 HashSet::from_iter([Stat::Min, Stat::Max])
546 )])
547 );
548 assert_eq!(
549 &converted,
550 &or(
551 gt(col("a_min"), lit("value")),
552 gt(lit("value"), col("a_max"))
553 )
554 );
555 }
556}