1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, ScopeDType, get_item, is_root, pack, root};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16 LazyLock::new(DefaultHashBuilder::default);
17
18pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34 if !matches!(dtype, DType::Struct(..)) {
35 vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36 }
37 StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40#[derive(Debug)]
43pub struct PartitionedExpr {
44 pub root: ExprRef,
46 pub partitions: Box<[ExprRef]>,
48 pub partition_names: FieldNames,
50 pub partition_dtypes: Box<[DType]>,
52}
53
54impl Display for PartitionedExpr {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 write!(
57 f,
58 "root: {} {{{}}}",
59 self.root,
60 self.partition_names
61 .iter()
62 .zip(self.partitions.iter())
63 .map(|(name, partition)| format!("{name}: {partition}"))
64 .join(", ")
65 )
66 }
67}
68
69impl PartitionedExpr {
70 pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
72 self.partition_names
73 .iter()
74 .position(|name| name == field)
75 .map(|idx| &self.partitions[idx])
76 }
77}
78
79#[derive(Debug)]
80struct StructFieldExpressionSplitter<'a> {
81 sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
82 accesses: &'a FieldAccesses<'a>,
83 scope_dtype: &'a StructFields,
84}
85
86impl<'a> StructFieldExpressionSplitter<'a> {
87 fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructFields) -> Self {
88 Self {
89 sub_expressions: HashMap::new(),
90 accesses,
91 scope_dtype,
92 }
93 }
94
95 pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
96 let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
97 field.hash(&mut hasher);
98 idx.hash(&mut hasher);
99 hasher.finish().to_string().into()
100 }
101
102 fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
103 let scope_dtype = match dtype {
104 DType::Struct(scope_dtype, _) => scope_dtype,
105 _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
106 };
107
108 let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
109
110 let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
111
112 let split = expr.clone().transform_with_context(&mut splitter, ())?;
113
114 let mut remove_accesses: Vec<FieldName> = Vec::new();
115
116 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
118 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
119 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
120 for (name, exprs) in splitter.sub_expressions.into_iter() {
121 let expr = if exprs.len() == 1 {
124 remove_accesses.push(Self::field_idx_name(&name, 0));
125 exprs.first().vortex_expect("exprs is non-empty").clone()
126 } else {
127 pack(
128 exprs
129 .into_iter()
130 .enumerate()
131 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
132 Nullability::NonNullable,
133 )
134 };
135
136 let field_dtype = scope_dtype
137 .field(&name)
138 .ok_or_else(|| vortex_err!("Missing field {name}"))?;
139 let field_ctx = ScopeDType::new(field_dtype);
140 let expr = simplify_typed(expr.clone(), &field_ctx)?;
141 let expr_dtype = expr.return_dtype(&field_ctx)?;
142
143 partitions.push(expr);
144 partition_names.push(name);
145 partition_dtypes.push(expr_dtype);
146 }
147
148 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
149 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
151 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
154
155 let split = split
156 .result()
157 .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
158
159 let ctx = ScopeDType::new(dtype.clone());
160
161 Ok(PartitionedExpr {
162 root: simplify_typed(split.into_inner(), &ctx)?,
163 partitions: partitions.into_boxed_slice(),
164 partition_names: partition_names.into(),
165 partition_dtypes: partition_dtypes.into_boxed_slice(),
166 })
167 }
168}
169
170impl FolderMut for StructFieldExpressionSplitter<'_> {
171 type NodeTy = ExprRef;
172 type Out = ExprRef;
173 type Context = ();
174
175 fn visit_down(
176 &mut self,
177 node: &Self::NodeTy,
178 _context: Self::Context,
179 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
180 let access = self.accesses.get(node);
182 if access.as_ref().is_some_and(|a| a.len() == 1) {
183 let field_name = access
184 .vortex_expect("access is non-empty")
185 .iter()
186 .next()
187 .vortex_expect("expected one field");
188
189 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
190 let idx = sub_exprs.len();
191
192 let replaced = node
194 .clone()
195 .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
196 sub_exprs.push(replaced.into_inner());
197
198 let access = get_item(
199 Self::field_idx_name(field_name, idx),
200 get_item(field_name.clone(), root()),
201 );
202
203 return Ok(FoldDown::SkipChildren(access));
204 };
205
206 if is_root(node) {
208 let field_names = self.scope_dtype.names();
209
210 let mut elements = Vec::with_capacity(field_names.len());
211
212 for field_name in field_names.iter() {
213 let sub_exprs = self
214 .sub_expressions
215 .entry(field_name.clone())
216 .or_insert_with(Vec::new);
217
218 let idx = sub_exprs.len();
219
220 sub_exprs.push(root());
221
222 elements.push((
223 field_name.clone(),
224 get_item(
226 Self::field_idx_name(field_name, idx),
227 get_item(field_name.clone(), root()),
228 ),
229 ));
230 }
231
232 return Ok(FoldDown::SkipChildren(pack(
233 elements,
234 Nullability::NonNullable,
235 )));
236 }
237
238 Ok(FoldDown::Continue(()))
240 }
241
242 fn visit_up(
243 &mut self,
244 node: Self::NodeTy,
245 _context: Self::Context,
246 children: Vec<Self::Out>,
247 ) -> VortexResult<FoldUp<Self::Out>> {
248 Ok(FoldUp::Continue(node.replacing_children(children)))
249 }
250}
251
252struct ScopeStepIntoFieldExpr(FieldName);
253
254impl MutNodeVisitor for ScopeStepIntoFieldExpr {
255 type NodeTy = ExprRef;
256
257 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
258 if is_root(&node) {
259 Ok(TransformResult::yes(pack(
260 [(self.0.clone(), root())],
261 Nullability::NonNullable,
262 )))
263 } else {
264 Ok(TransformResult::no(node))
265 }
266 }
267}
268
269pub(crate) struct ReplaceAccessesWithChild(Vec<FieldName>);
270
271impl ReplaceAccessesWithChild {
272 pub(crate) fn new(field_names: Vec<FieldName>) -> Self {
273 Self(field_names)
274 }
275}
276
277impl MutNodeVisitor for ReplaceAccessesWithChild {
278 type NodeTy = ExprRef;
279
280 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
281 if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
282 if self.0.contains(item.field()) {
283 return Ok(TransformResult::yes(item.child().clone()));
284 }
285 }
286 Ok(TransformResult::no(node))
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use std::sync::Arc;
293
294 use vortex_dtype::Nullability::NonNullable;
295 use vortex_dtype::PType::I32;
296 use vortex_dtype::{DType, StructFields};
297
298 use super::*;
299 use crate::transform::simplify::simplify;
300 use crate::transform::simplify_typed::simplify_typed;
301 use crate::{Pack, and, get_item, lit, pack, root, select};
302
303 fn dtype() -> DType {
304 DType::Struct(
305 Arc::new(StructFields::from_iter([
306 (
307 "a",
308 DType::Struct(
309 Arc::new(StructFields::from_iter([
310 ("a", I32.into()),
311 ("b", DType::from(I32)),
312 ])),
313 NonNullable,
314 ),
315 ),
316 ("b", I32.into()),
317 ("c", I32.into()),
318 ])),
319 NonNullable,
320 )
321 }
322
323 #[test]
324 fn test_expr_top_level_ref() {
325 let dtype = dtype();
326
327 let expr = root();
328
329 let split = StructFieldExpressionSplitter::split(expr, &dtype);
330
331 assert!(split.is_ok());
332
333 let partitioned = split.unwrap();
334
335 assert!(partitioned.root.as_any().is::<Pack>());
336 assert_eq!(
338 partitioned.partitions.len(),
339 dtype.as_struct().unwrap().names().len()
340 )
341 }
342
343 #[test]
344 fn test_expr_top_level_ref_get_item_and_split() {
345 let dtype = dtype();
346
347 let expr = get_item("b", get_item("a", root()));
348
349 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
350 let split_a = partitioned.find_partition(&"a".into());
351 assert!(split_a.is_some());
352 let split_a = split_a.unwrap();
353
354 assert_eq!(&partitioned.root, &get_item("a", root()));
355 assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", root()));
356 }
357
358 #[test]
359 fn test_expr_top_level_ref_get_item_and_split_pack() {
360 let dtype = dtype();
361
362 let expr = pack(
363 [
364 ("a", get_item("a", get_item("a", root()))),
365 ("b", get_item("b", get_item("a", root()))),
366 ("c", get_item("c", root())),
367 ],
368 NonNullable,
369 );
370 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
371
372 let split_a = partitioned.find_partition(&"a".into()).unwrap();
373 assert_eq!(
374 &simplify(split_a.clone()).unwrap(),
375 &pack(
376 [
377 (
378 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
379 get_item("a", root())
380 ),
381 (
382 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
383 get_item("b", root())
384 )
385 ],
386 NonNullable
387 )
388 );
389 let split_c = partitioned.find_partition(&"c".into()).unwrap();
390 assert_eq!(&simplify(split_c.clone()).unwrap(), &root())
391 }
392
393 #[test]
394 fn test_expr_top_level_ref_get_item_add() {
395 let dtype = dtype();
396
397 let expr = and(get_item("b", get_item("a", root())), lit(1));
398 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
399
400 assert_eq!(partitioned.partitions.len(), 1);
402 }
403
404 #[test]
405 fn test_expr_top_level_ref_get_item_add_cannot_split() {
406 let dtype = dtype();
407
408 let expr = and(get_item("b", get_item("a", root())), get_item("b", root()));
409 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
410
411 assert_eq!(partitioned.partitions.len(), 2);
413 }
414
415 #[test]
417 fn test_expr_partition_many_occurrences_of_field() {
418 let dtype = dtype();
419
420 let expr = and(
421 get_item("b", get_item("a", root())),
422 select(vec!["a".into(), "b".into()], root()),
423 );
424 let expr = simplify_typed(expr, &ScopeDType::new(dtype.clone())).unwrap();
425 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
426
427 assert_eq!(partitioned.partitions.len(), 2);
429
430 assert_eq!(
433 &partitioned.root,
434 &and(
435 get_item(
436 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
437 get_item("a", root())
438 ),
439 pack(
440 [
441 (
442 "a",
443 get_item(
444 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
445 get_item("a", root())
446 )
447 ),
448 ("b", get_item("b", root()))
449 ],
450 NonNullable
451 )
452 )
453 )
454 }
455}