1use vortex_array::aliases::hash_map::HashMap;
2use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
3use vortex_error::{VortexExpect, VortexResult, vortex_bail};
4
5use crate::transform::immediate_access::{FieldAccesses, immediate_scope_accesses};
6use crate::transform::simplify_typed::simplify_typed;
7use crate::traversal::{FoldDown, FoldUp, FolderMut, MutNodeVisitor, Node, TransformResult};
8use crate::{ExprRef, GetItem, Identity, get_item, ident, pack};
9
10pub fn partition(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
26 if !matches!(dtype, DType::Struct(..)) {
27 vortex_bail!("Expected a struct dtype, got {:?}", dtype);
28 }
29 StructFieldExpressionSplitter::split(expr, dtype)
30}
31
32#[derive(Debug)]
34pub struct PartitionedExpr {
35 pub root: ExprRef,
37 pub partitions: Box<[ExprRef]>,
39 pub partition_names: FieldNames,
41}
42
43impl PartitionedExpr {
44 pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
46 self.partition_names
47 .iter()
48 .position(|name| name == field)
49 .map(|idx| &self.partitions[idx])
50 }
51}
52
53#[derive(Debug)]
54struct StructFieldExpressionSplitter<'a> {
55 sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
56 accesses: &'a FieldAccesses<'a>,
57 scope_dtype: &'a StructDType,
58}
59
60impl<'a> StructFieldExpressionSplitter<'a> {
61 fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self {
62 Self {
63 sub_expressions: HashMap::new(),
64 accesses,
65 scope_dtype,
66 }
67 }
68
69 pub(crate) fn field_idx_name(field: &FieldName, idx: usize) -> FieldName {
70 format!("__e__{}.{}", field, idx).into()
71 }
72
73 fn split(expr: ExprRef, dtype: &DType) -> VortexResult<PartitionedExpr> {
74 let scope_dtype = match dtype {
75 DType::Struct(scope_dtype, _) => scope_dtype,
76 _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype),
77 };
78
79 let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?;
80
81 let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
82
83 let split = expr.clone().transform_with_context(&mut splitter, ())?;
84
85 let mut remove_accesses: Vec<FieldName> = Vec::new();
86
87 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
89 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
90 for (name, exprs) in splitter.sub_expressions.into_iter() {
91 let field_dtype = scope_dtype.field(&name)?;
92 let expr = if exprs.len() == 1 {
95 remove_accesses.push(Self::field_idx_name(&name, 0));
96 exprs.first().vortex_expect("exprs is non-empty").clone()
97 } else {
98 pack(
99 exprs
100 .into_iter()
101 .enumerate()
102 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
103 )
104 };
105
106 partitions.push(simplify_typed(expr, &field_dtype)?);
107 partition_names.push(name);
108 }
109
110 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
111 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
113 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
116
117 let split = split
118 .result()
119 .transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
120
121 Ok(PartitionedExpr {
122 root: simplify_typed(split.result, dtype)?,
123 partitions: partitions.into_boxed_slice(),
124 partition_names: partition_names.into(),
125 })
126 }
127}
128
129impl FolderMut for StructFieldExpressionSplitter<'_> {
130 type NodeTy = ExprRef;
131 type Out = ExprRef;
132 type Context = ();
133
134 fn visit_down(
135 &mut self,
136 node: &Self::NodeTy,
137 _context: Self::Context,
138 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
139 let access = self.accesses.get(node);
141 if access.as_ref().is_some_and(|a| a.len() == 1) {
142 let field_name = access
143 .vortex_expect("access is non-empty")
144 .iter()
145 .next()
146 .vortex_expect("expected one field");
147
148 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
151 let idx = sub_exprs.len();
152
153 let replaced = node
155 .clone()
156 .transform(&mut ScopeStepIntoFieldExpr(field_name.clone()))?;
157 sub_exprs.push(replaced.result);
158
159 let access = get_item(
160 Self::field_idx_name(field_name, idx),
161 get_item(field_name.clone(), ident()),
162 );
163
164 return Ok(FoldDown::SkipChildren(access));
165 };
166
167 if node.as_any().is::<Identity>() {
169 let field_names = self.scope_dtype.names();
170
171 let mut elements = Vec::with_capacity(field_names.len());
172
173 for field_name in field_names.iter() {
174 let sub_exprs = self
175 .sub_expressions
176 .entry(field_name.clone())
177 .or_insert_with(Vec::new);
178
179 let idx = sub_exprs.len();
180
181 sub_exprs.push(ident());
182
183 elements.push((
184 field_name.clone(),
185 get_item(
187 Self::field_idx_name(field_name, idx),
188 get_item(field_name.clone(), ident()),
189 ),
190 ));
191 }
192
193 return Ok(FoldDown::SkipChildren(pack(elements)));
194 }
195
196 Ok(FoldDown::Continue(()))
198 }
199
200 fn visit_up(
201 &mut self,
202 node: Self::NodeTy,
203 _context: Self::Context,
204 children: Vec<Self::Out>,
205 ) -> VortexResult<FoldUp<Self::Out>> {
206 Ok(FoldUp::Continue(node.replacing_children(children)))
207 }
208}
209
210struct ScopeStepIntoFieldExpr(FieldName);
211
212impl MutNodeVisitor for ScopeStepIntoFieldExpr {
213 type NodeTy = ExprRef;
214
215 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
216 if node.as_any().is::<Identity>() {
217 Ok(TransformResult::yes(pack([(self.0.clone(), ident())])))
218 } else {
219 Ok(TransformResult::no(node))
220 }
221 }
222}
223
224struct ReplaceAccessesWithChild(Vec<FieldName>);
225
226impl MutNodeVisitor for ReplaceAccessesWithChild {
227 type NodeTy = ExprRef;
228
229 fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<ExprRef>> {
230 if let Some(item) = node.as_any().downcast_ref::<GetItem>() {
231 if self.0.contains(item.field()) {
232 return Ok(TransformResult::yes(item.child().clone()));
233 }
234 }
235 Ok(TransformResult::no(node))
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use std::sync::Arc;
242
243 use vortex_dtype::Nullability::NonNullable;
244 use vortex_dtype::PType::I32;
245 use vortex_dtype::{DType, StructDType};
246
247 use super::*;
248 use crate::transform::simplify::simplify;
249 use crate::transform::simplify_typed::simplify_typed;
250 use crate::{Pack, and, get_item, ident, lit, pack, select};
251
252 fn dtype() -> DType {
253 DType::Struct(
254 Arc::new(StructDType::from_iter([
255 (
256 "a",
257 DType::Struct(
258 Arc::new(StructDType::from_iter([
259 ("a", I32.into()),
260 ("b", DType::from(I32)),
261 ])),
262 NonNullable,
263 ),
264 ),
265 ("b", I32.into()),
266 ("c", I32.into()),
267 ])),
268 NonNullable,
269 )
270 }
271
272 #[test]
273 fn test_expr_top_level_ref() {
274 let dtype = dtype();
275
276 let expr = ident();
277
278 let split = StructFieldExpressionSplitter::split(expr, &dtype);
279
280 assert!(split.is_ok());
281
282 let partitioned = split.unwrap();
283
284 assert!(partitioned.root.as_any().is::<Pack>());
285 assert_eq!(
287 partitioned.partitions.len(),
288 dtype.as_struct().unwrap().names().len()
289 )
290 }
291
292 #[test]
293 fn test_expr_top_level_ref_get_item_and_split() {
294 let dtype = dtype();
295
296 let expr = get_item("b", get_item("a", ident()));
297
298 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
299 let split_a = partitioned.find_partition(&"a".into());
300 assert!(split_a.is_some());
301 let split_a = split_a.unwrap();
302
303 assert_eq!(&partitioned.root, &get_item("a", ident()));
304 assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
305 }
306
307 #[test]
308 fn test_expr_top_level_ref_get_item_and_split_pack() {
309 let dtype = dtype();
310
311 let expr = pack([
312 ("a", get_item("a", get_item("a", ident()))),
313 ("b", get_item("b", get_item("a", ident()))),
314 ("c", get_item("c", ident())),
315 ]);
316 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
317
318 let split_a = partitioned.find_partition(&"a".into()).unwrap();
319 assert_eq!(
320 &simplify(split_a.clone()).unwrap(),
321 &pack([
322 (
323 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
324 get_item("a", ident())
325 ),
326 (
327 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
328 get_item("b", ident())
329 )
330 ])
331 );
332 let split_c = partitioned.find_partition(&"c".into()).unwrap();
333 assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
334 }
335
336 #[test]
337 fn test_expr_top_level_ref_get_item_add() {
338 let dtype = dtype();
339
340 let expr = and(get_item("b", get_item("a", ident())), lit(1));
341 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
342
343 assert_eq!(partitioned.partitions.len(), 1);
345 }
346
347 #[test]
348 fn test_expr_top_level_ref_get_item_add_cannot_split() {
349 let dtype = dtype();
350
351 let expr = and(
352 get_item("b", get_item("a", ident())),
353 get_item("b", ident()),
354 );
355 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
356
357 assert_eq!(partitioned.partitions.len(), 2);
359 }
360
361 #[test]
363 fn test_expr_partition_many_occurrences_of_field() {
364 let dtype = dtype();
365
366 let expr = and(
367 get_item("b", get_item("a", ident())),
368 select(vec!["a".into(), "b".into()], ident()),
369 );
370 let expr = simplify_typed(expr, &dtype).unwrap();
371 let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
372
373 assert_eq!(partitioned.partitions.len(), 2);
375
376 assert_eq!(
379 &partitioned.root,
380 &and(
381 get_item(
382 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
383 get_item("a", ident())
384 ),
385 pack([
386 (
387 "a",
388 get_item(
389 StructFieldExpressionSplitter::field_idx_name(&"a".into(), 1),
390 get_item("a", ident())
391 )
392 ),
393 ("b", get_item("b", ident()))
394 ])
395 )
396 )
397 }
398}