1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, ScopeDType, get_item, is_root, pack, root};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16 LazyLock::new(DefaultHashBuilder::default);
17
18pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34 if !matches!(dtype, DType::Struct(..)) {
35 vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36 }
37 StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40#[derive(Debug)]
43pub struct PartitionedExpr {
44 pub root: ExprRef,
46 pub partitions: Box<[ExprRef]>,
48 pub partition_names: FieldNames,
50 pub partition_dtypes: Box<[DType]>,
52}
53
54impl Display for PartitionedExpr {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 write!(
57 f,
58 "root: {} {{{}}}",
59 self.root,
60 self.partition_names
61 .iter()
62 .zip(self.partitions.iter())
63 .map(|(name, partition)| format!("{name}: {partition}"))
64 .join(", ")
65 )
66 }
67}
68
69impl PartitionedExpr {
70 pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
72 self.partition_names
73 .iter()
74 .position(|name| name == field)
75 .map(|idx| &self.partitions[idx])
76 }
77}
78
79#[derive(Debug)]
80struct StructFieldExpressionSplitter<'a> {
81 sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
82 accesses: &'a FieldAccesses<'a>,
83 scope_dtype: &'a StructFields,
84}
85
86impl<'a> StructFieldExpressionSplitter<'a> {
87 fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructFields) -> Self {
88 Self {
89 sub_expressions: HashMap::new(),
90 accesses,
91 scope_dtype,
92 }
93 }
94
95 pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
96 let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
97 field.hash(&mut hasher);
98 idx.hash(&mut hasher);
99 hasher.finish().to_string().into()
100 }
101
102 fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
103 let scope_dtype = match dtype {
104 DType::Struct(scope_dtype, _) => scope_dtype,
105 _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
106 };
107
108 let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
109
110 let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
111
112 let split = expr
113 .clone()
114 .transform_with_context(&mut splitter, ())?
115 .result();
116
117 let mut remove_accesses: Vec<FieldName> = Vec::new();
118
119 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
121 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
122 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
123 for (name, exprs) in splitter.sub_expressions.into_iter() {
124 let expr = if exprs.len() == 1 {
127 remove_accesses.push(Self::field_idx_name(&name, 0));
128 exprs.first().vortex_expect("exprs is non-empty").clone()
129 } else {
130 pack(
131 exprs
132 .into_iter()
133 .enumerate()
134 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
135 Nullability::NonNullable,
136 )
137 };
138
139 let field_dtype = scope_dtype
140 .field(&name)
141 .ok_or_else(|| vortex_err!("Missing field {name}"))?;
142 let field_ctx = ScopeDType::new(field_dtype);
143 let expr = simplify_typed(expr.clone(), &field_ctx)?;
144 let expr_dtype = expr.return_dtype(&field_ctx)?;
145
146 partitions.push(expr);
147 partition_names.push(name);
148 partition_dtypes.push(expr_dtype);
149 }
150
151 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
152 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
154 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
157
158 let split = split
159 .transform(&mut ReplaceAccessesWithChild(remove_accesses))?
160 .into_inner();
161
162 let ctx = ScopeDType::new(DType::Struct(
163 StructFields::new(
164 FieldNames::from(partition_names.clone()),
165 partition_dtypes.clone(),
166 ),
167 Nullability::NonNullable,
168 ));
169
170 Ok(PartitionedExpr {
171 root: simplify_typed(split, &ctx)?,
172 partitions: partitions.into_boxed_slice(),
173 partition_names: partition_names.into(),
174 partition_dtypes: partition_dtypes.into_boxed_slice(),
175 })
176 }
177}
178
179impl FolderMut for StructFieldExpressionSplitter<'_> {
180 type NodeTy = ExprRef;
181 type Out = ExprRef;
182 type Context = ();
183
184 fn visit_down(
185 &mut self,
186 node: &Self::NodeTy,
187 _context: Self::Context,
188 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
189 let access = self.accesses.get(node);
191 if access.as_ref().is_some_and(|a| a.len() == 1) {
192 let field_name = access
193 .vortex_expect("access is non-empty")
194 .iter()
195 .next()
196 .vortex_expect("expected one field");
197
198 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
199 let idx = sub_exprs.len();
200
201 let replaced = node
203 .clone()
204 .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
205 sub_exprs.push(replaced.into_inner());
206
207 let access = get_item(
208 Self::field_idx_name(field_name, idx),
209 get_item(field_name.clone(), root()),
210 );
211
212 return Ok(FoldDown::SkipChildren(access));
213 };
214
215 if is_root(node) {
217 let field_names = self.scope_dtype.names();
218
219 let mut elements = Vec::with_capacity(field_names.len());
220
221 for field_name in field_names.iter() {
222 let sub_exprs = self
223 .sub_expressions
224 .entry(field_name.clone())
225 .or_insert_with(Vec::new);
226
227 let idx = sub_exprs.len();
228
229 sub_exprs.push(root());
230
231 elements.push((
232 field_name.clone(),
233 get_item(
235 Self::field_idx_name(field_name, idx),
236 get_item(field_name.clone(), root()),
237 ),
238 ));
239 }
240
241 return Ok(FoldDown::SkipChildren(pack(
242 elements,
243 Nullability::NonNullable,
244 )));
245 }
246
247 Ok(FoldDown::Continue(()))
249 }
250
251 fn visit_up(
252 &mut self,
253 node: Self::NodeTy,
254 _context: Self::Context,
255 children: Vec<Self::Out>,
256 ) -> VortexResult<FoldUp<Self::Out>> {
257 Ok(FoldUp::Continue(node.replacing_children(children)))
258 }
259}
260
261struct ScopeStepIntoFieldExpr(FieldName);
262
263impl MutNodeVisitor for ScopeStepIntoFieldExpr {
264 type NodeTy = ExprRef;
265
266 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
267 if is_root(&node) {
268 Ok(TransformResult::yes(pack(
269 [(self.0.clone(), root())],
270 Nullability::NonNullable,
271 )))
272 } else {
273 Ok(TransformResult::no(node))
274 }
275 }
276}
277
278pub(crate) struct ReplaceAccessesWithChild(Vec<FieldName>);
279
280impl ReplaceAccessesWithChild {
281 pub(crate) fn new(field_names: Vec<FieldName>) -> Self {
282 Self(field_names)
283 }
284}
285
286impl MutNodeVisitor for ReplaceAccessesWithChild {
287 type NodeTy = ExprRef;
288
289 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
290 if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
291 if self.0.contains(item.field()) {
292 return Ok(TransformResult::yes(item.child().clone()));
293 }
294 }
295 Ok(TransformResult::no(node))
296 }
297}
298
299#[cfg(test)]
300mod tests {
301
302 use vortex_dtype::Nullability::NonNullable;
303 use vortex_dtype::PType::I32;
304 use vortex_dtype::{DType, StructFields};
305 use vortex_utils::aliases::hash_set::HashSet;
306
307 use super::*;
308 use crate::transform::simplify::simplify;
309 use crate::transform::simplify_typed::simplify_typed;
310 use crate::{Pack, and, col, get_item, lit, merge, pack, root, select};
311
312 fn dtype() -> DType {
313 DType::Struct(
314 StructFields::from_iter([
315 (
316 "a",
317 DType::Struct(
318 StructFields::from_iter([("a", I32.into()), ("b", DType::from(I32))]),
319 NonNullable,
320 ),
321 ),
322 ("b", I32.into()),
323 ("c", I32.into()),
324 ]),
325 NonNullable,
326 )
327 }
328
329 #[test]
330 fn test_expr_top_level_ref() {
331 let dtype = dtype();
332
333 let expr = root();
334
335 let split = StructFieldExpressionSplitter::split(expr, &dtype);
336
337 assert!(split.is_ok());
338
339 let partitioned = split.unwrap();
340
341 assert!(partitioned.root.as_any().is::<Pack>());
342 assert_eq!(
344 partitioned.partitions.len(),
345 dtype.as_struct().unwrap().names().len()
346 )
347 }
348
349 #[test]
350 fn test_expr_top_level_ref_get_item_and_split() {
351 let dtype = dtype();
352
353 let expr = get_item("b", get_item("a", root()));
354
355 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
356 let split_a = partitioned.find_partition(&"a".into());
357 assert!(split_a.is_some());
358 let split_a = split_a.unwrap();
359
360 assert_eq!(&partitioned.root, &get_item("a", root()));
361 assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", root()));
362 }
363
364 #[test]
365 fn test_expr_top_level_ref_get_item_and_split_pack() {
366 let dtype = dtype();
367
368 let expr = pack(
369 [
370 ("a", get_item("a", get_item("a", root()))),
371 ("b", get_item("b", get_item("a", root()))),
372 ("c", get_item("c", root())),
373 ],
374 NonNullable,
375 );
376 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
377
378 let split_a = partitioned.find_partition(&"a".into()).unwrap();
379 assert_eq!(
380 &simplify(split_a.clone()).unwrap(),
381 &pack(
382 [
383 (
384 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
385 get_item("a", root())
386 ),
387 (
388 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
389 get_item("b", root())
390 )
391 ],
392 NonNullable
393 )
394 );
395 let split_c = partitioned.find_partition(&"c".into()).unwrap();
396 assert_eq!(&simplify(split_c.clone()).unwrap(), &root())
397 }
398
399 #[test]
400 fn test_expr_top_level_ref_get_item_add() {
401 let dtype = dtype();
402
403 let expr = and(get_item("b", get_item("a", root())), lit(1));
404 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
405
406 assert_eq!(partitioned.partitions.len(), 1);
408 }
409
410 #[test]
411 fn test_expr_top_level_ref_get_item_add_cannot_split() {
412 let dtype = dtype();
413
414 let expr = and(get_item("b", get_item("a", root())), get_item("b", root()));
415 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
416
417 assert_eq!(partitioned.partitions.len(), 2);
419 }
420
421 #[test]
423 fn test_expr_partition_many_occurrences_of_field() {
424 let dtype = dtype();
425
426 let expr = and(
427 get_item("b", get_item("a", root())),
428 select(vec!["a".into(), "b".into()], root()),
429 );
430 let expr = simplify_typed(expr, &ScopeDType::new(dtype.clone())).unwrap();
431 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
432
433 assert_eq!(partitioned.partitions.len(), 2);
435
436 assert_eq!(
439 &partitioned.root,
440 &and(
441 get_item(
442 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
443 get_item("a", root())
444 ),
445 pack(
446 [
447 (
448 "a",
449 get_item(
450 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
451 get_item("a", root())
452 )
453 ),
454 ("b", get_item("b", root()))
455 ],
456 NonNullable
457 )
458 )
459 )
460 }
461
462 #[test]
463 fn test_expr_merge() {
464 let dtype = dtype();
465
466 let expr = merge(
467 [col("a"), pack([("b", col("b"))], NonNullable)],
468 NonNullable,
469 );
470
471 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
472 let expected = pack(
473 [
474 ("a", get_item("a", col("a"))),
475 ("b", get_item("b", col("b"))),
476 ],
477 NonNullable,
478 );
479 assert_eq!(
480 &partitioned.root, &expected,
481 "{} {}",
482 partitioned.root, expected
483 );
484 let expected = [root(), pack([("b", root())], NonNullable)]
485 .into_iter()
486 .collect::<HashSet<_>>();
487 assert_eq!(
488 &partitioned
489 .partitions
490 .clone()
491 .into_iter()
492 .collect::<HashSet<_>>(),
493 &expected,
494 "{} {}",
495 partitioned.partitions.iter().join(";"),
496 expected.iter().join(";")
497 );
498 }
499}