vortex_array/expr/transform/
partition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use crate::expr::Expression;
12use crate::expr::exprs::get_item::get_item;
13use crate::expr::exprs::pack::pack;
14use crate::expr::exprs::root::root;
15use crate::expr::transform::annotations::{
16    Annotation, AnnotationFn, Annotations, descendent_annotations,
17};
18use crate::expr::transform::simplify_typed::simplify_typed;
19use crate::expr::traversal::{NodeExt, NodeRewriter, Transformed, TraversalOrder};
20
21/// Partition an expression into sub-expressions that are uniquely associated with an annotation.
22/// A root expression is also returned that can be used to recombine the results of the partitions
23/// into the result of the original expression.
24///
25/// ## Note
26///
27/// This function currently respects the validity of each field in the scope, but the not validity
28/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
29/// expression for computing the validity, or to include that expression as part of the root.
30///
31/// See <https://github.com/vortex-data/vortex/issues/1907>.
32pub fn partition<A: AnnotationFn>(
33    expr: Expression,
34    scope: &DType,
35    annotate_fn: A,
36) -> VortexResult<PartitionedExpr<A::Annotation>>
37where
38    A::Annotation: Display,
39    FieldName: From<A::Annotation>,
40{
41    // Annotate each expression with the annotations that any of its descendent expressions have.
42    let annotations = descendent_annotations(&expr, annotate_fn);
43
44    // Now we split the original expression into sub-expressions based on the annotations, and
45    // generate a root expression to re-assemble the results.
46    let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
47    let root = expr.clone().rewrite(&mut splitter)?.value;
48
49    let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
50    let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
51    let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
52
53    for (annotation, exprs) in splitter.sub_expressions.into_iter() {
54        // We pack all sub-expressions for the same annotation into a single expression.
55        let expr = pack(
56            exprs.into_iter().enumerate().map(|(idx, expr)| {
57                (
58                    StructFieldExpressionSplitter::field_name(&annotation, idx),
59                    expr,
60                )
61            }),
62            Nullability::NonNullable,
63        );
64
65        let expr = simplify_typed(expr.clone(), scope)?;
66        let expr_dtype = expr.return_dtype(scope)?;
67
68        partitions.push(expr);
69        partition_annotations.push(annotation);
70        partition_dtypes.push(expr_dtype);
71    }
72
73    let partition_names = partition_annotations
74        .iter()
75        .map(|id| FieldName::from(id.clone()))
76        .collect::<FieldNames>();
77    let root_scope = DType::Struct(
78        StructFields::new(partition_names.clone(), partition_dtypes.clone()),
79        Nullability::NonNullable,
80    );
81
82    Ok(PartitionedExpr {
83        root: simplify_typed(root, &root_scope)?,
84        partitions: partitions.into_boxed_slice(),
85        partition_names,
86        partition_dtypes: partition_dtypes.into_boxed_slice(),
87        partition_annotations: partition_annotations.into_boxed_slice(),
88    })
89}
90
91/// The result of partitioning an expression.
92#[derive(Debug)]
93pub struct PartitionedExpr<A> {
94    /// The root expression used to re-assemble the results.
95    pub root: Expression,
96    /// The partition expressions themselves.
97    pub partitions: Box<[Expression]>,
98    /// The field name of each partition as referenced in the root expression.
99    pub partition_names: FieldNames,
100    /// The return dtype of each partition expression.
101    pub partition_dtypes: Box<[DType]>,
102    /// The annotation associated with each partition.
103    pub partition_annotations: Box<[A]>,
104}
105
106impl<A: Display> Display for PartitionedExpr<A> {
107    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108        write!(
109            f,
110            "root: {} {{{}}}",
111            self.root,
112            self.partition_names
113                .iter()
114                .zip(self.partitions.iter())
115                .map(|(name, partition)| format!("{name}: {partition}"))
116                .join(", ")
117        )
118    }
119}
120
121impl<A: Annotation> PartitionedExpr<A>
122where
123    FieldName: From<A>,
124{
125    /// Return the partition for a given field, if it exists.
126    // FIXME(ngates): this should return an iterator since an annotation may have multiple partitions.
127    pub fn find_partition(&self, id: &A) -> Option<&Expression> {
128        let id = FieldName::from(id.clone());
129        self.partition_names
130            .iter()
131            .position(|field| field == id)
132            .map(|idx| &self.partitions[idx])
133    }
134}
135
136#[derive(Debug)]
137struct StructFieldExpressionSplitter<'a, A: Annotation> {
138    annotations: &'a Annotations<'a, A>,
139    sub_expressions: HashMap<A, Vec<Expression>>,
140}
141
142impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
143    fn new(annotations: &'a Annotations<'a, A>) -> Self {
144        Self {
145            sub_expressions: HashMap::new(),
146            annotations,
147        }
148    }
149
150    /// Each annotation may be associated with multiple sub-expressions, so we need to
151    /// a unique name for each sub-expression.
152    fn field_name(annotation: &A, idx: usize) -> FieldName {
153        format!("{annotation}_{idx}").into()
154    }
155}
156
157impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
158where
159    FieldName: From<A>,
160{
161    type NodeTy = Expression;
162
163    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
164        match self.annotations.get(&node) {
165            // If this expression only accesses a single field, then we can skip the children
166            Some(annotations) if annotations.len() == 1 => {
167                let annotation = annotations
168                    .iter()
169                    .next()
170                    .vortex_expect("expected one field");
171                let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
172                let idx = sub_exprs.len();
173                sub_exprs.push(node.clone());
174                let value = get_item(
175                    StructFieldExpressionSplitter::field_name(annotation, idx),
176                    get_item(FieldName::from(annotation.clone()), root()),
177                );
178                Ok(Transformed {
179                    value,
180                    changed: true,
181                    order: TraversalOrder::Skip,
182                })
183            }
184
185            // Otherwise, continue traversing.
186            _ => Ok(Transformed::no(node)),
187        }
188    }
189
190    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
191        Ok(Transformed::no(node))
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use rstest::{fixture, rstest};
198    use vortex_dtype::Nullability::NonNullable;
199    use vortex_dtype::PType::I32;
200    use vortex_dtype::{DType, StructFields};
201
202    use super::*;
203    use crate::expr::exprs::binary::and;
204    use crate::expr::exprs::get_item::{col, get_item};
205    use crate::expr::exprs::literal::lit;
206    use crate::expr::exprs::merge::merge;
207    use crate::expr::exprs::pack::pack;
208    use crate::expr::exprs::root::root;
209    use crate::expr::exprs::select::select;
210    use crate::expr::transform::immediate_access::annotate_scope_access;
211    use crate::expr::transform::replace::replace_root_fields;
212    use crate::expr::transform::simplify::simplify;
213    use crate::expr::transform::simplify_typed::simplify_typed;
214
215    #[fixture]
216    fn dtype() -> DType {
217        DType::Struct(
218            StructFields::from_iter([
219                (
220                    "a",
221                    DType::Struct(
222                        StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
223                        NonNullable,
224                    ),
225                ),
226                ("b", I32.into()),
227                ("c", I32.into()),
228            ]),
229            NonNullable,
230        )
231    }
232
233    #[rstest]
234    fn test_expr_top_level_ref(dtype: DType) {
235        let fields = dtype.as_struct_fields_opt().unwrap();
236
237        let expr = root();
238        let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
239
240        // An un-expanded root expression is annotated by all fields, but since it is a single node
241        assert_eq!(partitioned.partitions.len(), 0);
242        assert_eq!(&partitioned.root, &root());
243
244        // Instead, callers must expand the root expression themselves.
245        let expr = replace_root_fields(expr, fields);
246        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
247
248        assert_eq!(partitioned.partitions.len(), fields.names().len());
249    }
250
251    #[rstest]
252    fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
253        let fields = dtype.as_struct_fields_opt().unwrap();
254
255        let expr = get_item("y", get_item("a", root()));
256
257        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
258        assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
259    }
260
261    #[rstest]
262    fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
263        let fields = dtype.as_struct_fields_opt().unwrap();
264
265        let expr = pack(
266            [
267                ("x", get_item("x", get_item("a", root()))),
268                ("y", get_item("y", get_item("a", root()))),
269                ("c", get_item("c", root())),
270            ],
271            NonNullable,
272        );
273        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
274
275        let split_a = partitioned.find_partition(&"a".into()).unwrap();
276        assert_eq!(
277            &simplify(split_a.clone()).unwrap(),
278            &pack(
279                [
280                    ("a_0", get_item("x", get_item("a", root()))),
281                    ("a_1", get_item("y", get_item("a", root())))
282                ],
283                NonNullable
284            )
285        );
286    }
287
288    #[rstest]
289    fn test_expr_top_level_ref_get_item_add(dtype: DType) {
290        let fields = dtype.as_struct_fields_opt().unwrap();
291
292        let expr = and(get_item("y", get_item("a", root())), lit(1));
293        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
294
295        // Whole expr is a single split
296        assert_eq!(partitioned.partitions.len(), 1);
297    }
298
299    #[rstest]
300    fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
301        let fields = dtype.as_struct_fields_opt().unwrap();
302
303        let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
304        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
305
306        // One for id.a and id.b
307        assert_eq!(partitioned.partitions.len(), 2);
308    }
309
310    // Test that typed_simplify removes select and partition precise
311    #[rstest]
312    fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
313        let fields = dtype.as_struct_fields_opt().unwrap();
314
315        let expr = and(
316            get_item("y", get_item("a", root())),
317            select(["a", "b"], root()),
318        );
319        let expr = simplify_typed(expr, &dtype).unwrap();
320        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
321
322        // One for id.a and id.b
323        assert_eq!(partitioned.partitions.len(), 2);
324
325        // This fetches [].$c which is unused, however a previous optimisation should replace select
326        // with get_item and pack removing this field.
327        assert_eq!(
328            &partitioned.root,
329            &and(
330                get_item("a_0", get_item("a", root())),
331                pack(
332                    [
333                        (
334                            "a",
335                            get_item(
336                                StructFieldExpressionSplitter::<FieldName>::field_name(
337                                    &"a".into(),
338                                    1
339                                ),
340                                get_item("a", root())
341                            )
342                        ),
343                        ("b", get_item("b_0", get_item("b", root())))
344                    ],
345                    NonNullable
346                )
347            )
348        )
349    }
350
351    #[rstest]
352    fn test_expr_merge(dtype: DType) {
353        let fields = dtype.as_struct_fields_opt().unwrap();
354
355        let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
356
357        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
358        let expected = pack(
359            [
360                ("x", get_item("x", get_item("a_0", col("a")))),
361                ("y", get_item("y", get_item("a_0", col("a")))),
362                ("b", get_item("b", get_item("b_0", col("b")))),
363            ],
364            NonNullable,
365        );
366        assert_eq!(
367            &partitioned.root, &expected,
368            "{} {}",
369            partitioned.root, expected
370        );
371
372        assert_eq!(partitioned.partitions.len(), 2);
373
374        let part_a = partitioned.find_partition(&"a".into()).unwrap();
375        let expected_a = pack([("a_0", col("a"))], NonNullable);
376        assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
377
378        let part_b = partitioned.find_partition(&"b".into()).unwrap();
379        let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
380        assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
381    }
382}