1use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use crate::transform::annotations::{
12 Annotation, AnnotationFn, Annotations, descendent_annotations,
13};
14use crate::transform::simplify_typed::simplify_typed;
15use crate::traversal::{NodeExt, NodeRewriter, Transformed, TraversalOrder};
16use crate::{ExprRef, get_item, pack, root};
17
18pub fn partition<A: AnnotationFn>(
30 expr: ExprRef,
31 scope: &DType,
32 annotate_fn: A,
33) -> VortexResult<PartitionedExpr<A::Annotation>>
34where
35 A::Annotation: Display,
36{
37 let annotations = descendent_annotations(&expr, annotate_fn);
39
40 let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
43 let root = expr.clone().rewrite(&mut splitter)?.value;
44
45 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
46 let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
47 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
48
49 for (annotation, exprs) in splitter.sub_expressions.into_iter() {
50 let expr = pack(
52 exprs.into_iter().enumerate().map(|(idx, expr)| {
53 (
54 StructFieldExpressionSplitter::field_name(&annotation, idx),
55 expr,
56 )
57 }),
58 Nullability::NonNullable,
59 );
60
61 let expr = simplify_typed(expr.clone(), scope)?;
62 let expr_dtype = expr.return_dtype(scope)?;
63
64 partitions.push(expr);
65 partition_annotations.push(annotation);
66 partition_dtypes.push(expr_dtype);
67 }
68
69 let partition_names =
70 FieldNames::from_iter(partition_annotations.iter().map(|id| id.to_string()));
71 let root_scope = DType::Struct(
72 StructFields::new(partition_names.clone(), partition_dtypes.clone()),
73 Nullability::NonNullable,
74 );
75
76 Ok(PartitionedExpr {
77 root: simplify_typed(root, &root_scope)?,
78 partitions: partitions.into_boxed_slice(),
79 partition_names,
80 partition_dtypes: partition_dtypes.into_boxed_slice(),
81 partition_annotations: partition_annotations.into_boxed_slice(),
82 })
83}
84
85#[derive(Debug)]
87pub struct PartitionedExpr<A> {
88 pub root: ExprRef,
90 pub partitions: Box<[ExprRef]>,
92 pub partition_names: FieldNames,
94 pub partition_dtypes: Box<[DType]>,
96 pub partition_annotations: Box<[A]>,
98}
99
100impl<A: Display> Display for PartitionedExpr<A> {
101 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102 write!(
103 f,
104 "root: {} {{{}}}",
105 self.root,
106 self.partition_names
107 .iter()
108 .zip(self.partitions.iter())
109 .map(|(name, partition)| format!("{name}: {partition}"))
110 .join(", ")
111 )
112 }
113}
114
115impl<A: Annotation + Display> PartitionedExpr<A> {
116 pub fn find_partition(&self, id: &A) -> Option<&ExprRef> {
119 let id = FieldName::from(id.to_string());
120 self.partition_names
121 .iter()
122 .position(|field| field == &id)
123 .map(|idx| &self.partitions[idx])
124 }
125}
126
127#[derive(Debug)]
128struct StructFieldExpressionSplitter<'a, A: Annotation> {
129 annotations: &'a Annotations<'a, A>,
130 sub_expressions: HashMap<A, Vec<ExprRef>>,
131}
132
133impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
134 fn new(annotations: &'a Annotations<'a, A>) -> Self {
135 Self {
136 sub_expressions: HashMap::new(),
137 annotations,
138 }
139 }
140
141 fn field_name(annotation: &A, idx: usize) -> FieldName {
144 format!("{annotation}_{idx}").into()
145 }
146}
147
148impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A> {
149 type NodeTy = ExprRef;
150
151 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
152 match self.annotations.get(&node) {
153 Some(annotations) if annotations.len() == 1 => {
155 let annotation = annotations
156 .iter()
157 .next()
158 .vortex_expect("expected one field");
159 let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
160 let idx = sub_exprs.len();
161 sub_exprs.push(node.clone());
162 let value = get_item(
163 StructFieldExpressionSplitter::field_name(annotation, idx),
164 get_item(FieldName::from(annotation.to_string()), root()),
165 );
166 Ok(Transformed {
167 value,
168 changed: true,
169 order: TraversalOrder::Skip,
170 })
171 }
172
173 _ => Ok(Transformed::no(node)),
175 }
176 }
177
178 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
179 Ok(Transformed::no(node))
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use rstest::{fixture, rstest};
186 use vortex_dtype::Nullability::NonNullable;
187 use vortex_dtype::PType::I32;
188 use vortex_dtype::{DType, StructFields};
189
190 use super::*;
191 use crate::transform::immediate_access::annotate_scope_access;
192 use crate::transform::replace::replace_root_fields;
193 use crate::transform::simplify::simplify;
194 use crate::transform::simplify_typed::simplify_typed;
195 use crate::{and, col, get_item, lit, merge, pack, root, select};
196
197 #[fixture]
198 fn dtype() -> DType {
199 DType::Struct(
200 StructFields::from_iter([
201 (
202 "a",
203 DType::Struct(
204 StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
205 NonNullable,
206 ),
207 ),
208 ("b", I32.into()),
209 ("c", I32.into()),
210 ]),
211 NonNullable,
212 )
213 }
214
215 #[rstest]
216 fn test_expr_top_level_ref(dtype: DType) {
217 let fields = dtype.as_struct().unwrap();
218
219 let expr = root();
220 let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
221
222 assert_eq!(partitioned.partitions.len(), 0);
224 assert_eq!(&partitioned.root, &root());
225
226 let expr = replace_root_fields(expr.clone(), fields);
228 let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
229
230 assert_eq!(partitioned.partitions.len(), fields.names().len());
231 }
232
233 #[rstest]
234 fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
235 let fields = dtype.as_struct().unwrap();
236
237 let expr = get_item("y", get_item("a", root()));
238
239 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
240 assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
241 }
242
243 #[rstest]
244 fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
245 let fields = dtype.as_struct().unwrap();
246
247 let expr = pack(
248 [
249 ("x", get_item("x", get_item("a", root()))),
250 ("y", get_item("y", get_item("a", root()))),
251 ("c", get_item("c", root())),
252 ],
253 NonNullable,
254 );
255 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
256
257 let split_a = partitioned.find_partition(&"a".into()).unwrap();
258 assert_eq!(
259 &simplify(split_a.clone()).unwrap(),
260 &pack(
261 [
262 ("a_0", get_item("x", get_item("a", root()))),
263 ("a_1", get_item("y", get_item("a", root())))
264 ],
265 NonNullable
266 )
267 );
268 }
269
270 #[rstest]
271 fn test_expr_top_level_ref_get_item_add(dtype: DType) {
272 let fields = dtype.as_struct().unwrap();
273
274 let expr = and(get_item("y", get_item("a", root())), lit(1));
275 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
276
277 assert_eq!(partitioned.partitions.len(), 1);
279 }
280
281 #[rstest]
282 fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
283 let fields = dtype.as_struct().unwrap();
284
285 let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
286 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
287
288 assert_eq!(partitioned.partitions.len(), 2);
290 }
291
292 #[rstest]
294 fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
295 let fields = dtype.as_struct().unwrap();
296
297 let expr = and(
298 get_item("y", get_item("a", root())),
299 select(vec!["a".into(), "b".into()], root()),
300 );
301 let expr = simplify_typed(expr, &dtype).unwrap();
302 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
303
304 assert_eq!(partitioned.partitions.len(), 2);
306
307 assert_eq!(
310 &partitioned.root,
311 &and(
312 get_item("a_0", get_item("a", root())),
313 pack(
314 [
315 (
316 "a",
317 get_item(
318 StructFieldExpressionSplitter::<FieldName>::field_name(
319 &"a".into(),
320 1
321 ),
322 get_item("a", root())
323 )
324 ),
325 ("b", get_item("b_0", get_item("b", root())))
326 ],
327 NonNullable
328 )
329 )
330 )
331 }
332
333 #[rstest]
334 fn test_expr_merge(dtype: DType) {
335 let fields = dtype.as_struct().unwrap();
336
337 let expr = merge(
338 [col("a"), pack([("b", col("b"))], NonNullable)],
339 NonNullable,
340 );
341
342 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
343 let expected = pack(
344 [
345 ("x", get_item("x", get_item("a_0", col("a")))),
346 ("y", get_item("y", get_item("a_0", col("a")))),
347 ("b", get_item("b", get_item("b_0", col("b")))),
348 ],
349 NonNullable,
350 );
351 assert_eq!(
352 &partitioned.root, &expected,
353 "{} {}",
354 partitioned.root, expected
355 );
356
357 assert_eq!(partitioned.partitions.len(), 2);
358
359 let part_a = partitioned.find_partition(&"a".into()).unwrap();
360 let expected_a = pack([("a_0", col("a"))], NonNullable);
361 assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
362
363 let part_b = partitioned.find_partition(&"b".into()).unwrap();
364 let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
365 assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
366 }
367}