vortex_expr/transform/
partition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use crate::transform::annotations::{
12    Annotation, AnnotationFn, Annotations, descendent_annotations,
13};
14use crate::transform::simplify_typed::simplify_typed;
15use crate::traversal::{NodeExt, NodeRewriter, Transformed, TraversalOrder};
16use crate::{ExprRef, get_item, pack, root};
17
18/// Partition an expression into sub-expressions that are uniquely associated with an annotation.
19/// A root expression is also returned that can be used to recombine the results of the partitions
20/// into the result of the original expression.
21///
22/// ## Note
23///
24/// This function currently respects the validity of each field in the scope, but the not validity
25/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
26/// expression for computing the validity, or to include that expression as part of the root.
27///
28/// See <https://github.com/vortex-data/vortex/issues/1907>.
29pub fn partition<A: AnnotationFn>(
30    expr: ExprRef,
31    scope: &DType,
32    annotate_fn: A,
33) -> VortexResult<PartitionedExpr<A::Annotation>>
34where
35    A::Annotation: Display,
36    FieldName: From<A::Annotation>,
37{
38    // Annotate each expression with the annotations that any of its descendent expressions have.
39    let annotations = descendent_annotations(&expr, annotate_fn);
40
41    // Now we split the original expression into sub-expressions based on the annotations, and
42    // generate a root expression to re-assemble the results.
43    let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
44    let root = expr.clone().rewrite(&mut splitter)?.value;
45
46    let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
47    let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
48    let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
49
50    for (annotation, exprs) in splitter.sub_expressions.into_iter() {
51        // We pack all sub-expressions for the same annotation into a single expression.
52        let expr = pack(
53            exprs.into_iter().enumerate().map(|(idx, expr)| {
54                (
55                    StructFieldExpressionSplitter::field_name(&annotation, idx),
56                    expr,
57                )
58            }),
59            Nullability::NonNullable,
60        );
61
62        let expr = simplify_typed(expr.clone(), scope)?;
63        let expr_dtype = expr.return_dtype(scope)?;
64
65        partitions.push(expr);
66        partition_annotations.push(annotation);
67        partition_dtypes.push(expr_dtype);
68    }
69
70    let partition_names = partition_annotations
71        .iter()
72        .map(|id| FieldName::from(id.clone()))
73        .collect::<FieldNames>();
74    let root_scope = DType::Struct(
75        StructFields::new(partition_names.clone(), partition_dtypes.clone()),
76        Nullability::NonNullable,
77    );
78
79    Ok(PartitionedExpr {
80        root: simplify_typed(root, &root_scope)?,
81        partitions: partitions.into_boxed_slice(),
82        partition_names,
83        partition_dtypes: partition_dtypes.into_boxed_slice(),
84        partition_annotations: partition_annotations.into_boxed_slice(),
85    })
86}
87
88/// The result of partitioning an expression.
89#[derive(Debug)]
90pub struct PartitionedExpr<A> {
91    /// The root expression used to re-assemble the results.
92    pub root: ExprRef,
93    /// The partition expressions themselves.
94    pub partitions: Box<[ExprRef]>,
95    /// The field name of each partition as referenced in the root expression.
96    pub partition_names: FieldNames,
97    /// The return dtype of each partition expression.
98    pub partition_dtypes: Box<[DType]>,
99    /// The annotation associated with each partition.
100    pub partition_annotations: Box<[A]>,
101}
102
103impl<A: Display> Display for PartitionedExpr<A> {
104    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
105        write!(
106            f,
107            "root: {} {{{}}}",
108            self.root,
109            self.partition_names
110                .iter()
111                .zip(self.partitions.iter())
112                .map(|(name, partition)| format!("{name}: {partition}"))
113                .join(", ")
114        )
115    }
116}
117
118impl<A: Annotation> PartitionedExpr<A>
119where
120    FieldName: From<A>,
121{
122    /// Return the partition for a given field, if it exists.
123    // FIXME(ngates): this should return an iterator since an annotation may have multiple partitions.
124    pub fn find_partition(&self, id: &A) -> Option<&ExprRef> {
125        let id = FieldName::from(id.clone());
126        self.partition_names
127            .iter()
128            .position(|field| field == id)
129            .map(|idx| &self.partitions[idx])
130    }
131}
132
133#[derive(Debug)]
134struct StructFieldExpressionSplitter<'a, A: Annotation> {
135    annotations: &'a Annotations<'a, A>,
136    sub_expressions: HashMap<A, Vec<ExprRef>>,
137}
138
139impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
140    fn new(annotations: &'a Annotations<'a, A>) -> Self {
141        Self {
142            sub_expressions: HashMap::new(),
143            annotations,
144        }
145    }
146
147    /// Each annotation may be associated with multiple sub-expressions, so we need to
148    /// a unique name for each sub-expression.
149    fn field_name(annotation: &A, idx: usize) -> FieldName {
150        format!("{annotation}_{idx}").into()
151    }
152}
153
154impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
155where
156    FieldName: From<A>,
157{
158    type NodeTy = ExprRef;
159
160    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
161        match self.annotations.get(&node) {
162            // If this expression only accesses a single field, then we can skip the children
163            Some(annotations) if annotations.len() == 1 => {
164                let annotation = annotations
165                    .iter()
166                    .next()
167                    .vortex_expect("expected one field");
168                let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
169                let idx = sub_exprs.len();
170                sub_exprs.push(node.clone());
171                let value = get_item(
172                    StructFieldExpressionSplitter::field_name(annotation, idx),
173                    get_item(FieldName::from(annotation.clone()), root()),
174                );
175                Ok(Transformed {
176                    value,
177                    changed: true,
178                    order: TraversalOrder::Skip,
179                })
180            }
181
182            // Otherwise, continue traversing.
183            _ => Ok(Transformed::no(node)),
184        }
185    }
186
187    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
188        Ok(Transformed::no(node))
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use rstest::{fixture, rstest};
195    use vortex_dtype::Nullability::NonNullable;
196    use vortex_dtype::PType::I32;
197    use vortex_dtype::{DType, StructFields};
198
199    use super::*;
200    use crate::transform::immediate_access::annotate_scope_access;
201    use crate::transform::replace::replace_root_fields;
202    use crate::transform::simplify::simplify;
203    use crate::transform::simplify_typed::simplify_typed;
204    use crate::{and, col, get_item, lit, merge, pack, root, select};
205
206    #[fixture]
207    fn dtype() -> DType {
208        DType::Struct(
209            StructFields::from_iter([
210                (
211                    "a",
212                    DType::Struct(
213                        StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
214                        NonNullable,
215                    ),
216                ),
217                ("b", I32.into()),
218                ("c", I32.into()),
219            ]),
220            NonNullable,
221        )
222    }
223
224    #[rstest]
225    fn test_expr_top_level_ref(dtype: DType) {
226        let fields = dtype.as_struct_fields_opt().unwrap();
227
228        let expr = root();
229        let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
230
231        // An un-expanded root expression is annotated by all fields, but since it is a single node
232        assert_eq!(partitioned.partitions.len(), 0);
233        assert_eq!(&partitioned.root, &root());
234
235        // Instead, callers must expand the root expression themselves.
236        let expr = replace_root_fields(expr.clone(), fields);
237        let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
238
239        assert_eq!(partitioned.partitions.len(), fields.names().len());
240    }
241
242    #[rstest]
243    fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
244        let fields = dtype.as_struct_fields_opt().unwrap();
245
246        let expr = get_item("y", get_item("a", root()));
247
248        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
249        assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
250    }
251
252    #[rstest]
253    fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
254        let fields = dtype.as_struct_fields_opt().unwrap();
255
256        let expr = pack(
257            [
258                ("x", get_item("x", get_item("a", root()))),
259                ("y", get_item("y", get_item("a", root()))),
260                ("c", get_item("c", root())),
261            ],
262            NonNullable,
263        );
264        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
265
266        let split_a = partitioned.find_partition(&"a".into()).unwrap();
267        assert_eq!(
268            &simplify(split_a.clone()).unwrap(),
269            &pack(
270                [
271                    ("a_0", get_item("x", get_item("a", root()))),
272                    ("a_1", get_item("y", get_item("a", root())))
273                ],
274                NonNullable
275            )
276        );
277    }
278
279    #[rstest]
280    fn test_expr_top_level_ref_get_item_add(dtype: DType) {
281        let fields = dtype.as_struct_fields_opt().unwrap();
282
283        let expr = and(get_item("y", get_item("a", root())), lit(1));
284        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
285
286        // Whole expr is a single split
287        assert_eq!(partitioned.partitions.len(), 1);
288    }
289
290    #[rstest]
291    fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
292        let fields = dtype.as_struct_fields_opt().unwrap();
293
294        let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
295        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
296
297        // One for id.a and id.b
298        assert_eq!(partitioned.partitions.len(), 2);
299    }
300
301    // Test that typed_simplify removes select and partition precise
302    #[rstest]
303    fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
304        let fields = dtype.as_struct_fields_opt().unwrap();
305
306        let expr = and(
307            get_item("y", get_item("a", root())),
308            select(["a", "b"], root()),
309        );
310        let expr = simplify_typed(expr, &dtype).unwrap();
311        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
312
313        // One for id.a and id.b
314        assert_eq!(partitioned.partitions.len(), 2);
315
316        // This fetches [].$c which is unused, however a previous optimisation should replace select
317        // with get_item and pack removing this field.
318        assert_eq!(
319            &partitioned.root,
320            &and(
321                get_item("a_0", get_item("a", root())),
322                pack(
323                    [
324                        (
325                            "a",
326                            get_item(
327                                StructFieldExpressionSplitter::<FieldName>::field_name(
328                                    &"a".into(),
329                                    1
330                                ),
331                                get_item("a", root())
332                            )
333                        ),
334                        ("b", get_item("b_0", get_item("b", root())))
335                    ],
336                    NonNullable
337                )
338            )
339        )
340    }
341
342    #[rstest]
343    fn test_expr_merge(dtype: DType) {
344        let fields = dtype.as_struct_fields_opt().unwrap();
345
346        let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
347
348        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
349        let expected = pack(
350            [
351                ("x", get_item("x", get_item("a_0", col("a")))),
352                ("y", get_item("y", get_item("a_0", col("a")))),
353                ("b", get_item("b", get_item("b_0", col("b")))),
354            ],
355            NonNullable,
356        );
357        assert_eq!(
358            &partitioned.root, &expected,
359            "{} {}",
360            partitioned.root, expected
361        );
362
363        assert_eq!(partitioned.partitions.len(), 2);
364
365        let part_a = partitioned.find_partition(&"a".into()).unwrap();
366        let expected_a = pack([("a_0", col("a"))], NonNullable);
367        assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
368
369        let part_b = partitioned.find_partition(&"b".into()).unwrap();
370        let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
371        assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
372    }
373}