1use std::cell::RefCell;
5use std::iter;
6
7use itertools::Itertools;
8use vortex_utils::aliases::hash_map::HashMap;
9
10use super::relation::Relation;
11use crate::dtype::Field;
12use crate::dtype::FieldName;
13use crate::dtype::FieldPath;
14use crate::dtype::FieldPathSet;
15use crate::expr::Expression;
16use crate::expr::StatsCatalog;
17use crate::expr::get_item;
18use crate::expr::root;
19use crate::expr::stats::Stat;
20
21pub type RequiredStats = Relation<FieldPath, Stat>;
22
23#[derive(Default)]
26pub(crate) struct TrackingStatsCatalog {
27 usage: RefCell<HashMap<(FieldPath, Stat), Expression>>,
28}
29
30impl TrackingStatsCatalog {
31 fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> {
34 self.usage.into_inner()
35 }
36}
37
38struct ScopeStatsCatalog<'a> {
40 inner: TrackingStatsCatalog,
41 available_stats: &'a FieldPathSet,
42}
43
44impl StatsCatalog for ScopeStatsCatalog<'_> {
45 fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
46 let stat_path = field_path.clone().push(stat.name());
47
48 if self.available_stats.contains(&stat_path) {
49 self.inner.stats_ref(field_path, stat)
50 } else {
51 None
52 }
53 }
54}
55
56impl StatsCatalog for TrackingStatsCatalog {
57 fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
58 let mut expr = root();
59 let name = field_path_stat_field_name(field_path, stat);
60 expr = get_item(name, expr);
61 self.usage
62 .borrow_mut()
63 .insert((field_path.clone(), stat), expr.clone());
64 Some(expr)
65 }
66}
67
68#[doc(hidden)]
69pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
70 field_path
71 .parts()
72 .iter()
73 .map(|f| match f {
74 Field::Name(n) => n.as_ref(),
75 Field::ElementType => todo!("element type not currently handled"),
76 })
77 .chain(iter::once(stat.name()))
78 .join("_")
79 .into()
80}
81
82pub fn checked_pruning_expr(
97 expr: &Expression,
98 available_stats: &FieldPathSet,
99) -> Option<(Expression, RequiredStats)> {
100 let catalog = ScopeStatsCatalog {
101 inner: Default::default(),
102 available_stats,
103 };
104
105 let expr = expr.stat_falsification(&catalog)?;
106
107 let mut relation: Relation<FieldPath, Stat> = Relation::new();
109 for ((field_path, stat), _) in catalog.inner.into_usages() {
110 relation.insert(field_path, stat)
111 }
112
113 Some((expr, relation))
114}
115
116#[cfg(test)]
117mod tests {
118 use rstest::fixture;
119 use rstest::rstest;
120 use vortex_utils::aliases::hash_set::HashSet;
121
122 use super::HashMap;
123 use crate::dtype::DType;
124 use crate::dtype::FieldName;
125 use crate::dtype::FieldNames;
126 use crate::dtype::FieldPath;
127 use crate::dtype::FieldPathSet;
128 use crate::dtype::Nullability;
129 use crate::dtype::StructFields;
130 use crate::expr::and;
131 use crate::expr::between;
132 use crate::expr::cast;
133 use crate::expr::col;
134 use crate::expr::eq;
135 use crate::expr::get_item;
136 use crate::expr::gt;
137 use crate::expr::gt_eq;
138 use crate::expr::lit;
139 use crate::expr::lt;
140 use crate::expr::lt_eq;
141 use crate::expr::not_eq;
142 use crate::expr::or;
143 use crate::expr::pruning::checked_pruning_expr;
144 use crate::expr::pruning::field_path_stat_field_name;
145 use crate::expr::root;
146 use crate::expr::stats::Stat;
147 use crate::scalar_fn::fns::between::BetweenOptions;
148 use crate::scalar_fn::fns::between::StrictComparison;
149
150 #[fixture]
152 fn available_stats() -> FieldPathSet {
153 let field_a = FieldPath::from_name("a");
154 let field_b = FieldPath::from_name("b");
155
156 FieldPathSet::from_iter([
157 field_a.clone().push(Stat::Min.name()),
158 field_a.push(Stat::Max.name()),
159 field_b.clone().push(Stat::Min.name()),
160 field_b.push(Stat::Max.name()),
161 ])
162 }
163
164 #[rstest]
165 pub fn pruning_equals(available_stats: FieldPathSet) {
166 let name = FieldName::from("a");
167 let literal_eq = lit(42);
168 let eq_expr = eq(get_item("a", root()), literal_eq.clone());
169 let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
170 let expected_expr = or(
171 gt(
172 get_item(
173 field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
174 root(),
175 ),
176 literal_eq.clone(),
177 ),
178 gt(
179 literal_eq,
180 col(field_path_stat_field_name(
181 &FieldPath::from_name(name),
182 Stat::Max,
183 )),
184 ),
185 );
186 assert_eq!(&converted, &expected_expr);
187 }
188
189 #[rstest]
190 pub fn pruning_equals_column(available_stats: FieldPathSet) {
191 let column = FieldName::from("a");
192 let other_col = FieldName::from("b");
193 let eq_expr = eq(col(column.clone()), col(other_col.clone()));
194
195 let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
196 assert_eq!(
197 refs.map(),
198 &HashMap::from_iter([
199 (
200 FieldPath::from_name(column.clone()),
201 HashSet::from_iter([Stat::Min, Stat::Max])
202 ),
203 (
204 FieldPath::from_name(other_col.clone()),
205 HashSet::from_iter([Stat::Max, Stat::Min])
206 )
207 ])
208 );
209 let expected_expr = or(
210 gt(
211 col(field_path_stat_field_name(
212 &FieldPath::from_name(column.clone()),
213 Stat::Min,
214 )),
215 col(field_path_stat_field_name(
216 &FieldPath::from_name(other_col.clone()),
217 Stat::Max,
218 )),
219 ),
220 gt(
221 col(field_path_stat_field_name(
222 &FieldPath::from_name(other_col),
223 Stat::Min,
224 )),
225 col(field_path_stat_field_name(
226 &FieldPath::from_name(column),
227 Stat::Max,
228 )),
229 ),
230 );
231 assert_eq!(&converted, &expected_expr);
232 }
233
234 #[rstest]
235 pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
236 let column = FieldName::from("a");
237 let other_col = FieldName::from("b");
238 let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
239
240 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
241 assert_eq!(
242 refs.map(),
243 &HashMap::from_iter([
244 (
245 FieldPath::from_name(column.clone()),
246 HashSet::from_iter([Stat::Min, Stat::Max])
247 ),
248 (
249 FieldPath::from_name(other_col.clone()),
250 HashSet::from_iter([Stat::Max, Stat::Min])
251 )
252 ])
253 );
254 let expected_expr = and(
255 eq(
256 col(field_path_stat_field_name(
257 &FieldPath::from_name(column.clone()),
258 Stat::Min,
259 )),
260 col(field_path_stat_field_name(
261 &FieldPath::from_name(other_col.clone()),
262 Stat::Max,
263 )),
264 ),
265 eq(
266 col(field_path_stat_field_name(
267 &FieldPath::from_name(column),
268 Stat::Max,
269 )),
270 col(field_path_stat_field_name(
271 &FieldPath::from_name(other_col),
272 Stat::Min,
273 )),
274 ),
275 );
276
277 assert_eq!(&converted, &expected_expr);
278 }
279
280 #[rstest]
281 pub fn pruning_gt_column(available_stats: FieldPathSet) {
282 let column = FieldName::from("a");
283 let other_col = FieldName::from("b");
284 let other_expr = col(other_col.clone());
285 let not_eq_expr = gt(col(column.clone()), other_expr);
286
287 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
288 assert_eq!(
289 refs.map(),
290 &HashMap::from_iter([
291 (
292 FieldPath::from_name(column.clone()),
293 HashSet::from_iter([Stat::Max])
294 ),
295 (
296 FieldPath::from_name(other_col.clone()),
297 HashSet::from_iter([Stat::Min])
298 )
299 ])
300 );
301 let expected_expr = lt_eq(
302 col(field_path_stat_field_name(
303 &FieldPath::from_name(column),
304 Stat::Max,
305 )),
306 col(field_path_stat_field_name(
307 &FieldPath::from_name(other_col),
308 Stat::Min,
309 )),
310 );
311 assert_eq!(&converted, &expected_expr);
312 }
313
314 #[rstest]
315 pub fn pruning_gt_value(available_stats: FieldPathSet) {
316 let column = FieldName::from("a");
317 let other_col = lit(42);
318 let not_eq_expr = gt(col(column.clone()), other_col.clone());
319
320 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
321 assert_eq!(
322 refs.map(),
323 &HashMap::from_iter([(
324 FieldPath::from_name(column.clone()),
325 HashSet::from_iter([Stat::Max])
326 ),])
327 );
328 let expected_expr = lt_eq(
329 col(field_path_stat_field_name(
330 &FieldPath::from_name(column),
331 Stat::Max,
332 )),
333 other_col,
334 );
335 assert_eq!(&converted, &(expected_expr));
336 }
337
338 #[rstest]
339 pub fn pruning_lt_column(available_stats: FieldPathSet) {
340 let column = FieldName::from("a");
341 let other_col = FieldName::from("b");
342 let other_expr = col(other_col.clone());
343 let not_eq_expr = lt(col(column.clone()), other_expr);
344
345 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
346 assert_eq!(
347 refs.map(),
348 &HashMap::from_iter([
349 (
350 FieldPath::from_name(column.clone()),
351 HashSet::from_iter([Stat::Min])
352 ),
353 (
354 FieldPath::from_name(other_col.clone()),
355 HashSet::from_iter([Stat::Max])
356 )
357 ])
358 );
359 let expected_expr = gt_eq(
360 col(field_path_stat_field_name(
361 &FieldPath::from_name(column),
362 Stat::Min,
363 )),
364 col(field_path_stat_field_name(
365 &FieldPath::from_name(other_col),
366 Stat::Max,
367 )),
368 );
369 assert_eq!(&converted, &expected_expr);
370 }
371
372 #[rstest]
373 pub fn pruning_lt_value(available_stats: FieldPathSet) {
374 let expr = lt(col("a"), lit(42));
377
378 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
379 assert_eq!(
380 refs.map(),
381 &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
382 );
383 assert_eq!(&converted, >_eq(col("a_min"), lit(42)));
384 }
385
386 #[rstest]
387 fn pruning_identity(available_stats: FieldPathSet) {
388 let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50)));
389
390 let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
391
392 let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
393 assert_eq!(&predicate.to_string(), &expected_expr.to_string());
394 }
395 #[rstest]
396 pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
397 let column = FieldName::from("a");
399 let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
400 let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
401
402 assert_eq!(
404 &predicate,
405 &or(
406 lt_eq(col(FieldName::from("a_max")), lit(10)),
407 gt_eq(col(FieldName::from("a_min")), lit(50)),
408 ),
409 );
410 }
411
412 #[rstest]
413 fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
414 let expr = gt_eq(col("x"), gt(col("y"), col("z")));
439 assert!(checked_pruning_expr(&expr, &available_stats).is_none());
440 }
442
443 #[fixture]
444 fn available_stats_with_nans() -> FieldPathSet {
445 let float_col = FieldPath::from_name("float_col");
446 let int_col = FieldPath::from_name("int_col");
447
448 FieldPathSet::from_iter([
449 float_col.clone().push(Stat::Min.name()),
451 float_col.clone().push(Stat::Max.name()),
452 float_col.push(Stat::NaNCount.name()),
453 int_col.clone().push(Stat::Min.name()),
455 int_col.push(Stat::Max.name()),
456 ])
457 }
458
459 #[rstest]
460 fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
461 let expr = gt_eq(col("float_col"), lit(f32::NAN));
462 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
463 assert_eq!(
464 &converted,
465 &and(
466 and(
467 eq(col("float_col_nan_count"), lit(0u64)),
468 eq(lit(1u64), lit(0u64)),
470 ),
471 lt(col("float_col_max"), lit(f32::NAN)),
475 )
476 );
477
478 let expr = and(
480 gt(col("float_col"), lit(10f32)),
481 lt(col("int_col"), lit(10)),
482 );
483
484 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
485
486 assert_eq!(
487 &converted,
488 &or(
489 and(
491 and(
492 eq(col("float_col_nan_count"), lit(0u64)),
493 eq(lit(0u64), lit(0u64)),
495 ),
496 lt_eq(col("float_col_max"), lit(10f32)),
498 ),
499 gt_eq(col("int_col_min"), lit(10)),
501 )
502 )
503 }
504
505 #[rstest]
506 fn pruning_between(available_stats: FieldPathSet) {
507 let expr = between(
508 col("a"),
509 lit(10),
510 lit(50),
511 BetweenOptions {
512 lower_strict: StrictComparison::NonStrict,
513 upper_strict: StrictComparison::NonStrict,
514 },
515 );
516 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
517 assert_eq!(
518 refs.map(),
519 &HashMap::from_iter([(
520 FieldPath::from_name("a"),
521 HashSet::from_iter([Stat::Min, Stat::Max])
522 )])
523 );
524 assert_eq!(
525 &converted,
526 &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
527 );
528 }
529
530 #[rstest]
531 fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
532 let struct_dtype = DType::Struct(
535 StructFields::new(
536 FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
537 vec![
538 DType::Utf8(Nullability::Nullable),
539 DType::Utf8(Nullability::Nullable),
540 ],
541 ),
542 Nullability::NonNullable,
543 );
544 let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
545 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
546 assert_eq!(
547 refs.map(),
548 &HashMap::from_iter([(
549 FieldPath::from_name("a"),
550 HashSet::from_iter([Stat::Min, Stat::Max])
551 )])
552 );
553 assert_eq!(
554 &converted,
555 &or(
556 gt(col("a_min"), lit("value")),
557 gt(lit("value"), col("a_max"))
558 )
559 );
560 }
561}