vortex_expr/transform/
partition.rs

1use std::fmt::{Display, Formatter};
2
3use itertools::Itertools;
4use vortex_array::aliases::hash_map::HashMap;
5use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
6use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7
8use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
9use crate::transform::simplify_typed::simplify_typed;
10use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
11use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
12
13/// Partition an expression over the fields of the scope.
14///
15/// This returns a partitioned expression that can be push-down over each field of the scope.
16/// The results of each partition can then be recombined to reproduce the result of the original
17/// expression.
18///
19/// ## Note
20///
21/// This function currently respects the validity of each field in the scope, but the not validity
22/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
23/// expression for computing the validity, or to include that expression as part of the root.
24///
25/// See <https://github.com/spiraldb/vortex/issues/1907>.
26///
27// TODO(ngates): document the behaviour of conflicting `Field::Index` and `Field::Name`.
28pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
29    if !matches!(dtype, DType::Struct(..)) {
30        vortex_bail!("Expected a struct dtype, got {:?}", dtype);
31    }
32    StructFieldExpressionSplitter::split(expr, dtype)
33}
34
35/// The result of partitioning an expression.
36#[derive(Debug)]
37pub struct PartitionedExpr {
38    /// The root expression used to re-assemble the results.
39    pub root: ExprRef,
40    /// The partitions of the expression.
41    pub partitions: Box<[ExprRef]>,
42    /// The field names for the partitions
43    pub partition_names: FieldNames,
44    /// The return DTypes of each partition.
45    pub partition_dtypes: Box<[DType]>,
46}
47
48impl Display for PartitionedExpr {
49    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50        write!(
51            f,
52            "root: {} {{{}}}",
53            self.root,
54            self.partition_names
55                .iter()
56                .zip(self.partitions.iter())
57                .map(|(name, partition)| format!("{}: {}", name, partition))
58                .join(", ")
59        )
60    }
61}
62
63impl PartitionedExpr {
64    /// Return the partition for a given field, if it exists.
65    pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
66        self.partition_names
67            .iter()
68            .position(|name| name == field)
69            .map(|idx| &self.partitions[idx])
70    }
71}
72
73#[derive(Debug)]
74struct StructFieldExpressionSplitter<'a> {
75    sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
76    accesses: &'a FieldAccesses<'a>,
77    scope_dtype: &'a StructDType,
78}
79
80impl<'a> StructFieldExpressionSplitter<'a> {
81    fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
82        Self {
83            sub_expressions: HashMap::new(),
84            accesses,
85            scope_dtype,
86        }
87    }
88
89    pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
90        format!("__e__{}.{}", field, idx).into()
91    }
92
93    fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
94        let scope_dtype = match dtype {
95            DType::Struct(scope_dtype, _) => scope_dtype,
96            _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
97        };
98
99        let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
100
101        let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
102
103        let split = expr.clone().transform_with_context(&mut splitter, ())?;
104
105        let mut remove_accesses: Vec<FieldName> = Vec::new();
106
107        // Create partitions which can be passed to layout fields
108        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
109        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
110        let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
111        for (name, exprs) in splitter.sub_expressions.into_iter() {
112            let field_dtype = scope_dtype.field(&name)?;
113            // If there is a single expr then we don't need to `pack` this, and we must update
114            // the root expr removing this access.
115            let expr = if exprs.len() == 1 {
116                remove_accesses.push(Self::field_idx_name(&name, 0));
117                exprs.first().vortex_expect("exprs is non-empty").clone()
118            } else {
119                pack(
120                    exprs
121                        .into_iter()
122                        .enumerate()
123                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
124                )
125            };
126
127            let expr = simplify_typed(expr.clone(), &field_dtype)?;
128            let expr_dtype = expr.return_dtype(&field_dtype)?;
129
130            partitions.push(expr);
131            partition_names.push(name);
132            partition_dtypes.push(expr_dtype);
133        }
134
135        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
136        // Ensure that there are not more accesses than partitions, we missed something
137        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
138        // Ensure that there are as many partitions as there are accesses/fields in the scope,
139        // this will affect performance, not correctness.
140        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
141
142        let split = split
143            .result()
144            .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
145
146        Ok(PartitionedExpr {
147            root: simplify_typed(split.result, dtype)?,
148            partitions: partitions.into_boxed_slice(),
149            partition_names: partition_names.into(),
150            partition_dtypes: partition_dtypes.into_boxed_slice(),
151        })
152    }
153}
154
155impl FolderMut for StructFieldExpressionSplitter<'_> {
156    type NodeTy = ExprRef;
157    type Out = ExprRef;
158    type Context = ();
159
160    fn visit_down(
161        &mut self,
162        node: &Self::NodeTy,
163        _context: Self::Context,
164    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
165        // If this expression only accesses a single field, then we can skip the children
166        let access = self.accesses.get(node);
167        if access.as_ref().is_some_and(|a| a.len() == 1) {
168            let field_name = access
169                .vortex_expect("access is non-empty")
170                .iter()
171                .next()
172                .vortex_expect("expected one field");
173
174            // TODO(joe): dedup the sub_expression, if there are two expressions that are the same
175            // only create one entry here and reuse it.
176            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
177            let idx = sub_exprs.len();
178
179            // Need to replace get_item(f, ident) with ident, making the expr relative to the child.
180            let replaced = node
181                .clone()
182                .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
183            sub_exprs.push(replaced.result);
184
185            let access = get_item(
186                Self::field_idx_name(field_name, idx),
187                get_item(field_name.clone(), ident()),
188            );
189
190            return Ok(FoldDown::SkipChildren(access));
191        };
192
193        // If the expression is an identity, then we need to partition it into the fields of the scope.
194        if node.as_any().is::<Identity>() {
195            let field_names = self.scope_dtype.names();
196
197            let mut elements = Vec::with_capacity(field_names.len());
198
199            for field_name in field_names.iter() {
200                let sub_exprs = self
201                    .sub_expressions
202                    .entry(field_name.clone())
203                    .or_insert_with(Vec::new);
204
205                let idx = sub_exprs.len();
206
207                sub_exprs.push(ident());
208
209                elements.push((
210                    field_name.clone(),
211                    // Partitions are packed into a struct of field name -> occurrence idx -> array
212                    get_item(
213                        Self::field_idx_name(field_name, idx),
214                        get_item(field_name.clone(), ident()),
215                    ),
216                ));
217            }
218
219            return Ok(FoldDown::SkipChildren(pack(elements)));
220        }
221
222        // Otherwise, continue traversing.
223        Ok(FoldDown::Continue(()))
224    }
225
226    fn visit_up(
227        &mut self,
228        node: Self::NodeTy,
229        _context: Self::Context,
230        children: Vec<Self::Out>,
231    ) -> VortexResult<FoldUp<Self::Out>> {
232        Ok(FoldUp::Continue(node.replacing_children(children)))
233    }
234}
235
236struct ScopeStepIntoFieldExpr(FieldName);
237
238impl MutNodeVisitor for ScopeStepIntoFieldExpr {
239    type NodeTy = ExprRef;
240
241    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
242        if node.as_any().is::<Identity>() {
243            Ok(TransformResult::yes(pack([(self.0.clone(), ident())])))
244        } else {
245            Ok(TransformResult::no(node))
246        }
247    }
248}
249
250struct ReplaceAccessesWithChild(Vec<FieldName>);
251
252impl MutNodeVisitor for ReplaceAccessesWithChild {
253    type NodeTy = ExprRef;
254
255    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
256        if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
257            if self.0.contains(item.field()) {
258                return Ok(TransformResult::yes(item.child().clone()));
259            }
260        }
261        Ok(TransformResult::no(node))
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use std::sync::Arc;
268
269    use vortex_dtype::Nullability::NonNullable;
270    use vortex_dtype::PType::I32;
271    use vortex_dtype::{DType, StructDType};
272
273    use super::*;
274    use crate::transform::simplify::simplify;
275    use crate::transform::simplify_typed::simplify_typed;
276    use crate::{Pack, and, get_item, ident, lit, pack, select};
277
278    fn dtype() -> DType {
279        DType::Struct(
280            Arc::new(StructDType::from_iter([
281                (
282                    "a",
283                    DType::Struct(
284                        Arc::new(StructDType::from_iter([
285                            ("a", I32.into()),
286                            ("b", DType::from(I32)),
287                        ])),
288                        NonNullable,
289                    ),
290                ),
291                ("b", I32.into()),
292                ("c", I32.into()),
293            ])),
294            NonNullable,
295        )
296    }
297
298    #[test]
299    fn test_expr_top_level_ref() {
300        let dtype = dtype();
301
302        let expr = ident();
303
304        let split = StructFieldExpressionSplitter::split(expr, &dtype);
305
306        assert!(split.is_ok());
307
308        let partitioned = split.unwrap();
309
310        assert!(partitioned.root.as_any().is::<Pack>());
311        // Have a single top level pack with all fields in dtype
312        assert_eq!(
313            partitioned.partitions.len(),
314            dtype.as_struct().unwrap().names().len()
315        )
316    }
317
318    #[test]
319    fn test_expr_top_level_ref_get_item_and_split() {
320        let dtype = dtype();
321
322        let expr = get_item("b", get_item("a", ident()));
323
324        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
325        let split_a = partitioned.find_partition(&"a".into());
326        assert!(split_a.is_some());
327        let split_a = split_a.unwrap();
328
329        assert_eq!(&partitioned.root, &get_item("a", ident()));
330        assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
331    }
332
333    #[test]
334    fn test_expr_top_level_ref_get_item_and_split_pack() {
335        let dtype = dtype();
336
337        let expr = pack([
338            ("a", get_item("a", get_item("a", ident()))),
339            ("b", get_item("b", get_item("a", ident()))),
340            ("c", get_item("c", ident())),
341        ]);
342        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
343
344        let split_a = partitioned.find_partition(&"a".into()).unwrap();
345        assert_eq!(
346            &simplify(split_a.clone()).unwrap(),
347            &pack([
348                (
349                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
350                    get_item("a", ident())
351                ),
352                (
353                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
354                    get_item("b", ident())
355                )
356            ])
357        );
358        let split_c = partitioned.find_partition(&"c".into()).unwrap();
359        assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
360    }
361
362    #[test]
363    fn test_expr_top_level_ref_get_item_add() {
364        let dtype = dtype();
365
366        let expr = and(get_item("b", get_item("a", ident())), lit(1));
367        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
368
369        // Whole expr is a single split
370        assert_eq!(partitioned.partitions.len(), 1);
371    }
372
373    #[test]
374    fn test_expr_top_level_ref_get_item_add_cannot_split() {
375        let dtype = dtype();
376
377        let expr = and(
378            get_item("b", get_item("a", ident())),
379            get_item("b", ident()),
380        );
381        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
382
383        // One for id.a and id.b
384        assert_eq!(partitioned.partitions.len(), 2);
385    }
386
387    // Test that typed_simplify removes select and partition precise
388    #[test]
389    fn test_expr_partition_many_occurrences_of_field() {
390        let dtype = dtype();
391
392        let expr = and(
393            get_item("b", get_item("a", ident())),
394            select(vec!["a".into(), "b".into()], ident()),
395        );
396        let expr = simplify_typed(expr, &dtype).unwrap();
397        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
398
399        // One for id.a and id.b
400        assert_eq!(partitioned.partitions.len(), 2);
401
402        // This fetches [].$c which is unused, however a previous optimisation should replace select
403        // with get_item and pack removing this field.
404        assert_eq!(
405            &partitioned.root,
406            &and(
407                get_item(
408                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
409                    get_item("a", ident())
410                ),
411                pack([
412                    (
413                        "a",
414                        get_item(
415                            StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
416                            get_item("a", ident())
417                        )
418                    ),
419                    ("b", get_item("b", ident()))
420                ])
421            )
422        )
423    }
424}