1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{FieldName, Nullability};
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::access_analysis::{Accesses, variable_scope_accesses};
11use crate::transform::partition::ReplaceAccessesWithChild;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, Node};
13use crate::{ExprRef, Identifier, get_item, pack, var};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16 LazyLock::new(DefaultHashBuilder::default);
17
18pub fn var_partitions(expr: &ExprRef) -> VortexResult<VarPartitionedExpr> {
20 VariableExpressionSplitter::split_all(expr)
21}
22
23pub fn var_partitions_with_map(
28 expr: &ExprRef,
29 f: impl Fn(&Identifier) -> Identifier,
30) -> VortexResult<VarPartitionedExpr> {
31 VariableExpressionSplitter::split(expr, f)
32}
33
34#[derive(Debug)]
37pub struct VarPartitionedExpr {
38 pub root: ExprRef,
40 pub partitions: Box<[ExprRef]>,
42 pub partition_names: Box<[Identifier]>,
44}
45
46impl Display for VarPartitionedExpr {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 write!(
49 f,
50 "root: {} {{{}}}",
51 self.root,
52 self.partition_names
53 .iter()
54 .zip(self.partitions.iter())
55 .map(|(name, partition)| format!("{name}: {partition}"))
56 .join(", ")
57 )
58 }
59}
60
61impl VarPartitionedExpr {
62 pub fn find_partition(&self, field: &Identifier) -> Option<&ExprRef> {
64 self.partition_names
65 .iter()
66 .position(|name| name == field)
67 .map(|idx| &self.partitions[idx])
68 }
69}
70
71#[derive(Debug)]
72struct VariableExpressionSplitter<'a> {
73 sub_expressions: HashMap<Identifier, Vec<ExprRef>>,
74 accesses: &'a Accesses<'a, Identifier>,
75}
76
77impl<'a> VariableExpressionSplitter<'a> {
78 fn new(accesses: &'a Accesses<'a, Identifier>) -> Self {
79 Self {
80 sub_expressions: HashMap::new(),
81 accesses,
82 }
83 }
84
85 pub(crate) fn field_idx_name(field: &Identifier, idx: usize) -> FieldName {
86 let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
87 field.hash(&mut hasher);
88 idx.hash(&mut hasher);
89 hasher.finish().to_string().into()
90 }
91
92 fn split_all(expr: &ExprRef) -> VortexResult<VarPartitionedExpr> {
93 Self::split(expr, Clone::clone)
94 }
95
96 fn split(
97 expr: &ExprRef,
98 f: impl Fn(&Identifier) -> Identifier,
99 ) -> VortexResult<VarPartitionedExpr> {
100 let field_accesses = variable_scope_accesses(expr, f)?;
101
102 let mut splitter = VariableExpressionSplitter::new(&field_accesses);
103 let split = expr.clone().transform_with_context(&mut splitter, ())?;
104 let mut remove_accesses: Vec<FieldName> = Vec::new();
105
106 let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
107 let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
108 for (name, exprs) in splitter.sub_expressions.into_iter() {
109 let expr = if exprs.len() == 1 {
112 remove_accesses.push(Self::field_idx_name(&name, 0));
113 exprs.first().vortex_expect("exprs is non-empty").clone()
114 } else {
115 pack(
116 exprs
117 .into_iter()
118 .enumerate()
119 .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
120 Nullability::NonNullable,
121 )
122 };
123
124 partitions.push(expr);
125 partition_names.push(name);
126 }
127
128 let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
129 assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
131 debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
134
135 let split = split
136 .result()
137 .transform(&mut ReplaceAccessesWithChild::new(remove_accesses))?;
138
139 Ok(VarPartitionedExpr {
140 root: split.into_inner(),
141 partitions: partitions.into_boxed_slice(),
142 partition_names: partition_names.into(),
143 })
144 }
145}
146
147impl FolderMut for VariableExpressionSplitter<'_> {
148 type NodeTy = ExprRef;
149 type Out = ExprRef;
150 type Context = ();
151
152 fn visit_down(
153 &mut self,
154 node: &Self::NodeTy,
155 _context: Self::Context,
156 ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
157 let access = self.accesses.get(node);
159 if access.as_ref().is_some_and(|a| a.len() == 1) {
160 let field_name = access
161 .vortex_expect("access is non-empty")
162 .iter()
163 .next()
164 .vortex_expect("expected one field");
165
166 let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
167 let idx = sub_exprs.len();
168
169 sub_exprs.push(node.clone());
170
171 let access = get_item(
172 Self::field_idx_name(field_name, idx),
173 var(field_name.clone()),
174 );
175
176 return Ok(FoldDown::SkipChildren(access));
177 };
178
179 Ok(FoldDown::Continue(()))
181 }
182
183 fn visit_up(
184 &mut self,
185 node: Self::NodeTy,
186 _context: Self::Context,
187 children: Vec<Self::Out>,
188 ) -> VortexResult<FoldUp<Self::Out>> {
189 Ok(FoldUp::Continue(node.replacing_children(children)))
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use vortex_dtype::Nullability::NonNullable;
196
197 use super::*;
198 use crate::{Pack, Var, and, root, var};
199
200 #[test]
201 fn test_expr_top_level_ref() {
202 let expr = root();
203
204 let split = VariableExpressionSplitter::split_all(&expr);
205
206 assert!(split.is_ok());
207
208 let partitioned = split.unwrap();
209
210 assert!(partitioned.root.as_any().is::<Var>());
211 assert_eq!(partitioned.partitions.len(), 1)
213 }
214
215 #[test]
216 fn test_expr_top_level_ref_get_item_and_split() {
217 let expr = pack([("root", root()), ("x", var("x"))], NonNullable);
218
219 let partitioned = VariableExpressionSplitter::split_all(&expr).unwrap();
220
221 assert_eq!(partitioned.partitions.len(), 2);
222 assert_eq!(partitioned.find_partition(&"".into()), Some(&root()));
223 assert_eq!(partitioned.find_partition(&"x".into()), Some(&var("x")));
224 }
225
226 #[test]
227 fn test_partition_var_split_with() {
228 let expr = pack(
229 [("root", root()), ("x", var("x")), ("y", var("y"))],
230 NonNullable,
231 );
232
233 let partitioned = VariableExpressionSplitter::split(&expr, |id| {
234 if id == "x" { id.clone() } else { "".into() }
235 })
236 .unwrap();
237
238 assert_eq!(partitioned.partitions.len(), 2);
239 assert!(
240 partitioned
241 .find_partition(&"".into())
242 .unwrap()
243 .as_any()
244 .is::<Pack>()
245 );
246 assert_eq!(partitioned.find_partition(&"x".into()), Some(&var("x")));
247 }
248
249 #[test]
250 fn test_expr_top_level_ref_get_item_and_split_pack() {
251 let expr = and(and(var("x"), root()), var("x"));
252 let partitioned = VariableExpressionSplitter::split_all(&expr).unwrap();
253 assert_eq!(partitioned.partitions.len(), 2);
254 }
255}