1use std::iter;
5
6use itertools::Itertools;
7use vortex_array::stats::Stat;
8use vortex_dtype::{Field, FieldName, FieldPath, FieldPathSet};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use super::relation::Relation;
12use crate::{ExprRef, StatsCatalog, get_item, root};
13
14pub type RequiredStats = Relation<FieldPath, Stat>;
15
16#[derive(Default)]
19struct TrackingStatsCatalog {
20 usage: HashMap<(FieldPath, Stat), ExprRef>,
21}
22
23impl TrackingStatsCatalog {
24 fn into_usages(self) -> HashMap<(FieldPath, Stat), ExprRef> {
27 self.usage
28 }
29}
30
31struct ScopeStatsCatalog<'a> {
33 any_catalog: TrackingStatsCatalog,
34 available_stats: &'a FieldPathSet,
35}
36
37impl StatsCatalog for ScopeStatsCatalog<'_> {
38 fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
39 let stat_path = field_path.clone().push(stat.name());
40
41 if self.available_stats.contains(&stat_path) {
42 self.any_catalog.stats_ref(field_path, stat)
43 } else {
44 None
45 }
46 }
47}
48
49impl StatsCatalog for TrackingStatsCatalog {
50 fn stats_ref(&mut self, field_path: &FieldPath, stat: Stat) -> Option<ExprRef> {
51 let mut expr = root();
52 let name = field_path_stat_field_name(field_path, stat);
53 expr = get_item(name, expr);
54 self.usage.insert((field_path.clone(), stat), expr.clone());
55 Some(expr)
56 }
57}
58
59#[doc(hidden)]
60pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
61 field_path
62 .parts()
63 .iter()
64 .map(|f| match f {
65 Field::Name(n) => n.as_ref(),
66 Field::ElementType => todo!("element type not currently handled"),
67 })
68 .chain(iter::once(stat.name()))
69 .join("_")
70 .into()
71}
72
73pub fn checked_pruning_expr(
83 expr: &ExprRef,
84 available_stats: &FieldPathSet,
85) -> Option<(ExprRef, RequiredStats)> {
86 let mut catalog = ScopeStatsCatalog {
87 any_catalog: Default::default(),
88 available_stats,
89 };
90
91 let expr = expr.stat_falsification(&mut catalog)?;
92
93 let mut relation: Relation<FieldPath, Stat> = Relation::new();
95 for ((field_path, stat), _) in catalog.any_catalog.into_usages() {
96 relation.insert(field_path, stat)
97 }
98
99 Some((expr, relation))
100}
101
102#[cfg(test)]
103mod tests {
104 use rstest::{fixture, rstest};
105 use vortex_array::stats::Stat;
106 use vortex_dtype::{FieldName, FieldPath, FieldPathSet};
107
108 use crate::pruning::pruning_expr::HashMap;
109 use crate::pruning::{checked_pruning_expr, field_path_stat_field_name};
110 use crate::{HashSet, and, col, eq, get_item, gt, gt_eq, lit, lt, lt_eq, not_eq, or, root};
111
112 #[fixture]
114 fn available_stats() -> FieldPathSet {
115 let field_a = FieldPath::from_name("a");
116 let field_b = FieldPath::from_name("b");
117
118 FieldPathSet::from_iter([
119 field_a.clone().push(Stat::Min.name()),
120 field_a.push(Stat::Max.name()),
121 field_b.clone().push(Stat::Min.name()),
122 field_b.push(Stat::Max.name()),
123 ])
124 }
125
126 #[rstest]
127 pub fn pruning_equals(available_stats: FieldPathSet) {
128 let name = FieldName::from("a");
129 let literal_eq = lit(42);
130 let eq_expr = eq(get_item("a", root()), literal_eq.clone());
131 let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
132 let expected_expr = or(
133 gt(
134 get_item(
135 field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
136 root(),
137 ),
138 literal_eq.clone(),
139 ),
140 gt(
141 literal_eq,
142 col(field_path_stat_field_name(
143 &FieldPath::from_name(name),
144 Stat::Max,
145 )),
146 ),
147 );
148 assert_eq!(&converted, &expected_expr);
149 }
150
151 #[rstest]
152 pub fn pruning_equals_column(available_stats: FieldPathSet) {
153 let column = FieldName::from("a");
154 let other_col = FieldName::from("b");
155 let eq_expr = eq(col(column.clone()), col(other_col.clone()));
156
157 let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
158 assert_eq!(
159 refs.map(),
160 &HashMap::from_iter([
161 (
162 FieldPath::from_name(column.clone()),
163 HashSet::from_iter([Stat::Min, Stat::Max])
164 ),
165 (
166 FieldPath::from_name(other_col.clone()),
167 HashSet::from_iter([Stat::Max, Stat::Min])
168 )
169 ])
170 );
171 let expected_expr = or(
172 gt(
173 col(field_path_stat_field_name(
174 &FieldPath::from_name(column.clone()),
175 Stat::Min,
176 )),
177 col(field_path_stat_field_name(
178 &FieldPath::from_name(other_col.clone()),
179 Stat::Max,
180 )),
181 ),
182 gt(
183 col(field_path_stat_field_name(
184 &FieldPath::from_name(other_col),
185 Stat::Min,
186 )),
187 col(field_path_stat_field_name(
188 &FieldPath::from_name(column),
189 Stat::Max,
190 )),
191 ),
192 );
193 assert_eq!(&converted, &expected_expr);
194 }
195
196 #[rstest]
197 pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
198 let column = FieldName::from("a");
199 let other_col = FieldName::from("b");
200 let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
201
202 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
203 assert_eq!(
204 refs.map(),
205 &HashMap::from_iter([
206 (
207 FieldPath::from_name(column.clone()),
208 HashSet::from_iter([Stat::Min, Stat::Max])
209 ),
210 (
211 FieldPath::from_name(other_col.clone()),
212 HashSet::from_iter([Stat::Max, Stat::Min])
213 )
214 ])
215 );
216 let expected_expr = and(
217 eq(
218 col(field_path_stat_field_name(
219 &FieldPath::from_name(column.clone()),
220 Stat::Min,
221 )),
222 col(field_path_stat_field_name(
223 &FieldPath::from_name(other_col.clone()),
224 Stat::Max,
225 )),
226 ),
227 eq(
228 col(field_path_stat_field_name(
229 &FieldPath::from_name(column),
230 Stat::Max,
231 )),
232 col(field_path_stat_field_name(
233 &FieldPath::from_name(other_col),
234 Stat::Min,
235 )),
236 ),
237 );
238
239 assert_eq!(&converted, &expected_expr);
240 }
241
242 #[rstest]
243 pub fn pruning_gt_column(available_stats: FieldPathSet) {
244 let column = FieldName::from("a");
245 let other_col = FieldName::from("b");
246 let other_expr = col(other_col.clone());
247 let not_eq_expr = gt(col(column.clone()), other_expr.clone());
248
249 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
250 assert_eq!(
251 refs.map(),
252 &HashMap::from_iter([
253 (
254 FieldPath::from_name(column.clone()),
255 HashSet::from_iter([Stat::Max])
256 ),
257 (
258 FieldPath::from_name(other_col.clone()),
259 HashSet::from_iter([Stat::Min])
260 )
261 ])
262 );
263 let expected_expr = lt_eq(
264 col(field_path_stat_field_name(
265 &FieldPath::from_name(column),
266 Stat::Max,
267 )),
268 col(field_path_stat_field_name(
269 &FieldPath::from_name(other_col),
270 Stat::Min,
271 )),
272 );
273 assert_eq!(&converted, &expected_expr);
274 }
275
276 #[rstest]
277 pub fn pruning_gt_value(available_stats: FieldPathSet) {
278 let column = FieldName::from("a");
279 let other_col = lit(42);
280 let not_eq_expr = gt(col(column.clone()), other_col.clone());
281
282 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
283 assert_eq!(
284 refs.map(),
285 &HashMap::from_iter([(
286 FieldPath::from_name(column.clone()),
287 HashSet::from_iter([Stat::Max])
288 ),])
289 );
290 let expected_expr = lt_eq(
291 col(field_path_stat_field_name(
292 &FieldPath::from_name(column),
293 Stat::Max,
294 )),
295 other_col.clone(),
296 );
297 assert_eq!(&converted, &(expected_expr));
298 }
299
300 #[rstest]
301 pub fn pruning_lt_column(available_stats: FieldPathSet) {
302 let column = FieldName::from("a");
303 let other_col = FieldName::from("b");
304 let other_expr = col(other_col.clone());
305 let not_eq_expr = lt(col(column.clone()), other_expr.clone());
306
307 let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
308 assert_eq!(
309 refs.map(),
310 &HashMap::from_iter([
311 (
312 FieldPath::from_name(column.clone()),
313 HashSet::from_iter([Stat::Min])
314 ),
315 (
316 FieldPath::from_name(other_col.clone()),
317 HashSet::from_iter([Stat::Max])
318 )
319 ])
320 );
321 let expected_expr = gt_eq(
322 col(field_path_stat_field_name(
323 &FieldPath::from_name(column),
324 Stat::Min,
325 )),
326 col(field_path_stat_field_name(
327 &FieldPath::from_name(other_col),
328 Stat::Max,
329 )),
330 );
331 assert_eq!(&converted, &expected_expr);
332 }
333
334 #[rstest]
335 pub fn pruning_lt_value(available_stats: FieldPathSet) {
336 let expr = lt(col("a"), lit(42));
339
340 let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
341 assert_eq!(
342 refs.map(),
343 &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
344 );
345 assert_eq!(&converted, >_eq(col("a_min"), lit(42)));
346 }
347
348 #[rstest]
349 fn pruning_identity(available_stats: FieldPathSet) {
350 let expr = or(lt(col("a").clone(), lit(10)), gt(col("a").clone(), lit(50)));
351
352 let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
353
354 let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
355 assert_eq!(&predicate.to_string(), &expected_expr.to_string());
356 }
357 #[rstest]
358 pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
359 let column = FieldName::from("a");
361 let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
362 let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
363
364 assert_eq!(
366 &predicate,
367 &or(
368 lt_eq(col(FieldName::from("a_max")), lit(10)),
369 gt_eq(col(FieldName::from("a_min")), lit(50)),
370 ),
371 );
372 }
373
374 #[rstest]
375 fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
376 let expr = gt_eq(col("x"), gt(col("y"), col("z")));
401 assert!(checked_pruning_expr(&expr, &available_stats).is_none());
402 }
404
405 #[fixture]
406 fn available_stats_with_nans() -> FieldPathSet {
407 let float_col = FieldPath::from_name("float_col");
408 let int_col = FieldPath::from_name("int_col");
409
410 FieldPathSet::from_iter([
411 float_col.clone().push(Stat::Min.name()),
413 float_col.clone().push(Stat::Max.name()),
414 float_col.push(Stat::NaNCount.name()),
415 int_col.clone().push(Stat::Min.name()),
417 int_col.push(Stat::Max.name()),
418 ])
419 }
420
421 #[rstest]
422 fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
423 let expr = gt_eq(col("float_col"), lit(f32::NAN));
424 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
425 assert_eq!(
426 &converted,
427 &and(
428 and(
429 eq(col("float_col_nan_count"), lit(0u64)),
430 eq(lit(1u64), lit(0u64)),
432 ),
433 lt(col("float_col_max"), lit(f32::NAN)),
437 )
438 );
439
440 let expr = and(
442 gt(col("float_col"), lit(10f32)),
443 lt(col("int_col"), lit(10)),
444 );
445
446 let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
447
448 assert_eq!(
449 &converted,
450 &or(
451 and(
453 and(
454 eq(col("float_col_nan_count"), lit(0u64)),
455 eq(lit(0u64), lit(0u64)),
457 ),
458 lt_eq(col("float_col_max"), lit(10f32)),
460 ),
461 gt_eq(col("int_col_min"), lit(10)),
463 )
464 )
465 }
466}