Skip to main content

vortex_array/expr/transform/
partition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7
8use itertools::Itertools;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_utils::aliases::hash_map::HashMap;
12
13use crate::dtype::DType;
14use crate::dtype::FieldName;
15use crate::dtype::FieldNames;
16use crate::dtype::Nullability;
17use crate::dtype::StructFields;
18use crate::expr::Expression;
19use crate::expr::analysis::Annotation;
20use crate::expr::analysis::AnnotationFn;
21use crate::expr::analysis::Annotations;
22use crate::expr::analysis::descendent_annotations;
23use crate::expr::get_item;
24use crate::expr::pack;
25use crate::expr::root;
26use crate::expr::traversal::NodeExt;
27use crate::expr::traversal::NodeRewriter;
28use crate::expr::traversal::Transformed;
29use crate::expr::traversal::TraversalOrder;
30
31/// Partition an expression into sub-expressions that are uniquely associated with an annotation.
32/// A root expression is also returned that can be used to recombine the results of the partitions
33/// into the result of the original expression.
34///
35/// ## Note
36///
37/// This function currently respects the validity of each field in the scope, but the not validity
38/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
39/// expression for computing the validity, or to include that expression as part of the root.
40///
41/// See <https://github.com/vortex-data/vortex/issues/1907>.
42pub fn partition<A: AnnotationFn>(
43    expr: Expression,
44    scope: &DType,
45    annotate_fn: A,
46) -> VortexResult<PartitionedExpr<A::Annotation>>
47where
48    A::Annotation: Display,
49    FieldName: From<A::Annotation>,
50{
51    // Annotate each expression with the annotations that any of its descendent expressions have.
52    let annotations = descendent_annotations(&expr, annotate_fn);
53    partition_annotations(expr.clone(), scope, annotations)
54}
55
56pub fn partition_annotations<A>(
57    expr: Expression,
58    scope: &DType,
59    annotations: Annotations<A>,
60) -> VortexResult<PartitionedExpr<A>>
61where
62    A: Display + Clone + Eq + Hash,
63    FieldName: From<A>,
64{
65    // Now we split the original expression into sub-expressions based on the annotations, and
66    // generate a root expression to re-assemble the results.
67    let mut splitter = StructFieldExpressionSplitter::<A>::new(&annotations);
68    let root = expr.rewrite(&mut splitter)?.value;
69
70    let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
71    let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
72    let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
73
74    for (annotation, exprs) in splitter.sub_expressions.into_iter() {
75        // We pack all sub-expressions for the same annotation into a single expression.
76        let expr = pack(
77            exprs.into_iter().enumerate().map(|(idx, expr)| {
78                (
79                    StructFieldExpressionSplitter::field_name(&annotation, idx),
80                    expr,
81                )
82            }),
83            Nullability::NonNullable,
84        );
85
86        let expr = expr.optimize_recursive(scope)?;
87        let expr_dtype = expr.return_dtype(scope)?;
88
89        partitions.push(expr);
90        partition_annotations.push(annotation);
91        partition_dtypes.push(expr_dtype);
92    }
93
94    let partition_names = partition_annotations
95        .iter()
96        .map(|id| FieldName::from(id.clone()))
97        .collect::<FieldNames>();
98    let root_scope = DType::Struct(
99        StructFields::new(partition_names.clone(), partition_dtypes.clone()),
100        Nullability::NonNullable,
101    );
102
103    Ok(PartitionedExpr {
104        root: root.optimize_recursive(&root_scope)?,
105        partitions: partitions.into_boxed_slice(),
106        partition_names,
107        partition_dtypes: partition_dtypes.into_boxed_slice(),
108        partition_annotations: partition_annotations.into_boxed_slice(),
109    })
110}
111
112/// The result of partitioning an expression.
113#[derive(Debug)]
114pub struct PartitionedExpr<A> {
115    /// The root expression used to re-assemble the results.
116    pub root: Expression,
117    /// The partition expressions themselves.
118    pub partitions: Box<[Expression]>,
119    /// The field name of each partition as referenced in the root expression.
120    pub partition_names: FieldNames,
121    /// The return dtype of each partition expression.
122    pub partition_dtypes: Box<[DType]>,
123    /// The annotation associated with each partition.
124    pub partition_annotations: Box<[A]>,
125}
126
127impl<A: Display> Display for PartitionedExpr<A> {
128    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129        write!(
130            f,
131            "root: {} {{{}}}",
132            self.root,
133            self.partition_names
134                .iter()
135                .zip(self.partitions.iter())
136                .map(|(name, partition)| format!("{name}: {partition}"))
137                .join(", ")
138        )
139    }
140}
141
142impl<A: Annotation> PartitionedExpr<A>
143where
144    FieldName: From<A>,
145{
146    /// Return the partition for a given field, if it exists.
147    // FIXME(ngates): this should return an iterator since an annotation may have multiple partitions.
148    pub fn find_partition(&self, id: &A) -> Option<&Expression> {
149        let id = FieldName::from(id.clone());
150        self.partition_names
151            .iter()
152            .position(|field| field == id)
153            .map(|idx| &self.partitions[idx])
154    }
155}
156
157#[derive(Debug)]
158struct StructFieldExpressionSplitter<'a, A: Annotation> {
159    annotations: &'a Annotations<'a, A>,
160    sub_expressions: HashMap<A, Vec<Expression>>,
161}
162
163impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
164    fn new(annotations: &'a Annotations<'a, A>) -> Self {
165        Self {
166            sub_expressions: HashMap::new(),
167            annotations,
168        }
169    }
170
171    /// Each annotation may be associated with multiple sub-expressions, so we need to
172    /// a unique name for each sub-expression.
173    fn field_name(annotation: &A, idx: usize) -> FieldName {
174        format!("{annotation}_{idx}").into()
175    }
176}
177
178impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
179where
180    FieldName: From<A>,
181{
182    type NodeTy = Expression;
183
184    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
185        match self.annotations.get(&node) {
186            // If this expression only accesses a single field, then we can skip the children
187            Some(annotations) if annotations.len() == 1 => {
188                let annotation = annotations
189                    .iter()
190                    .next()
191                    .vortex_expect("expected one field");
192                let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
193                let idx = sub_exprs.len();
194                sub_exprs.push(node.clone());
195                let value = get_item(
196                    StructFieldExpressionSplitter::field_name(annotation, idx),
197                    get_item(FieldName::from(annotation.clone()), root()),
198                );
199                Ok(Transformed {
200                    value,
201                    changed: true,
202                    order: TraversalOrder::Skip,
203                })
204            }
205
206            // Otherwise, continue traversing.
207            _ => Ok(Transformed::no(node)),
208        }
209    }
210
211    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
212        Ok(Transformed::no(node))
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use rstest::fixture;
219    use rstest::rstest;
220
221    use super::*;
222    use crate::dtype::DType;
223    use crate::dtype::Nullability::NonNullable;
224    use crate::dtype::PType::I32;
225    use crate::dtype::StructFields;
226    use crate::expr::analysis::make_free_field_annotator;
227    use crate::expr::and;
228    use crate::expr::col;
229    use crate::expr::get_item;
230    use crate::expr::lit;
231    use crate::expr::merge;
232    use crate::expr::pack;
233    use crate::expr::root;
234    use crate::expr::transform::replace::replace_root_fields;
235
236    #[fixture]
237    fn dtype() -> DType {
238        DType::Struct(
239            StructFields::from_iter([
240                (
241                    "a",
242                    DType::Struct(
243                        StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
244                        NonNullable,
245                    ),
246                ),
247                ("b", I32.into()),
248                ("c", I32.into()),
249            ]),
250            NonNullable,
251        )
252    }
253
254    #[rstest]
255    fn test_expr_top_level_ref(dtype: DType) {
256        let fields = dtype.as_struct_fields_opt().unwrap();
257
258        let expr = root();
259        let partitioned =
260            partition(expr.clone(), &dtype, make_free_field_annotator(fields)).unwrap();
261
262        // An un-expanded root expression is annotated by all fields, but since it is a single node
263        assert_eq!(partitioned.partitions.len(), 0);
264        assert_eq!(&partitioned.root, &root());
265
266        // Instead, callers must expand the root expression themselves.
267        let expr = replace_root_fields(expr, fields);
268        let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
269
270        assert_eq!(partitioned.partitions.len(), fields.names().len());
271    }
272
273    #[rstest]
274    fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
275        let fields = dtype.as_struct_fields_opt().unwrap();
276
277        let expr = get_item("y", get_item("a", root()));
278
279        let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
280        assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
281    }
282
283    #[rstest]
284    fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
285        let fields = dtype.as_struct_fields_opt().unwrap();
286
287        let expr = pack(
288            [
289                ("x", get_item("x", get_item("a", root()))),
290                ("y", get_item("y", get_item("a", root()))),
291                ("c", get_item("c", root())),
292            ],
293            NonNullable,
294        );
295        let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
296
297        let split_a = partitioned.find_partition(&"a".into()).unwrap();
298        assert_eq!(
299            &split_a.optimize_recursive(&dtype).unwrap(),
300            &pack(
301                [
302                    ("a_0", get_item("x", get_item("a", root()))),
303                    ("a_1", get_item("y", get_item("a", root())))
304                ],
305                NonNullable
306            )
307        );
308    }
309
310    #[rstest]
311    fn test_expr_top_level_ref_get_item_add(dtype: DType) {
312        let fields = dtype.as_struct_fields_opt().unwrap();
313
314        let expr = and(get_item("y", get_item("a", root())), lit(1));
315        let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
316
317        // Whole expr is a single split
318        assert_eq!(partitioned.partitions.len(), 1);
319    }
320
321    #[rstest]
322    fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
323        let fields = dtype.as_struct_fields_opt().unwrap();
324
325        let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
326        let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
327
328        // One for id.a and id.b
329        assert_eq!(partitioned.partitions.len(), 2);
330    }
331
332    #[rstest]
333    fn test_expr_merge(dtype: DType) {
334        let fields = dtype.as_struct_fields_opt().unwrap();
335
336        let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
337
338        let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
339        let expected = pack(
340            [
341                ("x", get_item("x", get_item("a_0", col("a")))),
342                ("y", get_item("y", get_item("a_0", col("a")))),
343                ("b", get_item("b", get_item("b_0", col("b")))),
344            ],
345            NonNullable,
346        );
347        assert_eq!(
348            &partitioned.root, &expected,
349            "{} {}",
350            partitioned.root, expected
351        );
352
353        assert_eq!(partitioned.partitions.len(), 2);
354
355        let part_a = partitioned.find_partition(&"a".into()).unwrap();
356        let expected_a = pack([("a_0", col("a"))], NonNullable);
357        assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
358
359        let part_b = partitioned.find_partition(&"b".into()).unwrap();
360        let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
361        assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
362    }
363}