vortex_expr/transform/
partition.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, ScopeDType, get_item, is_root, pack, root};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16    LazyLock::new(DefaultHashBuilder::default);
17
18/// Partition an expression over the fields of the scope.
19///
20/// This returns a partitioned expression that can be push-down over each field of the scope.
21/// The results of each partition evaluation can then be recombined to reproduce the result of
22/// the original expression.
23///
24/// ## Note
25///
26/// This function currently respects the validity of each field in the scope, but the not validity
27/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
28/// expression for computing the validity, or to include that expression as part of the root.
29///
30/// See <https://github.com/vortex-data/vortex/issues/1907>.
31///
32// TODO(ngates): document the behaviour of conflicting `Field::Index` and `Field::Name`.
33pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34    if !matches!(dtype, DType::Struct(..)) {
35        vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36    }
37    StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40// TODO(joe): replace with let expressions.
41/// The result of partitioning an expression.
42#[derive(Debug)]
43pub struct PartitionedExpr {
44    /// The root expression used to re-assemble the results.
45    pub root: ExprRef,
46    /// The partitions of the expression.
47    pub partitions: Box<[ExprRef]>,
48    /// The field names for the partitions
49    pub partition_names: FieldNames,
50    /// The return DTypes of each partition.
51    pub partition_dtypes: Box<[DType]>,
52}
53
54impl Display for PartitionedExpr {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        write!(
57            f,
58            "root: {} {{{}}}",
59            self.root,
60            self.partition_names
61                .iter()
62                .zip(self.partitions.iter())
63                .map(|(name, partition)| format!("{name}: {partition}"))
64                .join(", ")
65        )
66    }
67}
68
69impl PartitionedExpr {
70    /// Return the partition for a given field, if it exists.
71    pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
72        self.partition_names
73            .iter()
74            .position(|name| name == field)
75            .map(|idx| &self.partitions[idx])
76    }
77}
78
79#[derive(Debug)]
80struct StructFieldExpressionSplitter<'a> {
81    sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
82    accesses: &'a FieldAccesses<'a>,
83    scope_dtype: &'a StructFields,
84}
85
86impl<'a> StructFieldExpressionSplitter<'a> {
87    fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructFields) -> Self {
88        Self {
89            sub_expressions: HashMap::new(),
90            accesses,
91            scope_dtype,
92        }
93    }
94
95    pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
96        let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
97        field.hash(&mut hasher);
98        idx.hash(&mut hasher);
99        hasher.finish().to_string().into()
100    }
101
102    fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
103        let scope_dtype = match dtype {
104            DType::Struct(scope_dtype, _) => scope_dtype,
105            _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
106        };
107
108        let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
109
110        let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
111
112        let split = expr.clone().transform_with_context(&mut splitter, ())?;
113
114        let mut remove_accesses: Vec<FieldName> = Vec::new();
115
116        // Create partitions which can be passed to layout fields
117        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
118        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
119        let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
120        for (name, exprs) in splitter.sub_expressions.into_iter() {
121            // If there is a single expr then we don't need to `pack` this, and we must update
122            // the root expr removing this access.
123            let expr = if exprs.len() == 1 {
124                remove_accesses.push(Self::field_idx_name(&name, 0));
125                exprs.first().vortex_expect("exprs is non-empty").clone()
126            } else {
127                pack(
128                    exprs
129                        .into_iter()
130                        .enumerate()
131                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
132                    Nullability::NonNullable,
133                )
134            };
135
136            let field_dtype = scope_dtype
137                .field(&name)
138                .ok_or_else(|| vortex_err!("Missing field {name}"))?;
139            let field_ctx = ScopeDType::new(field_dtype);
140            let expr = simplify_typed(expr.clone(), &field_ctx)?;
141            let expr_dtype = expr.return_dtype(&field_ctx)?;
142
143            partitions.push(expr);
144            partition_names.push(name);
145            partition_dtypes.push(expr_dtype);
146        }
147
148        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
149        // Ensure that there are not more accesses than partitions, we missed something
150        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
151        // Ensure that there are as many partitions as there are accesses/fields in the scope,
152        // this will affect performance, not correctness.
153        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
154
155        let split = split
156            .result()
157            .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
158
159        let ctx = ScopeDType::new(dtype.clone());
160
161        Ok(PartitionedExpr {
162            root: simplify_typed(split.into_inner(), &ctx)?,
163            partitions: partitions.into_boxed_slice(),
164            partition_names: partition_names.into(),
165            partition_dtypes: partition_dtypes.into_boxed_slice(),
166        })
167    }
168}
169
170impl FolderMut for StructFieldExpressionSplitter<'_> {
171    type NodeTy = ExprRef;
172    type Out = ExprRef;
173    type Context = ();
174
175    fn visit_down(
176        &mut self,
177        node: &Self::NodeTy,
178        _context: Self::Context,
179    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
180        // If this expression only accesses a single field, then we can skip the children
181        let access = self.accesses.get(node);
182        if access.as_ref().is_some_and(|a| a.len() == 1) {
183            let field_name = access
184                .vortex_expect("access is non-empty")
185                .iter()
186                .next()
187                .vortex_expect("expected one field");
188
189            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
190            let idx = sub_exprs.len();
191
192            // Need to replace get_item(f, ident) with ident, making the expr relative to the child.
193            let replaced = node
194                .clone()
195                .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
196            sub_exprs.push(replaced.into_inner());
197
198            let access = get_item(
199                Self::field_idx_name(field_name, idx),
200                get_item(field_name.clone(), root()),
201            );
202
203            return Ok(FoldDown::SkipChildren(access));
204        };
205
206        // If the expression is an identity, then we need to partition it into the fields of the scope.
207        if is_root(node) {
208            let field_names = self.scope_dtype.names();
209
210            let mut elements = Vec::with_capacity(field_names.len());
211
212            for field_name in field_names.iter() {
213                let sub_exprs = self
214                    .sub_expressions
215                    .entry(field_name.clone())
216                    .or_insert_with(Vec::new);
217
218                let idx = sub_exprs.len();
219
220                sub_exprs.push(root());
221
222                elements.push((
223                    field_name.clone(),
224                    // Partitions are packed into a struct of field name -> occurrence idx -> array
225                    get_item(
226                        Self::field_idx_name(field_name, idx),
227                        get_item(field_name.clone(), root()),
228                    ),
229                ));
230            }
231
232            return Ok(FoldDown::SkipChildren(pack(
233                elements,
234                Nullability::NonNullable,
235            )));
236        }
237
238        // Otherwise, continue traversing.
239        Ok(FoldDown::Continue(()))
240    }
241
242    fn visit_up(
243        &mut self,
244        node: Self::NodeTy,
245        _context: Self::Context,
246        children: Vec<Self::Out>,
247    ) -> VortexResult<FoldUp<Self::Out>> {
248        Ok(FoldUp::Continue(node.replacing_children(children)))
249    }
250}
251
252struct ScopeStepIntoFieldExpr(FieldName);
253
254impl MutNodeVisitor for ScopeStepIntoFieldExpr {
255    type NodeTy = ExprRef;
256
257    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
258        if is_root(&node) {
259            Ok(TransformResult::yes(pack(
260                [(self.0.clone(), root())],
261                Nullability::NonNullable,
262            )))
263        } else {
264            Ok(TransformResult::no(node))
265        }
266    }
267}
268
269pub(crate) struct ReplaceAccessesWithChild(Vec<FieldName>);
270
271impl ReplaceAccessesWithChild {
272    pub(crate) fn new(field_names: Vec<FieldName>) -> Self {
273        Self(field_names)
274    }
275}
276
277impl MutNodeVisitor for ReplaceAccessesWithChild {
278    type NodeTy = ExprRef;
279
280    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
281        if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
282            if self.0.contains(item.field()) {
283                return Ok(TransformResult::yes(item.child().clone()));
284            }
285        }
286        Ok(TransformResult::no(node))
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use std::sync::Arc;
293
294    use vortex_dtype::Nullability::NonNullable;
295    use vortex_dtype::PType::I32;
296    use vortex_dtype::{DType, StructFields};
297
298    use super::*;
299    use crate::transform::simplify::simplify;
300    use crate::transform::simplify_typed::simplify_typed;
301    use crate::{Pack, and, get_item, lit, pack, root, select};
302
303    fn dtype() -> DType {
304        DType::Struct(
305            Arc::new(StructFields::from_iter([
306                (
307                    "a",
308                    DType::Struct(
309                        Arc::new(StructFields::from_iter([
310                            ("a", I32.into()),
311                            ("b", DType::from(I32)),
312                        ])),
313                        NonNullable,
314                    ),
315                ),
316                ("b", I32.into()),
317                ("c", I32.into()),
318            ])),
319            NonNullable,
320        )
321    }
322
323    #[test]
324    fn test_expr_top_level_ref() {
325        let dtype = dtype();
326
327        let expr = root();
328
329        let split = StructFieldExpressionSplitter::split(expr, &dtype);
330
331        assert!(split.is_ok());
332
333        let partitioned = split.unwrap();
334
335        assert!(partitioned.root.as_any().is::<Pack>());
336        // Have a single top level pack with all fields in dtype
337        assert_eq!(
338            partitioned.partitions.len(),
339            dtype.as_struct().unwrap().names().len()
340        )
341    }
342
343    #[test]
344    fn test_expr_top_level_ref_get_item_and_split() {
345        let dtype = dtype();
346
347        let expr = get_item("b", get_item("a", root()));
348
349        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
350        let split_a = partitioned.find_partition(&"a".into());
351        assert!(split_a.is_some());
352        let split_a = split_a.unwrap();
353
354        assert_eq!(&partitioned.root, &get_item("a", root()));
355        assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", root()));
356    }
357
358    #[test]
359    fn test_expr_top_level_ref_get_item_and_split_pack() {
360        let dtype = dtype();
361
362        let expr = pack(
363            [
364                ("a", get_item("a", get_item("a", root()))),
365                ("b", get_item("b", get_item("a", root()))),
366                ("c", get_item("c", root())),
367            ],
368            NonNullable,
369        );
370        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
371
372        let split_a = partitioned.find_partition(&"a".into()).unwrap();
373        assert_eq!(
374            &simplify(split_a.clone()).unwrap(),
375            &pack(
376                [
377                    (
378                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
379                        get_item("a", root())
380                    ),
381                    (
382                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
383                        get_item("b", root())
384                    )
385                ],
386                NonNullable
387            )
388        );
389        let split_c = partitioned.find_partition(&"c".into()).unwrap();
390        assert_eq!(&simplify(split_c.clone()).unwrap(), &root())
391    }
392
393    #[test]
394    fn test_expr_top_level_ref_get_item_add() {
395        let dtype = dtype();
396
397        let expr = and(get_item("b", get_item("a", root())), lit(1));
398        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
399
400        // Whole expr is a single split
401        assert_eq!(partitioned.partitions.len(), 1);
402    }
403
404    #[test]
405    fn test_expr_top_level_ref_get_item_add_cannot_split() {
406        let dtype = dtype();
407
408        let expr = and(get_item("b", get_item("a", root())), get_item("b", root()));
409        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
410
411        // One for id.a and id.b
412        assert_eq!(partitioned.partitions.len(), 2);
413    }
414
415    // Test that typed_simplify removes select and partition precise
416    #[test]
417    fn test_expr_partition_many_occurrences_of_field() {
418        let dtype = dtype();
419
420        let expr = and(
421            get_item("b", get_item("a", root())),
422            select(vec!["a".into(), "b".into()], root()),
423        );
424        let expr = simplify_typed(expr, &ScopeDType::new(dtype.clone())).unwrap();
425        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
426
427        // One for id.a and id.b
428        assert_eq!(partitioned.partitions.len(), 2);
429
430        // This fetches [].$c which is unused, however a previous optimisation should replace select
431        // with get_item and pack removing this field.
432        assert_eq!(
433            &partitioned.root,
434            &and(
435                get_item(
436                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
437                    get_item("a", root())
438                ),
439                pack(
440                    [
441                        (
442                            "a",
443                            get_item(
444                                StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
445                                get_item("a", root())
446                            )
447                        ),
448                        ("b", get_item("b", root()))
449                    ],
450                    NonNullable
451                )
452            )
453        )
454    }
455}