vortex_expr/transform/
partition.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_array::aliases::hash_map::{DefaultHashBuilder, HashMap};
7use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16    LazyLock::new(DefaultHashBuilder::default);
17
18/// Partition an expression over the fields of the scope.
19///
20/// This returns a partitioned expression that can be push-down over each field of the scope.
21/// The results of each partition can then be recombined to reproduce the result of the original
22/// expression.
23///
24/// ## Note
25///
26/// This function currently respects the validity of each field in the scope, but the not validity
27/// of the scope itself. The fix would be for the returned `PartitionedExpr` to include a partition
28/// expression for computing the validity, or to include that expression as part of the root.
29///
30/// See <https://github.com/spiraldb/vortex/issues/1907>.
31///
32// TODO(ngates): document the behaviour of conflicting `Field::Index` and `Field::Name`.
33pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34    if !matches!(dtype, DType::Struct(..)) {
35        vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36    }
37    StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40/// The result of partitioning an expression.
41#[derive(Debug)]
42pub struct PartitionedExpr {
43    /// The root expression used to re-assemble the results.
44    pub root: ExprRef,
45    /// The partitions of the expression.
46    pub partitions: Box<[ExprRef]>,
47    /// The field names for the partitions
48    pub partition_names: FieldNames,
49    /// The return DTypes of each partition.
50    pub partition_dtypes: Box<[DType]>,
51}
52
53impl Display for PartitionedExpr {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        write!(
56            f,
57            "root: {} {{{}}}",
58            self.root,
59            self.partition_names
60                .iter()
61                .zip(self.partitions.iter())
62                .map(|(name, partition)| format!("{}: {}", name, partition))
63                .join(", ")
64        )
65    }
66}
67
68impl PartitionedExpr {
69    /// Return the partition for a given field, if it exists.
70    pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
71        self.partition_names
72            .iter()
73            .position(|name| name == field)
74            .map(|idx| &self.partitions[idx])
75    }
76}
77
78#[derive(Debug)]
79struct StructFieldExpressionSplitter<'a> {
80    sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
81    accesses: &'a FieldAccesses<'a>,
82    scope_dtype: &'a StructDType,
83}
84
85impl<'a> StructFieldExpressionSplitter<'a> {
86    fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
87        Self {
88            sub_expressions: HashMap::new(),
89            accesses,
90            scope_dtype,
91        }
92    }
93
94    pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
95        let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
96        field.hash(&mut hasher);
97        idx.hash(&mut hasher);
98        hasher.finish().to_string().into()
99    }
100
101    fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
102        let scope_dtype = match dtype {
103            DType::Struct(scope_dtype, _) => scope_dtype,
104            _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
105        };
106
107        let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
108
109        let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
110
111        let split = expr.clone().transform_with_context(&mut splitter, ())?;
112
113        let mut remove_accesses: Vec<FieldName> = Vec::new();
114
115        // Create partitions which can be passed to layout fields
116        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
117        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
118        let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
119        for (name, exprs) in splitter.sub_expressions.into_iter() {
120            let field_dtype = scope_dtype.field(&name)?;
121            // If there is a single expr then we don't need to `pack` this, and we must update
122            // the root expr removing this access.
123            let expr = if exprs.len() == 1 {
124                remove_accesses.push(Self::field_idx_name(&name, 0));
125                exprs.first().vortex_expect("exprs is non-empty").clone()
126            } else {
127                pack(
128                    exprs
129                        .into_iter()
130                        .enumerate()
131                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
132                )
133            };
134
135            let expr = simplify_typed(expr.clone(), &field_dtype)?;
136            let expr_dtype = expr.return_dtype(&field_dtype)?;
137
138            partitions.push(expr);
139            partition_names.push(name);
140            partition_dtypes.push(expr_dtype);
141        }
142
143        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
144        // Ensure that there are not more accesses than partitions, we missed something
145        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
146        // Ensure that there are as many partitions as there are accesses/fields in the scope,
147        // this will affect performance, not correctness.
148        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
149
150        let split = split
151            .result()
152            .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
153
154        Ok(PartitionedExpr {
155            root: simplify_typed(split.result, dtype)?,
156            partitions: partitions.into_boxed_slice(),
157            partition_names: partition_names.into(),
158            partition_dtypes: partition_dtypes.into_boxed_slice(),
159        })
160    }
161}
162
163impl FolderMut for StructFieldExpressionSplitter<'_> {
164    type NodeTy = ExprRef;
165    type Out = ExprRef;
166    type Context = ();
167
168    fn visit_down(
169        &mut self,
170        node: &Self::NodeTy,
171        _context: Self::Context,
172    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
173        // If this expression only accesses a single field, then we can skip the children
174        let access = self.accesses.get(node);
175        if access.as_ref().is_some_and(|a| a.len() == 1) {
176            let field_name = access
177                .vortex_expect("access is non-empty")
178                .iter()
179                .next()
180                .vortex_expect("expected one field");
181
182            // TODO(joe): dedup the sub_expression, if there are two expressions that are the same
183            // only create one entry here and reuse it.
184            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
185            let idx = sub_exprs.len();
186
187            // Need to replace get_item(f, ident) with ident, making the expr relative to the child.
188            let replaced = node
189                .clone()
190                .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
191            sub_exprs.push(replaced.result);
192
193            let access = get_item(
194                Self::field_idx_name(field_name, idx),
195                get_item(field_name.clone(), ident()),
196            );
197
198            return Ok(FoldDown::SkipChildren(access));
199        };
200
201        // If the expression is an identity, then we need to partition it into the fields of the scope.
202        if node.as_any().is::<Identity>() {
203            let field_names = self.scope_dtype.names();
204
205            let mut elements = Vec::with_capacity(field_names.len());
206
207            for field_name in field_names.iter() {
208                let sub_exprs = self
209                    .sub_expressions
210                    .entry(field_name.clone())
211                    .or_insert_with(Vec::new);
212
213                let idx = sub_exprs.len();
214
215                sub_exprs.push(ident());
216
217                elements.push((
218                    field_name.clone(),
219                    // Partitions are packed into a struct of field name -> occurrence idx -> array
220                    get_item(
221                        Self::field_idx_name(field_name, idx),
222                        get_item(field_name.clone(), ident()),
223                    ),
224                ));
225            }
226
227            return Ok(FoldDown::SkipChildren(pack(elements)));
228        }
229
230        // Otherwise, continue traversing.
231        Ok(FoldDown::Continue(()))
232    }
233
234    fn visit_up(
235        &mut self,
236        node: Self::NodeTy,
237        _context: Self::Context,
238        children: Vec<Self::Out>,
239    ) -> VortexResult<FoldUp<Self::Out>> {
240        Ok(FoldUp::Continue(node.replacing_children(children)))
241    }
242}
243
244struct ScopeStepIntoFieldExpr(FieldName);
245
246impl MutNodeVisitor for ScopeStepIntoFieldExpr {
247    type NodeTy = ExprRef;
248
249    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
250        if node.as_any().is::<Identity>() {
251            Ok(TransformResult::yes(pack([(self.0.clone(), ident())])))
252        } else {
253            Ok(TransformResult::no(node))
254        }
255    }
256}
257
258struct ReplaceAccessesWithChild(Vec<FieldName>);
259
260impl MutNodeVisitor for ReplaceAccessesWithChild {
261    type NodeTy = ExprRef;
262
263    fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
264        if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
265            if self.0.contains(item.field()) {
266                return Ok(TransformResult::yes(item.child().clone()));
267            }
268        }
269        Ok(TransformResult::no(node))
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use std::sync::Arc;
276
277    use vortex_dtype::Nullability::NonNullable;
278    use vortex_dtype::PType::I32;
279    use vortex_dtype::{DType, StructDType};
280
281    use super::*;
282    use crate::transform::simplify::simplify;
283    use crate::transform::simplify_typed::simplify_typed;
284    use crate::{Pack, and, get_item, ident, lit, pack, select};
285
286    fn dtype() -> DType {
287        DType::Struct(
288            Arc::new(StructDType::from_iter([
289                (
290                    "a",
291                    DType::Struct(
292                        Arc::new(StructDType::from_iter([
293                            ("a", I32.into()),
294                            ("b", DType::from(I32)),
295                        ])),
296                        NonNullable,
297                    ),
298                ),
299                ("b", I32.into()),
300                ("c", I32.into()),
301            ])),
302            NonNullable,
303        )
304    }
305
306    #[test]
307    fn test_expr_top_level_ref() {
308        let dtype = dtype();
309
310        let expr = ident();
311
312        let split = StructFieldExpressionSplitter::split(expr, &dtype);
313
314        assert!(split.is_ok());
315
316        let partitioned = split.unwrap();
317
318        assert!(partitioned.root.as_any().is::<Pack>());
319        // Have a single top level pack with all fields in dtype
320        assert_eq!(
321            partitioned.partitions.len(),
322            dtype.as_struct().unwrap().names().len()
323        )
324    }
325
326    #[test]
327    fn test_expr_top_level_ref_get_item_and_split() {
328        let dtype = dtype();
329
330        let expr = get_item("b", get_item("a", ident()));
331
332        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
333        let split_a = partitioned.find_partition(&"a".into());
334        assert!(split_a.is_some());
335        let split_a = split_a.unwrap();
336
337        assert_eq!(&partitioned.root, &get_item("a", ident()));
338        assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
339    }
340
341    #[test]
342    fn test_expr_top_level_ref_get_item_and_split_pack() {
343        let dtype = dtype();
344
345        let expr = pack([
346            ("a", get_item("a", get_item("a", ident()))),
347            ("b", get_item("b", get_item("a", ident()))),
348            ("c", get_item("c", ident())),
349        ]);
350        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
351
352        let split_a = partitioned.find_partition(&"a".into()).unwrap();
353        assert_eq!(
354            &simplify(split_a.clone()).unwrap(),
355            &pack([
356                (
357                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
358                    get_item("a", ident())
359                ),
360                (
361                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
362                    get_item("b", ident())
363                )
364            ])
365        );
366        let split_c = partitioned.find_partition(&"c".into()).unwrap();
367        assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
368    }
369
370    #[test]
371    fn test_expr_top_level_ref_get_item_add() {
372        let dtype = dtype();
373
374        let expr = and(get_item("b", get_item("a", ident())), lit(1));
375        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
376
377        // Whole expr is a single split
378        assert_eq!(partitioned.partitions.len(), 1);
379    }
380
381    #[test]
382    fn test_expr_top_level_ref_get_item_add_cannot_split() {
383        let dtype = dtype();
384
385        let expr = and(
386            get_item("b", get_item("a", ident())),
387            get_item("b", ident()),
388        );
389        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
390
391        // One for id.a and id.b
392        assert_eq!(partitioned.partitions.len(), 2);
393    }
394
395    // Test that typed_simplify removes select and partition precise
396    #[test]
397    fn test_expr_partition_many_occurrences_of_field() {
398        let dtype = dtype();
399
400        let expr = and(
401            get_item("b", get_item("a", ident())),
402            select(vec!["a".into(), "b".into()], ident()),
403        );
404        let expr = simplify_typed(expr, &dtype).unwrap();
405        let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
406
407        // One for id.a and id.b
408        assert_eq!(partitioned.partitions.len(), 2);
409
410        // This fetches [].$c which is unused, however a previous optimisation should replace select
411        // with get_item and pack removing this field.
412        assert_eq!(
413            &partitioned.root,
414            &and(
415                get_item(
416                    StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
417                    get_item("a", ident())
418                ),
419                pack([
420                    (
421                        "a",
422                        get_item(
423                            StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
424                            get_item("a", ident())
425                        )
426                    ),
427                    ("b", get_item("b", ident()))
428                ])
429            )
430        )
431    }
432}