vortex_array/expr/transform/
partition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use vortex_dtype::DType;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldNames;
11use vortex_dtype::Nullability;
12use vortex_dtype::StructFields;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::aliases::hash_map::HashMap;
16
17use crate::expr::Expression;
18use crate::expr::analysis::Annotation;
19use crate::expr::analysis::AnnotationFn;
20use crate::expr::analysis::Annotations;
21use crate::expr::analysis::descendent_annotations;
22use crate::expr::exprs::get_item::get_item;
23use crate::expr::exprs::pack::pack;
24use crate::expr::exprs::root::root;
25use crate::expr::transform::ExprOptimizer;
26use crate::expr::traversal::NodeExt;
27use crate::expr::traversal::NodeRewriter;
28use crate::expr::traversal::Transformed;
29use crate::expr::traversal::TraversalOrder;
30
31/// Partition an expression into sub-expressions that are uniquely associated with an annotation.
32/// A root expression is also returned that can be used to recombine the results of the partitions
33/// into the result of the original expression.
34///
35/// ## Note
36///
37/// This function currently respects the validity of each field in the scope, but the not validity
38/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
39/// expression for computing the validity, or to include that expression as part of the root.
40///
41/// See <https://github.com/vortex-data/vortex/issues/1907>.
42pub fn partition<A: AnnotationFn>(
43    expr: Expression,
44    scope: &DType,
45    annotate_fn: A,
46    optimizer: &ExprOptimizer,
47) -> VortexResult<PartitionedExpr<A::Annotation>>
48where
49    A::Annotation: Display,
50    FieldName: From<A::Annotation>,
51{
52    // Annotate each expression with the annotations that any of its descendent expressions have.
53    let annotations = descendent_annotations(&expr, annotate_fn);
54
55    // Now we split the original expression into sub-expressions based on the annotations, and
56    // generate a root expression to re-assemble the results.
57    let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
58    let root = expr.clone().rewrite(&mut splitter)?.value;
59
60    let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
61    let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
62    let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
63
64    for (annotation, exprs) in splitter.sub_expressions.into_iter() {
65        // We pack all sub-expressions for the same annotation into a single expression.
66        let expr = pack(
67            exprs.into_iter().enumerate().map(|(idx, expr)| {
68                (
69                    StructFieldExpressionSplitter::field_name(&annotation, idx),
70                    expr,
71                )
72            }),
73            Nullability::NonNullable,
74        );
75
76        let expr = optimizer.optimize_typed(expr.clone(), scope)?;
77        let expr_dtype = expr.return_dtype(scope)?;
78
79        partitions.push(expr);
80        partition_annotations.push(annotation);
81        partition_dtypes.push(expr_dtype);
82    }
83
84    let partition_names = partition_annotations
85        .iter()
86        .map(|id| FieldName::from(id.clone()))
87        .collect::<FieldNames>();
88    let root_scope = DType::Struct(
89        StructFields::new(partition_names.clone(), partition_dtypes.clone()),
90        Nullability::NonNullable,
91    );
92
93    Ok(PartitionedExpr {
94        root: optimizer.optimize_typed(root, &root_scope)?,
95        partitions: partitions.into_boxed_slice(),
96        partition_names,
97        partition_dtypes: partition_dtypes.into_boxed_slice(),
98        partition_annotations: partition_annotations.into_boxed_slice(),
99    })
100}
101
102/// The result of partitioning an expression.
103#[derive(Debug)]
104pub struct PartitionedExpr<A> {
105    /// The root expression used to re-assemble the results.
106    pub root: Expression,
107    /// The partition expressions themselves.
108    pub partitions: Box<[Expression]>,
109    /// The field name of each partition as referenced in the root expression.
110    pub partition_names: FieldNames,
111    /// The return dtype of each partition expression.
112    pub partition_dtypes: Box<[DType]>,
113    /// The annotation associated with each partition.
114    pub partition_annotations: Box<[A]>,
115}
116
117impl<A: Display> Display for PartitionedExpr<A> {
118    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
119        write!(
120            f,
121            "root: {} {{{}}}",
122            self.root,
123            self.partition_names
124                .iter()
125                .zip(self.partitions.iter())
126                .map(|(name, partition)| format!("{name}: {partition}"))
127                .join(", ")
128        )
129    }
130}
131
132impl<A: Annotation> PartitionedExpr<A>
133where
134    FieldName: From<A>,
135{
136    /// Return the partition for a given field, if it exists.
137    // FIXME(ngates): this should return an iterator since an annotation may have multiple partitions.
138    pub fn find_partition(&self, id: &A) -> Option<&Expression> {
139        let id = FieldName::from(id.clone());
140        self.partition_names
141            .iter()
142            .position(|field| field == id)
143            .map(|idx| &self.partitions[idx])
144    }
145}
146
147#[derive(Debug)]
148struct StructFieldExpressionSplitter<'a, A: Annotation> {
149    annotations: &'a Annotations<'a, A>,
150    sub_expressions: HashMap<A, Vec<Expression>>,
151}
152
153impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
154    fn new(annotations: &'a Annotations<'a, A>) -> Self {
155        Self {
156            sub_expressions: HashMap::new(),
157            annotations,
158        }
159    }
160
161    /// Each annotation may be associated with multiple sub-expressions, so we need to
162    /// a unique name for each sub-expression.
163    fn field_name(annotation: &A, idx: usize) -> FieldName {
164        format!("{annotation}_{idx}").into()
165    }
166}
167
168impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
169where
170    FieldName: From<A>,
171{
172    type NodeTy = Expression;
173
174    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
175        match self.annotations.get(&node) {
176            // If this expression only accesses a single field, then we can skip the children
177            Some(annotations) if annotations.len() == 1 => {
178                let annotation = annotations
179                    .iter()
180                    .next()
181                    .vortex_expect("expected one field");
182                let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
183                let idx = sub_exprs.len();
184                sub_exprs.push(node.clone());
185                let value = get_item(
186                    StructFieldExpressionSplitter::field_name(annotation, idx),
187                    get_item(FieldName::from(annotation.clone()), root()),
188                );
189                Ok(Transformed {
190                    value,
191                    changed: true,
192                    order: TraversalOrder::Skip,
193                })
194            }
195
196            // Otherwise, continue traversing.
197            _ => Ok(Transformed::no(node)),
198        }
199    }
200
201    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
202        Ok(Transformed::no(node))
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use rstest::fixture;
209    use rstest::rstest;
210    use vortex_dtype::DType;
211    use vortex_dtype::Nullability::NonNullable;
212    use vortex_dtype::PType::I32;
213    use vortex_dtype::StructFields;
214
215    use super::*;
216    use crate::expr::analysis::annotate_scope_access;
217    use crate::expr::exprs::binary::and;
218    use crate::expr::exprs::get_item::col;
219    use crate::expr::exprs::get_item::get_item;
220    use crate::expr::exprs::literal::lit;
221    use crate::expr::exprs::merge::merge;
222    use crate::expr::exprs::pack::pack;
223    use crate::expr::exprs::root::root;
224    use crate::expr::exprs::select::select;
225    use crate::expr::session::ExprSession;
226    use crate::expr::transform::replace::replace_root_fields;
227    use crate::expr::transform::simplify_typed::simplify_typed;
228
229    #[fixture]
230    fn dtype() -> DType {
231        DType::Struct(
232            StructFields::from_iter([
233                (
234                    "a",
235                    DType::Struct(
236                        StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
237                        NonNullable,
238                    ),
239                ),
240                ("b", I32.into()),
241                ("c", I32.into()),
242            ]),
243            NonNullable,
244        )
245    }
246
247    #[rstest]
248    fn test_expr_top_level_ref(dtype: DType) {
249        let fields = dtype.as_struct_fields_opt().unwrap();
250        let session = ExprSession::default();
251        let optimizer = ExprOptimizer::new(&session);
252
253        let expr = root();
254        let partitioned = partition(
255            expr.clone(),
256            &dtype,
257            annotate_scope_access(fields),
258            &optimizer,
259        )
260        .unwrap();
261
262        // An un-expanded root expression is annotated by all fields, but since it is a single node
263        assert_eq!(partitioned.partitions.len(), 0);
264        assert_eq!(&partitioned.root, &root());
265
266        // Instead, callers must expand the root expression themselves.
267        let expr = replace_root_fields(expr, fields);
268        let partitioned =
269            partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
270
271        assert_eq!(partitioned.partitions.len(), fields.names().len());
272    }
273
274    #[rstest]
275    fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
276        let fields = dtype.as_struct_fields_opt().unwrap();
277        let session = ExprSession::default();
278        let optimizer = ExprOptimizer::new(&session);
279
280        let expr = get_item("y", get_item("a", root()));
281
282        let partitioned =
283            partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
284        assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
285    }
286
287    #[rstest]
288    fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
289        let fields = dtype.as_struct_fields_opt().unwrap();
290        let session = ExprSession::default();
291        let optimizer = ExprOptimizer::new(&session);
292
293        let expr = pack(
294            [
295                ("x", get_item("x", get_item("a", root()))),
296                ("y", get_item("y", get_item("a", root()))),
297                ("c", get_item("c", root())),
298            ],
299            NonNullable,
300        );
301        let partitioned =
302            partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
303
304        let split_a = partitioned.find_partition(&"a".into()).unwrap();
305        assert_eq!(
306            &simplify_typed(
307                split_a.clone(),
308                &dtype,
309                ExprSession::default().rewrite_rules()
310            )
311            .unwrap(),
312            &pack(
313                [
314                    ("a_0", get_item("x", get_item("a", root()))),
315                    ("a_1", get_item("y", get_item("a", root())))
316                ],
317                NonNullable
318            )
319        );
320    }
321
322    #[rstest]
323    fn test_expr_top_level_ref_get_item_add(dtype: DType) {
324        let fields = dtype.as_struct_fields_opt().unwrap();
325        let session = ExprSession::default();
326        let optimizer = ExprOptimizer::new(&session);
327
328        let expr = and(get_item("y", get_item("a", root())), lit(1));
329        let partitioned =
330            partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
331
332        // Whole expr is a single split
333        assert_eq!(partitioned.partitions.len(), 1);
334    }
335
336    #[rstest]
337    fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
338        let fields = dtype.as_struct_fields_opt().unwrap();
339        let session = ExprSession::default();
340        let optimizer = ExprOptimizer::new(&session);
341
342        let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
343        let partitioned =
344            partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
345
346        // One for id.a and id.b
347        assert_eq!(partitioned.partitions.len(), 2);
348    }
349
350    // Test that typed_simplify removes select and partition precise
351    #[rstest]
352    fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
353        let fields = dtype.as_struct_fields_opt().unwrap();
354        let session = ExprSession::default();
355        let optimizer = ExprOptimizer::new(&session);
356
357        let expr = and(
358            get_item("y", get_item("a", root())),
359            select(["a", "b"], root()),
360        );
361        let expr = simplify_typed(expr, &dtype, ExprSession::default().rewrite_rules()).unwrap();
362        let partitioned =
363            partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
364
365        // One for id.a and id.b
366        assert_eq!(partitioned.partitions.len(), 2);
367
368        // This fetches [].$c which is unused, however a previous optimisation should replace select
369        // with get_item and pack removing this field.
370        assert_eq!(
371            &partitioned.root,
372            &and(
373                get_item("a_0", get_item("a", root())),
374                pack(
375                    [
376                        (
377                            "a",
378                            get_item(
379                                StructFieldExpressionSplitter::<FieldName>::field_name(
380                                    &"a".into(),
381                                    1
382                                ),
383                                get_item("a", root())
384                            )
385                        ),
386                        ("b", get_item("b_0", get_item("b", root())))
387                    ],
388                    NonNullable
389                )
390            )
391        )
392    }
393
394    #[rstest]
395    fn test_expr_merge(dtype: DType) {
396        let fields = dtype.as_struct_fields_opt().unwrap();
397        let session = ExprSession::default();
398        let optimizer = ExprOptimizer::new(&session);
399
400        let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
401
402        let partitioned =
403            partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
404        let expected = pack(
405            [
406                ("x", get_item("x", get_item("a_0", col("a")))),
407                ("y", get_item("y", get_item("a_0", col("a")))),
408                ("b", get_item("b", get_item("b_0", col("b")))),
409            ],
410            NonNullable,
411        );
412        assert_eq!(
413            &partitioned.root, &expected,
414            "{} {}",
415            partitioned.root, expected
416        );
417
418        assert_eq!(partitioned.partitions.len(), 2);
419
420        let part_a = partitioned.find_partition(&"a".into()).unwrap();
421        let expected_a = pack([("a_0", col("a"))], NonNullable);
422        assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
423
424        let part_b = partitioned.find_partition(&"b".into()).unwrap();
425        let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
426        assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
427    }
428}