vortex_array/expr/session/
rewrite.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8use vortex_error::VortexResult;
9use vortex_utils::aliases::dash_map::DashMap;
10
11use crate::expr::ExprId;
12use crate::expr::Expression;
13use crate::expr::VTable;
14use crate::expr::transform::rules::Any;
15use crate::expr::transform::rules::DynParentReduceRule;
16use crate::expr::transform::rules::DynReduceRule;
17use crate::expr::transform::rules::DynTypedParentReduceRule;
18use crate::expr::transform::rules::DynTypedReduceRule;
19use crate::expr::transform::rules::Exact;
20use crate::expr::transform::rules::Matcher;
21use crate::expr::transform::rules::ParentReduceRule;
22use crate::expr::transform::rules::ReduceRule;
23use crate::expr::transform::rules::RuleContext;
24use crate::expr::transform::rules::TypedRuleContext;
25
26/// Adapter for ReduceRule
27struct ReduceRuleAdapter<V: VTable, R> {
28    rule: R,
29    _phantom: PhantomData<V>,
30}
31
32impl<V: VTable, R: Debug> Debug for ReduceRuleAdapter<V, R> {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("ReduceRuleAdapter")
35            .field("rule", &self.rule)
36            .finish()
37    }
38}
39
40/// Adapter for ParentReduceRule
41struct ReduceParentRuleAdapter<Child: VTable, Parent: Matcher, R> {
42    rule: R,
43    _phantom: PhantomData<(Child, Parent)>,
44}
45
46impl<Child: VTable, Parent: Matcher, R: Debug> Debug for ReduceParentRuleAdapter<Child, Parent, R> {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("ReduceParentRuleAdapter")
49            .field("rule", &self.rule)
50            .finish()
51    }
52}
53
54impl<V, R> DynReduceRule for ReduceRuleAdapter<V, R>
55where
56    V: VTable,
57    R: Debug + Send + Sync + 'static + ReduceRule<V, RuleContext>,
58{
59    fn reduce(&self, expr: &Expression, ctx: &RuleContext) -> VortexResult<Option<Expression>> {
60        let Some(view) = expr.as_opt::<V>() else {
61            return Ok(None);
62        };
63        self.rule.reduce(&view, ctx)
64    }
65}
66
67impl<V, R> DynTypedReduceRule for ReduceRuleAdapter<V, R>
68where
69    V: VTable,
70    R: Debug + Send + Sync + 'static + ReduceRule<V, TypedRuleContext>,
71{
72    fn reduce(
73        &self,
74        expr: &Expression,
75        ctx: &TypedRuleContext,
76    ) -> VortexResult<Option<Expression>> {
77        let Some(view) = expr.as_opt::<V>() else {
78            return Ok(None);
79        };
80        self.rule.reduce(&view, ctx)
81    }
82}
83
84impl<Child, Parent, R> DynParentReduceRule for ReduceParentRuleAdapter<Child, Parent, R>
85where
86    Child: VTable,
87    Parent: Matcher,
88    R: Debug + Send + Sync + 'static + ParentReduceRule<Child, Parent, RuleContext>,
89{
90    fn reduce_parent(
91        &self,
92        expr: &Expression,
93        parent: &Expression,
94        child_idx: usize,
95        ctx: &RuleContext,
96    ) -> VortexResult<Option<Expression>> {
97        let Some(view) = expr.as_opt::<Child>() else {
98            return Ok(None);
99        };
100        let Some(parent_view) = Parent::try_match(parent) else {
101            return Ok(None);
102        };
103        self.rule.reduce_parent(&view, parent_view, child_idx, ctx)
104    }
105}
106
107impl<Child, Parent, R> DynTypedParentReduceRule for ReduceParentRuleAdapter<Child, Parent, R>
108where
109    Child: VTable,
110    Parent: Matcher,
111    R: Debug + Send + Sync + 'static + ParentReduceRule<Child, Parent, TypedRuleContext>,
112{
113    fn reduce_parent(
114        &self,
115        expr: &Expression,
116        parent: &Expression,
117        child_idx: usize,
118        ctx: &TypedRuleContext,
119    ) -> VortexResult<Option<Expression>> {
120        let Some(view) = expr.as_opt::<Child>() else {
121            return Ok(None);
122        };
123        let Some(parent_view) = Parent::try_match(parent) else {
124            return Ok(None);
125        };
126        self.rule.reduce_parent(&view, parent_view, child_idx, ctx)
127    }
128}
129
130type RuleRegistry<Rule> = DashMap<ExprId, Vec<Arc<Rule>>>;
131type ParentRuleRegistry<Rule> = DashMap<(ExprId, ExprId), Vec<Arc<Rule>>>;
132
133/// Inner struct that holds all the rule registries.
134/// Wrapped in a single Arc by RewriteRuleRegistry for efficient cloning.
135#[derive(Default, Debug)]
136struct RewriteRuleRegistryInner {
137    /// Typed reduce rules (require TypedRewriteContext), indexed by expression ID
138    typed_reduce_rules: RuleRegistry<dyn DynTypedReduceRule>,
139    /// Untyped reduce rules (require only RewriteContext), indexed by expression ID
140    reduce_rules: RuleRegistry<dyn DynReduceRule>,
141    /// Parent reduce rules for specific parent types, indexed by (child_id, parent_id)
142    typed_parent_rules: ParentRuleRegistry<dyn DynTypedParentReduceRule>,
143    /// Parent reduce rules for specific parent types, indexed by (child_id, parent_id)
144    parent_rules: ParentRuleRegistry<dyn DynParentReduceRule>,
145    /// Wildcard parent rules (match any parent), indexed by child_id only
146    typed_any_parent_rules: RuleRegistry<dyn DynTypedParentReduceRule>,
147    /// Wildcard parent rules (match any parent), indexed by child_id only
148    any_parent_rules: RuleRegistry<dyn DynParentReduceRule>,
149}
150
151/// Registry of expression rewrite rules.
152///
153/// Stores rewrite rules indexed by the expression ID they apply to.
154/// Typed and untyped rules are stored separately for better organization.
155#[derive(Clone, Debug)]
156pub struct RewriteRuleRegistry {
157    inner: Arc<RewriteRuleRegistryInner>,
158}
159
160impl Default for RewriteRuleRegistry {
161    fn default() -> Self {
162        Self {
163            inner: Arc::new(RewriteRuleRegistryInner::default()),
164        }
165    }
166}
167
168impl RewriteRuleRegistry {
169    pub fn new() -> Self {
170        Self::default()
171    }
172
173    /// Register a generic reduce rule that uses Typed context.
174    /// Use this for rules that need access to dtype information.
175    pub fn register_typed_reduce_rule<V, R>(&mut self, vtable: &'static V, rule: R)
176    where
177        V: VTable,
178        R: 'static + ReduceRule<V, TypedRuleContext>,
179    {
180        let adapter = ReduceRuleAdapter {
181            rule,
182            _phantom: PhantomData,
183        };
184        self.inner
185            .typed_reduce_rules
186            .entry(vtable.id())
187            .or_default()
188            .push(Arc::new(adapter));
189    }
190
191    /// Register a generic reduce rule that only uses Untyped context.
192    /// Use this for rules that don't need access to dtype information.
193    pub fn register_reduce_rule<V, R>(&mut self, vtable: &'static V, rule: R)
194    where
195        V: VTable,
196        R: 'static + ReduceRule<V, RuleContext>,
197    {
198        let adapter = ReduceRuleAdapter {
199            rule,
200            _phantom: PhantomData,
201        };
202        self.inner
203            .reduce_rules
204            .entry(vtable.id())
205            .or_default()
206            .push(Arc::new(adapter));
207    }
208
209    /// Register a parent rule for a specific parent type.
210    pub fn register_parent_rule_specific<Child, Parent, R>(
211        &mut self,
212        child_vtable: &'static Child,
213        parent_vtable: &'static Parent,
214        rule: R,
215    ) where
216        Child: VTable,
217        Parent: VTable,
218        R: 'static + ParentReduceRule<Child, Exact<Parent>, RuleContext>,
219    {
220        let adapter = ReduceParentRuleAdapter {
221            rule,
222            _phantom: PhantomData,
223        };
224        self.inner
225            .parent_rules
226            .entry((child_vtable.id(), parent_vtable.id()))
227            .or_default()
228            .push(Arc::new(adapter));
229    }
230
231    /// Register a parent rule that matches ANY parent type (wildcard).
232    pub fn register_parent_rule_any<Child, R>(&mut self, child_vtable: &'static Child, rule: R)
233    where
234        Child: VTable,
235        R: 'static + ParentReduceRule<Child, Any, RuleContext>,
236    {
237        let adapter = ReduceParentRuleAdapter {
238            rule,
239            _phantom: PhantomData,
240        };
241        self.inner
242            .any_parent_rules
243            .entry(child_vtable.id())
244            .or_default()
245            .push(Arc::new(adapter));
246    }
247
248    /// Register a typed parent rule for a specific parent type.
249    pub fn register_typed_parent_rule_specific<Child, Parent, R>(
250        &mut self,
251        child_vtable: &'static Child,
252        parent_vtable: &'static Parent,
253        rule: R,
254    ) where
255        Child: VTable,
256        Parent: VTable,
257        R: 'static + ParentReduceRule<Child, Exact<Parent>, TypedRuleContext>,
258    {
259        let adapter = ReduceParentRuleAdapter {
260            rule,
261            _phantom: PhantomData,
262        };
263        self.inner
264            .typed_parent_rules
265            .entry((child_vtable.id(), parent_vtable.id()))
266            .or_default()
267            .push(Arc::new(adapter));
268    }
269
270    /// Register a typed parent rule that matches ANY parent type (wildcard).
271    pub fn register_typed_parent_rule_any<Child, R>(
272        &mut self,
273        child_vtable: &'static Child,
274        rule: R,
275    ) where
276        Child: VTable,
277        R: 'static + ParentReduceRule<Child, Any, TypedRuleContext>,
278    {
279        let adapter = ReduceParentRuleAdapter {
280            rule,
281            _phantom: PhantomData,
282        };
283        self.inner
284            .typed_any_parent_rules
285            .entry(child_vtable.id())
286            .or_default()
287            .push(Arc::new(adapter));
288    }
289
290    /// Execute a callback with all typed reduce rules for a given expression ID.
291    pub(crate) fn with_typed_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
292    where
293        F: FnOnce(&mut dyn Iterator<Item = &dyn DynTypedReduceRule>) -> R,
294    {
295        f(&mut self
296            .inner
297            .typed_reduce_rules
298            .get(id)
299            .iter()
300            .flat_map(|v| v.value())
301            .map(|arc| arc.as_ref()))
302    }
303
304    /// Execute a callback with all untyped reduce rules for a given expression ID.
305    pub(crate) fn with_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
306    where
307        F: FnOnce(&mut dyn Iterator<Item = &dyn DynReduceRule>) -> R,
308    {
309        f(&mut self
310            .inner
311            .reduce_rules
312            .get(id)
313            .iter()
314            .flat_map(|v| v.value())
315            .map(|arc| arc.as_ref()))
316    }
317
318    /// Execute a callback with all untyped parent reduce rules for a given child and parent expression ID.
319    ///
320    /// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
321    pub(crate) fn with_parent_rules<F, R>(
322        &self,
323        child_id: &ExprId,
324        parent_id: Option<&ExprId>,
325        f: F,
326    ) -> R
327    where
328        F: FnOnce(&mut dyn Iterator<Item = &dyn DynParentReduceRule>) -> R,
329    {
330        let specific_entry = parent_id.and_then(|pid| {
331            self.inner
332                .parent_rules
333                .get(&(child_id.clone(), pid.clone()))
334        });
335        let wildcard_entry = self.inner.any_parent_rules.get(child_id);
336
337        f(&mut specific_entry
338            .iter()
339            .flat_map(|v| v.value())
340            .chain(wildcard_entry.iter().flat_map(|v| v.value()))
341            .map(|arc| arc.as_ref()))
342    }
343
344    /// Execute a callback with all typed parent reduce rules for a given child and parent expression ID.
345    ///
346    /// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
347    pub(crate) fn with_typed_parent_rules<F, R>(
348        &self,
349        child_id: &ExprId,
350        parent_id: Option<&ExprId>,
351        f: F,
352    ) -> R
353    where
354        F: FnOnce(&mut dyn Iterator<Item = &dyn DynTypedParentReduceRule>) -> R,
355    {
356        let specific_entry = parent_id.and_then(|pid| {
357            self.inner
358                .typed_parent_rules
359                .get(&(child_id.clone(), pid.clone()))
360        });
361        let wildcard_entry = self.inner.typed_any_parent_rules.get(child_id);
362
363        f(&mut specific_entry
364            .iter()
365            .flat_map(|v| v.value())
366            .chain(wildcard_entry.iter().flat_map(|v| v.value()))
367            .map(|arc| arc.as_ref()))
368    }
369}