1use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use crate::expr::Expression;
12use crate::expr::exprs::get_item::get_item;
13use crate::expr::exprs::pack::pack;
14use crate::expr::exprs::root::root;
15use crate::expr::transform::annotations::{
16 Annotation, AnnotationFn, Annotations, descendent_annotations,
17};
18use crate::expr::transform::simplify_typed::simplify_typed;
19use crate::expr::traversal::{NodeExt, NodeRewriter, Transformed, TraversalOrder};
20
21pub fn partition<A: AnnotationFn>(
33 expr: Expression,
34 scope: &DType,
35 annotate_fn: A,
36) -> VortexResult<PartitionedExpr<A::Annotation>>
37where
38 A::Annotation: Display,
39 FieldName: From<A::Annotation>,
40{
41 let annotations = descendent_annotations(&expr, annotate_fn);
43
44 let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
47 let root = expr.clone().rewrite(&mut splitter)?.value;
48
49 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
50 let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
51 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
52
53 for (annotation, exprs) in splitter.sub_expressions.into_iter() {
54 let expr = pack(
56 exprs.into_iter().enumerate().map(|(idx, expr)| {
57 (
58 StructFieldExpressionSplitter::field_name(&annotation, idx),
59 expr,
60 )
61 }),
62 Nullability::NonNullable,
63 );
64
65 let expr = simplify_typed(expr.clone(), scope)?;
66 let expr_dtype = expr.return_dtype(scope)?;
67
68 partitions.push(expr);
69 partition_annotations.push(annotation);
70 partition_dtypes.push(expr_dtype);
71 }
72
73 let partition_names = partition_annotations
74 .iter()
75 .map(|id| FieldName::from(id.clone()))
76 .collect::<FieldNames>();
77 let root_scope = DType::Struct(
78 StructFields::new(partition_names.clone(), partition_dtypes.clone()),
79 Nullability::NonNullable,
80 );
81
82 Ok(PartitionedExpr {
83 root: simplify_typed(root, &root_scope)?,
84 partitions: partitions.into_boxed_slice(),
85 partition_names,
86 partition_dtypes: partition_dtypes.into_boxed_slice(),
87 partition_annotations: partition_annotations.into_boxed_slice(),
88 })
89}
90
91#[derive(Debug)]
93pub struct PartitionedExpr<A> {
94 pub root: Expression,
96 pub partitions: Box<[Expression]>,
98 pub partition_names: FieldNames,
100 pub partition_dtypes: Box<[DType]>,
102 pub partition_annotations: Box<[A]>,
104}
105
106impl<A: Display> Display for PartitionedExpr<A> {
107 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108 write!(
109 f,
110 "root: {} {{{}}}",
111 self.root,
112 self.partition_names
113 .iter()
114 .zip(self.partitions.iter())
115 .map(|(name, partition)| format!("{name}: {partition}"))
116 .join(", ")
117 )
118 }
119}
120
121impl<A: Annotation> PartitionedExpr<A>
122where
123 FieldName: From<A>,
124{
125 pub fn find_partition(&self, id: &A) -> Option<&Expression> {
128 let id = FieldName::from(id.clone());
129 self.partition_names
130 .iter()
131 .position(|field| field == id)
132 .map(|idx| &self.partitions[idx])
133 }
134}
135
136#[derive(Debug)]
137struct StructFieldExpressionSplitter<'a, A: Annotation> {
138 annotations: &'a Annotations<'a, A>,
139 sub_expressions: HashMap<A, Vec<Expression>>,
140}
141
142impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
143 fn new(annotations: &'a Annotations<'a, A>) -> Self {
144 Self {
145 sub_expressions: HashMap::new(),
146 annotations,
147 }
148 }
149
150 fn field_name(annotation: &A, idx: usize) -> FieldName {
153 format!("{annotation}_{idx}").into()
154 }
155}
156
157impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
158where
159 FieldName: From<A>,
160{
161 type NodeTy = Expression;
162
163 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
164 match self.annotations.get(&node) {
165 Some(annotations) if annotations.len() == 1 => {
167 let annotation = annotations
168 .iter()
169 .next()
170 .vortex_expect("expected one field");
171 let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
172 let idx = sub_exprs.len();
173 sub_exprs.push(node.clone());
174 let value = get_item(
175 StructFieldExpressionSplitter::field_name(annotation, idx),
176 get_item(FieldName::from(annotation.clone()), root()),
177 );
178 Ok(Transformed {
179 value,
180 changed: true,
181 order: TraversalOrder::Skip,
182 })
183 }
184
185 _ => Ok(Transformed::no(node)),
187 }
188 }
189
190 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
191 Ok(Transformed::no(node))
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use rstest::{fixture, rstest};
198 use vortex_dtype::Nullability::NonNullable;
199 use vortex_dtype::PType::I32;
200 use vortex_dtype::{DType, StructFields};
201
202 use super::*;
203 use crate::expr::exprs::binary::and;
204 use crate::expr::exprs::get_item::{col, get_item};
205 use crate::expr::exprs::literal::lit;
206 use crate::expr::exprs::merge::merge;
207 use crate::expr::exprs::pack::pack;
208 use crate::expr::exprs::root::root;
209 use crate::expr::exprs::select::select;
210 use crate::expr::transform::immediate_access::annotate_scope_access;
211 use crate::expr::transform::replace::replace_root_fields;
212 use crate::expr::transform::simplify::simplify;
213 use crate::expr::transform::simplify_typed::simplify_typed;
214
215 #[fixture]
216 fn dtype() -> DType {
217 DType::Struct(
218 StructFields::from_iter([
219 (
220 "a",
221 DType::Struct(
222 StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
223 NonNullable,
224 ),
225 ),
226 ("b", I32.into()),
227 ("c", I32.into()),
228 ]),
229 NonNullable,
230 )
231 }
232
233 #[rstest]
234 fn test_expr_top_level_ref(dtype: DType) {
235 let fields = dtype.as_struct_fields_opt().unwrap();
236
237 let expr = root();
238 let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
239
240 assert_eq!(partitioned.partitions.len(), 0);
242 assert_eq!(&partitioned.root, &root());
243
244 let expr = replace_root_fields(expr, fields);
246 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
247
248 assert_eq!(partitioned.partitions.len(), fields.names().len());
249 }
250
251 #[rstest]
252 fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
253 let fields = dtype.as_struct_fields_opt().unwrap();
254
255 let expr = get_item("y", get_item("a", root()));
256
257 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
258 assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
259 }
260
261 #[rstest]
262 fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
263 let fields = dtype.as_struct_fields_opt().unwrap();
264
265 let expr = pack(
266 [
267 ("x", get_item("x", get_item("a", root()))),
268 ("y", get_item("y", get_item("a", root()))),
269 ("c", get_item("c", root())),
270 ],
271 NonNullable,
272 );
273 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
274
275 let split_a = partitioned.find_partition(&"a".into()).unwrap();
276 assert_eq!(
277 &simplify(split_a.clone()).unwrap(),
278 &pack(
279 [
280 ("a_0", get_item("x", get_item("a", root()))),
281 ("a_1", get_item("y", get_item("a", root())))
282 ],
283 NonNullable
284 )
285 );
286 }
287
288 #[rstest]
289 fn test_expr_top_level_ref_get_item_add(dtype: DType) {
290 let fields = dtype.as_struct_fields_opt().unwrap();
291
292 let expr = and(get_item("y", get_item("a", root())), lit(1));
293 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
294
295 assert_eq!(partitioned.partitions.len(), 1);
297 }
298
299 #[rstest]
300 fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
301 let fields = dtype.as_struct_fields_opt().unwrap();
302
303 let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
304 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
305
306 assert_eq!(partitioned.partitions.len(), 2);
308 }
309
310 #[rstest]
312 fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
313 let fields = dtype.as_struct_fields_opt().unwrap();
314
315 let expr = and(
316 get_item("y", get_item("a", root())),
317 select(["a", "b"], root()),
318 );
319 let expr = simplify_typed(expr, &dtype).unwrap();
320 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
321
322 assert_eq!(partitioned.partitions.len(), 2);
324
325 assert_eq!(
328 &partitioned.root,
329 &and(
330 get_item("a_0", get_item("a", root())),
331 pack(
332 [
333 (
334 "a",
335 get_item(
336 StructFieldExpressionSplitter::<FieldName>::field_name(
337 &"a".into(),
338 1
339 ),
340 get_item("a", root())
341 )
342 ),
343 ("b", get_item("b_0", get_item("b", root())))
344 ],
345 NonNullable
346 )
347 )
348 )
349 }
350
351 #[rstest]
352 fn test_expr_merge(dtype: DType) {
353 let fields = dtype.as_struct_fields_opt().unwrap();
354
355 let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
356
357 let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
358 let expected = pack(
359 [
360 ("x", get_item("x", get_item("a_0", col("a")))),
361 ("y", get_item("y", get_item("a_0", col("a")))),
362 ("b", get_item("b", get_item("b_0", col("b")))),
363 ],
364 NonNullable,
365 );
366 assert_eq!(
367 &partitioned.root, &expected,
368 "{} {}",
369 partitioned.root, expected
370 );
371
372 assert_eq!(partitioned.partitions.len(), 2);
373
374 let part_a = partitioned.find_partition(&"a".into()).unwrap();
375 let expected_a = pack([("a_0", col("a"))], NonNullable);
376 assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
377
378 let part_b = partitioned.find_partition(&"b".into()).unwrap();
379 let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
380 assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
381 }
382}