vortex_expr/transform/
partition.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, ScopeDType, get_item, is_root, pack, root};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16    LazyLock::new(DefaultHashBuilder::default);
17
18/// Partition an expression over the fields of the scope.
19///
20/// This returns a partitioned expression that can be push-down over each field of the scope.
21/// The results of each partition evaluation can then be recombined to reproduce the result of
22/// the original expression.
23///
24/// ## Note
25///
26/// This function currently respects the validity of each field in the scope, but the not validity
27/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
28/// expression for computing the validity, or to include that expression as part of the root.
29///
30/// See <https://github.com/vortex-data/vortex/issues/1907>.
31///
32// TODO(ngates): document the behaviour of conflicting `Field::Index` and `Field::Name`.
33pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34    if !matches!(dtype, DType::Struct(..)) {
35        vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36    }
37    StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40// TODO(joe): replace with let expressions.
41/// The result of partitioning an expression.
42#[derive(Debug)]
43pub struct PartitionedExpr {
44    /// The root expression used to re-assemble the results.
45    pub root: ExprRef,
46    /// The partitions of the expression.
47    pub partitions: Box<[ExprRef]>,
48    /// The field names for the partitions
49    pub partition_names: FieldNames,
50    /// The return DTypes of each partition.
51    pub partition_dtypes: Box<[DType]>,
52}
53
54impl Display for PartitionedExpr {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        write!(
57            f,
58            "root: {} {{{}}}",
59            self.root,
60            self.partition_names
61                .iter()
62                .zip(self.partitions.iter())
63                .map(|(name, partition)| format!("{name}: {partition}"))
64                .join(", ")
65        )
66    }
67}
68
69impl PartitionedExpr {
70    /// Return the partition for a given field, if it exists.
71    pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
72        self.partition_names
73            .iter()
74            .position(|name| name == field)
75            .map(|idx| &self.partitions[idx])
76    }
77}
78
79#[derive(Debug)]
80struct StructFieldExpressionSplitter<'a> {
81    sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
82    accesses: &'a FieldAccesses<'a>,
83    scope_dtype: &'a StructFields,
84}
85
86impl<'a> StructFieldExpressionSplitter<'a> {
87    fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructFields) -> Self {
88        Self {
89            sub_expressions: HashMap::new(),
90            accesses,
91            scope_dtype,
92        }
93    }
94
95    pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
96        let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
97        field.hash(&mut hasher);
98        idx.hash(&mut hasher);
99        hasher.finish().to_string().into()
100    }
101
102    fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
103        let scope_dtype = match dtype {
104            DType::Struct(scope_dtype, _) => scope_dtype,
105            _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
106        };
107
108        let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
109
110        let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
111
112        let split = expr
113            .clone()
114            .transform_with_context(&mut splitter, ())?
115            .result();
116
117        let mut remove_accesses: Vec<FieldName> = Vec::new();
118
119        // Create partitions which can be passed to layout fields
120        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
121        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
122        let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
123        for (name, exprs) in splitter.sub_expressions.into_iter() {
124            // If there is a single expr then we don't need to `pack` this, and we must update
125            // the root expr removing this access.
126            let expr = if exprs.len() == 1 {
127                remove_accesses.push(Self::field_idx_name(&name, 0));
128                exprs.first().vortex_expect("exprs is non-empty").clone()
129            } else {
130                pack(
131                    exprs
132                        .into_iter()
133                        .enumerate()
134                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
135                    Nullability::NonNullable,
136                )
137            };
138
139            let field_dtype = scope_dtype
140                .field(&name)
141                .ok_or_else(|| vortex_err!("Missing field {name}"))?;
142            let field_ctx = ScopeDType::new(field_dtype);
143            let expr = simplify_typed(expr.clone(), &field_ctx)?;
144            let expr_dtype = expr.return_dtype(&field_ctx)?;
145
146            partitions.push(expr);
147            partition_names.push(name);
148            partition_dtypes.push(expr_dtype);
149        }
150
151        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
152        // Ensure that there are not more accesses than partitions, we missed something
153        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
154        // Ensure that there are as many partitions as there are accesses/fields in the scope,
155        // this will affect performance, not correctness.
156        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
157
158        let split = split
159            .transform(&mut ReplaceAccessesWithChild(remove_accesses))?
160            .into_inner();
161
162        let ctx = ScopeDType::new(DType::Struct(
163            StructFields::new(
164                FieldNames::from(partition_names.clone()),
165                partition_dtypes.clone(),
166            ),
167            Nullability::NonNullable,
168        ));
169
170        Ok(PartitionedExpr {
171            root: simplify_typed(split, &ctx)?,
172            partitions: partitions.into_boxed_slice(),
173            partition_names: partition_names.into(),
174            partition_dtypes: partition_dtypes.into_boxed_slice(),
175        })
176    }
177}
178
179impl FolderMut for StructFieldExpressionSplitter<'_> {
180    type NodeTy = ExprRef;
181    type Out = ExprRef;
182    type Context = ();
183
184    fn visit_down(
185        &mut self,
186        node: &Self::NodeTy,
187        _context: Self::Context,
188    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
189        // If this expression only accesses a single field, then we can skip the children
190        let access = self.accesses.get(node);
191        if access.as_ref().is_some_and(|a| a.len() == 1) {
192            let field_name = access
193                .vortex_expect("access is non-empty")
194                .iter()
195                .next()
196                .vortex_expect("expected one field");
197
198            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
199            let idx = sub_exprs.len();
200
201            // Need to replace get_item(f, ident) with ident, making the expr relative to the child.
202            let replaced = node
203                .clone()
204                .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
205            sub_exprs.push(replaced.into_inner());
206
207            let access = get_item(
208                Self::field_idx_name(field_name, idx),
209                get_item(field_name.clone(), root()),
210            );
211
212            return Ok(FoldDown::SkipChildren(access));
213        };
214
215        // If the expression is an identity, then we need to partition it into the fields of the scope.
216        if is_root(node) {
217            let field_names = self.scope_dtype.names();
218
219            let mut elements = Vec::with_capacity(field_names.len());
220
221            for field_name in field_names.iter() {
222                let sub_exprs = self
223                    .sub_expressions
224                    .entry(field_name.clone())
225                    .or_insert_with(Vec::new);
226
227                let idx = sub_exprs.len();
228
229                sub_exprs.push(root());
230
231                elements.push((
232                    field_name.clone(),
233                    // Partitions are packed into a struct of field name -> occurrence idx -> array
234                    get_item(
235                        Self::field_idx_name(field_name, idx),
236                        get_item(field_name.clone(), root()),
237                    ),
238                ));
239            }
240
241            return Ok(FoldDown::SkipChildren(pack(
242                elements,
243                Nullability::NonNullable,
244            )));
245        }
246
247        // Otherwise, continue traversing.
248        Ok(FoldDown::Continue(()))
249    }
250
251    fn visit_up(
252        &mut self,
253        node: Self::NodeTy,
254        _context: Self::Context,
255        children: Vec<Self::Out>,
256    ) -> VortexResult<FoldUp<Self::Out>> {
257        Ok(FoldUp::Continue(node.replacing_children(children)))
258    }
259}
260
261struct ScopeStepIntoFieldExpr(FieldName);
262
263impl MutNodeVisitor for ScopeStepIntoFieldExpr {
264    type NodeTy = ExprRef;
265
266    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
267        if is_root(&node) {
268            Ok(TransformResult::yes(pack(
269                [(self.0.clone(), root())],
270                Nullability::NonNullable,
271            )))
272        } else {
273            Ok(TransformResult::no(node))
274        }
275    }
276}
277
278pub(crate) struct ReplaceAccessesWithChild(Vec<FieldName>);
279
280impl ReplaceAccessesWithChild {
281    pub(crate) fn new(field_names: Vec<FieldName>) -> Self {
282        Self(field_names)
283    }
284}
285
286impl MutNodeVisitor for ReplaceAccessesWithChild {
287    type NodeTy = ExprRef;
288
289    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
290        if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
291            if self.0.contains(item.field()) {
292                return Ok(TransformResult::yes(item.child().clone()));
293            }
294        }
295        Ok(TransformResult::no(node))
296    }
297}
298
299#[cfg(test)]
300mod tests {
301
302    use vortex_dtype::Nullability::NonNullable;
303    use vortex_dtype::PType::I32;
304    use vortex_dtype::{DType, StructFields};
305    use vortex_utils::aliases::hash_set::HashSet;
306
307    use super::*;
308    use crate::transform::simplify::simplify;
309    use crate::transform::simplify_typed::simplify_typed;
310    use crate::{Pack, and, col, get_item, lit, merge, pack, root, select};
311
312    fn dtype() -> DType {
313        DType::Struct(
314            StructFields::from_iter([
315                (
316                    "a",
317                    DType::Struct(
318                        StructFields::from_iter([("a", I32.into()), ("b", DType::from(I32))]),
319                        NonNullable,
320                    ),
321                ),
322                ("b", I32.into()),
323                ("c", I32.into()),
324            ]),
325            NonNullable,
326        )
327    }
328
329    #[test]
330    fn test_expr_top_level_ref() {
331        let dtype = dtype();
332
333        let expr = root();
334
335        let split = StructFieldExpressionSplitter::split(expr, &dtype);
336
337        assert!(split.is_ok());
338
339        let partitioned = split.unwrap();
340
341        assert!(partitioned.root.as_any().is::<Pack>());
342        // Have a single top level pack with all fields in dtype
343        assert_eq!(
344            partitioned.partitions.len(),
345            dtype.as_struct().unwrap().names().len()
346        )
347    }
348
349    #[test]
350    fn test_expr_top_level_ref_get_item_and_split() {
351        let dtype = dtype();
352
353        let expr = get_item("b", get_item("a", root()));
354
355        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
356        let split_a = partitioned.find_partition(&"a".into());
357        assert!(split_a.is_some());
358        let split_a = split_a.unwrap();
359
360        assert_eq!(&partitioned.root, &get_item("a", root()));
361        assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", root()));
362    }
363
364    #[test]
365    fn test_expr_top_level_ref_get_item_and_split_pack() {
366        let dtype = dtype();
367
368        let expr = pack(
369            [
370                ("a", get_item("a", get_item("a", root()))),
371                ("b", get_item("b", get_item("a", root()))),
372                ("c", get_item("c", root())),
373            ],
374            NonNullable,
375        );
376        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
377
378        let split_a = partitioned.find_partition(&"a".into()).unwrap();
379        assert_eq!(
380            &simplify(split_a.clone()).unwrap(),
381            &pack(
382                [
383                    (
384                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
385                        get_item("a", root())
386                    ),
387                    (
388                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
389                        get_item("b", root())
390                    )
391                ],
392                NonNullable
393            )
394        );
395        let split_c = partitioned.find_partition(&"c".into()).unwrap();
396        assert_eq!(&simplify(split_c.clone()).unwrap(), &root())
397    }
398
399    #[test]
400    fn test_expr_top_level_ref_get_item_add() {
401        let dtype = dtype();
402
403        let expr = and(get_item("b", get_item("a", root())), lit(1));
404        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
405
406        // Whole expr is a single split
407        assert_eq!(partitioned.partitions.len(), 1);
408    }
409
410    #[test]
411    fn test_expr_top_level_ref_get_item_add_cannot_split() {
412        let dtype = dtype();
413
414        let expr = and(get_item("b", get_item("a", root())), get_item("b", root()));
415        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
416
417        // One for id.a and id.b
418        assert_eq!(partitioned.partitions.len(), 2);
419    }
420
421    // Test that typed_simplify removes select and partition precise
422    #[test]
423    fn test_expr_partition_many_occurrences_of_field() {
424        let dtype = dtype();
425
426        let expr = and(
427            get_item("b", get_item("a", root())),
428            select(vec!["a".into(), "b".into()], root()),
429        );
430        let expr = simplify_typed(expr, &ScopeDType::new(dtype.clone())).unwrap();
431        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
432
433        // One for id.a and id.b
434        assert_eq!(partitioned.partitions.len(), 2);
435
436        // This fetches [].$c which is unused, however a previous optimisation should replace select
437        // with get_item and pack removing this field.
438        assert_eq!(
439            &partitioned.root,
440            &and(
441                get_item(
442                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
443                    get_item("a", root())
444                ),
445                pack(
446                    [
447                        (
448                            "a",
449                            get_item(
450                                StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
451                                get_item("a", root())
452                            )
453                        ),
454                        ("b", get_item("b", root()))
455                    ],
456                    NonNullable
457                )
458            )
459        )
460    }
461
462    #[test]
463    fn test_expr_merge() {
464        let dtype = dtype();
465
466        let expr = merge(
467            [col("a"), pack([("b", col("b"))], NonNullable)],
468            NonNullable,
469        );
470
471        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
472        let expected = pack(
473            [
474                ("a", get_item("a", col("a"))),
475                ("b", get_item("b", col("b"))),
476            ],
477            NonNullable,
478        );
479        assert_eq!(
480            &partitioned.root, &expected,
481            "{} {}",
482            partitioned.root, expected
483        );
484        let expected = [root(), pack([("b", root())], NonNullable)]
485            .into_iter()
486            .collect::<HashSet<_>>();
487        assert_eq!(
488            &partitioned
489                .partitions
490                .clone()
491                .into_iter()
492                .collect::<HashSet<_>>(),
493            &expected,
494            "{} {}",
495            partitioned.partitions.iter().join(";"),
496            expected.iter().join(";")
497        );
498    }
499}