vortex_array/optimizer/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_error::VortexResult;
7use vortex_utils::aliases::hash_map::HashMap;
8
9use crate::Array;
10use crate::ArrayVisitor;
11use crate::array::ArrayRef;
12use crate::optimizer::rules::AnyArray;
13use crate::optimizer::rules::ArrayParentReduceRule;
14use crate::optimizer::rules::ArrayReduceRule;
15use crate::optimizer::rules::DynArrayParentReduceRule;
16use crate::optimizer::rules::DynArrayReduceRule;
17use crate::optimizer::rules::MatchKey;
18use crate::optimizer::rules::Matcher;
19use crate::optimizer::rules::ParentReduceRuleAdapter;
20use crate::optimizer::rules::ReduceRuleAdapter;
21
22pub mod rules;
23
24#[cfg(test)]
25mod tests;
26
27/// Optimizer for arrays that applies registered rewrite rules.
28///
29/// This optimizer recursively traverses an array tree, applying reduce rules
30/// to transform arrays into more efficient representations.
31#[derive(Default, Debug, Clone)]
32pub struct ArrayOptimizer {
33    /// Reduce rules indexed by encoding ID
34    reduce_rules: HashMap<MatchKey, Vec<Arc<dyn DynArrayReduceRule>>>,
35    /// Parent reduce rules for specific parent types, indexed by (child, parent)
36    parent_rules: HashMap<(MatchKey, MatchKey), Vec<Arc<dyn DynArrayParentReduceRule>>>,
37}
38
39impl ArrayOptimizer {
40    /// Optimize the given array by applying registered rewrite rules.
41    ///
42    /// This performs two passes following the ExprSession pattern:
43    /// 1. Apply parent rules - bottom-up traversal checking parent-child relationships
44    /// 2. Apply reduce rules - bottom-up traversal applying transformations to each node
45    pub fn optimize_array(&self, array: ArrayRef) -> VortexResult<ArrayRef> {
46        // First pass: apply parent rules
47        let array = self.apply_parent_rules(array)?;
48
49        // Second pass: apply reduce rules
50        let array = self.apply_reduce_rules(array)?;
51
52        Ok(array)
53    }
54
55    /// Apply parent rules in a bottom-up traversal.
56    ///
57    /// For each array, recursively process children first, then check if any parent
58    /// rules apply to transform children based on their parent context.
59    fn apply_parent_rules(&self, array: ArrayRef) -> VortexResult<ArrayRef> {
60        // First, recursively apply parent rules to all children
61        let children = array.children();
62        if children.is_empty() {
63            return Ok(array);
64        }
65
66        let mut optimized_children = Vec::with_capacity(children.len());
67        let mut children_changed = false;
68
69        for child in children.iter() {
70            let optimized_child = self.apply_parent_rules(child.clone())?;
71            children_changed |= !Arc::ptr_eq(&optimized_child, child);
72            optimized_children.push(optimized_child);
73        }
74
75        // Reconstruct array with optimized children if any changed
76        let array = if children_changed {
77            array.with_children(&optimized_children)?
78        } else {
79            array
80        };
81
82        // Now try to apply parent rules to each optimized child in the context of this array
83        // Use the optimized_children list directly instead of re-fetching from array.children()
84        // let mut transformed_children = Vec::with_capacity(optimized_children.len());
85
86        for (idx, child) in optimized_children.iter().enumerate() {
87            let result = self.with_parent_rules(
88                child,
89                Some(&array),
90                |rules| -> VortexResult<Option<ArrayRef>> {
91                    for rule in rules {
92                        if let Some(new_array) = rule.reduce_parent(child, &array, idx)? {
93                            return Ok(Some(new_array));
94                        }
95                    }
96                    Ok(None)
97                },
98            )?;
99
100            if let Some(transformed) = result {
101                return Ok(transformed);
102            }
103        }
104
105        // Reconstruct array with transformed children if any rules matched
106        Ok(array)
107    }
108
109    /// Apply reduce rules in a bottom-up traversal.
110    ///
111    /// For each array, recursively process children first, then try to apply
112    /// reduce rules to transform the array itself.
113    fn apply_reduce_rules(&self, array: ArrayRef) -> VortexResult<ArrayRef> {
114        // First, recursively apply reduce rules to all children
115        let children = array.children();
116        if !children.is_empty() {
117            let mut new_children = Vec::with_capacity(children.len());
118            let mut changed = false;
119
120            for child in children.iter() {
121                let optimized_child = self.apply_reduce_rules(child.clone())?;
122                changed |= !Arc::ptr_eq(&optimized_child, child);
123                new_children.push(optimized_child);
124            }
125
126            // Reconstruct array with optimized children if any changed
127            let array = if changed {
128                array.with_children(&new_children)?
129            } else {
130                array
131            };
132
133            // Now try to apply reduce rules to this array
134            self.try_reduce(array)
135        } else {
136            // Leaf node - just try to reduce
137            self.try_reduce(array)
138        }
139    }
140
141    /// Try to apply reduce rules to a single array, recursively if a rule matches.
142    fn try_reduce(&self, array: ArrayRef) -> VortexResult<ArrayRef> {
143        let result = self.with_reduce_rules(&array, |rules| -> VortexResult<Option<ArrayRef>> {
144            for rule in rules {
145                if let Some(new_array) = rule.reduce(&array)? {
146                    return Ok(Some(new_array));
147                }
148            }
149            Ok(None)
150        })?;
151
152        if let Some(transformed) = result {
153            // Rule matched - recursively try to reduce the result
154            // self.try_reduce(transformed)
155            Ok(transformed)
156        } else {
157            Ok(array)
158        }
159    }
160
161    /// Register a reduce rule for a specific array encoding.
162    pub fn register_reduce_rule<M, R>(&mut self, rule: R)
163    where
164        M: Matcher,
165        R: ArrayReduceRule<M> + 'static,
166    {
167        let key = rule.matcher().key();
168        let adapter = ReduceRuleAdapter::new(rule);
169        self.reduce_rules
170            .entry(key)
171            .or_default()
172            .push(Arc::new(adapter));
173    }
174
175    /// Register a parent rule for a specific parent type.
176    pub fn register_parent_rule<Child, Parent, R>(&mut self, rule: R)
177    where
178        Child: Matcher,
179        Parent: Matcher,
180        R: ArrayParentReduceRule<Child, Parent> + 'static,
181    {
182        let key = (rule.child().key(), rule.parent().key());
183        let adapter = ParentReduceRuleAdapter::new(rule);
184        self.parent_rules
185            .entry(key)
186            .or_default()
187            .push(Arc::new(adapter));
188    }
189
190    /// Register a parent rule that matches ANY parent type (wildcard).
191    pub fn register_any_parent_rule<Child, R>(&mut self, rule: R)
192    where
193        Child: Matcher,
194        R: ArrayParentReduceRule<Child, AnyArray> + 'static,
195    {
196        let key = (rule.child().key(), MatchKey::Any);
197        let adapter = ParentReduceRuleAdapter::new(rule);
198        self.parent_rules
199            .entry(key)
200            .or_default()
201            .push(Arc::new(adapter));
202    }
203
204    /// Execute a callback with all reduce rules for a given encoding ID.
205    pub(crate) fn with_reduce_rules<F, R>(&self, array: &ArrayRef, f: F) -> R
206    where
207        F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayReduceRule>) -> R,
208    {
209        let exact = self.reduce_rules.get(&MatchKey::Array(array.encoding_id()));
210        let any = self.reduce_rules.get(&MatchKey::Any);
211        f(&mut exact
212            .iter()
213            .chain(any.iter())
214            .flat_map(|v| v.iter())
215            .map(|v| v.as_ref()))
216    }
217
218    /// Execute a callback with all parent reduce rules for a given child and parent encoding ID.
219    ///
220    /// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
221    pub(crate) fn with_parent_rules<F, R>(
222        &self,
223        child: &ArrayRef,
224        parent: Option<&ArrayRef>,
225        f: F,
226    ) -> R
227    where
228        F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayParentReduceRule>) -> R,
229    {
230        let exact = parent.and_then(|parent| {
231            self.parent_rules.get(&(
232                MatchKey::Array(child.encoding_id()),
233                MatchKey::Array(parent.encoding_id()),
234            ))
235        });
236        let any = self
237            .parent_rules
238            .get(&(MatchKey::Array(child.encoding_id()), MatchKey::Any));
239
240        f(&mut exact
241            .iter()
242            .chain(any.iter())
243            .flat_map(|v| v.iter())
244            .map(|arc| arc.as_ref()))
245    }
246}