vortex_array/expr/transform/
partition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldNames;
11use vortex_dtype::Nullability;
12use vortex_dtype::StructFields;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::aliases::hash_map::HashMap;
16
17use crate::expr::Expression;
18use crate::expr::analysis::Annotation;
19use crate::expr::analysis::AnnotationFn;
20use crate::expr::analysis::Annotations;
21use crate::expr::analysis::descendent_annotations;
22use crate::expr::exprs::get_item::get_item;
23use crate::expr::exprs::pack::pack;
24use crate::expr::exprs::root::root;
25use crate::expr::traversal::NodeExt;
26use crate::expr::traversal::NodeRewriter;
27use crate::expr::traversal::Transformed;
28use crate::expr::traversal::TraversalOrder;
29
30/// Partition an expression into sub-expressions that are uniquely associated with an annotation.
31/// A root expression is also returned that can be used to recombine the results of the partitions
32/// into the result of the original expression.
33///
34/// ## Note
35///
36/// This function currently respects the validity of each field in the scope, but the not validity
37/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
38/// expression for computing the validity, or to include that expression as part of the root.
39///
40/// See <https://github.com/vortex-data/vortex/issues/1907>.
41pub fn partition<A: AnnotationFn>(
42    expr: Expression,
43    scope: &DType,
44    annotate_fn: A,
45) -> VortexResult<PartitionedExpr<A::Annotation>>
46where
47    A::Annotation: Display,
48    FieldName: From<A::Annotation>,
49{
50    // Annotate each expression with the annotations that any of its descendent expressions have.
51    let annotations = descendent_annotations(&expr, annotate_fn);
52
53    // Now we split the original expression into sub-expressions based on the annotations, and
54    // generate a root expression to re-assemble the results.
55    let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
56    let root = expr.clone().rewrite(&mut splitter)?.value;
57
58    let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
59    let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
60    let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
61
62    for (annotation, exprs) in splitter.sub_expressions.into_iter() {
63        // We pack all sub-expressions for the same annotation into a single expression.
64        let expr = pack(
65            exprs.into_iter().enumerate().map(|(idx, expr)| {
66                (
67                    StructFieldExpressionSplitter::field_name(&annotation, idx),
68                    expr,
69                )
70            }),
71            Nullability::NonNullable,
72        );
73
74        let expr = expr.optimize_recursive(scope)?;
75        let expr_dtype = expr.return_dtype(scope)?;
76
77        partitions.push(expr);
78        partition_annotations.push(annotation);
79        partition_dtypes.push(expr_dtype);
80    }
81
82    let partition_names = partition_annotations
83        .iter()
84        .map(|id| FieldName::from(id.clone()))
85        .collect::<FieldNames>();
86    let root_scope = DType::Struct(
87        StructFields::new(partition_names.clone(), partition_dtypes.clone()),
88        Nullability::NonNullable,
89    );
90
91    Ok(PartitionedExpr {
92        root: root.optimize_recursive(&root_scope)?,
93        partitions: partitions.into_boxed_slice(),
94        partition_names,
95        partition_dtypes: partition_dtypes.into_boxed_slice(),
96        partition_annotations: partition_annotations.into_boxed_slice(),
97    })
98}
99
100/// The result of partitioning an expression.
101#[derive(Debug)]
102pub struct PartitionedExpr<A> {
103    /// The root expression used to re-assemble the results.
104    pub root: Expression,
105    /// The partition expressions themselves.
106    pub partitions: Box<[Expression]>,
107    /// The field name of each partition as referenced in the root expression.
108    pub partition_names: FieldNames,
109    /// The return dtype of each partition expression.
110    pub partition_dtypes: Box<[DType]>,
111    /// The annotation associated with each partition.
112    pub partition_annotations: Box<[A]>,
113}
114
115impl<A: Display> Display for PartitionedExpr<A> {
116    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117        write!(
118            f,
119            "root: {} {{{}}}",
120            self.root,
121            self.partition_names
122                .iter()
123                .zip(self.partitions.iter())
124                .map(|(name, partition)| format!("{name}: {partition}"))
125                .join(", ")
126        )
127    }
128}
129
130impl<A: Annotation> PartitionedExpr<A>
131where
132    FieldName: From<A>,
133{
134    /// Return the partition for a given field, if it exists.
135    // FIXME(ngates): this should return an iterator since an annotation may have multiple partitions.
136    pub fn find_partition(&self, id: &A) -> Option<&Expression> {
137        let id = FieldName::from(id.clone());
138        self.partition_names
139            .iter()
140            .position(|field| field == id)
141            .map(|idx| &self.partitions[idx])
142    }
143}
144
145#[derive(Debug)]
146struct StructFieldExpressionSplitter<'a, A: Annotation> {
147    annotations: &'a Annotations<'a, A>,
148    sub_expressions: HashMap<A, Vec<Expression>>,
149}
150
151impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
152    fn new(annotations: &'a Annotations<'a, A>) -> Self {
153        Self {
154            sub_expressions: HashMap::new(),
155            annotations,
156        }
157    }
158
159    /// Each annotation may be associated with multiple sub-expressions, so we need to
160    /// a unique name for each sub-expression.
161    fn field_name(annotation: &A, idx: usize) -> FieldName {
162        format!("{annotation}_{idx}").into()
163    }
164}
165
166impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
167where
168    FieldName: From<A>,
169{
170    type NodeTy = Expression;
171
172    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
173        match self.annotations.get(&node) {
174            // If this expression only accesses a single field, then we can skip the children
175            Some(annotations) if annotations.len() == 1 => {
176                let annotation = annotations
177                    .iter()
178                    .next()
179                    .vortex_expect("expected one field");
180                let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
181                let idx = sub_exprs.len();
182                sub_exprs.push(node.clone());
183                let value = get_item(
184                    StructFieldExpressionSplitter::field_name(annotation, idx),
185                    get_item(FieldName::from(annotation.clone()), root()),
186                );
187                Ok(Transformed {
188                    value,
189                    changed: true,
190                    order: TraversalOrder::Skip,
191                })
192            }
193
194            // Otherwise, continue traversing.
195            _ => Ok(Transformed::no(node)),
196        }
197    }
198
199    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
200        Ok(Transformed::no(node))
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use rstest::fixture;
207    use rstest::rstest;
208    use vortex_dtype::DType;
209    use vortex_dtype::Nullability::NonNullable;
210    use vortex_dtype::PType::I32;
211    use vortex_dtype::StructFields;
212
213    use super::*;
214    use crate::expr::analysis::annotate_scope_access;
215    use crate::expr::exprs::binary::and;
216    use crate::expr::exprs::get_item::col;
217    use crate::expr::exprs::get_item::get_item;
218    use crate::expr::exprs::literal::lit;
219    use crate::expr::exprs::merge::merge;
220    use crate::expr::exprs::pack::pack;
221    use crate::expr::exprs::root::root;
222    use crate::expr::exprs::select::select;
223    use crate::expr::transform::replace::replace_root_fields;
224
225    #[fixture]
226    fn dtype() -> DType {
227        DType::Struct(
228            StructFields::from_iter([
229                (
230                    "a",
231                    DType::Struct(
232                        StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
233                        NonNullable,
234                    ),
235                ),
236                ("b", I32.into()),
237                ("c", I32.into()),
238            ]),
239            NonNullable,
240        )
241    }
242
243    #[rstest]
244    fn test_expr_top_level_ref(dtype: DType) {
245        let fields = dtype.as_struct_fields_opt().unwrap();
246
247        let expr = root();
248        let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
249
250        // An un-expanded root expression is annotated by all fields, but since it is a single node
251        assert_eq!(partitioned.partitions.len(), 0);
252        assert_eq!(&partitioned.root, &root());
253
254        // Instead, callers must expand the root expression themselves.
255        let expr = replace_root_fields(expr, fields);
256        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
257
258        assert_eq!(partitioned.partitions.len(), fields.names().len());
259    }
260
261    #[rstest]
262    fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
263        let fields = dtype.as_struct_fields_opt().unwrap();
264
265        let expr = get_item("y", get_item("a", root()));
266
267        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
268        assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
269    }
270
271    #[rstest]
272    fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
273        let fields = dtype.as_struct_fields_opt().unwrap();
274
275        let expr = pack(
276            [
277                ("x", get_item("x", get_item("a", root()))),
278                ("y", get_item("y", get_item("a", root()))),
279                ("c", get_item("c", root())),
280            ],
281            NonNullable,
282        );
283        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
284
285        let split_a = partitioned.find_partition(&"a".into()).unwrap();
286        assert_eq!(
287            &split_a.optimize_recursive(&dtype).unwrap(),
288            &pack(
289                [
290                    ("a_0", get_item("x", get_item("a", root()))),
291                    ("a_1", get_item("y", get_item("a", root())))
292                ],
293                NonNullable
294            )
295        );
296    }
297
298    #[rstest]
299    fn test_expr_top_level_ref_get_item_add(dtype: DType) {
300        let fields = dtype.as_struct_fields_opt().unwrap();
301
302        let expr = and(get_item("y", get_item("a", root())), lit(1));
303        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
304
305        // Whole expr is a single split
306        assert_eq!(partitioned.partitions.len(), 1);
307    }
308
309    #[rstest]
310    fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
311        let fields = dtype.as_struct_fields_opt().unwrap();
312
313        let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
314        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
315
316        // One for id.a and id.b
317        assert_eq!(partitioned.partitions.len(), 2);
318    }
319
320    // Test that typed_simplify removes select and partition precise
321    #[rstest]
322    fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
323        let fields = dtype.as_struct_fields_opt().unwrap();
324
325        let expr = and(
326            get_item("y", get_item("a", root())),
327            select(["a", "b"], root()),
328        );
329        let expr = expr.optimize_recursive(&dtype).unwrap();
330        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
331
332        // One for id.a and id.b
333        assert_eq!(partitioned.partitions.len(), 2);
334
335        // This fetches [].$c which is unused, however a previous optimisation should replace select
336        // with get_item and pack removing this field.
337        assert_eq!(
338            &partitioned.root,
339            &and(
340                get_item("a_0", get_item("a", root())),
341                pack(
342                    [
343                        (
344                            "a",
345                            get_item(
346                                StructFieldExpressionSplitter::<FieldName>::field_name(
347                                    &"a".into(),
348                                    1
349                                ),
350                                get_item("a", root())
351                            )
352                        ),
353                        ("b", get_item("b_0", get_item("b", root())))
354                    ],
355                    NonNullable
356                )
357            )
358        )
359    }
360
361    #[rstest]
362    fn test_expr_merge(dtype: DType) {
363        let fields = dtype.as_struct_fields_opt().unwrap();
364
365        let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
366
367        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
368        let expected = pack(
369            [
370                ("x", get_item("x", get_item("a_0", col("a")))),
371                ("y", get_item("y", get_item("a_0", col("a")))),
372                ("b", get_item("b", get_item("b_0", col("b")))),
373            ],
374            NonNullable,
375        );
376        assert_eq!(
377            &partitioned.root, &expected,
378            "{} {}",
379            partitioned.root, expected
380        );
381
382        assert_eq!(partitioned.partitions.len(), 2);
383
384        let part_a = partitioned.find_partition(&"a".into()).unwrap();
385        let expected_a = pack([("a_0", col("a"))], NonNullable);
386        assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
387
388        let part_b = partitioned.find_partition(&"b".into()).unwrap();
389        let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
390        assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
391    }
392}