vortex_expr/transform/
var_partition.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::LazyLock;
4
5use itertools::Itertools;
6use vortex_dtype::{FieldName, Nullability};
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_utils::aliases::hash_map::{DefaultHashBuilder, HashMap};
9
10use crate::transform::access_analysis::{Accesses, variable_scope_accesses};
11use crate::transform::partition::ReplaceAccessesWithChild;
12use crate::traversal::{FoldDown, FoldUp, FolderMut, Node};
13use crate::{ExprRef, Identifier, get_item, pack, var};
14
15static SPLITTER_RANDOM_STATE: LazyLock<DefaultHashBuilder> =
16    LazyLock::new(DefaultHashBuilder::default);
17
18/// Partition an expression by the variable identifiers.
19pub fn var_partitions(expr: &ExprRef) -> VortexResult<VarPartitionedExpr> {
20    VariableExpressionSplitter::split_all(expr)
21}
22
23/// Partition an expression using the partition function `f`
24/// e.g. var(x) + var(y) + var(z), where f(x) = {x} and f(y | z) = {y}
25/// the partitioned expr will be
26/// root: var(x) + var(y).0 + var(y).1, { x: var(x), y: pack(0: var(y), 1: var(z) }
27pub fn var_partitions_with_map(
28    expr: &ExprRef,
29    f: impl Fn(&Identifier) -> Identifier,
30) -> VortexResult<VarPartitionedExpr> {
31    VariableExpressionSplitter::split(expr, f)
32}
33
34// TODO(joe): replace with let expressions.
35/// The result of partitioning an expression.
36#[derive(Debug)]
37pub struct VarPartitionedExpr {
38    /// The root expression used to re-assemble the results.
39    pub root: ExprRef,
40    /// The partitions of the expression.
41    pub partitions: Box<[ExprRef]>,
42    /// The field names for the partitions
43    pub partition_names: Box<[Identifier]>,
44}
45
46impl Display for VarPartitionedExpr {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        write!(
49            f,
50            "root: {} {{{}}}",
51            self.root,
52            self.partition_names
53                .iter()
54                .zip(self.partitions.iter())
55                .map(|(name, partition)| format!("{name}: {partition}"))
56                .join(", ")
57        )
58    }
59}
60
61impl VarPartitionedExpr {
62    /// Return the partition for a given field, if it exists.
63    pub fn find_partition(&self, field: &Identifier) -> Option<&ExprRef> {
64        self.partition_names
65            .iter()
66            .position(|name| name == field)
67            .map(|idx| &self.partitions[idx])
68    }
69}
70
71#[derive(Debug)]
72struct VariableExpressionSplitter<'a> {
73    sub_expressions: HashMap<Identifier, Vec<ExprRef>>,
74    accesses: &'a Accesses<'a, Identifier>,
75}
76
77impl<'a> VariableExpressionSplitter<'a> {
78    fn new(accesses: &'a Accesses<'a, Identifier>) -> Self {
79        Self {
80            sub_expressions: HashMap::new(),
81            accesses,
82        }
83    }
84
85    pub(crate) fn field_idx_name(field: &Identifier, idx: usize) -> FieldName {
86        let mut hasher = SPLITTER_RANDOM_STATE.build_hasher();
87        field.hash(&mut hasher);
88        idx.hash(&mut hasher);
89        hasher.finish().to_string().into()
90    }
91
92    fn split_all(expr: &ExprRef) -> VortexResult<VarPartitionedExpr> {
93        Self::split(expr, Clone::clone)
94    }
95
96    fn split(
97        expr: &ExprRef,
98        f: impl Fn(&Identifier) -> Identifier,
99    ) -> VortexResult<VarPartitionedExpr> {
100        let field_accesses = variable_scope_accesses(expr, f)?;
101
102        let mut splitter = VariableExpressionSplitter::new(&field_accesses);
103        let split = expr.clone().transform_with_context(&mut splitter, ())?;
104        let mut remove_accesses: Vec<FieldName> = Vec::new();
105
106        let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
107        let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
108        for (name, exprs) in splitter.sub_expressions.into_iter() {
109            // If there is a single expr then we don't need to `pack` this, and we must update
110            // the root expr removing this access.
111            let expr = if exprs.len() == 1 {
112                remove_accesses.push(Self::field_idx_name(&name, 0));
113                exprs.first().vortex_expect("exprs is non-empty").clone()
114            } else {
115                pack(
116                    exprs
117                        .into_iter()
118                        .enumerate()
119                        .map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
120                    Nullability::NonNullable,
121                )
122            };
123
124            partitions.push(expr);
125            partition_names.push(name);
126        }
127
128        let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
129        // Ensure that there are not more accesses than partitions, we missed something
130        assert!(expression_access_counts.unwrap_or(0) <= partitions.len());
131        // Ensure that there are as many partitions as there are accesses/fields in the scope,
132        // this will affect performance, not correctness.
133        debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
134
135        let split = split
136            .result()
137            .transform(&mut ReplaceAccessesWithChild::new(remove_accesses))?;
138
139        Ok(VarPartitionedExpr {
140            root: split.into_inner(),
141            partitions: partitions.into_boxed_slice(),
142            partition_names: partition_names.into(),
143        })
144    }
145}
146
147impl FolderMut for VariableExpressionSplitter<'_> {
148    type NodeTy = ExprRef;
149    type Out = ExprRef;
150    type Context = ();
151
152    fn visit_down(
153        &mut self,
154        node: &Self::NodeTy,
155        _context: Self::Context,
156    ) -> VortexResult<FoldDown<ExprRef, Self::Context>> {
157        // If this expression only accesses a single field, then we can skip the children
158        let access = self.accesses.get(node);
159        if access.as_ref().is_some_and(|a| a.len() == 1) {
160            let field_name = access
161                .vortex_expect("access is non-empty")
162                .iter()
163                .next()
164                .vortex_expect("expected one field");
165
166            let sub_exprs = self.sub_expressions.entry(field_name.clone()).or_default();
167            let idx = sub_exprs.len();
168
169            sub_exprs.push(node.clone());
170
171            let access = get_item(
172                Self::field_idx_name(field_name, idx),
173                var(field_name.clone()),
174            );
175
176            return Ok(FoldDown::SkipChildren(access));
177        };
178
179        // Otherwise, continue traversing.
180        Ok(FoldDown::Continue(()))
181    }
182
183    fn visit_up(
184        &mut self,
185        node: Self::NodeTy,
186        _context: Self::Context,
187        children: Vec<Self::Out>,
188    ) -> VortexResult<FoldUp<Self::Out>> {
189        Ok(FoldUp::Continue(node.replacing_children(children)))
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use vortex_dtype::Nullability::NonNullable;
196
197    use super::*;
198    use crate::{Pack, Var, and, root, var};
199
200    #[test]
201    fn test_expr_top_level_ref() {
202        let expr = root();
203
204        let split = VariableExpressionSplitter::split_all(&expr);
205
206        assert!(split.is_ok());
207
208        let partitioned = split.unwrap();
209
210        assert!(partitioned.root.as_any().is::<Var>());
211        // Have a single top level pack with all fields in dtype
212        assert_eq!(partitioned.partitions.len(), 1)
213    }
214
215    #[test]
216    fn test_expr_top_level_ref_get_item_and_split() {
217        let expr = pack([("root", root()), ("x", var("x"))], NonNullable);
218
219        let partitioned = VariableExpressionSplitter::split_all(&expr).unwrap();
220
221        assert_eq!(partitioned.partitions.len(), 2);
222        assert_eq!(partitioned.find_partition(&"".into()), Some(&root()));
223        assert_eq!(partitioned.find_partition(&"x".into()), Some(&var("x")));
224    }
225
226    #[test]
227    fn test_partition_var_split_with() {
228        let expr = pack(
229            [("root", root()), ("x", var("x")), ("y", var("y"))],
230            NonNullable,
231        );
232
233        let partitioned = VariableExpressionSplitter::split(&expr, |id| {
234            if id == "x" { id.clone() } else { "".into() }
235        })
236        .unwrap();
237
238        assert_eq!(partitioned.partitions.len(), 2);
239        assert!(
240            partitioned
241                .find_partition(&"".into())
242                .unwrap()
243                .as_any()
244                .is::<Pack>()
245        );
246        assert_eq!(partitioned.find_partition(&"x".into()), Some(&var("x")));
247    }
248
249    #[test]
250    fn test_expr_top_level_ref_get_item_and_split_pack() {
251        let expr = and(and(var("x"), root()), var("x"));
252        let partitioned = VariableExpressionSplitter::split_all(&expr).unwrap();
253        assert_eq!(partitioned.partitions.len(), 2);
254    }
255}