vortex_expr/transform/
partition.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_array::aliases::hash_map::{DefaultHashBuilder, HashMap};
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructDType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16    LazyLock::new(DefaultHashBuilder::default);
17
18/// Partition an expression over the fields of the scope.
19///
20/// This returns a partitioned expression that can be push-down over each field of the scope.
21/// The results of each partition can then be recombined to reproduce the result of the original
22/// expression.
23///
24/// ## Note
25///
26/// This function currently respects the validity of each field in the scope, but the not validity
27/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
28/// expression for computing the validity, or to include that expression as part of the root.
29///
30/// See <https://github.com/vortex-data/vortex/issues/1907>.
31///
32// TODO(ngates): document the behaviour of conflicting `Field::Index` and `Field::Name`.
33pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34    if !matches!(dtype, DType::Struct(..)) {
35        vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36    }
37    StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40/// The result of partitioning an expression.
41#[derive(Debug)]
42pub struct PartitionedExpr {
43    /// The root expression used to re-assemble the results.
44    pub root: ExprRef,
45    /// The partitions of the expression.
46    pub partitions: Box<[ExprRef]>,
47    /// The field names for the partitions
48    pub partition_names: FieldNames,
49    /// The return DTypes of each partition.
50    pub partition_dtypes: Box<[DType]>,
51}
52
53impl Display for PartitionedExpr {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        write!(
56            f,
57            "root: {} {{{}}}",
58            self.root,
59            self.partition_names
60                .iter()
61                .zip(self.partitions.iter())
62                .map(|(name, partition)| format!("{name}: {partition}"))
63                .join(", ")
64        )
65    }
66}
67
68impl PartitionedExpr {
69    /// Return the partition for a given field, if it exists.
70    pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
71        self.partition_names
72            .iter()
73            .position(|name| name == field)
74            .map(|idx| &self.partitions[idx])
75    }
76}
77
78#[derive(Debug)]
79struct StructFieldExpressionSplitter<'a> {
80    sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
81    accesses: &'a FieldAccesses<'a>,
82    scope_dtype: &'a StructDType,
83}
84
85impl<'a> StructFieldExpressionSplitter<'a> {
86    fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
87        Self {
88            sub_expressions: HashMap::new(),
89            accesses,
90            scope_dtype,
91        }
92    }
93
94    pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
95        let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
96        field.hash(&mut hasher);
97        idx.hash(&mut hasher);
98        hasher.finish().to_string().into()
99    }
100
101    fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
102        let scope_dtype = match dtype {
103            DType::Struct(scope_dtype, _) => scope_dtype,
104            _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
105        };
106
107        let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
108
109        let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
110
111        let split = expr.clone().transform_with_context(&mut splitter, ())?;
112
113        let mut remove_accesses: Vec<FieldName> = Vec::new();
114
115        // Create partitions which can be passed to layout fields
116        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
117        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
118        let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
119        for (name, exprs) in splitter.sub_expressions.into_iter() {
120            let field_dtype = scope_dtype.field(&name)?;
121            // If there is a single expr then we don't need to `pack` this, and we must update
122            // the root expr removing this access.
123            let expr = if exprs.len() == 1 {
124                remove_accesses.push(Self::field_idx_name(&name, 0));
125                exprs.first().vortex_expect("exprs is non-empty").clone()
126            } else {
127                pack(
128                    exprs
129                        .into_iter()
130                        .enumerate()
131                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
132                    Nullability::NonNullable,
133                )
134            };
135
136            let expr = simplify_typed(expr.clone(), &field_dtype)?;
137            let expr_dtype = expr.return_dtype(&field_dtype)?;
138
139            partitions.push(expr);
140            partition_names.push(name);
141            partition_dtypes.push(expr_dtype);
142        }
143
144        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
145        // Ensure that there are not more accesses than partitions, we missed something
146        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
147        // Ensure that there are as many partitions as there are accesses/fields in the scope,
148        // this will affect performance, not correctness.
149        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
150
151        let split = split
152            .result()
153            .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
154
155        Ok(PartitionedExpr {
156            root: simplify_typed(split.result, dtype)?,
157            partitions: partitions.into_boxed_slice(),
158            partition_names: partition_names.into(),
159            partition_dtypes: partition_dtypes.into_boxed_slice(),
160        })
161    }
162}
163
164impl FolderMut for StructFieldExpressionSplitter<'_> {
165    type NodeTy = ExprRef;
166    type Out = ExprRef;
167    type Context = ();
168
169    fn visit_down(
170        &mut self,
171        node: &Self::NodeTy,
172        _context: Self::Context,
173    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
174        // If this expression only accesses a single field, then we can skip the children
175        let access = self.accesses.get(node);
176        if access.as_ref().is_some_and(|a| a.len() == 1) {
177            let field_name = access
178                .vortex_expect("access is non-empty")
179                .iter()
180                .next()
181                .vortex_expect("expected one field");
182
183            // TODO(joe): dedup the sub_expression, if there are two expressions that are the same
184            // only create one entry here and reuse it.
185            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
186            let idx = sub_exprs.len();
187
188            // Need to replace get_item(f, ident) with ident, making the expr relative to the child.
189            let replaced = node
190                .clone()
191                .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
192            sub_exprs.push(replaced.result);
193
194            let access = get_item(
195                Self::field_idx_name(field_name, idx),
196                get_item(field_name.clone(), ident()),
197            );
198
199            return Ok(FoldDown::SkipChildren(access));
200        };
201
202        // If the expression is an identity, then we need to partition it into the fields of the scope.
203        if node.as_any().is::<Identity>() {
204            let field_names = self.scope_dtype.names();
205
206            let mut elements = Vec::with_capacity(field_names.len());
207
208            for field_name in field_names.iter() {
209                let sub_exprs = self
210                    .sub_expressions
211                    .entry(field_name.clone())
212                    .or_insert_with(Vec::new);
213
214                let idx = sub_exprs.len();
215
216                sub_exprs.push(ident());
217
218                elements.push((
219                    field_name.clone(),
220                    // Partitions are packed into a struct of field name -> occurrence idx -> array
221                    get_item(
222                        Self::field_idx_name(field_name, idx),
223                        get_item(field_name.clone(), ident()),
224                    ),
225                ));
226            }
227
228            return Ok(FoldDown::SkipChildren(pack(
229                elements,
230                Nullability::NonNullable,
231            )));
232        }
233
234        // Otherwise, continue traversing.
235        Ok(FoldDown::Continue(()))
236    }
237
238    fn visit_up(
239        &mut self,
240        node: Self::NodeTy,
241        _context: Self::Context,
242        children: Vec<Self::Out>,
243    ) -> VortexResult<FoldUp<Self::Out>> {
244        Ok(FoldUp::Continue(node.replacing_children(children)))
245    }
246}
247
248struct ScopeStepIntoFieldExpr(FieldName);
249
250impl MutNodeVisitor for ScopeStepIntoFieldExpr {
251    type NodeTy = ExprRef;
252
253    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
254        if node.as_any().is::<Identity>() {
255            Ok(TransformResult::yes(pack(
256                [(self.0.clone(), ident())],
257                Nullability::NonNullable,
258            )))
259        } else {
260            Ok(TransformResult::no(node))
261        }
262    }
263}
264
265struct ReplaceAccessesWithChild(Vec<FieldName>);
266
267impl MutNodeVisitor for ReplaceAccessesWithChild {
268    type NodeTy = ExprRef;
269
270    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
271        if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
272            if self.0.contains(item.field()) {
273                return Ok(TransformResult::yes(item.child().clone()));
274            }
275        }
276        Ok(TransformResult::no(node))
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use std::sync::Arc;
283
284    use vortex_dtype::Nullability::NonNullable;
285    use vortex_dtype::PType::I32;
286    use vortex_dtype::{DType, StructDType};
287
288    use super::*;
289    use crate::transform::simplify::simplify;
290    use crate::transform::simplify_typed::simplify_typed;
291    use crate::{Pack, and, get_item, ident, lit, pack, select};
292
293    fn dtype() -> DType {
294        DType::Struct(
295            Arc::new(StructDType::from_iter([
296                (
297                    "a",
298                    DType::Struct(
299                        Arc::new(StructDType::from_iter([
300                            ("a", I32.into()),
301                            ("b", DType::from(I32)),
302                        ])),
303                        NonNullable,
304                    ),
305                ),
306                ("b", I32.into()),
307                ("c", I32.into()),
308            ])),
309            NonNullable,
310        )
311    }
312
313    #[test]
314    fn test_expr_top_level_ref() {
315        let dtype = dtype();
316
317        let expr = ident();
318
319        let split = StructFieldExpressionSplitter::split(expr, &dtype);
320
321        assert!(split.is_ok());
322
323        let partitioned = split.unwrap();
324
325        assert!(partitioned.root.as_any().is::<Pack>());
326        // Have a single top level pack with all fields in dtype
327        assert_eq!(
328            partitioned.partitions.len(),
329            dtype.as_struct().unwrap().names().len()
330        )
331    }
332
333    #[test]
334    fn test_expr_top_level_ref_get_item_and_split() {
335        let dtype = dtype();
336
337        let expr = get_item("b", get_item("a", ident()));
338
339        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
340        let split_a = partitioned.find_partition(&"a".into());
341        assert!(split_a.is_some());
342        let split_a = split_a.unwrap();
343
344        assert_eq!(&partitioned.root, &get_item("a", ident()));
345        assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
346    }
347
348    #[test]
349    fn test_expr_top_level_ref_get_item_and_split_pack() {
350        let dtype = dtype();
351
352        let expr = pack(
353            [
354                ("a", get_item("a", get_item("a", ident()))),
355                ("b", get_item("b", get_item("a", ident()))),
356                ("c", get_item("c", ident())),
357            ],
358            NonNullable,
359        );
360        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
361
362        let split_a = partitioned.find_partition(&"a".into()).unwrap();
363        assert_eq!(
364            &simplify(split_a.clone()).unwrap(),
365            &pack(
366                [
367                    (
368                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
369                        get_item("a", ident())
370                    ),
371                    (
372                        StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
373                        get_item("b", ident())
374                    )
375                ],
376                NonNullable
377            )
378        );
379        let split_c = partitioned.find_partition(&"c".into()).unwrap();
380        assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
381    }
382
383    #[test]
384    fn test_expr_top_level_ref_get_item_add() {
385        let dtype = dtype();
386
387        let expr = and(get_item("b", get_item("a", ident())), lit(1));
388        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
389
390        // Whole expr is a single split
391        assert_eq!(partitioned.partitions.len(), 1);
392    }
393
394    #[test]
395    fn test_expr_top_level_ref_get_item_add_cannot_split() {
396        let dtype = dtype();
397
398        let expr = and(
399            get_item("b", get_item("a", ident())),
400            get_item("b", ident()),
401        );
402        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
403
404        // One for id.a and id.b
405        assert_eq!(partitioned.partitions.len(), 2);
406    }
407
408    // Test that typed_simplify removes select and partition precise
409    #[test]
410    fn test_expr_partition_many_occurrences_of_field() {
411        let dtype = dtype();
412
413        let expr = and(
414            get_item("b", get_item("a", ident())),
415            select(vec!["a".into(), "b".into()], ident()),
416        );
417        let expr = simplify_typed(expr, &dtype).unwrap();
418        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
419
420        // One for id.a and id.b
421        assert_eq!(partitioned.partitions.len(), 2);
422
423        // This fetches [].$c which is unused, however a previous optimisation should replace select
424        // with get_item and pack removing this field.
425        assert_eq!(
426            &partitioned.root,
427            &and(
428                get_item(
429                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
430                    get_item("a", ident())
431                ),
432                pack(
433                    [
434                        (
435                            "a",
436                            get_item(
437                                StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
438                                get_item("a", ident())
439                            )
440                        ),
441                        ("b", get_item("b", ident()))
442                    ],
443                    NonNullable
444                )
445            )
446        )
447    }
448}