1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldNames;
11use vortex_dtype::Nullability;
12use vortex_dtype::StructFields;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::aliases::hash_map::HashMap;
16
17use crate::expr::Expression;
18use crate::expr::analysis::Annotation;
19use crate::expr::analysis::AnnotationFn;
20use crate::expr::analysis::Annotations;
21use crate::expr::analysis::descendent_annotations;
22use crate::expr::exprs::get_item::get_item;
23use crate::expr::exprs::pack::pack;
24use crate::expr::exprs::root::root;
25use crate::expr::transform::ExprOptimizer;
26use crate::expr::traversal::NodeExt;
27use crate::expr::traversal::NodeRewriter;
28use crate::expr::traversal::Transformed;
29use crate::expr::traversal::TraversalOrder;
30
31pub fn partition<A: AnnotationFn>(
43 expr: Expression,
44 scope: &DType,
45 annotate_fn: A,
46 optimizer: &ExprOptimizer,
47) -> VortexResult<PartitionedExpr<A::Annotation>>
48where
49 A::Annotation: Display,
50 FieldName: From<A::Annotation>,
51{
52 let annotations = descendent_annotations(&expr, annotate_fn);
54
55 let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
58 let root = expr.clone().rewrite(&mut splitter)?.value;
59
60 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
61 let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
62 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
63
64 for (annotation, exprs) in splitter.sub_expressions.into_iter() {
65 let expr = pack(
67 exprs.into_iter().enumerate().map(|(idx, expr)| {
68 (
69 StructFieldExpressionSplitter::field_name(&annotation, idx),
70 expr,
71 )
72 }),
73 Nullability::NonNullable,
74 );
75
76 let expr = optimizer.optimize_typed(expr.clone(), scope)?;
77 let expr_dtype = expr.return_dtype(scope)?;
78
79 partitions.push(expr);
80 partition_annotations.push(annotation);
81 partition_dtypes.push(expr_dtype);
82 }
83
84 let partition_names = partition_annotations
85 .iter()
86 .map(|id| FieldName::from(id.clone()))
87 .collect::<FieldNames>();
88 let root_scope = DType::Struct(
89 StructFields::new(partition_names.clone(), partition_dtypes.clone()),
90 Nullability::NonNullable,
91 );
92
93 Ok(PartitionedExpr {
94 root: optimizer.optimize_typed(root, &root_scope)?,
95 partitions: partitions.into_boxed_slice(),
96 partition_names,
97 partition_dtypes: partition_dtypes.into_boxed_slice(),
98 partition_annotations: partition_annotations.into_boxed_slice(),
99 })
100}
101
102#[derive(Debug)]
104pub struct PartitionedExpr<A> {
105 pub root: Expression,
107 pub partitions: Box<[Expression]>,
109 pub partition_names: FieldNames,
111 pub partition_dtypes: Box<[DType]>,
113 pub partition_annotations: Box<[A]>,
115}
116
117impl<A: Display> Display for PartitionedExpr<A> {
118 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
119 write!(
120 f,
121 "root: {} {{{}}}",
122 self.root,
123 self.partition_names
124 .iter()
125 .zip(self.partitions.iter())
126 .map(|(name, partition)| format!("{name}: {partition}"))
127 .join(", ")
128 )
129 }
130}
131
132impl<A: Annotation> PartitionedExpr<A>
133where
134 FieldName: From<A>,
135{
136 pub fn find_partition(&self, id: &A) -> Option<&Expression> {
139 let id = FieldName::from(id.clone());
140 self.partition_names
141 .iter()
142 .position(|field| field == id)
143 .map(|idx| &self.partitions[idx])
144 }
145}
146
147#[derive(Debug)]
148struct StructFieldExpressionSplitter<'a, A: Annotation> {
149 annotations: &'a Annotations<'a, A>,
150 sub_expressions: HashMap<A, Vec<Expression>>,
151}
152
153impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
154 fn new(annotations: &'a Annotations<'a, A>) -> Self {
155 Self {
156 sub_expressions: HashMap::new(),
157 annotations,
158 }
159 }
160
161 fn field_name(annotation: &A, idx: usize) -> FieldName {
164 format!("{annotation}_{idx}").into()
165 }
166}
167
168impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
169where
170 FieldName: From<A>,
171{
172 type NodeTy = Expression;
173
174 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
175 match self.annotations.get(&node) {
176 Some(annotations) if annotations.len() == 1 => {
178 let annotation = annotations
179 .iter()
180 .next()
181 .vortex_expect("expected one field");
182 let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
183 let idx = sub_exprs.len();
184 sub_exprs.push(node.clone());
185 let value = get_item(
186 StructFieldExpressionSplitter::field_name(annotation, idx),
187 get_item(FieldName::from(annotation.clone()), root()),
188 );
189 Ok(Transformed {
190 value,
191 changed: true,
192 order: TraversalOrder::Skip,
193 })
194 }
195
196 _ => Ok(Transformed::no(node)),
198 }
199 }
200
201 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
202 Ok(Transformed::no(node))
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use rstest::fixture;
209 use rstest::rstest;
210 use vortex_dtype::DType;
211 use vortex_dtype::Nullability::NonNullable;
212 use vortex_dtype::PType::I32;
213 use vortex_dtype::StructFields;
214
215 use super::*;
216 use crate::expr::analysis::annotate_scope_access;
217 use crate::expr::exprs::binary::and;
218 use crate::expr::exprs::get_item::col;
219 use crate::expr::exprs::get_item::get_item;
220 use crate::expr::exprs::literal::lit;
221 use crate::expr::exprs::merge::merge;
222 use crate::expr::exprs::pack::pack;
223 use crate::expr::exprs::root::root;
224 use crate::expr::exprs::select::select;
225 use crate::expr::session::ExprSession;
226 use crate::expr::transform::replace::replace_root_fields;
227 use crate::expr::transform::simplify_typed::simplify_typed;
228
229 #[fixture]
230 fn dtype() -> DType {
231 DType::Struct(
232 StructFields::from_iter([
233 (
234 "a",
235 DType::Struct(
236 StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
237 NonNullable,
238 ),
239 ),
240 ("b", I32.into()),
241 ("c", I32.into()),
242 ]),
243 NonNullable,
244 )
245 }
246
247 #[rstest]
248 fn test_expr_top_level_ref(dtype: DType) {
249 let fields = dtype.as_struct_fields_opt().unwrap();
250 let session = ExprSession::default();
251 let optimizer = ExprOptimizer::new(&session);
252
253 let expr = root();
254 let partitioned = partition(
255 expr.clone(),
256 &dtype,
257 annotate_scope_access(fields),
258 &optimizer,
259 )
260 .unwrap();
261
262 assert_eq!(partitioned.partitions.len(), 0);
264 assert_eq!(&partitioned.root, &root());
265
266 let expr = replace_root_fields(expr, fields);
268 let partitioned =
269 partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
270
271 assert_eq!(partitioned.partitions.len(), fields.names().len());
272 }
273
274 #[rstest]
275 fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
276 let fields = dtype.as_struct_fields_opt().unwrap();
277 let session = ExprSession::default();
278 let optimizer = ExprOptimizer::new(&session);
279
280 let expr = get_item("y", get_item("a", root()));
281
282 let partitioned =
283 partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
284 assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
285 }
286
287 #[rstest]
288 fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
289 let fields = dtype.as_struct_fields_opt().unwrap();
290 let session = ExprSession::default();
291 let optimizer = ExprOptimizer::new(&session);
292
293 let expr = pack(
294 [
295 ("x", get_item("x", get_item("a", root()))),
296 ("y", get_item("y", get_item("a", root()))),
297 ("c", get_item("c", root())),
298 ],
299 NonNullable,
300 );
301 let partitioned =
302 partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
303
304 let split_a = partitioned.find_partition(&"a".into()).unwrap();
305 assert_eq!(
306 &simplify_typed(
307 split_a.clone(),
308 &dtype,
309 ExprSession::default().rewrite_rules()
310 )
311 .unwrap(),
312 &pack(
313 [
314 ("a_0", get_item("x", get_item("a", root()))),
315 ("a_1", get_item("y", get_item("a", root())))
316 ],
317 NonNullable
318 )
319 );
320 }
321
322 #[rstest]
323 fn test_expr_top_level_ref_get_item_add(dtype: DType) {
324 let fields = dtype.as_struct_fields_opt().unwrap();
325 let session = ExprSession::default();
326 let optimizer = ExprOptimizer::new(&session);
327
328 let expr = and(get_item("y", get_item("a", root())), lit(1));
329 let partitioned =
330 partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
331
332 assert_eq!(partitioned.partitions.len(), 1);
334 }
335
336 #[rstest]
337 fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
338 let fields = dtype.as_struct_fields_opt().unwrap();
339 let session = ExprSession::default();
340 let optimizer = ExprOptimizer::new(&session);
341
342 let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
343 let partitioned =
344 partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
345
346 assert_eq!(partitioned.partitions.len(), 2);
348 }
349
350 #[rstest]
352 fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
353 let fields = dtype.as_struct_fields_opt().unwrap();
354 let session = ExprSession::default();
355 let optimizer = ExprOptimizer::new(&session);
356
357 let expr = and(
358 get_item("y", get_item("a", root())),
359 select(["a", "b"], root()),
360 );
361 let expr = simplify_typed(expr, &dtype, ExprSession::default().rewrite_rules()).unwrap();
362 let partitioned =
363 partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
364
365 assert_eq!(partitioned.partitions.len(), 2);
367
368 assert_eq!(
371 &partitioned.root,
372 &and(
373 get_item("a_0", get_item("a", root())),
374 pack(
375 [
376 (
377 "a",
378 get_item(
379 StructFieldExpressionSplitter::<FieldName>::field_name(
380 &"a".into(),
381 1
382 ),
383 get_item("a", root())
384 )
385 ),
386 ("b", get_item("b_0", get_item("b", root())))
387 ],
388 NonNullable
389 )
390 )
391 )
392 }
393
394 #[rstest]
395 fn test_expr_merge(dtype: DType) {
396 let fields = dtype.as_struct_fields_opt().unwrap();
397 let session = ExprSession::default();
398 let optimizer = ExprOptimizer::new(&session);
399
400 let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
401
402 let partitioned =
403 partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
404 let expected = pack(
405 [
406 ("x", get_item("x", get_item("a_0", col("a")))),
407 ("y", get_item("y", get_item("a_0", col("a")))),
408 ("b", get_item("b", get_item("b_0", col("b")))),
409 ],
410 NonNullable,
411 );
412 assert_eq!(
413 &partitioned.root, &expected,
414 "{} {}",
415 partitioned.root, expected
416 );
417
418 assert_eq!(partitioned.partitions.len(), 2);
419
420 let part_a = partitioned.find_partition(&"a".into()).unwrap();
421 let expected_a = pack([("a_0", col("a"))], NonNullable);
422 assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
423
424 let part_b = partitioned.find_partition(&"b".into()).unwrap();
425 let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
426 assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
427 }
428}