1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldNames;
11use vortex_dtype::Nullability;
12use vortex_dtype::StructFields;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::aliases::hash_map::HashMap;
16
17use crate::expr::Expression;
18use crate::expr::analysis::Annotation;
19use crate::expr::analysis::AnnotationFn;
20use crate::expr::analysis::Annotations;
21use crate::expr::analysis::descendent_annotations;
22use crate::expr::exprs::get_item::get_item;
23use crate::expr::exprs::pack::pack;
24use crate::expr::exprs::root::root;
25use crate::expr::traversal::NodeExt;
26use crate::expr::traversal::NodeRewriter;
27use crate::expr::traversal::Transformed;
28use crate::expr::traversal::TraversalOrder;
29
30pub fn partition<A: AnnotationFn>(
42 expr: Expression,
43 scope: &DType,
44 annotate_fn: A,
45) -> VortexResult<PartitionedExpr<A::Annotation>>
46where
47 A::Annotation: Display,
48 FieldName: From<A::Annotation>,
49{
50 let annotations = descendent_annotations(&expr, annotate_fn);
52
53 let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
56 let root = expr.clone().rewrite(&mut splitter)?.value;
57
58 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
59 let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
60 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
61
62 for (annotation, exprs) in splitter.sub_expressions.into_iter() {
63 let expr = pack(
65 exprs.into_iter().enumerate().map(|(idx, expr)| {
66 (
67 StructFieldExpressionSplitter::field_name(&annotation, idx),
68 expr,
69 )
70 }),
71 Nullability::NonNullable,
72 );
73
74 let expr = expr.optimize_recursive(scope)?;
75 let expr_dtype = expr.return_dtype(scope)?;
76
77 partitions.push(expr);
78 partition_annotations.push(annotation);
79 partition_dtypes.push(expr_dtype);
80 }
81
82 let partition_names = partition_annotations
83 .iter()
84 .map(|id| FieldName::from(id.clone()))
85 .collect::<FieldNames>();
86 let root_scope = DType::Struct(
87 StructFields::new(partition_names.clone(), partition_dtypes.clone()),
88 Nullability::NonNullable,
89 );
90
91 Ok(PartitionedExpr {
92 root: root.optimize_recursive(&root_scope)?,
93 partitions: partitions.into_boxed_slice(),
94 partition_names,
95 partition_dtypes: partition_dtypes.into_boxed_slice(),
96 partition_annotations: partition_annotations.into_boxed_slice(),
97 })
98}
99
100#[derive(Debug)]
102pub struct PartitionedExpr<A> {
103 pub root: Expression,
105 pub partitions: Box<[Expression]>,
107 pub partition_names: FieldNames,
109 pub partition_dtypes: Box<[DType]>,
111 pub partition_annotations: Box<[A]>,
113}
114
115impl<A: Display> Display for PartitionedExpr<A> {
116 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117 write!(
118 f,
119 "root: {} {{{}}}",
120 self.root,
121 self.partition_names
122 .iter()
123 .zip(self.partitions.iter())
124 .map(|(name, partition)| format!("{name}: {partition}"))
125 .join(", ")
126 )
127 }
128}
129
130impl<A: Annotation> PartitionedExpr<A>
131where
132 FieldName: From<A>,
133{
134 pub fn find_partition(&self, id: &A) -> Option<&Expression> {
137 let id = FieldName::from(id.clone());
138 self.partition_names
139 .iter()
140 .position(|field| field == id)
141 .map(|idx| &self.partitions[idx])
142 }
143}
144
145#[derive(Debug)]
146struct StructFieldExpressionSplitter<'a, A: Annotation> {
147 annotations: &'a Annotations<'a, A>,
148 sub_expressions: HashMap<A, Vec<Expression>>,
149}
150
151impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
152 fn new(annotations: &'a Annotations<'a, A>) -> Self {
153 Self {
154 sub_expressions: HashMap::new(),
155 annotations,
156 }
157 }
158
159 fn field_name(annotation: &A, idx: usize) -> FieldName {
162 format!("{annotation}_{idx}").into()
163 }
164}
165
166impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
167where
168 FieldName: From<A>,
169{
170 type NodeTy = Expression;
171
172 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
173 match self.annotations.get(&node) {
174 Some(annotations) if annotations.len() == 1 => {
176 let annotation = annotations
177 .iter()
178 .next()
179 .vortex_expect("expected one field");
180 let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
181 let idx = sub_exprs.len();
182 sub_exprs.push(node.clone());
183 let value = get_item(
184 StructFieldExpressionSplitter::field_name(annotation, idx),
185 get_item(FieldName::from(annotation.clone()), root()),
186 );
187 Ok(Transformed {
188 value,
189 changed: true,
190 order: TraversalOrder::Skip,
191 })
192 }
193
194 _ => Ok(Transformed::no(node)),
196 }
197 }
198
199 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
200 Ok(Transformed::no(node))
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use rstest::fixture;
207 use rstest::rstest;
208 use vortex_dtype::DType;
209 use vortex_dtype::Nullability::NonNullable;
210 use vortex_dtype::PType::I32;
211 use vortex_dtype::StructFields;
212
213 use super::*;
214 use crate::expr::analysis::annotate_scope_access;
215 use crate::expr::exprs::binary::and;
216 use crate::expr::exprs::get_item::col;
217 use crate::expr::exprs::get_item::get_item;
218 use crate::expr::exprs::literal::lit;
219 use crate::expr::exprs::merge::merge;
220 use crate::expr::exprs::pack::pack;
221 use crate::expr::exprs::root::root;
222 use crate::expr::exprs::select::select;
223 use crate::expr::transform::replace::replace_root_fields;
224
225 #[fixture]
226 fn dtype() -> DType {
227 DType::Struct(
228 StructFields::from_iter([
229 (
230 "a",
231 DType::Struct(
232 StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
233 NonNullable,
234 ),
235 ),
236 ("b", I32.into()),
237 ("c", I32.into()),
238 ]),
239 NonNullable,
240 )
241 }
242
243 #[rstest]
244 fn test_expr_top_level_ref(dtype: DType) {
245 let fields = dtype.as_struct_fields_opt().unwrap();
246
247 let expr = root();
248 let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
249
250 assert_eq!(partitioned.partitions.len(), 0);
252 assert_eq!(&partitioned.root, &root());
253
254 let expr = replace_root_fields(expr, fields);
256 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
257
258 assert_eq!(partitioned.partitions.len(), fields.names().len());
259 }
260
261 #[rstest]
262 fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
263 let fields = dtype.as_struct_fields_opt().unwrap();
264
265 let expr = get_item("y", get_item("a", root()));
266
267 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
268 assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
269 }
270
271 #[rstest]
272 fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
273 let fields = dtype.as_struct_fields_opt().unwrap();
274
275 let expr = pack(
276 [
277 ("x", get_item("x", get_item("a", root()))),
278 ("y", get_item("y", get_item("a", root()))),
279 ("c", get_item("c", root())),
280 ],
281 NonNullable,
282 );
283 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
284
285 let split_a = partitioned.find_partition(&"a".into()).unwrap();
286 assert_eq!(
287 &split_a.optimize_recursive(&dtype).unwrap(),
288 &pack(
289 [
290 ("a_0", get_item("x", get_item("a", root()))),
291 ("a_1", get_item("y", get_item("a", root())))
292 ],
293 NonNullable
294 )
295 );
296 }
297
298 #[rstest]
299 fn test_expr_top_level_ref_get_item_add(dtype: DType) {
300 let fields = dtype.as_struct_fields_opt().unwrap();
301
302 let expr = and(get_item("y", get_item("a", root())), lit(1));
303 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
304
305 assert_eq!(partitioned.partitions.len(), 1);
307 }
308
309 #[rstest]
310 fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
311 let fields = dtype.as_struct_fields_opt().unwrap();
312
313 let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
314 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
315
316 assert_eq!(partitioned.partitions.len(), 2);
318 }
319
320 #[rstest]
322 fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
323 let fields = dtype.as_struct_fields_opt().unwrap();
324
325 let expr = and(
326 get_item("y", get_item("a", root())),
327 select(["a", "b"], root()),
328 );
329 let expr = expr.optimize_recursive(&dtype).unwrap();
330 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
331
332 assert_eq!(partitioned.partitions.len(), 2);
334
335 assert_eq!(
338 &partitioned.root,
339 &and(
340 get_item("a_0", get_item("a", root())),
341 pack(
342 [
343 (
344 "a",
345 get_item(
346 StructFieldExpressionSplitter::<FieldName>::field_name(
347 &"a".into(),
348 1
349 ),
350 get_item("a", root())
351 )
352 ),
353 ("b", get_item("b_0", get_item("b", root())))
354 ],
355 NonNullable
356 )
357 )
358 )
359 }
360
361 #[rstest]
362 fn test_expr_merge(dtype: DType) {
363 let fields = dtype.as_struct_fields_opt().unwrap();
364
365 let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
366
367 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
368 let expected = pack(
369 [
370 ("x", get_item("x", get_item("a_0", col("a")))),
371 ("y", get_item("y", get_item("a_0", col("a")))),
372 ("b", get_item("b", get_item("b_0", col("b")))),
373 ],
374 NonNullable,
375 );
376 assert_eq!(
377 &partitioned.root, &expected,
378 "{} {}",
379 partitioned.root, expected
380 );
381
382 assert_eq!(partitioned.partitions.len(), 2);
383
384 let part_a = partitioned.find_partition(&"a".into()).unwrap();
385 let expected_a = pack([("a_0", col("a"))], NonNullable);
386 assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
387
388 let part_b = partitioned.find_partition(&"b".into()).unwrap();
389 let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
390 assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
391 }
392}