1use std::fmt::{Display, Formatter};
2
3use itertools::Itertools;
4use vortex_array::aliases::hash_map::HashMap;
5use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
6use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7
8use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
9use crate::transform::simplify_typed::simplify_typed;
10use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
11use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
12
13pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
29 if !matches!(dtype, DType::Struct(..)) {
30 vortex_bail!("Expected a struct dtype, got {:?}", dtype);
31 }
32 StructFieldExpressionSplitter::split(expr, dtype)
33}
34
35#[derive(Debug)]
37pub struct PartitionedExpr {
38 pub root: ExprRef,
40 pub partitions: Box<[ExprRef]>,
42 pub partition_names: FieldNames,
44 pub partition_dtypes: Box<[DType]>,
46}
47
48impl Display for PartitionedExpr {
49 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50 write!(
51 f,
52 "root: {} {{{}}}",
53 self.root,
54 self.partition_names
55 .iter()
56 .zip(self.partitions.iter())
57 .map(|(name, partition)| format!("{}: {}", name, partition))
58 .join(", ")
59 )
60 }
61}
62
63impl PartitionedExpr {
64 pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
66 self.partition_names
67 .iter()
68 .position(|name| name == field)
69 .map(|idx| &self.partitions[idx])
70 }
71}
72
73#[derive(Debug)]
74struct StructFieldExpressionSplitter<'a> {
75 sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
76 accesses: &'a FieldAccesses<'a>,
77 scope_dtype: &'a StructDType,
78}
79
80impl<'a> StructFieldExpressionSplitter<'a> {
81 fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
82 Self {
83 sub_expressions: HashMap::new(),
84 accesses,
85 scope_dtype,
86 }
87 }
88
89 pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
90 format!("__e__{}.{}", field, idx).into()
91 }
92
93 fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
94 let scope_dtype = match dtype {
95 DType::Struct(scope_dtype, _) => scope_dtype,
96 _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
97 };
98
99 let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
100
101 let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
102
103 let split = expr.clone().transform_with_context(&mut splitter, ())?;
104
105 let mut remove_accesses: Vec<FieldName> = Vec::new();
106
107 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
109 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
110 let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
111 for (name, exprs) in splitter.sub_expressions.into_iter() {
112 let field_dtype = scope_dtype.field(&name)?;
113 let expr = if exprs.len() == 1 {
116 remove_accesses.push(Self::field_idx_name(&name, 0));
117 exprs.first().vortex_expect("exprs is non-empty").clone()
118 } else {
119 pack(
120 exprs
121 .into_iter()
122 .enumerate()
123 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
124 )
125 };
126
127 let expr = simplify_typed(expr.clone(), &field_dtype)?;
128 let expr_dtype = expr.return_dtype(&field_dtype)?;
129
130 partitions.push(expr);
131 partition_names.push(name);
132 partition_dtypes.push(expr_dtype);
133 }
134
135 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
136 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
138 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
141
142 let split = split
143 .result()
144 .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
145
146 Ok(PartitionedExpr {
147 root: simplify_typed(split.result, dtype)?,
148 partitions: partitions.into_boxed_slice(),
149 partition_names: partition_names.into(),
150 partition_dtypes: partition_dtypes.into_boxed_slice(),
151 })
152 }
153}
154
155impl FolderMut for StructFieldExpressionSplitter<'_> {
156 type NodeTy = ExprRef;
157 type Out = ExprRef;
158 type Context = ();
159
160 fn visit_down(
161 &mut self,
162 node: &Self::NodeTy,
163 _context: Self::Context,
164 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
165 let access = self.accesses.get(node);
167 if access.as_ref().is_some_and(|a| a.len() == 1) {
168 let field_name = access
169 .vortex_expect("access is non-empty")
170 .iter()
171 .next()
172 .vortex_expect("expected one field");
173
174 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
177 let idx = sub_exprs.len();
178
179 let replaced = node
181 .clone()
182 .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
183 sub_exprs.push(replaced.result);
184
185 let access = get_item(
186 Self::field_idx_name(field_name, idx),
187 get_item(field_name.clone(), ident()),
188 );
189
190 return Ok(FoldDown::SkipChildren(access));
191 };
192
193 if node.as_any().is::<Identity>() {
195 let field_names = self.scope_dtype.names();
196
197 let mut elements = Vec::with_capacity(field_names.len());
198
199 for field_name in field_names.iter() {
200 let sub_exprs = self
201 .sub_expressions
202 .entry(field_name.clone())
203 .or_insert_with(Vec::new);
204
205 let idx = sub_exprs.len();
206
207 sub_exprs.push(ident());
208
209 elements.push((
210 field_name.clone(),
211 get_item(
213 Self::field_idx_name(field_name, idx),
214 get_item(field_name.clone(), ident()),
215 ),
216 ));
217 }
218
219 return Ok(FoldDown::SkipChildren(pack(elements)));
220 }
221
222 Ok(FoldDown::Continue(()))
224 }
225
226 fn visit_up(
227 &mut self,
228 node: Self::NodeTy,
229 _context: Self::Context,
230 children: Vec<Self::Out>,
231 ) -> VortexResult<FoldUp<Self::Out>> {
232 Ok(FoldUp::Continue(node.replacing_children(children)))
233 }
234}
235
236struct ScopeStepIntoFieldExpr(FieldName);
237
238impl MutNodeVisitor for ScopeStepIntoFieldExpr {
239 type NodeTy = ExprRef;
240
241 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
242 if node.as_any().is::<Identity>() {
243 Ok(TransformResult::yes(pack([(self.0.clone(), ident())])))
244 } else {
245 Ok(TransformResult::no(node))
246 }
247 }
248}
249
250struct ReplaceAccessesWithChild(Vec<FieldName>);
251
252impl MutNodeVisitor for ReplaceAccessesWithChild {
253 type NodeTy = ExprRef;
254
255 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
256 if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
257 if self.0.contains(item.field()) {
258 return Ok(TransformResult::yes(item.child().clone()));
259 }
260 }
261 Ok(TransformResult::no(node))
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use std::sync::Arc;
268
269 use vortex_dtype::Nullability::NonNullable;
270 use vortex_dtype::PType::I32;
271 use vortex_dtype::{DType, StructDType};
272
273 use super::*;
274 use crate::transform::simplify::simplify;
275 use crate::transform::simplify_typed::simplify_typed;
276 use crate::{Pack, and, get_item, ident, lit, pack, select};
277
278 fn dtype() -> DType {
279 DType::Struct(
280 Arc::new(StructDType::from_iter([
281 (
282 "a",
283 DType::Struct(
284 Arc::new(StructDType::from_iter([
285 ("a", I32.into()),
286 ("b", DType::from(I32)),
287 ])),
288 NonNullable,
289 ),
290 ),
291 ("b", I32.into()),
292 ("c", I32.into()),
293 ])),
294 NonNullable,
295 )
296 }
297
298 #[test]
299 fn test_expr_top_level_ref() {
300 let dtype = dtype();
301
302 let expr = ident();
303
304 let split = StructFieldExpressionSplitter::split(expr, &dtype);
305
306 assert!(split.is_ok());
307
308 let partitioned = split.unwrap();
309
310 assert!(partitioned.root.as_any().is::<Pack>());
311 assert_eq!(
313 partitioned.partitions.len(),
314 dtype.as_struct().unwrap().names().len()
315 )
316 }
317
318 #[test]
319 fn test_expr_top_level_ref_get_item_and_split() {
320 let dtype = dtype();
321
322 let expr = get_item("b", get_item("a", ident()));
323
324 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
325 let split_a = partitioned.find_partition(&"a".into());
326 assert!(split_a.is_some());
327 let split_a = split_a.unwrap();
328
329 assert_eq!(&partitioned.root, &get_item("a", ident()));
330 assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
331 }
332
333 #[test]
334 fn test_expr_top_level_ref_get_item_and_split_pack() {
335 let dtype = dtype();
336
337 let expr = pack([
338 ("a", get_item("a", get_item("a", ident()))),
339 ("b", get_item("b", get_item("a", ident()))),
340 ("c", get_item("c", ident())),
341 ]);
342 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
343
344 let split_a = partitioned.find_partition(&"a".into()).unwrap();
345 assert_eq!(
346 &simplify(split_a.clone()).unwrap(),
347 &pack([
348 (
349 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
350 get_item("a", ident())
351 ),
352 (
353 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
354 get_item("b", ident())
355 )
356 ])
357 );
358 let split_c = partitioned.find_partition(&"c".into()).unwrap();
359 assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
360 }
361
362 #[test]
363 fn test_expr_top_level_ref_get_item_add() {
364 let dtype = dtype();
365
366 let expr = and(get_item("b", get_item("a", ident())), lit(1));
367 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
368
369 assert_eq!(partitioned.partitions.len(), 1);
371 }
372
373 #[test]
374 fn test_expr_top_level_ref_get_item_add_cannot_split() {
375 let dtype = dtype();
376
377 let expr = and(
378 get_item("b", get_item("a", ident())),
379 get_item("b", ident()),
380 );
381 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
382
383 assert_eq!(partitioned.partitions.len(), 2);
385 }
386
387 #[test]
389 fn test_expr_partition_many_occurrences_of_field() {
390 let dtype = dtype();
391
392 let expr = and(
393 get_item("b", get_item("a", ident())),
394 select(vec!["a".into(), "b".into()], ident()),
395 );
396 let expr = simplify_typed(expr, &dtype).unwrap();
397 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
398
399 assert_eq!(partitioned.partitions.len(), 2);
401
402 assert_eq!(
405 &partitioned.root,
406 &and(
407 get_item(
408 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
409 get_item("a", ident())
410 ),
411 pack([
412 (
413 "a",
414 get_item(
415 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
416 get_item("a", ident())
417 )
418 ),
419 ("b", get_item("b", ident()))
420 ])
421 )
422 )
423 }
424}