vortex_expr/transform/
partition.rs

1use vortex_array::aliases::hash_map::HashMap;
2use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
3use vortex_error::{VortexExpect, VortexResult, vortex_bail};
4
5use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
6use crate::transform::simplify_typed::simplify_typed;
7use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
8use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
9
10/// Partition an expression over the fields of the scope.
11///
12/// This returns a partitioned expression that can be push-down over each field of the scope.
13/// The results of each partition can then be recombined to reproduce the result of the original
14/// expression.
15///
16/// ## Note
17///
18/// This function currently respects the validity of each field in the scope, but the not validity
19/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
20/// expression for computing the validity, or to include that expression as part of the root.
21///
22/// See <https://github.com/spiraldb/vortex/issues/1907>.
23///
24// TODO(ngates): document the behaviour of conflicting `Field::Index` and `Field::Name`.
25pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
26    if !matches!(dtype, DType::Struct(..)) {
27        vortex_bail!("Expected a struct dtype, got {:?}", dtype);
28    }
29    StructFieldExpressionSplitter::split(expr, dtype)
30}
31
32/// The result of partitioning an expression.
33#[derive(Debug)]
34pub struct PartitionedExpr {
35    /// The root expression used to re-assemble the results.
36    pub root: ExprRef,
37    /// The partitions of the expression.
38    pub partitions: Box<[ExprRef]>,
39    /// The field names for the partitions
40    pub partition_names: FieldNames,
41}
42
43impl PartitionedExpr {
44    /// Return the partition for a given field, if it exists.
45    pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
46        self.partition_names
47            .iter()
48            .position(|name| name == field)
49            .map(|idx| &self.partitions[idx])
50    }
51}
52
53#[derive(Debug)]
54struct StructFieldExpressionSplitter<'a> {
55    sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
56    accesses: &'a FieldAccesses<'a>,
57    scope_dtype: &'a StructDType,
58}
59
60impl<'a> StructFieldExpressionSplitter<'a> {
61    fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
62        Self {
63            sub_expressions: HashMap::new(),
64            accesses,
65            scope_dtype,
66        }
67    }
68
69    pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
70        format!("__e__{}.{}", field, idx).into()
71    }
72
73    fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
74        let scope_dtype = match dtype {
75            DType::Struct(scope_dtype, _) => scope_dtype,
76            _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
77        };
78
79        let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
80
81        let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
82
83        let split = expr.clone().transform_with_context(&mut splitter, ())?;
84
85        let mut remove_accesses: Vec<FieldName> = Vec::new();
86
87        // Create partitions which can be passed to layout fields
88        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
89        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
90        for (name, exprs) in splitter.sub_expressions.into_iter() {
91            let field_dtype = scope_dtype.field(&name)?;
92            // If there is a single expr then we don't need to `pack` this, and we must update
93            // the root expr removing this access.
94            let expr = if exprs.len() == 1 {
95                remove_accesses.push(Self::field_idx_name(&name, 0));
96                exprs.first().vortex_expect("exprs is non-empty").clone()
97            } else {
98                pack(
99                    exprs
100                        .into_iter()
101                        .enumerate()
102                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
103                )
104            };
105
106            partitions.push(simplify_typed(expr, &field_dtype)?);
107            partition_names.push(name);
108        }
109
110        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
111        // Ensure that there are not more accesses than partitions, we missed something
112        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
113        // Ensure that there are as many partitions as there are accesses/fields in the scope,
114        // this will affect performance, not correctness.
115        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
116
117        let split = split
118            .result()
119            .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
120
121        Ok(PartitionedExpr {
122            root: simplify_typed(split.result, dtype)?,
123            partitions: partitions.into_boxed_slice(),
124            partition_names: partition_names.into(),
125        })
126    }
127}
128
129impl FolderMut for StructFieldExpressionSplitter<'_> {
130    type NodeTy = ExprRef;
131    type Out = ExprRef;
132    type Context = ();
133
134    fn visit_down(
135        &mut self,
136        node: &Self::NodeTy,
137        _context: Self::Context,
138    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
139        // If this expression only accesses a single field, then we can skip the children
140        let access = self.accesses.get(node);
141        if access.as_ref().is_some_and(|a| a.len() == 1) {
142            let field_name = access
143                .vortex_expect("access is non-empty")
144                .iter()
145                .next()
146                .vortex_expect("expected one field");
147
148            // TODO(joe): dedup the sub_expression, if there are two expressions that are the same
149            // only create one entry here and reuse it.
150            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
151            let idx = sub_exprs.len();
152
153            // Need to replace get_item(f, ident) with ident, making the expr relative to the child.
154            let replaced = node
155                .clone()
156                .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
157            sub_exprs.push(replaced.result);
158
159            let access = get_item(
160                Self::field_idx_name(field_name, idx),
161                get_item(field_name.clone(), ident()),
162            );
163
164            return Ok(FoldDown::SkipChildren(access));
165        };
166
167        // If the expression is an identity, then we need to partition it into the fields of the scope.
168        if node.as_any().is::<Identity>() {
169            let field_names = self.scope_dtype.names();
170
171            let mut elements = Vec::with_capacity(field_names.len());
172
173            for field_name in field_names.iter() {
174                let sub_exprs = self
175                    .sub_expressions
176                    .entry(field_name.clone())
177                    .or_insert_with(Vec::new);
178
179                let idx = sub_exprs.len();
180
181                sub_exprs.push(ident());
182
183                elements.push((
184                    field_name.clone(),
185                    // Partitions are packed into a struct of field name -> occurrence idx -> array
186                    get_item(
187                        Self::field_idx_name(field_name, idx),
188                        get_item(field_name.clone(), ident()),
189                    ),
190                ));
191            }
192
193            return Ok(FoldDown::SkipChildren(pack(elements)));
194        }
195
196        // Otherwise, continue traversing.
197        Ok(FoldDown::Continue(()))
198    }
199
200    fn visit_up(
201        &mut self,
202        node: Self::NodeTy,
203        _context: Self::Context,
204        children: Vec<Self::Out>,
205    ) -> VortexResult<FoldUp<Self::Out>> {
206        Ok(FoldUp::Continue(node.replacing_children(children)))
207    }
208}
209
210struct ScopeStepIntoFieldExpr(FieldName);
211
212impl MutNodeVisitor for ScopeStepIntoFieldExpr {
213    type NodeTy = ExprRef;
214
215    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
216        if node.as_any().is::<Identity>() {
217            Ok(TransformResult::yes(pack([(self.0.clone(), ident())])))
218        } else {
219            Ok(TransformResult::no(node))
220        }
221    }
222}
223
224struct ReplaceAccessesWithChild(Vec<FieldName>);
225
226impl MutNodeVisitor for ReplaceAccessesWithChild {
227    type NodeTy = ExprRef;
228
229    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
230        if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
231            if self.0.contains(item.field()) {
232                return Ok(TransformResult::yes(item.child().clone()));
233            }
234        }
235        Ok(TransformResult::no(node))
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use std::sync::Arc;
242
243    use vortex_dtype::Nullability::NonNullable;
244    use vortex_dtype::PType::I32;
245    use vortex_dtype::{DType, StructDType};
246
247    use super::*;
248    use crate::transform::simplify::simplify;
249    use crate::transform::simplify_typed::simplify_typed;
250    use crate::{Pack, and, get_item, ident, lit, pack, select};
251
252    fn dtype() -> DType {
253        DType::Struct(
254            Arc::new(StructDType::from_iter([
255                (
256                    "a",
257                    DType::Struct(
258                        Arc::new(StructDType::from_iter([
259                            ("a", I32.into()),
260                            ("b", DType::from(I32)),
261                        ])),
262                        NonNullable,
263                    ),
264                ),
265                ("b", I32.into()),
266                ("c", I32.into()),
267            ])),
268            NonNullable,
269        )
270    }
271
272    #[test]
273    fn test_expr_top_level_ref() {
274        let dtype = dtype();
275
276        let expr = ident();
277
278        let split = StructFieldExpressionSplitter::split(expr, &dtype);
279
280        assert!(split.is_ok());
281
282        let partitioned = split.unwrap();
283
284        assert!(partitioned.root.as_any().is::<Pack>());
285        // Have a single top level pack with all fields in dtype
286        assert_eq!(
287            partitioned.partitions.len(),
288            dtype.as_struct().unwrap().names().len()
289        )
290    }
291
292    #[test]
293    fn test_expr_top_level_ref_get_item_and_split() {
294        let dtype = dtype();
295
296        let expr = get_item("b", get_item("a", ident()));
297
298        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
299        let split_a = partitioned.find_partition(&"a".into());
300        assert!(split_a.is_some());
301        let split_a = split_a.unwrap();
302
303        assert_eq!(&partitioned.root, &get_item("a", ident()));
304        assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
305    }
306
307    #[test]
308    fn test_expr_top_level_ref_get_item_and_split_pack() {
309        let dtype = dtype();
310
311        let expr = pack([
312            ("a", get_item("a", get_item("a", ident()))),
313            ("b", get_item("b", get_item("a", ident()))),
314            ("c", get_item("c", ident())),
315        ]);
316        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
317
318        let split_a = partitioned.find_partition(&"a".into()).unwrap();
319        assert_eq!(
320            &simplify(split_a.clone()).unwrap(),
321            &pack([
322                (
323                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
324                    get_item("a", ident())
325                ),
326                (
327                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
328                    get_item("b", ident())
329                )
330            ])
331        );
332        let split_c = partitioned.find_partition(&"c".into()).unwrap();
333        assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
334    }
335
336    #[test]
337    fn test_expr_top_level_ref_get_item_add() {
338        let dtype = dtype();
339
340        let expr = and(get_item("b", get_item("a", ident())), lit(1));
341        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
342
343        // Whole expr is a single split
344        assert_eq!(partitioned.partitions.len(), 1);
345    }
346
347    #[test]
348    fn test_expr_top_level_ref_get_item_add_cannot_split() {
349        let dtype = dtype();
350
351        let expr = and(
352            get_item("b", get_item("a", ident())),
353            get_item("b", ident()),
354        );
355        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
356
357        // One for id.a and id.b
358        assert_eq!(partitioned.partitions.len(), 2);
359    }
360
361    // Test that typed_simplify removes select and partition precise
362    #[test]
363    fn test_expr_partition_many_occurrences_of_field() {
364        let dtype = dtype();
365
366        let expr = and(
367            get_item("b", get_item("a", ident())),
368            select(vec!["a".into(), "b".into()], ident()),
369        );
370        let expr = simplify_typed(expr, &dtype).unwrap();
371        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
372
373        // One for id.a and id.b
374        assert_eq!(partitioned.partitions.len(), 2);
375
376        // This fetches [].$c which is unused, however a previous optimisation should replace select
377        // with get_item and pack removing this field.
378        assert_eq!(
379            &partitioned.root,
380            &and(
381                get_item(
382                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
383                    get_item("a", ident())
384                ),
385                pack([
386                    (
387                        "a",
388                        get_item(
389                            StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
390                            get_item("a", ident())
391                        )
392                    ),
393                    ("b", get_item("b", ident()))
394                ])
395            )
396        )
397    }
398}