1use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use crate::transform::annotations::{
12 Annotation, AnnotationFn, Annotations, descendent_annotations,
13};
14use crate::transform::simplify_typed::simplify_typed;
15use crate::traversal::{NodeExt, NodeRewriter, Transformed, TraversalOrder};
16use crate::{ExprRef, get_item, pack, root};
17
18pub fn partition<A: AnnotationFn>(
30 expr: ExprRef,
31 scope: &DType,
32 annotate_fn: A,
33) -> VortexResult<PartitionedExpr<A::Annotation>>
34where
35 A::Annotation: Display,
36 FieldName: From<A::Annotation>,
37{
38 let annotations = descendent_annotations(&expr, annotate_fn);
40
41 let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
44 let root = expr.clone().rewrite(&mut splitter)?.value;
45
46 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
47 let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
48 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
49
50 for (annotation, exprs) in splitter.sub_expressions.into_iter() {
51 let expr = pack(
53 exprs.into_iter().enumerate().map(|(idx, expr)| {
54 (
55 StructFieldExpressionSplitter::field_name(&annotation, idx),
56 expr,
57 )
58 }),
59 Nullability::NonNullable,
60 );
61
62 let expr = simplify_typed(expr.clone(), scope)?;
63 let expr_dtype = expr.return_dtype(scope)?;
64
65 partitions.push(expr);
66 partition_annotations.push(annotation);
67 partition_dtypes.push(expr_dtype);
68 }
69
70 let partition_names = partition_annotations
71 .iter()
72 .map(|id| FieldName::from(id.clone()))
73 .collect::<FieldNames>();
74 let root_scope = DType::Struct(
75 StructFields::new(partition_names.clone(), partition_dtypes.clone()),
76 Nullability::NonNullable,
77 );
78
79 Ok(PartitionedExpr {
80 root: simplify_typed(root, &root_scope)?,
81 partitions: partitions.into_boxed_slice(),
82 partition_names,
83 partition_dtypes: partition_dtypes.into_boxed_slice(),
84 partition_annotations: partition_annotations.into_boxed_slice(),
85 })
86}
87
88#[derive(Debug)]
90pub struct PartitionedExpr<A> {
91 pub root: ExprRef,
93 pub partitions: Box<[ExprRef]>,
95 pub partition_names: FieldNames,
97 pub partition_dtypes: Box<[DType]>,
99 pub partition_annotations: Box<[A]>,
101}
102
103impl<A: Display> Display for PartitionedExpr<A> {
104 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
105 write!(
106 f,
107 "root: {} {{{}}}",
108 self.root,
109 self.partition_names
110 .iter()
111 .zip(self.partitions.iter())
112 .map(|(name, partition)| format!("{name}: {partition}"))
113 .join(", ")
114 )
115 }
116}
117
118impl<A: Annotation> PartitionedExpr<A>
119where
120 FieldName: From<A>,
121{
122 pub fn find_partition(&self, id: &A) -> Option<&ExprRef> {
125 let id = FieldName::from(id.clone());
126 self.partition_names
127 .iter()
128 .position(|field| field == id)
129 .map(|idx| &self.partitions[idx])
130 }
131}
132
133#[derive(Debug)]
134struct StructFieldExpressionSplitter<'a, A: Annotation> {
135 annotations: &'a Annotations<'a, A>,
136 sub_expressions: HashMap<A, Vec<ExprRef>>,
137}
138
139impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
140 fn new(annotations: &'a Annotations<'a, A>) -> Self {
141 Self {
142 sub_expressions: HashMap::new(),
143 annotations,
144 }
145 }
146
147 fn field_name(annotation: &A, idx: usize) -> FieldName {
150 format!("{annotation}_{idx}").into()
151 }
152}
153
154impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
155where
156 FieldName: From<A>,
157{
158 type NodeTy = ExprRef;
159
160 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
161 match self.annotations.get(&node) {
162 Some(annotations) if annotations.len() == 1 => {
164 let annotation = annotations
165 .iter()
166 .next()
167 .vortex_expect("expected one field");
168 let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
169 let idx = sub_exprs.len();
170 sub_exprs.push(node.clone());
171 let value = get_item(
172 StructFieldExpressionSplitter::field_name(annotation, idx),
173 get_item(FieldName::from(annotation.clone()), root()),
174 );
175 Ok(Transformed {
176 value,
177 changed: true,
178 order: TraversalOrder::Skip,
179 })
180 }
181
182 _ => Ok(Transformed::no(node)),
184 }
185 }
186
187 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
188 Ok(Transformed::no(node))
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use rstest::{fixture, rstest};
195 use vortex_dtype::Nullability::NonNullable;
196 use vortex_dtype::PType::I32;
197 use vortex_dtype::{DType, StructFields};
198
199 use super::*;
200 use crate::transform::immediate_access::annotate_scope_access;
201 use crate::transform::replace::replace_root_fields;
202 use crate::transform::simplify::simplify;
203 use crate::transform::simplify_typed::simplify_typed;
204 use crate::{and, col, get_item, lit, merge, pack, root, select};
205
206 #[fixture]
207 fn dtype() -> DType {
208 DType::Struct(
209 StructFields::from_iter([
210 (
211 "a",
212 DType::Struct(
213 StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
214 NonNullable,
215 ),
216 ),
217 ("b", I32.into()),
218 ("c", I32.into()),
219 ]),
220 NonNullable,
221 )
222 }
223
224 #[rstest]
225 fn test_expr_top_level_ref(dtype: DType) {
226 let fields = dtype.as_struct_fields_opt().unwrap();
227
228 let expr = root();
229 let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
230
231 assert_eq!(partitioned.partitions.len(), 0);
233 assert_eq!(&partitioned.root, &root());
234
235 let expr = replace_root_fields(expr.clone(), fields);
237 let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
238
239 assert_eq!(partitioned.partitions.len(), fields.names().len());
240 }
241
242 #[rstest]
243 fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
244 let fields = dtype.as_struct_fields_opt().unwrap();
245
246 let expr = get_item("y", get_item("a", root()));
247
248 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
249 assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
250 }
251
252 #[rstest]
253 fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
254 let fields = dtype.as_struct_fields_opt().unwrap();
255
256 let expr = pack(
257 [
258 ("x", get_item("x", get_item("a", root()))),
259 ("y", get_item("y", get_item("a", root()))),
260 ("c", get_item("c", root())),
261 ],
262 NonNullable,
263 );
264 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
265
266 let split_a = partitioned.find_partition(&"a".into()).unwrap();
267 assert_eq!(
268 &simplify(split_a.clone()).unwrap(),
269 &pack(
270 [
271 ("a_0", get_item("x", get_item("a", root()))),
272 ("a_1", get_item("y", get_item("a", root())))
273 ],
274 NonNullable
275 )
276 );
277 }
278
279 #[rstest]
280 fn test_expr_top_level_ref_get_item_add(dtype: DType) {
281 let fields = dtype.as_struct_fields_opt().unwrap();
282
283 let expr = and(get_item("y", get_item("a", root())), lit(1));
284 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
285
286 assert_eq!(partitioned.partitions.len(), 1);
288 }
289
290 #[rstest]
291 fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
292 let fields = dtype.as_struct_fields_opt().unwrap();
293
294 let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
295 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
296
297 assert_eq!(partitioned.partitions.len(), 2);
299 }
300
301 #[rstest]
303 fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
304 let fields = dtype.as_struct_fields_opt().unwrap();
305
306 let expr = and(
307 get_item("y", get_item("a", root())),
308 select(["a", "b"], root()),
309 );
310 let expr = simplify_typed(expr, &dtype).unwrap();
311 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
312
313 assert_eq!(partitioned.partitions.len(), 2);
315
316 assert_eq!(
319 &partitioned.root,
320 &and(
321 get_item("a_0", get_item("a", root())),
322 pack(
323 [
324 (
325 "a",
326 get_item(
327 StructFieldExpressionSplitter::<FieldName>::field_name(
328 &"a".into(),
329 1
330 ),
331 get_item("a", root())
332 )
333 ),
334 ("b", get_item("b_0", get_item("b", root())))
335 ],
336 NonNullable
337 )
338 )
339 )
340 }
341
342 #[rstest]
343 fn test_expr_merge(dtype: DType) {
344 let fields = dtype.as_struct_fields_opt().unwrap();
345
346 let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
347
348 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
349 let expected = pack(
350 [
351 ("x", get_item("x", get_item("a_0", col("a")))),
352 ("y", get_item("y", get_item("a_0", col("a")))),
353 ("b", get_item("b", get_item("b_0", col("b")))),
354 ],
355 NonNullable,
356 );
357 assert_eq!(
358 &partitioned.root, &expected,
359 "{} {}",
360 partitioned.root, expected
361 );
362
363 assert_eq!(partitioned.partitions.len(), 2);
364
365 let part_a = partitioned.find_partition(&"a".into()).unwrap();
366 let expected_a = pack([("a_0", col("a"))], NonNullable);
367 assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
368
369 let part_b = partitioned.find_partition(&"b".into()).unwrap();
370 let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
371 assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
372 }
373}