1use std::iter;
2
3use itertools::Itertools;
4use vortex_array::stats::Stat;
5use vortex_dtype::{Field, FieldName, FieldPath};
6use vortex_utils::aliases::hash_map::HashMap;
7
8use super::relation::Relation;
9use crate::{AccessPath, ExprRef, ScopeFieldPathSet, StatsCatalog, get_item, var};
10
11pub type RequiredStats = Relation<AccessPath, Stat>;
12
13#[derive(Default)]
15struct AnyStatsCatalog {
16 usage: HashMap<(AccessPath, Stat), ExprRef>,
17}
18
19struct ScopeStatsCatalog<'a> {
21 any_catalog: AnyStatsCatalog,
22 scope_field_paths: &'a ScopeFieldPathSet,
23}
24
25impl StatsCatalog for ScopeStatsCatalog<'_> {
26 fn stats_ref(&mut self, access_path: &AccessPath, stat: Stat) -> Option<ExprRef> {
27 let set = self.scope_field_paths.set(access_path.identifier())?;
28
29 let stat_path = access_path
30 .field_path
31 .clone()
32 .push(Field::Name(stat.name().into()));
33
34 if set.contains(&stat_path) {
35 self.any_catalog.stats_ref(access_path, stat)
36 } else {
37 None
38 }
39 }
40}
41
42impl StatsCatalog for AnyStatsCatalog {
43 fn stats_ref(&mut self, access_path: &AccessPath, stat: Stat) -> Option<ExprRef> {
44 let mut expr = var(access_path.identifier().clone());
45 let name = field_path_stat_field_name(access_path.field_path(), stat);
46 expr = get_item(name, expr);
47 self.usage.insert((access_path.clone(), stat), expr.clone());
48 Some(expr)
49 }
50}
51
52pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
53 field_path
54 .path()
55 .iter()
56 .map(|f| match f {
57 Field::Name(n) => n.as_ref(),
58 Field::ElementType => todo!("element type not currently handled"),
59 })
60 .chain(iter::once(stat.name()))
61 .join("_")
62 .into()
63}
64
65pub fn pruning_expr(expr: &ExprRef) -> Option<(ExprRef, RequiredStats)> {
68 let mut catalog = AnyStatsCatalog {
69 ..Default::default()
70 };
71 let expr = expr.stat_falsification(&mut catalog)?;
72
73 let mut relation: Relation<AccessPath, Stat> = Relation::new();
75 for ((field_path, stat), _) in catalog.usage.into_iter() {
76 relation.insert(field_path, stat)
77 }
78
79 Some((expr, relation))
80}
81
82pub fn checked_pruning_expr(
87 expr: &ExprRef,
88 scope_field_paths: &ScopeFieldPathSet,
89) -> Option<(ExprRef, RequiredStats)> {
90 let mut catalog = ScopeStatsCatalog {
91 any_catalog: Default::default(),
92 scope_field_paths,
93 };
94
95 let expr = expr.stat_falsification(&mut catalog)?;
96
97 let mut relation: Relation<AccessPath, Stat> = Relation::new();
99 for ((field_path, stat), _) in catalog.any_catalog.usage.into_iter() {
100 relation.insert(field_path, stat)
101 }
102
103 Some((expr, relation))
104}
105
106#[cfg(test)]
107mod tests {
108 use vortex_array::stats::Stat;
109 use vortex_dtype::{FieldName, FieldPath};
110
111 use crate::pruning::field_path_stat_field_name;
112 use crate::pruning::pruning_expr::{HashMap, pruning_expr};
113 use crate::{
114 AccessPath, HashSet, and, col, eq, get_item, get_item_scope, gt, gt_eq, lit, lt, lt_eq,
115 not_eq, or, root,
116 };
117
118 #[test]
119 pub fn pruning_equals() {
120 let name = FieldName::from("a");
121 let literal_eq = lit(42);
122 let eq_expr = eq(get_item("a", root()), literal_eq.clone());
123 let (converted, _refs) = pruning_expr(&eq_expr).unwrap();
124 let expected_expr = or(
125 gt(
126 get_item(
127 field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
128 root(),
129 ),
130 literal_eq.clone(),
131 ),
132 gt(
133 literal_eq,
134 get_item_scope(field_path_stat_field_name(
135 &FieldPath::from_name(name),
136 Stat::Max,
137 )),
138 ),
139 );
140 assert_eq!(&converted, &expected_expr);
141 }
142
143 #[test]
144 pub fn pruning_equals_column() {
145 let column = FieldName::from("a");
146 let other_col = FieldName::from("b");
147 let eq_expr = eq(
148 get_item_scope(column.clone()),
149 get_item_scope(other_col.clone()),
150 );
151
152 let (converted, refs) = pruning_expr(&eq_expr).unwrap();
153 assert_eq!(
154 refs.map(),
155 &HashMap::from_iter([
156 (
157 AccessPath::root_field(column.clone()),
158 HashSet::from_iter([Stat::Min, Stat::Max])
159 ),
160 (
161 AccessPath::root_field(other_col.clone()),
162 HashSet::from_iter([Stat::Max, Stat::Min])
163 )
164 ])
165 );
166 let expected_expr = or(
167 gt(
168 get_item_scope(field_path_stat_field_name(
169 &FieldPath::from_name(column.clone()),
170 Stat::Min,
171 )),
172 get_item_scope(field_path_stat_field_name(
173 &FieldPath::from_name(other_col.clone()),
174 Stat::Max,
175 )),
176 ),
177 gt(
178 get_item_scope(field_path_stat_field_name(
179 &FieldPath::from_name(other_col),
180 Stat::Min,
181 )),
182 get_item_scope(field_path_stat_field_name(
183 &FieldPath::from_name(column),
184 Stat::Max,
185 )),
186 ),
187 );
188 assert_eq!(&converted, &expected_expr);
189 }
190
191 #[test]
192 pub fn pruning_not_equals_column() {
193 let column = FieldName::from("a");
194 let other_col = FieldName::from("b");
195 let not_eq_expr = not_eq(
196 get_item_scope(column.clone()),
197 get_item_scope(other_col.clone()),
198 );
199
200 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
201 assert_eq!(
202 refs.map(),
203 &HashMap::from_iter([
204 (
205 AccessPath::root_field(column.clone()),
206 HashSet::from_iter([Stat::Min, Stat::Max])
207 ),
208 (
209 AccessPath::root_field(other_col.clone()),
210 HashSet::from_iter([Stat::Max, Stat::Min])
211 )
212 ])
213 );
214 let expected_expr = and(
215 eq(
216 get_item_scope(field_path_stat_field_name(
217 &FieldPath::from_name(column.clone()),
218 Stat::Min,
219 )),
220 get_item_scope(field_path_stat_field_name(
221 &FieldPath::from_name(other_col.clone()),
222 Stat::Max,
223 )),
224 ),
225 eq(
226 get_item_scope(field_path_stat_field_name(
227 &FieldPath::from_name(column),
228 Stat::Max,
229 )),
230 get_item_scope(field_path_stat_field_name(
231 &FieldPath::from_name(other_col),
232 Stat::Min,
233 )),
234 ),
235 );
236
237 assert_eq!(&converted, &expected_expr);
238 }
239
240 #[test]
241 pub fn pruning_gt_column() {
242 let column = FieldName::from("a");
243 let other_col = FieldName::from("b");
244 let other_expr = get_item_scope(other_col.clone());
245 let not_eq_expr = gt(get_item_scope(column.clone()), other_expr.clone());
246
247 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
248 assert_eq!(
249 refs.map(),
250 &HashMap::from_iter([
251 (
252 AccessPath::root_field(column.clone()),
253 HashSet::from_iter([Stat::Max])
254 ),
255 (
256 AccessPath::root_field(other_col.clone()),
257 HashSet::from_iter([Stat::Min])
258 )
259 ])
260 );
261 let expected_expr = lt_eq(
262 get_item_scope(field_path_stat_field_name(
263 &FieldPath::from_name(column),
264 Stat::Max,
265 )),
266 get_item_scope(field_path_stat_field_name(
267 &FieldPath::from_name(other_col),
268 Stat::Min,
269 )),
270 );
271 assert_eq!(&converted, &expected_expr);
272 }
273
274 #[test]
275 pub fn pruning_gt_value() {
276 let column = FieldName::from("a");
277 let other_col = lit(42);
278 let not_eq_expr = gt(get_item_scope(column.clone()), other_col.clone());
279
280 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
281 assert_eq!(
282 refs.map(),
283 &HashMap::from_iter([(
284 AccessPath::root_field(column.clone()),
285 HashSet::from_iter([Stat::Max])
286 ),])
287 );
288 let expected_expr = lt_eq(
289 get_item_scope(field_path_stat_field_name(
290 &FieldPath::from_name(column),
291 Stat::Max,
292 )),
293 other_col.clone(),
294 );
295 assert_eq!(&converted, &(expected_expr));
296 }
297
298 #[test]
299 pub fn pruning_lt_column() {
300 let column = FieldName::from("a");
301 let other_col = FieldName::from("b");
302 let other_expr = get_item_scope(other_col.clone());
303 let not_eq_expr = lt(get_item_scope(column.clone()), other_expr.clone());
304
305 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
306 assert_eq!(
307 refs.map(),
308 &HashMap::from_iter([
309 (
310 AccessPath::root_field(column.clone()),
311 HashSet::from_iter([Stat::Min])
312 ),
313 (
314 AccessPath::root_field(other_col.clone()),
315 HashSet::from_iter([Stat::Max])
316 )
317 ])
318 );
319 let expected_expr = gt_eq(
320 get_item_scope(field_path_stat_field_name(
321 &FieldPath::from_name(column),
322 Stat::Min,
323 )),
324 get_item_scope(field_path_stat_field_name(
325 &FieldPath::from_name(other_col),
326 Stat::Max,
327 )),
328 );
329 assert_eq!(&converted, &expected_expr);
330 }
331
332 #[test]
333 pub fn pruning_lt_value() {
334 let column = FieldName::from("a");
335 let other_col = lit(42);
336 let not_eq_expr = lt(get_item_scope(column.clone()), other_col.clone());
337
338 let (converted, refs) = pruning_expr(¬_eq_expr).unwrap();
339 assert_eq!(
340 refs.map(),
341 &HashMap::from_iter([(
342 AccessPath::root_field(column.clone()),
343 HashSet::from_iter([Stat::Min])
344 )])
345 );
346 let expected_expr = gt_eq(
347 get_item_scope(field_path_stat_field_name(
348 &FieldPath::from_name(column),
349 Stat::Min,
350 )),
351 other_col.clone(),
352 );
353 assert_eq!(&converted, &expected_expr);
354 }
355
356 #[test]
357 fn pruning_identity() {
358 let expr = or(lt(root().clone(), lit(10)), gt(root().clone(), lit(50)));
359
360 let (predicate, _) = pruning_expr(&expr).unwrap();
361
362 let expected_expr = and(
363 gt_eq(get_item_scope(FieldName::from("min")), lit(10)),
364 lt_eq(get_item_scope(FieldName::from("max")), lit(50)),
365 );
366 assert_eq!(&predicate, &expected_expr)
367 }
368 #[test]
369 pub fn pruning_and_or_operators() {
370 let column = FieldName::from("a");
372 let and_expr = and(
373 gt(get_item_scope(column.clone()), lit(10)),
374 lt(get_item_scope(column), lit(50)),
375 );
376 let (predicate, _) = pruning_expr(&and_expr).unwrap();
377
378 assert_eq!(
380 &predicate,
381 &or(
382 lt_eq(get_item_scope(FieldName::from("a_max")), lit(10)),
383 gt_eq(get_item_scope(FieldName::from("a_min")), lit(50))
384 ),
385 );
386 }
387
388 #[test]
389 fn test_gt_eq_with_booleans() {
390 let expr = gt_eq(col("x"), gt(col("y"), col("z")));
427 assert!(pruning_expr(&expr).is_none());
428 }
430}