1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_array::aliases::hash_map::{DefaultHashBuilder, HashMap};
7use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructDType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16 LazyLock::new(DefaultHashBuilder::default);
17
18pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34 if !matches!(dtype, DType::Struct(..)) {
35 vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36 }
37 StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40#[derive(Debug)]
42pub struct PartitionedExpr {
43 pub root: ExprRef,
45 pub partitions: Box<[ExprRef]>,
47 pub partition_names: FieldNames,
49 pub partition_dtypes: Box<[DType]>,
51}
52
53impl Display for PartitionedExpr {
54 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55 write!(
56 f,
57 "root: {} {{{}}}",
58 self.root,
59 self.partition_names
60 .iter()
61 .zip(self.partitions.iter())
62 .map(|(name, partition)| format!("{}: {}", name, partition))
63 .join(", ")
64 )
65 }
66}
67
68impl PartitionedExpr {
69 pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
71 self.partition_names
72 .iter()
73 .position(|name| name == field)
74 .map(|idx| &self.partitions[idx])
75 }
76}
77
78#[derive(Debug)]
79struct StructFieldExpressionSplitter<'a> {
80 sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
81 accesses: &'a FieldAccesses<'a>,
82 scope_dtype: &'a StructDType,
83}
84
85impl<'a> StructFieldExpressionSplitter<'a> {
86 fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
87 Self {
88 sub_expressions: HashMap::new(),
89 accesses,
90 scope_dtype,
91 }
92 }
93
94 pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
95 let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
96 field.hash(&mut hasher);
97 idx.hash(&mut hasher);
98 hasher.finish().to_string().into()
99 }
100
101 fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
102 let scope_dtype = match dtype {
103 DType::Struct(scope_dtype, _) => scope_dtype,
104 _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
105 };
106
107 let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
108
109 let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
110
111 let split = expr.clone().transform_with_context(&mut splitter, ())?;
112
113 let mut remove_accesses: Vec<FieldName> = Vec::new();
114
115 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
117 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
118 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
119 for (name, exprs) in splitter.sub_expressions.into_iter() {
120 let field_dtype = scope_dtype.field(&name)?;
121 let expr = if exprs.len() == 1 {
124 remove_accesses.push(Self::field_idx_name(&name, 0));
125 exprs.first().vortex_expect("exprs is non-empty").clone()
126 } else {
127 pack(
128 exprs
129 .into_iter()
130 .enumerate()
131 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
132 Nullability::NonNullable,
133 )
134 };
135
136 let expr = simplify_typed(expr.clone(), &field_dtype)?;
137 let expr_dtype = expr.return_dtype(&field_dtype)?;
138
139 partitions.push(expr);
140 partition_names.push(name);
141 partition_dtypes.push(expr_dtype);
142 }
143
144 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
145 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
147 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
150
151 let split = split
152 .result()
153 .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
154
155 Ok(PartitionedExpr {
156 root: simplify_typed(split.result, dtype)?,
157 partitions: partitions.into_boxed_slice(),
158 partition_names: partition_names.into(),
159 partition_dtypes: partition_dtypes.into_boxed_slice(),
160 })
161 }
162}
163
164impl FolderMut for StructFieldExpressionSplitter<'_> {
165 type NodeTy = ExprRef;
166 type Out = ExprRef;
167 type Context = ();
168
169 fn visit_down(
170 &mut self,
171 node: &Self::NodeTy,
172 _context: Self::Context,
173 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
174 let access = self.accesses.get(node);
176 if access.as_ref().is_some_and(|a| a.len() == 1) {
177 let field_name = access
178 .vortex_expect("access is non-empty")
179 .iter()
180 .next()
181 .vortex_expect("expected one field");
182
183 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
186 let idx = sub_exprs.len();
187
188 let replaced = node
190 .clone()
191 .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
192 sub_exprs.push(replaced.result);
193
194 let access = get_item(
195 Self::field_idx_name(field_name, idx),
196 get_item(field_name.clone(), ident()),
197 );
198
199 return Ok(FoldDown::SkipChildren(access));
200 };
201
202 if node.as_any().is::<Identity>() {
204 let field_names = self.scope_dtype.names();
205
206 let mut elements = Vec::with_capacity(field_names.len());
207
208 for field_name in field_names.iter() {
209 let sub_exprs = self
210 .sub_expressions
211 .entry(field_name.clone())
212 .or_insert_with(Vec::new);
213
214 let idx = sub_exprs.len();
215
216 sub_exprs.push(ident());
217
218 elements.push((
219 field_name.clone(),
220 get_item(
222 Self::field_idx_name(field_name, idx),
223 get_item(field_name.clone(), ident()),
224 ),
225 ));
226 }
227
228 return Ok(FoldDown::SkipChildren(pack(
229 elements,
230 Nullability::NonNullable,
231 )));
232 }
233
234 Ok(FoldDown::Continue(()))
236 }
237
238 fn visit_up(
239 &mut self,
240 node: Self::NodeTy,
241 _context: Self::Context,
242 children: Vec<Self::Out>,
243 ) -> VortexResult<FoldUp<Self::Out>> {
244 Ok(FoldUp::Continue(node.replacing_children(children)))
245 }
246}
247
248struct ScopeStepIntoFieldExpr(FieldName);
249
250impl MutNodeVisitor for ScopeStepIntoFieldExpr {
251 type NodeTy = ExprRef;
252
253 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
254 if node.as_any().is::<Identity>() {
255 Ok(TransformResult::yes(pack(
256 [(self.0.clone(), ident())],
257 Nullability::NonNullable,
258 )))
259 } else {
260 Ok(TransformResult::no(node))
261 }
262 }
263}
264
265struct ReplaceAccessesWithChild(Vec<FieldName>);
266
267impl MutNodeVisitor for ReplaceAccessesWithChild {
268 type NodeTy = ExprRef;
269
270 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
271 if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
272 if self.0.contains(item.field()) {
273 return Ok(TransformResult::yes(item.child().clone()));
274 }
275 }
276 Ok(TransformResult::no(node))
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use std::sync::Arc;
283
284 use vortex_dtype::Nullability::NonNullable;
285 use vortex_dtype::PType::I32;
286 use vortex_dtype::{DType, StructDType};
287
288 use super::*;
289 use crate::transform::simplify::simplify;
290 use crate::transform::simplify_typed::simplify_typed;
291 use crate::{Pack, and, get_item, ident, lit, pack, select};
292
293 fn dtype() -> DType {
294 DType::Struct(
295 Arc::new(StructDType::from_iter([
296 (
297 "a",
298 DType::Struct(
299 Arc::new(StructDType::from_iter([
300 ("a", I32.into()),
301 ("b", DType::from(I32)),
302 ])),
303 NonNullable,
304 ),
305 ),
306 ("b", I32.into()),
307 ("c", I32.into()),
308 ])),
309 NonNullable,
310 )
311 }
312
313 #[test]
314 fn test_expr_top_level_ref() {
315 let dtype = dtype();
316
317 let expr = ident();
318
319 let split = StructFieldExpressionSplitter::split(expr, &dtype);
320
321 assert!(split.is_ok());
322
323 let partitioned = split.unwrap();
324
325 assert!(partitioned.root.as_any().is::<Pack>());
326 assert_eq!(
328 partitioned.partitions.len(),
329 dtype.as_struct().unwrap().names().len()
330 )
331 }
332
333 #[test]
334 fn test_expr_top_level_ref_get_item_and_split() {
335 let dtype = dtype();
336
337 let expr = get_item("b", get_item("a", ident()));
338
339 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
340 let split_a = partitioned.find_partition(&"a".into());
341 assert!(split_a.is_some());
342 let split_a = split_a.unwrap();
343
344 assert_eq!(&partitioned.root, &get_item("a", ident()));
345 assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
346 }
347
348 #[test]
349 fn test_expr_top_level_ref_get_item_and_split_pack() {
350 let dtype = dtype();
351
352 let expr = pack(
353 [
354 ("a", get_item("a", get_item("a", ident()))),
355 ("b", get_item("b", get_item("a", ident()))),
356 ("c", get_item("c", ident())),
357 ],
358 NonNullable,
359 );
360 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
361
362 let split_a = partitioned.find_partition(&"a".into()).unwrap();
363 assert_eq!(
364 &simplify(split_a.clone()).unwrap(),
365 &pack(
366 [
367 (
368 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
369 get_item("a", ident())
370 ),
371 (
372 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
373 get_item("b", ident())
374 )
375 ],
376 NonNullable
377 )
378 );
379 let split_c = partitioned.find_partition(&"c".into()).unwrap();
380 assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
381 }
382
383 #[test]
384 fn test_expr_top_level_ref_get_item_add() {
385 let dtype = dtype();
386
387 let expr = and(get_item("b", get_item("a", ident())), lit(1));
388 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
389
390 assert_eq!(partitioned.partitions.len(), 1);
392 }
393
394 #[test]
395 fn test_expr_top_level_ref_get_item_add_cannot_split() {
396 let dtype = dtype();
397
398 let expr = and(
399 get_item("b", get_item("a", ident())),
400 get_item("b", ident()),
401 );
402 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
403
404 assert_eq!(partitioned.partitions.len(), 2);
406 }
407
408 #[test]
410 fn test_expr_partition_many_occurrences_of_field() {
411 let dtype = dtype();
412
413 let expr = and(
414 get_item("b", get_item("a", ident())),
415 select(vec!["a".into(), "b".into()], ident()),
416 );
417 let expr = simplify_typed(expr, &dtype).unwrap();
418 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
419
420 assert_eq!(partitioned.partitions.len(), 2);
422
423 assert_eq!(
426 &partitioned.root,
427 &and(
428 get_item(
429 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
430 get_item("a", ident())
431 ),
432 pack(
433 [
434 (
435 "a",
436 get_item(
437 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
438 get_item("a", ident())
439 )
440 ),
441 ("b", get_item("b", ident()))
442 ],
443 NonNullable
444 )
445 )
446 )
447 }
448}