vortex_expr/transform/
partition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_utils::aliases::hash_map::HashMap;
10
11use crate::transform::annotations::{
12    Annotation, AnnotationFn, Annotations, descendent_annotations,
13};
14use crate::transform::simplify_typed::simplify_typed;
15use crate::traversal::{NodeExt, NodeRewriter, Transformed, TraversalOrder};
16use crate::{ExprRef, get_item, pack, root};
17
18/// Partition an expression into sub-expressions that are uniquely associated with an annotation.
19/// A root expression is also returned that can be used to recombine the results of the partitions
20/// into the result of the original expression.
21///
22/// ## Note
23///
24/// This function currently respects the validity of each field in the scope, but the not validity
25/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
26/// expression for computing the validity, or to include that expression as part of the root.
27///
28/// See <https://github.com/vortex-data/vortex/issues/1907>.
29pub fn partition<A: AnnotationFn>(
30    expr: ExprRef,
31    scope: &DType,
32    annotate_fn: A,
33) -> VortexResult<PartitionedExpr<A::Annotation>>
34where
35    A::Annotation: Display,
36{
37    // Annotate each expression with the annotations that any of its descendent expressions have.
38    let annotations = descendent_annotations(&expr, annotate_fn);
39
40    // Now we split the original expression into sub-expressions based on the annotations, and
41    // generate a root expression to re-assemble the results.
42    let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
43    let root = expr.clone().rewrite(&mut splitter)?.value;
44
45    let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
46    let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
47    let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
48
49    for (annotation, exprs) in splitter.sub_expressions.into_iter() {
50        // We pack all sub-expressions for the same annotation into a single expression.
51        let expr = pack(
52            exprs.into_iter().enumerate().map(|(idx, expr)| {
53                (
54                    StructFieldExpressionSplitter::field_name(&annotation, idx),
55                    expr,
56                )
57            }),
58            Nullability::NonNullable,
59        );
60
61        let expr = simplify_typed(expr.clone(), scope)?;
62        let expr_dtype = expr.return_dtype(scope)?;
63
64        partitions.push(expr);
65        partition_annotations.push(annotation);
66        partition_dtypes.push(expr_dtype);
67    }
68
69    let partition_names =
70        FieldNames::from_iter(partition_annotations.iter().map(|id| id.to_string()));
71    let root_scope = DType::Struct(
72        StructFields::new(partition_names.clone(), partition_dtypes.clone()),
73        Nullability::NonNullable,
74    );
75
76    Ok(PartitionedExpr {
77        root: simplify_typed(root, &root_scope)?,
78        partitions: partitions.into_boxed_slice(),
79        partition_names,
80        partition_dtypes: partition_dtypes.into_boxed_slice(),
81        partition_annotations: partition_annotations.into_boxed_slice(),
82    })
83}
84
85/// The result of partitioning an expression.
86#[derive(Debug)]
87pub struct PartitionedExpr<A> {
88    /// The root expression used to re-assemble the results.
89    pub root: ExprRef,
90    /// The partition expressions themselves.
91    pub partitions: Box<[ExprRef]>,
92    /// The field name of each partition as referenced in the root expression.
93    pub partition_names: FieldNames,
94    /// The return dtype of each partition expression.
95    pub partition_dtypes: Box<[DType]>,
96    /// The annotation associated with each partition.
97    pub partition_annotations: Box<[A]>,
98}
99
100impl<A: Display> Display for PartitionedExpr<A> {
101    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102        write!(
103            f,
104            "root: {} {{{}}}",
105            self.root,
106            self.partition_names
107                .iter()
108                .zip(self.partitions.iter())
109                .map(|(name, partition)| format!("{name}: {partition}"))
110                .join(", ")
111        )
112    }
113}
114
115impl<A: Annotation + Display> PartitionedExpr<A> {
116    /// Return the partition for a given field, if it exists.
117    // FIXME(ngates): this should return an iterator since an annotation may have multiple partitions.
118    pub fn find_partition(&self, id: &A) -> Option<&ExprRef> {
119        let id = FieldName::from(id.to_string());
120        self.partition_names
121            .iter()
122            .position(|field| field == &id)
123            .map(|idx| &self.partitions[idx])
124    }
125}
126
127#[derive(Debug)]
128struct StructFieldExpressionSplitter<'a, A: Annotation> {
129    annotations: &'a Annotations<'a, A>,
130    sub_expressions: HashMap<A, Vec<ExprRef>>,
131}
132
133impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
134    fn new(annotations: &'a Annotations<'a, A>) -> Self {
135        Self {
136            sub_expressions: HashMap::new(),
137            annotations,
138        }
139    }
140
141    /// Each annotation may be associated with multiple sub-expressions, so we need to
142    /// a unique name for each sub-expression.
143    fn field_name(annotation: &A, idx: usize) -> FieldName {
144        format!("{annotation}_{idx}").into()
145    }
146}
147
148impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A> {
149    type NodeTy = ExprRef;
150
151    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
152        match self.annotations.get(&node) {
153            // If this expression only accesses a single field, then we can skip the children
154            Some(annotations) if annotations.len() == 1 => {
155                let annotation = annotations
156                    .iter()
157                    .next()
158                    .vortex_expect("expected one field");
159                let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
160                let idx = sub_exprs.len();
161                sub_exprs.push(node.clone());
162                let value = get_item(
163                    StructFieldExpressionSplitter::field_name(annotation, idx),
164                    get_item(FieldName::from(annotation.to_string()), root()),
165                );
166                Ok(Transformed {
167                    value,
168                    changed: true,
169                    order: TraversalOrder::Skip,
170                })
171            }
172
173            // Otherwise, continue traversing.
174            _ => Ok(Transformed::no(node)),
175        }
176    }
177
178    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
179        Ok(Transformed::no(node))
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use rstest::{fixture, rstest};
186    use vortex_dtype::Nullability::NonNullable;
187    use vortex_dtype::PType::I32;
188    use vortex_dtype::{DType, StructFields};
189
190    use super::*;
191    use crate::transform::immediate_access::annotate_scope_access;
192    use crate::transform::replace::replace_root_fields;
193    use crate::transform::simplify::simplify;
194    use crate::transform::simplify_typed::simplify_typed;
195    use crate::{and, col, get_item, lit, merge, pack, root, select};
196
197    #[fixture]
198    fn dtype() -> DType {
199        DType::Struct(
200            StructFields::from_iter([
201                (
202                    "a",
203                    DType::Struct(
204                        StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
205                        NonNullable,
206                    ),
207                ),
208                ("b", I32.into()),
209                ("c", I32.into()),
210            ]),
211            NonNullable,
212        )
213    }
214
215    #[rstest]
216    fn test_expr_top_level_ref(dtype: DType) {
217        let fields = dtype.as_struct().unwrap();
218
219        let expr = root();
220        let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
221
222        // An un-expanded root expression is annotated by all fields, but since it is a single node
223        assert_eq!(partitioned.partitions.len(), 0);
224        assert_eq!(&partitioned.root, &root());
225
226        // Instead, callers must expand the root expression themselves.
227        let expr = replace_root_fields(expr.clone(), fields);
228        let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
229
230        assert_eq!(partitioned.partitions.len(), fields.names().len());
231    }
232
233    #[rstest]
234    fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
235        let fields = dtype.as_struct().unwrap();
236
237        let expr = get_item("y", get_item("a", root()));
238
239        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
240        assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
241    }
242
243    #[rstest]
244    fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
245        let fields = dtype.as_struct().unwrap();
246
247        let expr = pack(
248            [
249                ("x", get_item("x", get_item("a", root()))),
250                ("y", get_item("y", get_item("a", root()))),
251                ("c", get_item("c", root())),
252            ],
253            NonNullable,
254        );
255        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
256
257        let split_a = partitioned.find_partition(&"a".into()).unwrap();
258        assert_eq!(
259            &simplify(split_a.clone()).unwrap(),
260            &pack(
261                [
262                    ("a_0", get_item("x", get_item("a", root()))),
263                    ("a_1", get_item("y", get_item("a", root())))
264                ],
265                NonNullable
266            )
267        );
268    }
269
270    #[rstest]
271    fn test_expr_top_level_ref_get_item_add(dtype: DType) {
272        let fields = dtype.as_struct().unwrap();
273
274        let expr = and(get_item("y", get_item("a", root())), lit(1));
275        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
276
277        // Whole expr is a single split
278        assert_eq!(partitioned.partitions.len(), 1);
279    }
280
281    #[rstest]
282    fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
283        let fields = dtype.as_struct().unwrap();
284
285        let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
286        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
287
288        // One for id.a and id.b
289        assert_eq!(partitioned.partitions.len(), 2);
290    }
291
292    // Test that typed_simplify removes select and partition precise
293    #[rstest]
294    fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
295        let fields = dtype.as_struct().unwrap();
296
297        let expr = and(
298            get_item("y", get_item("a", root())),
299            select(vec!["a".into(), "b".into()], root()),
300        );
301        let expr = simplify_typed(expr, &dtype).unwrap();
302        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
303
304        // One for id.a and id.b
305        assert_eq!(partitioned.partitions.len(), 2);
306
307        // This fetches [].$c which is unused, however a previous optimisation should replace select
308        // with get_item and pack removing this field.
309        assert_eq!(
310            &partitioned.root,
311            &and(
312                get_item("a_0", get_item("a", root())),
313                pack(
314                    [
315                        (
316                            "a",
317                            get_item(
318                                StructFieldExpressionSplitter::<FieldName>::field_name(
319                                    &"a".into(),
320                                    1
321                                ),
322                                get_item("a", root())
323                            )
324                        ),
325                        ("b", get_item("b_0", get_item("b", root())))
326                    ],
327                    NonNullable
328                )
329            )
330        )
331    }
332
333    #[rstest]
334    fn test_expr_merge(dtype: DType) {
335        let fields = dtype.as_struct().unwrap();
336
337        let expr = merge(
338            [col("a"), pack([("b", col("b"))], NonNullable)],
339            NonNullable,
340        );
341
342        let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
343        let expected = pack(
344            [
345                ("x", get_item("x", get_item("a_0", col("a")))),
346                ("y", get_item("y", get_item("a_0", col("a")))),
347                ("b", get_item("b", get_item("b_0", col("b")))),
348            ],
349            NonNullable,
350        );
351        assert_eq!(
352            &partitioned.root, &expected,
353            "{} {}",
354            partitioned.root, expected
355        );
356
357        assert_eq!(partitioned.partitions.len(), 2);
358
359        let part_a = partitioned.find_partition(&"a".into()).unwrap();
360        let expected_a = pack([("a_0", col("a"))], NonNullable);
361        assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
362
363        let part_b = partitioned.find_partition(&"b".into()).unwrap();
364        let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
365        assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
366    }
367}