1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, ScopeDType, get_item, is_root, pack, root};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16 LazyLock::new(DefaultHashBuilder::default);
17
18pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34 if !matches!(dtype, DType::Struct(..)) {
35 vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36 }
37 StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40#[derive(Debug)]
43pub struct PartitionedExpr {
44 pub root: ExprRef,
46 pub partitions: Box<[ExprRef]>,
48 pub partition_names: FieldNames,
50 pub partition_dtypes: Box<[DType]>,
52}
53
54impl Display for PartitionedExpr {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 write!(
57 f,
58 "root: {} {{{}}}",
59 self.root,
60 self.partition_names
61 .iter()
62 .zip(self.partitions.iter())
63 .map(|(name, partition)| format!("{name}: {partition}"))
64 .join(", ")
65 )
66 }
67}
68
69impl PartitionedExpr {
70 pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
72 self.partition_names
73 .iter()
74 .position(|name| name == field)
75 .map(|idx| &self.partitions[idx])
76 }
77}
78
79#[derive(Debug)]
80struct StructFieldExpressionSplitter<'a> {
81 sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
82 accesses: &'a FieldAccesses<'a>,
83 scope_dtype: &'a StructFields,
84}
85
86impl<'a> StructFieldExpressionSplitter<'a> {
87 fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructFields) -> Self {
88 Self {
89 sub_expressions: HashMap::new(),
90 accesses,
91 scope_dtype,
92 }
93 }
94
95 pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
96 let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
97 field.hash(&mut hasher);
98 idx.hash(&mut hasher);
99 hasher.finish().to_string().into()
100 }
101
102 fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
103 let scope_dtype = match dtype {
104 DType::Struct(scope_dtype, _) => scope_dtype,
105 _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
106 };
107
108 let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
109
110 let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
111
112 let split = expr.clone().transform_with_context(&mut splitter, ())?;
113
114 let mut remove_accesses: Vec<FieldName> = Vec::new();
115
116 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
118 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
119 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
120 for (name, exprs) in splitter.sub_expressions.into_iter() {
121 let field_dtype = scope_dtype.field(&name)?;
122 let expr = if exprs.len() == 1 {
125 remove_accesses.push(Self::field_idx_name(&name, 0));
126 exprs.first().vortex_expect("exprs is non-empty").clone()
127 } else {
128 pack(
129 exprs
130 .into_iter()
131 .enumerate()
132 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
133 Nullability::NonNullable,
134 )
135 };
136
137 let field_ctx = ScopeDType::new(field_dtype);
138 let expr = simplify_typed(expr.clone(), &field_ctx)?;
139 let expr_dtype = expr.return_dtype(&field_ctx)?;
140
141 partitions.push(expr);
142 partition_names.push(name);
143 partition_dtypes.push(expr_dtype);
144 }
145
146 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
147 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
149 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
152
153 let split = split
154 .result()
155 .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
156
157 let ctx = ScopeDType::new(dtype.clone());
158
159 Ok(PartitionedExpr {
160 root: simplify_typed(split.into_inner(), &ctx)?,
161 partitions: partitions.into_boxed_slice(),
162 partition_names: partition_names.into(),
163 partition_dtypes: partition_dtypes.into_boxed_slice(),
164 })
165 }
166}
167
168impl FolderMut for StructFieldExpressionSplitter<'_> {
169 type NodeTy = ExprRef;
170 type Out = ExprRef;
171 type Context = ();
172
173 fn visit_down(
174 &mut self,
175 node: &Self::NodeTy,
176 _context: Self::Context,
177 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
178 let access = self.accesses.get(node);
180 if access.as_ref().is_some_and(|a| a.len() == 1) {
181 let field_name = access
182 .vortex_expect("access is non-empty")
183 .iter()
184 .next()
185 .vortex_expect("expected one field");
186
187 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
190 let idx = sub_exprs.len();
191
192 let replaced = node
194 .clone()
195 .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
196 sub_exprs.push(replaced.into_inner());
197
198 let access = get_item(
199 Self::field_idx_name(field_name, idx),
200 get_item(field_name.clone(), root()),
201 );
202
203 return Ok(FoldDown::SkipChildren(access));
204 };
205
206 if is_root(node) {
208 let field_names = self.scope_dtype.names();
209
210 let mut elements = Vec::with_capacity(field_names.len());
211
212 for field_name in field_names.iter() {
213 let sub_exprs = self
214 .sub_expressions
215 .entry(field_name.clone())
216 .or_insert_with(Vec::new);
217
218 let idx = sub_exprs.len();
219
220 sub_exprs.push(root());
221
222 elements.push((
223 field_name.clone(),
224 get_item(
226 Self::field_idx_name(field_name, idx),
227 get_item(field_name.clone(), root()),
228 ),
229 ));
230 }
231
232 return Ok(FoldDown::SkipChildren(pack(
233 elements,
234 Nullability::NonNullable,
235 )));
236 }
237
238 Ok(FoldDown::Continue(()))
240 }
241
242 fn visit_up(
243 &mut self,
244 node: Self::NodeTy,
245 _context: Self::Context,
246 children: Vec<Self::Out>,
247 ) -> VortexResult<FoldUp<Self::Out>> {
248 Ok(FoldUp::Continue(node.replacing_children(children)))
249 }
250}
251
252struct ScopeStepIntoFieldExpr(FieldName);
253
254impl MutNodeVisitor for ScopeStepIntoFieldExpr {
255 type NodeTy = ExprRef;
256
257 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
258 if is_root(&node) {
259 Ok(TransformResult::yes(pack(
260 [(self.0.clone(), root())],
261 Nullability::NonNullable,
262 )))
263 } else {
264 Ok(TransformResult::no(node))
265 }
266 }
267}
268
269struct ReplaceAccessesWithChild(Vec<FieldName>);
270
271impl MutNodeVisitor for ReplaceAccessesWithChild {
272 type NodeTy = ExprRef;
273
274 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
275 if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
276 if self.0.contains(item.field()) {
277 return Ok(TransformResult::yes(item.child().clone()));
278 }
279 }
280 Ok(TransformResult::no(node))
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use std::sync::Arc;
287
288 use vortex_dtype::Nullability::NonNullable;
289 use vortex_dtype::PType::I32;
290 use vortex_dtype::{DType, StructFields};
291
292 use super::*;
293 use crate::transform::simplify::simplify;
294 use crate::transform::simplify_typed::simplify_typed;
295 use crate::{Pack, and, get_item, lit, pack, root, select};
296
297 fn dtype() -> DType {
298 DType::Struct(
299 Arc::new(StructFields::from_iter([
300 (
301 "a",
302 DType::Struct(
303 Arc::new(StructFields::from_iter([
304 ("a", I32.into()),
305 ("b", DType::from(I32)),
306 ])),
307 NonNullable,
308 ),
309 ),
310 ("b", I32.into()),
311 ("c", I32.into()),
312 ])),
313 NonNullable,
314 )
315 }
316
317 #[test]
318 fn test_expr_top_level_ref() {
319 let dtype = dtype();
320
321 let expr = root();
322
323 let split = StructFieldExpressionSplitter::split(expr, &dtype);
324
325 assert!(split.is_ok());
326
327 let partitioned = split.unwrap();
328
329 assert!(partitioned.root.as_any().is::<Pack>());
330 assert_eq!(
332 partitioned.partitions.len(),
333 dtype.as_struct().unwrap().names().len()
334 )
335 }
336
337 #[test]
338 fn test_expr_top_level_ref_get_item_and_split() {
339 let dtype = dtype();
340
341 let expr = get_item("b", get_item("a", root()));
342
343 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
344 let split_a = partitioned.find_partition(&"a".into());
345 assert!(split_a.is_some());
346 let split_a = split_a.unwrap();
347
348 assert_eq!(&partitioned.root, &get_item("a", root()));
349 assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", root()));
350 }
351
352 #[test]
353 fn test_expr_top_level_ref_get_item_and_split_pack() {
354 let dtype = dtype();
355
356 let expr = pack(
357 [
358 ("a", get_item("a", get_item("a", root()))),
359 ("b", get_item("b", get_item("a", root()))),
360 ("c", get_item("c", root())),
361 ],
362 NonNullable,
363 );
364 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
365
366 let split_a = partitioned.find_partition(&"a".into()).unwrap();
367 assert_eq!(
368 &simplify(split_a.clone()).unwrap(),
369 &pack(
370 [
371 (
372 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
373 get_item("a", root())
374 ),
375 (
376 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
377 get_item("b", root())
378 )
379 ],
380 NonNullable
381 )
382 );
383 let split_c = partitioned.find_partition(&"c".into()).unwrap();
384 assert_eq!(&simplify(split_c.clone()).unwrap(), &root())
385 }
386
387 #[test]
388 fn test_expr_top_level_ref_get_item_add() {
389 let dtype = dtype();
390
391 let expr = and(get_item("b", get_item("a", root())), lit(1));
392 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
393
394 assert_eq!(partitioned.partitions.len(), 1);
396 }
397
398 #[test]
399 fn test_expr_top_level_ref_get_item_add_cannot_split() {
400 let dtype = dtype();
401
402 let expr = and(get_item("b", get_item("a", root())), get_item("b", root()));
403 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
404
405 assert_eq!(partitioned.partitions.len(), 2);
407 }
408
409 #[test]
411 fn test_expr_partition_many_occurrences_of_field() {
412 let dtype = dtype();
413
414 let expr = and(
415 get_item("b", get_item("a", root())),
416 select(vec!["a".into(), "b".into()], root()),
417 );
418 let expr = simplify_typed(expr, &ScopeDType::new(dtype.clone())).unwrap();
419 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
420
421 assert_eq!(partitioned.partitions.len(), 2);
423
424 assert_eq!(
427 &partitioned.root,
428 &and(
429 get_item(
430 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
431 get_item("a", root())
432 ),
433 pack(
434 [
435 (
436 "a",
437 get_item(
438 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
439 get_item("a", root())
440 )
441 ),
442 ("b", get_item("b", root()))
443 ],
444 NonNullable
445 )
446 )
447 )
448 }
449}