1use std::iter;
2
3use itertools::Itertools;
4use vortex_array::stats::Stat;
5use vortex_array::{Array, ArrayRef};
6use vortex_dtype::{Field, FieldName, FieldPath};
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_utils::aliases::hash_map::HashMap;
9use vortex_utils::aliases::hash_set::HashSet;
10
11use super::relation::Relation;
12use crate::{AccessPath, ExprRef, Scope, StatsCatalog, get_item, var};
13
14#[derive(Debug, Clone)]
15pub struct PruningPredicate {
16 expr: ExprRef,
17 required_stats: Relation<AccessPath, Stat>,
18}
19
20impl PruningPredicate {
21 pub fn try_new(original_expr: &ExprRef) -> Option<Self> {
22 let (expr, required_stats) = pruning_expr(original_expr)?;
23
24 Some(Self {
25 expr,
26 required_stats,
27 })
28 }
29
30 pub fn expr(&self) -> &ExprRef {
31 &self.expr
32 }
33
34 pub fn required_stats(&self) -> &HashMap<AccessPath, HashSet<Stat>> {
35 self.required_stats.map()
36 }
37
38 pub fn evaluate(&self, metadata: &Scope) -> VortexResult<Option<ArrayRef>> {
44 let known_stats = metadata
47 .iter()
48 .flat_map(|(access_path, stat)| {
49 stat.dtype()
50 .as_struct()
51 .vortex_expect("metadata must be struct array")
52 .names()
53 .iter()
54 .map(|n| {
55 (
56 AccessPath::new(FieldPath::root(), access_path.clone()),
57 n.clone(),
58 )
59 })
60 })
61 .collect::<HashSet<(AccessPath, FieldName)>>();
62 let required_stats = self
63 .required_stats()
64 .iter()
65 .flat_map(|(path, stats)| stats.iter().map(|s| (path.clone(), s.name().into())))
66 .collect::<HashSet<(AccessPath, FieldName)>>();
67
68 let missing_stats = required_stats.difference(&known_stats).collect::<Vec<_>>();
69
70 if !missing_stats.is_empty() {
71 return Ok(None);
72 }
73
74 self.expr.evaluate(metadata).map(Some)
75 }
76}
77
78#[derive(Default)]
79struct FileStatsCatalog {
80 usage: HashMap<(AccessPath, Stat), ExprRef>,
81}
82
83impl StatsCatalog for FileStatsCatalog {
84 fn stats_ref(&mut self, access_path: &AccessPath, stat: Stat) -> Option<ExprRef> {
85 let mut expr = var(access_path.identifier().clone());
86 let name = access_path_stat_field_name(access_path, stat);
87 expr = get_item(name, expr);
88 self.usage.insert((access_path.clone(), stat), expr.clone());
89 Some(expr)
90 }
91}
92
93pub fn access_path_stat_field_name(access_path: &AccessPath, stat: Stat) -> FieldName {
94 access_path
95 .field_path
96 .path()
97 .iter()
98 .map(|f| match f {
99 Field::Name(n) => n.as_ref(),
100 Field::ElementType => todo!("element type not currently handled"),
101 })
102 .chain(iter::once(stat.name()))
103 .join("_")
104 .into()
105}
106
107#[allow(clippy::type_complexity)]
108pub fn pruning_expr(expr: &ExprRef) -> Option<(ExprRef, Relation<AccessPath, Stat>)> {
110 let mut catalog = FileStatsCatalog {
111 ..Default::default()
112 };
113 let expr = expr.stat_falsification(&mut catalog)?;
114
115 let mut relation: Relation<AccessPath, Stat> = Relation::new();
116 for ((field_path, stat), _) in catalog.usage.into_iter() {
117 relation.insert(field_path, stat)
118 }
119
120 Some((expr, relation))
121}
122
123#[cfg(test)]
124mod tests {
125 use vortex_array::stats::Stat;
126 use vortex_dtype::FieldName;
127
128 use crate::pruning::pruning_predicate::{HashMap, pruning_expr};
129 use crate::pruning::{PruningPredicate, access_path_stat_field_name};
130 use crate::{
131 AccessPath, HashSet, and, col, eq, get_item, get_item_scope, gt, gt_eq, lit, lt, lt_eq,
132 not_eq, or, root,
133 };
134
135 #[test]
136 pub fn pruning_equals() {
137 let name = FieldName::from("a");
138 let literal_eq = lit(42);
139 let eq_expr = eq(get_item("a", root()), literal_eq.clone());
140 let (converted, _refs) = pruning_expr(&eq_expr).unwrap();
141 let expected_expr = or(
142 gt(
143 get_item(
144 access_path_stat_field_name(&AccessPath::root_field(name.clone()), Stat::Min),
145 root(),
146 ),
147 literal_eq.clone(),
148 ),
149 gt(
150 literal_eq,
151 get_item_scope(access_path_stat_field_name(
152 &AccessPath::root_field(name),
153 Stat::Max,
154 )),
155 ),
156 );
157 assert_eq!(&converted, &expected_expr);
158 }
159
160 #[test]
161 pub fn pruning_equals_column() {
162 let column = FieldName::from("a");
163 let other_col = FieldName::from("b");
164 let eq_expr = eq(
165 get_item_scope(column.clone()),
166 get_item_scope(other_col.clone()),
167 );
168
169 let (converted, refs) = pruning_expr(&eq_expr).unwrap();
170 assert_eq!(
171 refs.map(),
172 &HashMap::from_iter([
173 (
174 AccessPath::root_field(column.clone()),
175 HashSet::from_iter([Stat::Min, Stat::Max])
176 ),
177 (
178 AccessPath::root_field(other_col.clone()),
179 HashSet::from_iter([Stat::Max, Stat::Min])
180 )
181 ])
182 );
183 let expected_expr = or(
184 gt(
185 get_item_scope(access_path_stat_field_name(
186 &AccessPath::root_field(column.clone()),
187 Stat::Min,
188 )),
189 get_item_scope(access_path_stat_field_name(
190 &AccessPath::root_field(other_col.clone()),
191 Stat::Max,
192 )),
193 ),
194 gt(
195 get_item_scope(access_path_stat_field_name(
196 &AccessPath::root_field(other_col),
197 Stat::Min,
198 )),
199 get_item_scope(access_path_stat_field_name(
200 &AccessPath::root_field(column),
201 Stat::Max,
202 )),
203 ),
204 );
205 assert_eq!(&converted, &expected_expr);
206 }
207
208 #[test]
209 pub fn pruning_not_equals_column() {
210 let column = FieldName::from("a");
211 let other_col = FieldName::from("b");
212 let not_eq_expr = not_eq(
213 get_item_scope(column.clone()),
214 get_item_scope(other_col.clone()),
215 );
216
217 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
218 assert_eq!(
219 refs.map(),
220 &HashMap::from_iter([
221 (
222 AccessPath::root_field(column.clone()),
223 HashSet::from_iter([Stat::Min, Stat::Max])
224 ),
225 (
226 AccessPath::root_field(other_col.clone()),
227 HashSet::from_iter([Stat::Max, Stat::Min])
228 )
229 ])
230 );
231 let expected_expr = and(
232 eq(
233 get_item_scope(access_path_stat_field_name(
234 &AccessPath::root_field(column.clone()),
235 Stat::Min,
236 )),
237 get_item_scope(access_path_stat_field_name(
238 &AccessPath::root_field(other_col.clone()),
239 Stat::Max,
240 )),
241 ),
242 eq(
243 get_item_scope(access_path_stat_field_name(
244 &AccessPath::root_field(column),
245 Stat::Max,
246 )),
247 get_item_scope(access_path_stat_field_name(
248 &AccessPath::root_field(other_col),
249 Stat::Min,
250 )),
251 ),
252 );
253
254 assert_eq!(&converted, &expected_expr);
255 }
256
257 #[test]
258 pub fn pruning_gt_column() {
259 let column = FieldName::from("a");
260 let other_col = FieldName::from("b");
261 let other_expr = get_item_scope(other_col.clone());
262 let not_eq_expr = gt(get_item_scope(column.clone()), other_expr.clone());
263
264 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
265 assert_eq!(
266 refs.map(),
267 &HashMap::from_iter([
268 (
269 AccessPath::root_field(column.clone()),
270 HashSet::from_iter([Stat::Max])
271 ),
272 (
273 AccessPath::root_field(other_col.clone()),
274 HashSet::from_iter([Stat::Min])
275 )
276 ])
277 );
278 let expected_expr = lt_eq(
279 get_item_scope(access_path_stat_field_name(
280 &AccessPath::root_field(column),
281 Stat::Max,
282 )),
283 get_item_scope(access_path_stat_field_name(
284 &AccessPath::root_field(other_col),
285 Stat::Min,
286 )),
287 );
288 assert_eq!(&converted, &expected_expr);
289 }
290
291 #[test]
292 pub fn pruning_gt_value() {
293 let column = FieldName::from("a");
294 let other_col = lit(42);
295 let not_eq_expr = gt(get_item_scope(column.clone()), other_col.clone());
296
297 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
298 assert_eq!(
299 refs.map(),
300 &HashMap::from_iter([(
301 AccessPath::root_field(column.clone()),
302 HashSet::from_iter([Stat::Max])
303 ),])
304 );
305 let expected_expr = lt_eq(
306 get_item_scope(access_path_stat_field_name(
307 &AccessPath::root_field(column),
308 Stat::Max,
309 )),
310 other_col.clone(),
311 );
312 assert_eq!(&converted, &(expected_expr));
313 }
314
315 #[test]
316 pub fn pruning_lt_column() {
317 let column = FieldName::from("a");
318 let other_col = FieldName::from("b");
319 let other_expr = get_item_scope(other_col.clone());
320 let not_eq_expr = lt(get_item_scope(column.clone()), other_expr.clone());
321
322 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
323 assert_eq!(
324 refs.map(),
325 &HashMap::from_iter([
326 (
327 AccessPath::root_field(column.clone()),
328 HashSet::from_iter([Stat::Min])
329 ),
330 (
331 AccessPath::root_field(other_col.clone()),
332 HashSet::from_iter([Stat::Max])
333 )
334 ])
335 );
336 let expected_expr = gt_eq(
337 get_item_scope(access_path_stat_field_name(
338 &AccessPath::root_field(column),
339 Stat::Min,
340 )),
341 get_item_scope(access_path_stat_field_name(
342 &AccessPath::root_field(other_col),
343 Stat::Max,
344 )),
345 );
346 assert_eq!(&converted, &expected_expr);
347 }
348
349 #[test]
350 pub fn pruning_lt_value() {
351 let column = FieldName::from("a");
352 let other_col = lit(42);
353 let not_eq_expr = lt(get_item_scope(column.clone()), other_col.clone());
354
355 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
356 assert_eq!(
357 refs.map(),
358 &HashMap::from_iter([(
359 AccessPath::root_field(column.clone()),
360 HashSet::from_iter([Stat::Min])
361 )])
362 );
363 let expected_expr = gt_eq(
364 get_item_scope(access_path_stat_field_name(
365 &AccessPath::root_field(column),
366 Stat::Min,
367 )),
368 other_col.clone(),
369 );
370 assert_eq!(&converted, &expected_expr);
371 }
372
373 #[test]
374 fn pruning_identity() {
375 let expr = or(lt(root().clone(), lit(10)), gt(root().clone(), lit(50)));
376
377 let (predicate, _) = pruning_expr(&expr).unwrap();
378
379 let expected_expr = and(
380 gt_eq(get_item_scope(FieldName::from("min")), lit(10)),
381 lt_eq(get_item_scope(FieldName::from("max")), lit(50)),
382 );
383 assert_eq!(&predicate, &expected_expr)
384 }
385 #[test]
386 pub fn pruning_and_or_operators() {
387 let column = FieldName::from("a");
389 let and_expr = and(
390 gt(get_item_scope(column.clone()), lit(10)),
391 lt(get_item_scope(column), lit(50)),
392 );
393 let (predicate, _) = pruning_expr(&and_expr).unwrap();
394
395 assert_eq!(
397 &predicate,
398 &or(
399 lt_eq(get_item_scope(FieldName::from("a_max")), lit(10)),
400 gt_eq(get_item_scope(FieldName::from("a_min")), lit(50))
401 ),
402 );
403 }
404
405 #[test]
406 fn test_gt_eq_with_booleans() {
407 let expr = gt_eq(col("x"), gt(col("y"), col("z")));
444 assert!(PruningPredicate::try_new(&expr).is_none());
445 }
447}