1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7
8use itertools::Itertools;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_utils::aliases::hash_map::HashMap;
12
13use crate::dtype::DType;
14use crate::dtype::FieldName;
15use crate::dtype::FieldNames;
16use crate::dtype::Nullability;
17use crate::dtype::StructFields;
18use crate::expr::Expression;
19use crate::expr::analysis::Annotation;
20use crate::expr::analysis::AnnotationFn;
21use crate::expr::analysis::Annotations;
22use crate::expr::analysis::descendent_annotations;
23use crate::expr::get_item;
24use crate::expr::pack;
25use crate::expr::root;
26use crate::expr::traversal::NodeExt;
27use crate::expr::traversal::NodeRewriter;
28use crate::expr::traversal::Transformed;
29use crate::expr::traversal::TraversalOrder;
30
31pub fn partition<A: AnnotationFn>(
43 expr: Expression,
44 scope: &DType,
45 annotate_fn: A,
46) -> VortexResult<PartitionedExpr<A::Annotation>>
47where
48 A::Annotation: Display,
49 FieldName: From<A::Annotation>,
50{
51 let annotations = descendent_annotations(&expr, annotate_fn);
53 partition_annotations(expr.clone(), scope, annotations)
54}
55
56pub fn partition_annotations<A>(
57 expr: Expression,
58 scope: &DType,
59 annotations: Annotations<A>,
60) -> VortexResult<PartitionedExpr<A>>
61where
62 A: Display + Clone + Eq + Hash,
63 FieldName: From<A>,
64{
65 let mut splitter = StructFieldExpressionSplitter::<A>::new(&annotations);
68 let root = expr.rewrite(&mut splitter)?.value;
69
70 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
71 let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
72 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
73
74 for (annotation, exprs) in splitter.sub_expressions.into_iter() {
75 let expr = pack(
77 exprs.into_iter().enumerate().map(|(idx, expr)| {
78 (
79 StructFieldExpressionSplitter::field_name(&annotation, idx),
80 expr,
81 )
82 }),
83 Nullability::NonNullable,
84 );
85
86 let expr = expr.optimize_recursive(scope)?;
87 let expr_dtype = expr.return_dtype(scope)?;
88
89 partitions.push(expr);
90 partition_annotations.push(annotation);
91 partition_dtypes.push(expr_dtype);
92 }
93
94 let partition_names = partition_annotations
95 .iter()
96 .map(|id| FieldName::from(id.clone()))
97 .collect::<FieldNames>();
98 let root_scope = DType::Struct(
99 StructFields::new(partition_names.clone(), partition_dtypes.clone()),
100 Nullability::NonNullable,
101 );
102
103 Ok(PartitionedExpr {
104 root: root.optimize_recursive(&root_scope)?,
105 partitions: partitions.into_boxed_slice(),
106 partition_names,
107 partition_dtypes: partition_dtypes.into_boxed_slice(),
108 partition_annotations: partition_annotations.into_boxed_slice(),
109 })
110}
111
112#[derive(Debug)]
114pub struct PartitionedExpr<A> {
115 pub root: Expression,
117 pub partitions: Box<[Expression]>,
119 pub partition_names: FieldNames,
121 pub partition_dtypes: Box<[DType]>,
123 pub partition_annotations: Box<[A]>,
125}
126
127impl<A: Display> Display for PartitionedExpr<A> {
128 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129 write!(
130 f,
131 "root: {} {{{}}}",
132 self.root,
133 self.partition_names
134 .iter()
135 .zip(self.partitions.iter())
136 .map(|(name, partition)| format!("{name}: {partition}"))
137 .join(", ")
138 )
139 }
140}
141
142impl<A: Annotation> PartitionedExpr<A>
143where
144 FieldName: From<A>,
145{
146 pub fn find_partition(&self, id: &A) -> Option<&Expression> {
149 let id = FieldName::from(id.clone());
150 self.partition_names
151 .iter()
152 .position(|field| field == id)
153 .map(|idx| &self.partitions[idx])
154 }
155}
156
157#[derive(Debug)]
158struct StructFieldExpressionSplitter<'a, A: Annotation> {
159 annotations: &'a Annotations<'a, A>,
160 sub_expressions: HashMap<A, Vec<Expression>>,
161}
162
163impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
164 fn new(annotations: &'a Annotations<'a, A>) -> Self {
165 Self {
166 sub_expressions: HashMap::new(),
167 annotations,
168 }
169 }
170
171 fn field_name(annotation: &A, idx: usize) -> FieldName {
174 format!("{annotation}_{idx}").into()
175 }
176}
177
178impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
179where
180 FieldName: From<A>,
181{
182 type NodeTy = Expression;
183
184 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
185 match self.annotations.get(&node) {
186 Some(annotations) if annotations.len() == 1 => {
188 let annotation = annotations
189 .iter()
190 .next()
191 .vortex_expect("expected one field");
192 let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
193 let idx = sub_exprs.len();
194 sub_exprs.push(node.clone());
195 let value = get_item(
196 StructFieldExpressionSplitter::field_name(annotation, idx),
197 get_item(FieldName::from(annotation.clone()), root()),
198 );
199 Ok(Transformed {
200 value,
201 changed: true,
202 order: TraversalOrder::Skip,
203 })
204 }
205
206 _ => Ok(Transformed::no(node)),
208 }
209 }
210
211 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
212 Ok(Transformed::no(node))
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use rstest::fixture;
219 use rstest::rstest;
220
221 use super::*;
222 use crate::dtype::DType;
223 use crate::dtype::Nullability::NonNullable;
224 use crate::dtype::PType::I32;
225 use crate::dtype::StructFields;
226 use crate::expr::analysis::make_free_field_annotator;
227 use crate::expr::and;
228 use crate::expr::col;
229 use crate::expr::get_item;
230 use crate::expr::lit;
231 use crate::expr::merge;
232 use crate::expr::pack;
233 use crate::expr::root;
234 use crate::expr::transform::replace::replace_root_fields;
235
236 #[fixture]
237 fn dtype() -> DType {
238 DType::Struct(
239 StructFields::from_iter([
240 (
241 "a",
242 DType::Struct(
243 StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
244 NonNullable,
245 ),
246 ),
247 ("b", I32.into()),
248 ("c", I32.into()),
249 ]),
250 NonNullable,
251 )
252 }
253
254 #[rstest]
255 fn test_expr_top_level_ref(dtype: DType) {
256 let fields = dtype.as_struct_fields_opt().unwrap();
257
258 let expr = root();
259 let partitioned =
260 partition(expr.clone(), &dtype, make_free_field_annotator(fields)).unwrap();
261
262 assert_eq!(partitioned.partitions.len(), 0);
264 assert_eq!(&partitioned.root, &root());
265
266 let expr = replace_root_fields(expr, fields);
268 let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
269
270 assert_eq!(partitioned.partitions.len(), fields.names().len());
271 }
272
273 #[rstest]
274 fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
275 let fields = dtype.as_struct_fields_opt().unwrap();
276
277 let expr = get_item("y", get_item("a", root()));
278
279 let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
280 assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
281 }
282
283 #[rstest]
284 fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
285 let fields = dtype.as_struct_fields_opt().unwrap();
286
287 let expr = pack(
288 [
289 ("x", get_item("x", get_item("a", root()))),
290 ("y", get_item("y", get_item("a", root()))),
291 ("c", get_item("c", root())),
292 ],
293 NonNullable,
294 );
295 let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
296
297 let split_a = partitioned.find_partition(&"a".into()).unwrap();
298 assert_eq!(
299 &split_a.optimize_recursive(&dtype).unwrap(),
300 &pack(
301 [
302 ("a_0", get_item("x", get_item("a", root()))),
303 ("a_1", get_item("y", get_item("a", root())))
304 ],
305 NonNullable
306 )
307 );
308 }
309
310 #[rstest]
311 fn test_expr_top_level_ref_get_item_add(dtype: DType) {
312 let fields = dtype.as_struct_fields_opt().unwrap();
313
314 let expr = and(get_item("y", get_item("a", root())), lit(1));
315 let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
316
317 assert_eq!(partitioned.partitions.len(), 1);
319 }
320
321 #[rstest]
322 fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
323 let fields = dtype.as_struct_fields_opt().unwrap();
324
325 let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
326 let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
327
328 assert_eq!(partitioned.partitions.len(), 2);
330 }
331
332 #[rstest]
333 fn test_expr_merge(dtype: DType) {
334 let fields = dtype.as_struct_fields_opt().unwrap();
335
336 let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
337
338 let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
339 let expected = pack(
340 [
341 ("x", get_item("x", get_item("a_0", col("a")))),
342 ("y", get_item("y", get_item("a_0", col("a")))),
343 ("b", get_item("b", get_item("b_0", col("b")))),
344 ],
345 NonNullable,
346 );
347 assert_eq!(
348 &partitioned.root, &expected,
349 "{} {}",
350 partitioned.root, expected
351 );
352
353 assert_eq!(partitioned.partitions.len(), 2);
354
355 let part_a = partitioned.find_partition(&"a".into()).unwrap();
356 let expected_a = pack([("a_0", col("a"))], NonNullable);
357 assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
358
359 let part_b = partitioned.find_partition(&"b".into()).unwrap();
360 let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
361 assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
362 }
363}