vortex_expr/transform/
partition.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, ScopeDType, get_item, is_root, pack, root};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16    LazyLock::new(DefaultHashBuilder::default);
17
18/// Partition an expression over the fields of the scope.
19///
20/// This returns a partitioned expression that can be push-down over each field of the scope.
21/// The results of each partition evaluation can then be recombined to reproduce the result of
22/// the original expression.
23///
24/// ## Note
25///
26/// This function currently respects the validity of each field in the scope, but the not validity
27/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
28/// expression for computing the validity, or to include that expression as part of the root.
29///
30/// See <https://github.com/vortex-data/vortex/issues/1907>.
31///
32// TODO(ngates): document the behaviour of conflicting `Field::Index` and `Field::Name`.
33pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34    if !matches!(dtype, DType::Struct(..)) {
35        vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36    }
37    StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40// TODO(joe): replace with let expressions.
41/// The result of partitioning an expression.
42#[derive(Debug)]
43pub struct PartitionedExpr {
44    /// The root expression used to re-assemble the results.
45    pub root: ExprRef,
46    /// The partitions of the expression.
47    pub partitions: Box<[ExprRef]>,
48    /// The field names for the partitions
49    pub partition_names: FieldNames,
50    /// The return DTypes of each partition.
51    pub partition_dtypes: Box<[DType]>,
52}
53
54impl Display for PartitionedExpr {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        write!(
57            f,
58            "root: {} {{{}}}",
59            self.root,
60            self.partition_names
61                .iter()
62                .zip(self.partitions.iter())
63                .map(|(name, partition)| format!("{name}: {partition}"))
64                .join(", ")
65        )
66    }
67}
68
69impl PartitionedExpr {
70    /// Return the partition for a given field, if it exists.
71    pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
72        self.partition_names
73            .iter()
74            .position(|name| name == field)
75            .map(|idx| &self.partitions[idx])
76    }
77}
78
79#[derive(Debug)]
80struct StructFieldExpressionSplitter<'a> {
81    sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
82    accesses: &'a FieldAccesses<'a>,
83    scope_dtype: &'a StructFields,
84}
85
86impl<'a> StructFieldExpressionSplitter<'a> {
87    fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructFields) -> Self {
88        Self {
89            sub_expressions: HashMap::new(),
90            accesses,
91            scope_dtype,
92        }
93    }
94
95    pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
96        let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
97        field.hash(&mut hasher);
98        idx.hash(&mut hasher);
99        hasher.finish().to_string().into()
100    }
101
102    fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
103        let scope_dtype = match dtype {
104            DType::Struct(scope_dtype, _) => scope_dtype,
105            _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
106        };
107
108        let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
109
110        let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
111
112        let split = expr.clone().transform_with_context(&mut splitter, ())?;
113
114        let mut remove_accesses: Vec<FieldName> = Vec::new();
115
116        // Create partitions which can be passed to layout fields
117        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
118        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
119        let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
120        for (name, exprs) in splitter.sub_expressions.into_iter() {
121            let field_dtype = scope_dtype.field(&name)?;
122            // If there is a single expr then we don't need to `pack` this, and we must update
123            // the root expr removing this access.
124            let expr = if exprs.len() == 1 {
125                remove_accesses.push(Self::field_idx_name(&name, 0));
126                exprs.first().vortex_expect("exprs is non-empty").clone()
127            } else {
128                pack(
129                    exprs
130                        .into_iter()
131                        .enumerate()
132                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
133                    Nullability::NonNullable,
134                )
135            };
136
137            let field_ctx = ScopeDType::new(field_dtype);
138            let expr = simplify_typed(expr.clone(), &field_ctx)?;
139            let expr_dtype = expr.return_dtype(&field_ctx)?;
140
141            partitions.push(expr);
142            partition_names.push(name);
143            partition_dtypes.push(expr_dtype);
144        }
145
146        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
147        // Ensure that there are not more accesses than partitions, we missed something
148        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
149        // Ensure that there are as many partitions as there are accesses/fields in the scope,
150        // this will affect performance, not correctness.
151        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
152
153        let split = split
154            .result()
155            .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
156
157        let ctx = ScopeDType::new(dtype.clone());
158
159        Ok(PartitionedExpr {
160            root: simplify_typed(split.into_inner(), &ctx)?,
161            partitions: partitions.into_boxed_slice(),
162            partition_names: partition_names.into(),
163            partition_dtypes: partition_dtypes.into_boxed_slice(),
164        })
165    }
166}
167
168impl FolderMut for StructFieldExpressionSplitter<'_> {
169    type NodeTy = ExprRef;
170    type Out = ExprRef;
171    type Context = ();
172
173    fn visit_down(
174        &mut self,
175        node: &Self::NodeTy,
176        _context: Self::Context,
177    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
178        // If this expression only accesses a single field, then we can skip the children
179        let access = self.accesses.get(node);
180        if access.as_ref().is_some_and(|a| a.len() == 1) {
181            let field_name = access
182                .vortex_expect("access is non-empty")
183                .iter()
184                .next()
185                .vortex_expect("expected one field");
186
187            // TODO(joe): dedup the sub_expression, if there are two expressions that are the same
188            // only create one entry here and reuse it.
189            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
190            let idx = sub_exprs.len();
191
192            // Need to replace get_item(f, ident) with ident, making the expr relative to the child.
193            let replaced = node
194                .clone()
195                .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
196            sub_exprs.push(replaced.into_inner());
197
198            let access = get_item(
199                Self::field_idx_name(field_name, idx),
200                get_item(field_name.clone(), root()),
201            );
202
203            return Ok(FoldDown::SkipChildren(access));
204        };
205
206        // If the expression is an identity, then we need to partition it into the fields of the scope.
207        if is_root(node) {
208            let field_names = self.scope_dtype.names();
209
210            let mut elements = Vec::with_capacity(field_names.len());
211
212            for field_name in field_names.iter() {
213                let sub_exprs = self
214                    .sub_expressions
215                    .entry(field_name.clone())
216                    .or_insert_with(Vec::new);
217
218                let idx = sub_exprs.len();
219
220                sub_exprs.push(root());
221
222                elements.push((
223                    field_name.clone(),
224                    // Partitions are packed into a struct of field name -> occurrence idx -> array
225                    get_item(
226                        Self::field_idx_name(field_name, idx),
227                        get_item(field_name.clone(), root()),
228                    ),
229                ));
230            }
231
232            return Ok(FoldDown::SkipChildren(pack(
233                elements,
234                Nullability::NonNullable,
235            )));
236        }
237
238        // Otherwise, continue traversing.
239        Ok(FoldDown::Continue(()))
240    }
241
242    fn visit_up(
243        &mut self,
244        node: Self::NodeTy,
245        _context: Self::Context,
246        children: Vec<Self::Out>,
247    ) -> VortexResult<FoldUp<Self::Out>> {
248        Ok(FoldUp::Continue(node.replacing_children(children)))
249    }
250}
251
252struct ScopeStepIntoFieldExpr(FieldName);
253
254impl MutNodeVisitor for ScopeStepIntoFieldExpr {
255    type NodeTy = ExprRef;
256
257    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
258        if is_root(&node) {
259            Ok(TransformResult::yes(pack(
260                [(self.0.clone(), root())],
261                Nullability::NonNullable,
262            )))
263        } else {
264            Ok(TransformResult::no(node))
265        }
266    }
267}
268
269struct ReplaceAccessesWithChild(Vec<FieldName>);
270
271impl MutNodeVisitor for ReplaceAccessesWithChild {
272    type NodeTy = ExprRef;
273
274    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
275        if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
276            if self.0.contains(item.field()) {
277                return Ok(TransformResult::yes(item.child().clone()));
278            }
279        }
280        Ok(TransformResult::no(node))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use std::sync::Arc;
287
288    use vortex_dtype::Nullability::NonNullable;
289    use vortex_dtype::PType::I32;
290    use vortex_dtype::{DType, StructFields};
291
292    use super::*;
293    use crate::transform::simplify::simplify;
294    use crate::transform::simplify_typed::simplify_typed;
295    use crate::{Pack, and, get_item, lit, pack, root, select};
296
297    fn dtype() -> DType {
298        DType::Struct(
299            Arc::new(StructFields::from_iter([
300                (
301                    "a",
302                    DType::Struct(
303                        Arc::new(StructFields::from_iter([
304                            ("a", I32.into()),
305                            ("b", DType::from(I32)),
306                        ])),
307                        NonNullable,
308                    ),
309                ),
310                ("b", I32.into()),
311                ("c", I32.into()),
312            ])),
313            NonNullable,
314        )
315    }
316
317    #[test]
318    fn test_expr_top_level_ref() {
319        let dtype = dtype();
320
321        let expr = root();
322
323        let split = StructFieldExpressionSplitter::split(expr, &dtype);
324
325        assert!(split.is_ok());
326
327        let partitioned = split.unwrap();
328
329        assert!(partitioned.root.as_any().is::<Pack>());
330        // Have a single top level pack with all fields in dtype
331        assert_eq!(
332            partitioned.partitions.len(),
333            dtype.as_struct().unwrap().names().len()
334        )
335    }
336
337    #[test]
338    fn test_expr_top_level_ref_get_item_and_split() {
339        let dtype = dtype();
340
341        let expr = get_item("b", get_item("a", root()));
342
343        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
344        let split_a = partitioned.find_partition(&"a".into());
345        assert!(split_a.is_some());
346        let split_a = split_a.unwrap();
347
348        assert_eq!(&partitioned.root, &get_item("a", root()));
349        assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", root()));
350    }
351
352    #[test]
353    fn test_expr_top_level_ref_get_item_and_split_pack() {
354        let dtype = dtype();
355
356        let expr = pack(
357            [
358                ("a", get_item("a", get_item("a", root()))),
359                ("b", get_item("b", get_item("a", root()))),
360                ("c", get_item("c", root())),
361            ],
362            NonNullable,
363        );
364        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
365
366        let split_a = partitioned.find_partition(&"a".into()).unwrap();
367        assert_eq!(
368            &simplify(split_a.clone()).unwrap(),
369            &pack(
370                [
371                    (
372                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
373                        get_item("a", root())
374                    ),
375                    (
376                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
377                        get_item("b", root())
378                    )
379                ],
380                NonNullable
381            )
382        );
383        let split_c = partitioned.find_partition(&"c".into()).unwrap();
384        assert_eq!(&simplify(split_c.clone()).unwrap(), &root())
385    }
386
387    #[test]
388    fn test_expr_top_level_ref_get_item_add() {
389        let dtype = dtype();
390
391        let expr = and(get_item("b", get_item("a", root())), lit(1));
392        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
393
394        // Whole expr is a single split
395        assert_eq!(partitioned.partitions.len(), 1);
396    }
397
398    #[test]
399    fn test_expr_top_level_ref_get_item_add_cannot_split() {
400        let dtype = dtype();
401
402        let expr = and(get_item("b", get_item("a", root())), get_item("b", root()));
403        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
404
405        // One for id.a and id.b
406        assert_eq!(partitioned.partitions.len(), 2);
407    }
408
409    // Test that typed_simplify removes select and partition precise
410    #[test]
411    fn test_expr_partition_many_occurrences_of_field() {
412        let dtype = dtype();
413
414        let expr = and(
415            get_item("b", get_item("a", root())),
416            select(vec!["a".into(), "b".into()], root()),
417        );
418        let expr = simplify_typed(expr, &ScopeDType::new(dtype.clone())).unwrap();
419        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
420
421        // One for id.a and id.b
422        assert_eq!(partitioned.partitions.len(), 2);
423
424        // This fetches [].$c which is unused, however a previous optimisation should replace select
425        // with get_item and pack removing this field.
426        assert_eq!(
427            &partitioned.root,
428            &and(
429                get_item(
430                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
431                    get_item("a", root())
432                ),
433                pack(
434                    [
435                        (
436                            "a",
437                            get_item(
438                                StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
439                                get_item("a", root())
440                            )
441                        ),
442                        ("b", get_item("b", root()))
443                    ],
444                    NonNullable
445                )
446            )
447        )
448    }
449}