1use std::iter;
5
6use itertools::Itertools;
7use vortex_array::stats::Stat;
8use vortex_dtype::{Field, FieldName, FieldPath, FieldPathSet};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use super::relation::Relation;
12use crate::{ExprRef, StatsCatalog, get_item, root};
13
14pub type RequiredStats = Relation<FieldPath, Stat>;
15
16#[derive(Default)]
19struct TrackingStatsCatalog {
20 usage: HashMap<(FieldPath, Stat), ExprRef>,
21}
22
23impl TrackingStatsCatalog {
24 fn into_usages(self) -> HashMap<(FieldPath, Stat), ExprRef> {
27 self.usage
28 }
29}
30
31struct ScopeStatsCatalog<'a> {
33 any_catalog: TrackingStatsCatalog,
34 available_stats: &'a FieldPathSet,
35}
36
37impl StatsCatalog for ScopeStatsCatalog<'_> {
38 fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
39 let stat_path = field_path.clone().push(stat.name());
40
41 if self.available_stats.contains(&stat_path) {
42 self.any_catalog.stats_ref(field_path, stat)
43 } else {
44 None
45 }
46 }
47}
48
49impl StatsCatalog for TrackingStatsCatalog {
50 fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
51 let mut expr = root();
52 let name = field_path_stat_field_name(field_path, stat);
53 expr = get_item(name, expr);
54 self.usage.insert((field_path.clone(), stat), expr.clone());
55 Some(expr)
56 }
57}
58
59#[doc(hidden)]
60pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
61 field_path
62 .parts()
63 .iter()
64 .map(|f| match f {
65 Field::Name(n) => n.as_ref(),
66 Field::ElementType => todo!("element type not currently handled"),
67 })
68 .chain(iter::once(stat.name()))
69 .join("_")
70 .into()
71}
72
73pub fn checked_pruning_expr(
83 expr: &ExprRef,
84 available_stats: &FieldPathSet,
85) -> Option<(ExprRef, RequiredStats)> {
86 let mut catalog = ScopeStatsCatalog {
87 any_catalog: Default::default(),
88 available_stats,
89 };
90
91 let expr = expr.stat_falsification(&mut catalog)?;
92
93 let mut relation: Relation<FieldPath, Stat> = Relation::new();
95 for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
96 relation.insert(field_path, stat)
97 }
98
99 Some((expr, relation))
100}
101
102#[cfg(test)]
103mod tests {
104 use rstest::{fixture, rstest};
105 use vortex_array::compute::{BetweenOptions, StrictComparison};
106 use vortex_array::stats::Stat;
107 use vortex_dtype::{FieldName, FieldPath, FieldPathSet};
108
109 use crate::pruning::pruning_expr::HashMap;
110 use crate::pruning::{checked_pruning_expr, field_path_stat_field_name};
111 use crate::{
112 HashSet, and, between, col, eq, get_item, gt, gt_eq, lit, lt, lt_eq, not_eq, or, root,
113 };
114
115 #[fixture]
117 fn available_stats() -> FieldPathSet {
118 let field_a = FieldPath::from_name("a");
119 let field_b = FieldPath::from_name("b");
120
121 FieldPathSet::from_iter([
122 field_a.clone().push(Stat::Min.name()),
123 field_a.push(Stat::Max.name()),
124 field_b.clone().push(Stat::Min.name()),
125 field_b.push(Stat::Max.name()),
126 ])
127 }
128
129 #[rstest]
130 pub fn pruning_equals(available_stats: FieldPathSet) {
131 let name = FieldName::from("a");
132 let literal_eq = lit(42);
133 let eq_expr = eq(get_item("a", root()), literal_eq.clone());
134 let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
135 let expected_expr = or(
136 gt(
137 get_item(
138 field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
139 root(),
140 ),
141 literal_eq.clone(),
142 ),
143 gt(
144 literal_eq,
145 col(field_path_stat_field_name(
146 &FieldPath::from_name(name),
147 Stat::Max,
148 )),
149 ),
150 );
151 assert_eq!(&converted, &expected_expr);
152 }
153
154 #[rstest]
155 pub fn pruning_equals_column(available_stats: FieldPathSet) {
156 let column = FieldName::from("a");
157 let other_col = FieldName::from("b");
158 let eq_expr = eq(col(column.clone()), col(other_col.clone()));
159
160 let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
161 assert_eq!(
162 refs.map(),
163 &HashMap::from_iter([
164 (
165 FieldPath::from_name(column.clone()),
166 HashSet::from_iter([Stat::Min, Stat::Max])
167 ),
168 (
169 FieldPath::from_name(other_col.clone()),
170 HashSet::from_iter([Stat::Max, Stat::Min])
171 )
172 ])
173 );
174 let expected_expr = or(
175 gt(
176 col(field_path_stat_field_name(
177 &FieldPath::from_name(column.clone()),
178 Stat::Min,
179 )),
180 col(field_path_stat_field_name(
181 &FieldPath::from_name(other_col.clone()),
182 Stat::Max,
183 )),
184 ),
185 gt(
186 col(field_path_stat_field_name(
187 &FieldPath::from_name(other_col),
188 Stat::Min,
189 )),
190 col(field_path_stat_field_name(
191 &FieldPath::from_name(column),
192 Stat::Max,
193 )),
194 ),
195 );
196 assert_eq!(&converted, &expected_expr);
197 }
198
199 #[rstest]
200 pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
201 let column = FieldName::from("a");
202 let other_col = FieldName::from("b");
203 let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
204
205 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
206 assert_eq!(
207 refs.map(),
208 &HashMap::from_iter([
209 (
210 FieldPath::from_name(column.clone()),
211 HashSet::from_iter([Stat::Min, Stat::Max])
212 ),
213 (
214 FieldPath::from_name(other_col.clone()),
215 HashSet::from_iter([Stat::Max, Stat::Min])
216 )
217 ])
218 );
219 let expected_expr = and(
220 eq(
221 col(field_path_stat_field_name(
222 &FieldPath::from_name(column.clone()),
223 Stat::Min,
224 )),
225 col(field_path_stat_field_name(
226 &FieldPath::from_name(other_col.clone()),
227 Stat::Max,
228 )),
229 ),
230 eq(
231 col(field_path_stat_field_name(
232 &FieldPath::from_name(column),
233 Stat::Max,
234 )),
235 col(field_path_stat_field_name(
236 &FieldPath::from_name(other_col),
237 Stat::Min,
238 )),
239 ),
240 );
241
242 assert_eq!(&converted, &expected_expr);
243 }
244
245 #[rstest]
246 pub fn pruning_gt_column(available_stats: FieldPathSet) {
247 let column = FieldName::from("a");
248 let other_col = FieldName::from("b");
249 let other_expr = col(other_col.clone());
250 let not_eq_expr = gt(col(column.clone()), other_expr.clone());
251
252 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
253 assert_eq!(
254 refs.map(),
255 &HashMap::from_iter([
256 (
257 FieldPath::from_name(column.clone()),
258 HashSet::from_iter([Stat::Max])
259 ),
260 (
261 FieldPath::from_name(other_col.clone()),
262 HashSet::from_iter([Stat::Min])
263 )
264 ])
265 );
266 let expected_expr = lt_eq(
267 col(field_path_stat_field_name(
268 &FieldPath::from_name(column),
269 Stat::Max,
270 )),
271 col(field_path_stat_field_name(
272 &FieldPath::from_name(other_col),
273 Stat::Min,
274 )),
275 );
276 assert_eq!(&converted, &expected_expr);
277 }
278
279 #[rstest]
280 pub fn pruning_gt_value(available_stats: FieldPathSet) {
281 let column = FieldName::from("a");
282 let other_col = lit(42);
283 let not_eq_expr = gt(col(column.clone()), other_col.clone());
284
285 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
286 assert_eq!(
287 refs.map(),
288 &HashMap::from_iter([(
289 FieldPath::from_name(column.clone()),
290 HashSet::from_iter([Stat::Max])
291 ),])
292 );
293 let expected_expr = lt_eq(
294 col(field_path_stat_field_name(
295 &FieldPath::from_name(column),
296 Stat::Max,
297 )),
298 other_col.clone(),
299 );
300 assert_eq!(&converted, &(expected_expr));
301 }
302
303 #[rstest]
304 pub fn pruning_lt_column(available_stats: FieldPathSet) {
305 let column = FieldName::from("a");
306 let other_col = FieldName::from("b");
307 let other_expr = col(other_col.clone());
308 let not_eq_expr = lt(col(column.clone()), other_expr.clone());
309
310 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
311 assert_eq!(
312 refs.map(),
313 &HashMap::from_iter([
314 (
315 FieldPath::from_name(column.clone()),
316 HashSet::from_iter([Stat::Min])
317 ),
318 (
319 FieldPath::from_name(other_col.clone()),
320 HashSet::from_iter([Stat::Max])
321 )
322 ])
323 );
324 let expected_expr = gt_eq(
325 col(field_path_stat_field_name(
326 &FieldPath::from_name(column),
327 Stat::Min,
328 )),
329 col(field_path_stat_field_name(
330 &FieldPath::from_name(other_col),
331 Stat::Max,
332 )),
333 );
334 assert_eq!(&converted, &expected_expr);
335 }
336
337 #[rstest]
338 pub fn pruning_lt_value(available_stats: FieldPathSet) {
339 let expr = lt(col("a"), lit(42));
342
343 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
344 assert_eq!(
345 refs.map(),
346 &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
347 );
348 assert_eq!(&converted, >_eq(col("a_min"), lit(42)));
349 }
350
351 #[rstest]
352 fn pruning_identity(available_stats: FieldPathSet) {
353 let expr = or(lt(col("a").clone(), lit(10)), gt(col("a").clone(), lit(50)));
354
355 let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
356
357 let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
358 assert_eq!(&predicate.to_string(), &expected_expr.to_string());
359 }
360 #[rstest]
361 pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
362 let column = FieldName::from("a");
364 let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
365 let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
366
367 assert_eq!(
369 &predicate,
370 &or(
371 lt_eq(col(FieldName::from("a_max")), lit(10)),
372 gt_eq(col(FieldName::from("a_min")), lit(50)),
373 ),
374 );
375 }
376
377 #[rstest]
378 fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
379 let expr = gt_eq(col("x"), gt(col("y"), col("z")));
404 assert!(checked_pruning_expr(&expr, &available_stats).is_none());
405 }
407
408 #[fixture]
409 fn available_stats_with_nans() -> FieldPathSet {
410 let float_col = FieldPath::from_name("float_col");
411 let int_col = FieldPath::from_name("int_col");
412
413 FieldPathSet::from_iter([
414 float_col.clone().push(Stat::Min.name()),
416 float_col.clone().push(Stat::Max.name()),
417 float_col.push(Stat::NaNCount.name()),
418 int_col.clone().push(Stat::Min.name()),
420 int_col.push(Stat::Max.name()),
421 ])
422 }
423
424 #[rstest]
425 fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
426 let expr = gt_eq(col("float_col"), lit(f32::NAN));
427 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
428 assert_eq!(
429 &converted,
430 &and(
431 and(
432 eq(col("float_col_nan_count"), lit(0u64)),
433 eq(lit(1u64), lit(0u64)),
435 ),
436 lt(col("float_col_max"), lit(f32::NAN)),
440 )
441 );
442
443 let expr = and(
445 gt(col("float_col"), lit(10f32)),
446 lt(col("int_col"), lit(10)),
447 );
448
449 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
450
451 assert_eq!(
452 &converted,
453 &or(
454 and(
456 and(
457 eq(col("float_col_nan_count"), lit(0u64)),
458 eq(lit(0u64), lit(0u64)),
460 ),
461 lt_eq(col("float_col_max"), lit(10f32)),
463 ),
464 gt_eq(col("int_col_min"), lit(10)),
466 )
467 )
468 }
469
470 #[rstest]
471 fn pruning_between(available_stats: FieldPathSet) {
472 let expr = between(
473 col("a"),
474 lit(10),
475 lit(50),
476 BetweenOptions {
477 lower_strict: StrictComparison::NonStrict,
478 upper_strict: StrictComparison::NonStrict,
479 },
480 );
481 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
482 assert_eq!(
483 refs.map(),
484 &HashMap::from_iter([(
485 FieldPath::from_name("a"),
486 HashSet::from_iter([Stat::Min, Stat::Max])
487 )])
488 );
489 assert_eq!(
490 &converted,
491 &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
492 );
493 }
494}