1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_array::aliases::hash_map::{DefaultHashBuilder, HashMap};
7use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9
10use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
11use crate::transform::simplify_typed::simplify_typed;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
13use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16 LazyLock::new(DefaultHashBuilder::default);
17
18pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
34 if !matches!(dtype, DType::Struct(..)) {
35 vortex_bail!("Expected a struct dtype, got {:?}", dtype);
36 }
37 StructFieldExpressionSplitter::split(expr, dtype)
38}
39
40#[derive(Debug)]
42pub struct PartitionedExpr {
43 pub root: ExprRef,
45 pub partitions: Box<[ExprRef]>,
47 pub partition_names: FieldNames,
49 pub partition_dtypes: Box<[DType]>,
51}
52
53impl Display for PartitionedExpr {
54 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55 write!(
56 f,
57 "root: {} {{{}}}",
58 self.root,
59 self.partition_names
60 .iter()
61 .zip(self.partitions.iter())
62 .map(|(name, partition)| format!("{}: {}", name, partition))
63 .join(", ")
64 )
65 }
66}
67
68impl PartitionedExpr {
69 pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
71 self.partition_names
72 .iter()
73 .position(|name| name == field)
74 .map(|idx| &self.partitions[idx])
75 }
76}
77
78#[derive(Debug)]
79struct StructFieldExpressionSplitter<'a> {
80 sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
81 accesses: &'a FieldAccesses<'a>,
82 scope_dtype: &'a StructDType,
83}
84
85impl<'a> StructFieldExpressionSplitter<'a> {
86 fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
87 Self {
88 sub_expressions: HashMap::new(),
89 accesses,
90 scope_dtype,
91 }
92 }
93
94 pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
95 let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
96 field.hash(&mut hasher);
97 idx.hash(&mut hasher);
98 hasher.finish().to_string().into()
99 }
100
101 fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
102 let scope_dtype = match dtype {
103 DType::Struct(scope_dtype, _) => scope_dtype,
104 _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
105 };
106
107 let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
108
109 let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
110
111 let split = expr.clone().transform_with_context(&mut splitter, ())?;
112
113 let mut remove_accesses: Vec<FieldName> = Vec::new();
114
115 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
117 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
118 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
119 for (name, exprs) in splitter.sub_expressions.into_iter() {
120 let field_dtype = scope_dtype.field(&name)?;
121 let expr = if exprs.len() == 1 {
124 remove_accesses.push(Self::field_idx_name(&name, 0));
125 exprs.first().vortex_expect("exprs is non-empty").clone()
126 } else {
127 pack(
128 exprs
129 .into_iter()
130 .enumerate()
131 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
132 )
133 };
134
135 let expr = simplify_typed(expr.clone(), &field_dtype)?;
136 let expr_dtype = expr.return_dtype(&field_dtype)?;
137
138 partitions.push(expr);
139 partition_names.push(name);
140 partition_dtypes.push(expr_dtype);
141 }
142
143 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
144 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
146 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
149
150 let split = split
151 .result()
152 .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
153
154 Ok(PartitionedExpr {
155 root: simplify_typed(split.result, dtype)?,
156 partitions: partitions.into_boxed_slice(),
157 partition_names: partition_names.into(),
158 partition_dtypes: partition_dtypes.into_boxed_slice(),
159 })
160 }
161}
162
163impl FolderMut for StructFieldExpressionSplitter<'_> {
164 type NodeTy = ExprRef;
165 type Out = ExprRef;
166 type Context = ();
167
168 fn visit_down(
169 &mut self,
170 node: &Self::NodeTy,
171 _context: Self::Context,
172 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
173 let access = self.accesses.get(node);
175 if access.as_ref().is_some_and(|a| a.len() == 1) {
176 let field_name = access
177 .vortex_expect("access is non-empty")
178 .iter()
179 .next()
180 .vortex_expect("expected one field");
181
182 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
185 let idx = sub_exprs.len();
186
187 let replaced = node
189 .clone()
190 .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
191 sub_exprs.push(replaced.result);
192
193 let access = get_item(
194 Self::field_idx_name(field_name, idx),
195 get_item(field_name.clone(), ident()),
196 );
197
198 return Ok(FoldDown::SkipChildren(access));
199 };
200
201 if node.as_any().is::<Identity>() {
203 let field_names = self.scope_dtype.names();
204
205 let mut elements = Vec::with_capacity(field_names.len());
206
207 for field_name in field_names.iter() {
208 let sub_exprs = self
209 .sub_expressions
210 .entry(field_name.clone())
211 .or_insert_with(Vec::new);
212
213 let idx = sub_exprs.len();
214
215 sub_exprs.push(ident());
216
217 elements.push((
218 field_name.clone(),
219 get_item(
221 Self::field_idx_name(field_name, idx),
222 get_item(field_name.clone(), ident()),
223 ),
224 ));
225 }
226
227 return Ok(FoldDown::SkipChildren(pack(elements)));
228 }
229
230 Ok(FoldDown::Continue(()))
232 }
233
234 fn visit_up(
235 &mut self,
236 node: Self::NodeTy,
237 _context: Self::Context,
238 children: Vec<Self::Out>,
239 ) -> VortexResult<FoldUp<Self::Out>> {
240 Ok(FoldUp::Continue(node.replacing_children(children)))
241 }
242}
243
244struct ScopeStepIntoFieldExpr(FieldName);
245
246impl MutNodeVisitor for ScopeStepIntoFieldExpr {
247 type NodeTy = ExprRef;
248
249 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
250 if node.as_any().is::<Identity>() {
251 Ok(TransformResult::yes(pack([(self.0.clone(), ident())])))
252 } else {
253 Ok(TransformResult::no(node))
254 }
255 }
256}
257
258struct ReplaceAccessesWithChild(Vec<FieldName>);
259
260impl MutNodeVisitor for ReplaceAccessesWithChild {
261 type NodeTy = ExprRef;
262
263 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
264 if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
265 if self.0.contains(item.field()) {
266 return Ok(TransformResult::yes(item.child().clone()));
267 }
268 }
269 Ok(TransformResult::no(node))
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use std::sync::Arc;
276
277 use vortex_dtype::Nullability::NonNullable;
278 use vortex_dtype::PType::I32;
279 use vortex_dtype::{DType, StructDType};
280
281 use super::*;
282 use crate::transform::simplify::simplify;
283 use crate::transform::simplify_typed::simplify_typed;
284 use crate::{Pack, and, get_item, ident, lit, pack, select};
285
286 fn dtype() -> DType {
287 DType::Struct(
288 Arc::new(StructDType::from_iter([
289 (
290 "a",
291 DType::Struct(
292 Arc::new(StructDType::from_iter([
293 ("a", I32.into()),
294 ("b", DType::from(I32)),
295 ])),
296 NonNullable,
297 ),
298 ),
299 ("b", I32.into()),
300 ("c", I32.into()),
301 ])),
302 NonNullable,
303 )
304 }
305
306 #[test]
307 fn test_expr_top_level_ref() {
308 let dtype = dtype();
309
310 let expr = ident();
311
312 let split = StructFieldExpressionSplitter::split(expr, &dtype);
313
314 assert!(split.is_ok());
315
316 let partitioned = split.unwrap();
317
318 assert!(partitioned.root.as_any().is::<Pack>());
319 assert_eq!(
321 partitioned.partitions.len(),
322 dtype.as_struct().unwrap().names().len()
323 )
324 }
325
326 #[test]
327 fn test_expr_top_level_ref_get_item_and_split() {
328 let dtype = dtype();
329
330 let expr = get_item("b", get_item("a", ident()));
331
332 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
333 let split_a = partitioned.find_partition(&"a".into());
334 assert!(split_a.is_some());
335 let split_a = split_a.unwrap();
336
337 assert_eq!(&partitioned.root, &get_item("a", ident()));
338 assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
339 }
340
341 #[test]
342 fn test_expr_top_level_ref_get_item_and_split_pack() {
343 let dtype = dtype();
344
345 let expr = pack([
346 ("a", get_item("a", get_item("a", ident()))),
347 ("b", get_item("b", get_item("a", ident()))),
348 ("c", get_item("c", ident())),
349 ]);
350 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
351
352 let split_a = partitioned.find_partition(&"a".into()).unwrap();
353 assert_eq!(
354 &simplify(split_a.clone()).unwrap(),
355 &pack([
356 (
357 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
358 get_item("a", ident())
359 ),
360 (
361 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
362 get_item("b", ident())
363 )
364 ])
365 );
366 let split_c = partitioned.find_partition(&"c".into()).unwrap();
367 assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
368 }
369
370 #[test]
371 fn test_expr_top_level_ref_get_item_add() {
372 let dtype = dtype();
373
374 let expr = and(get_item("b", get_item("a", ident())), lit(1));
375 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
376
377 assert_eq!(partitioned.partitions.len(), 1);
379 }
380
381 #[test]
382 fn test_expr_top_level_ref_get_item_add_cannot_split() {
383 let dtype = dtype();
384
385 let expr = and(
386 get_item("b", get_item("a", ident())),
387 get_item("b", ident()),
388 );
389 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
390
391 assert_eq!(partitioned.partitions.len(), 2);
393 }
394
395 #[test]
397 fn test_expr_partition_many_occurrences_of_field() {
398 let dtype = dtype();
399
400 let expr = and(
401 get_item("b", get_item("a", ident())),
402 select(vec!["a".into(), "b".into()], ident()),
403 );
404 let expr = simplify_typed(expr, &dtype).unwrap();
405 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
406
407 assert_eq!(partitioned.partitions.len(), 2);
409
410 assert_eq!(
413 &partitioned.root,
414 &and(
415 get_item(
416 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
417 get_item("a", ident())
418 ),
419 pack([
420 (
421 "a",
422 get_item(
423 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
424 get_item("a", ident())
425 )
426 ),
427 ("b", get_item("b", ident()))
428 ])
429 )
430 )
431 }
432}