1use std::fmt::Debug;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8use vortex_error::VortexResult;
9use vortex_utils::aliases::dash_map::DashMap;
10
11use crate::expr::ExprId;
12use crate::expr::Expression;
13use crate::expr::VTable;
14use crate::expr::transform::rules::Any;
15use crate::expr::transform::rules::DynParentReduceRule;
16use crate::expr::transform::rules::DynReduceRule;
17use crate::expr::transform::rules::DynTypedParentReduceRule;
18use crate::expr::transform::rules::DynTypedReduceRule;
19use crate::expr::transform::rules::Exact;
20use crate::expr::transform::rules::Matcher;
21use crate::expr::transform::rules::ParentReduceRule;
22use crate::expr::transform::rules::ReduceRule;
23use crate::expr::transform::rules::RuleContext;
24use crate::expr::transform::rules::TypedRuleContext;
25
26struct ReduceRuleAdapter<V: VTable, R> {
28 rule: R,
29 _phantom: PhantomData<V>,
30}
31
32impl<V: VTable, R: Debug> Debug for ReduceRuleAdapter<V, R> {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("ReduceRuleAdapter")
35 .field("rule", &self.rule)
36 .finish()
37 }
38}
39
40struct ReduceParentRuleAdapter<Child: VTable, Parent: Matcher, R> {
42 rule: R,
43 _phantom: PhantomData<(Child, Parent)>,
44}
45
46impl<Child: VTable, Parent: Matcher, R: Debug> Debug for ReduceParentRuleAdapter<Child, Parent, R> {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("ReduceParentRuleAdapter")
49 .field("rule", &self.rule)
50 .finish()
51 }
52}
53
54impl<V, R> DynReduceRule for ReduceRuleAdapter<V, R>
55where
56 V: VTable,
57 R: Debug + Send + Sync + 'static + ReduceRule<V, RuleContext>,
58{
59 fn reduce(&self, expr: &Expression, ctx: &RuleContext) -> VortexResult<Option<Expression>> {
60 let Some(view) = expr.as_opt::<V>() else {
61 return Ok(None);
62 };
63 self.rule.reduce(&view, ctx)
64 }
65}
66
67impl<V, R> DynTypedReduceRule for ReduceRuleAdapter<V, R>
68where
69 V: VTable,
70 R: Debug + Send + Sync + 'static + ReduceRule<V, TypedRuleContext>,
71{
72 fn reduce(
73 &self,
74 expr: &Expression,
75 ctx: &TypedRuleContext,
76 ) -> VortexResult<Option<Expression>> {
77 let Some(view) = expr.as_opt::<V>() else {
78 return Ok(None);
79 };
80 self.rule.reduce(&view, ctx)
81 }
82}
83
84impl<Child, Parent, R> DynParentReduceRule for ReduceParentRuleAdapter<Child, Parent, R>
85where
86 Child: VTable,
87 Parent: Matcher,
88 R: Debug + Send + Sync + 'static + ParentReduceRule<Child, Parent, RuleContext>,
89{
90 fn reduce_parent(
91 &self,
92 expr: &Expression,
93 parent: &Expression,
94 child_idx: usize,
95 ctx: &RuleContext,
96 ) -> VortexResult<Option<Expression>> {
97 let Some(view) = expr.as_opt::<Child>() else {
98 return Ok(None);
99 };
100 let Some(parent_view) = Parent::try_match(parent) else {
101 return Ok(None);
102 };
103 self.rule.reduce_parent(&view, parent_view, child_idx, ctx)
104 }
105}
106
107impl<Child, Parent, R> DynTypedParentReduceRule for ReduceParentRuleAdapter<Child, Parent, R>
108where
109 Child: VTable,
110 Parent: Matcher,
111 R: Debug + Send + Sync + 'static + ParentReduceRule<Child, Parent, TypedRuleContext>,
112{
113 fn reduce_parent(
114 &self,
115 expr: &Expression,
116 parent: &Expression,
117 child_idx: usize,
118 ctx: &TypedRuleContext,
119 ) -> VortexResult<Option<Expression>> {
120 let Some(view) = expr.as_opt::<Child>() else {
121 return Ok(None);
122 };
123 let Some(parent_view) = Parent::try_match(parent) else {
124 return Ok(None);
125 };
126 self.rule.reduce_parent(&view, parent_view, child_idx, ctx)
127 }
128}
129
130type RuleRegistry<Rule> = DashMap<ExprId, Vec<Arc<Rule>>>;
131type ParentRuleRegistry<Rule> = DashMap<(ExprId, ExprId), Vec<Arc<Rule>>>;
132
133#[derive(Default, Debug)]
136struct RewriteRuleRegistryInner {
137 typed_reduce_rules: RuleRegistry<dyn DynTypedReduceRule>,
139 reduce_rules: RuleRegistry<dyn DynReduceRule>,
141 typed_parent_rules: ParentRuleRegistry<dyn DynTypedParentReduceRule>,
143 parent_rules: ParentRuleRegistry<dyn DynParentReduceRule>,
145 typed_any_parent_rules: RuleRegistry<dyn DynTypedParentReduceRule>,
147 any_parent_rules: RuleRegistry<dyn DynParentReduceRule>,
149}
150
151#[derive(Clone, Debug)]
156pub struct RewriteRuleRegistry {
157 inner: Arc<RewriteRuleRegistryInner>,
158}
159
160impl Default for RewriteRuleRegistry {
161 fn default() -> Self {
162 Self {
163 inner: Arc::new(RewriteRuleRegistryInner::default()),
164 }
165 }
166}
167
168impl RewriteRuleRegistry {
169 pub fn new() -> Self {
170 Self::default()
171 }
172
173 pub fn register_typed_reduce_rule<V, R>(&mut self, vtable: &'static V, rule: R)
176 where
177 V: VTable,
178 R: 'static + ReduceRule<V, TypedRuleContext>,
179 {
180 let adapter = ReduceRuleAdapter {
181 rule,
182 _phantom: PhantomData,
183 };
184 self.inner
185 .typed_reduce_rules
186 .entry(vtable.id())
187 .or_default()
188 .push(Arc::new(adapter));
189 }
190
191 pub fn register_reduce_rule<V, R>(&mut self, vtable: &'static V, rule: R)
194 where
195 V: VTable,
196 R: 'static + ReduceRule<V, RuleContext>,
197 {
198 let adapter = ReduceRuleAdapter {
199 rule,
200 _phantom: PhantomData,
201 };
202 self.inner
203 .reduce_rules
204 .entry(vtable.id())
205 .or_default()
206 .push(Arc::new(adapter));
207 }
208
209 pub fn register_parent_rule_specific<Child, Parent, R>(
211 &mut self,
212 child_vtable: &'static Child,
213 parent_vtable: &'static Parent,
214 rule: R,
215 ) where
216 Child: VTable,
217 Parent: VTable,
218 R: 'static + ParentReduceRule<Child, Exact<Parent>, RuleContext>,
219 {
220 let adapter = ReduceParentRuleAdapter {
221 rule,
222 _phantom: PhantomData,
223 };
224 self.inner
225 .parent_rules
226 .entry((child_vtable.id(), parent_vtable.id()))
227 .or_default()
228 .push(Arc::new(adapter));
229 }
230
231 pub fn register_parent_rule_any<Child, R>(&mut self, child_vtable: &'static Child, rule: R)
233 where
234 Child: VTable,
235 R: 'static + ParentReduceRule<Child, Any, RuleContext>,
236 {
237 let adapter = ReduceParentRuleAdapter {
238 rule,
239 _phantom: PhantomData,
240 };
241 self.inner
242 .any_parent_rules
243 .entry(child_vtable.id())
244 .or_default()
245 .push(Arc::new(adapter));
246 }
247
248 pub fn register_typed_parent_rule_specific<Child, Parent, R>(
250 &mut self,
251 child_vtable: &'static Child,
252 parent_vtable: &'static Parent,
253 rule: R,
254 ) where
255 Child: VTable,
256 Parent: VTable,
257 R: 'static + ParentReduceRule<Child, Exact<Parent>, TypedRuleContext>,
258 {
259 let adapter = ReduceParentRuleAdapter {
260 rule,
261 _phantom: PhantomData,
262 };
263 self.inner
264 .typed_parent_rules
265 .entry((child_vtable.id(), parent_vtable.id()))
266 .or_default()
267 .push(Arc::new(adapter));
268 }
269
270 pub fn register_typed_parent_rule_any<Child, R>(
272 &mut self,
273 child_vtable: &'static Child,
274 rule: R,
275 ) where
276 Child: VTable,
277 R: 'static + ParentReduceRule<Child, Any, TypedRuleContext>,
278 {
279 let adapter = ReduceParentRuleAdapter {
280 rule,
281 _phantom: PhantomData,
282 };
283 self.inner
284 .typed_any_parent_rules
285 .entry(child_vtable.id())
286 .or_default()
287 .push(Arc::new(adapter));
288 }
289
290 pub(crate) fn with_typed_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
292 where
293 F: FnOnce(&mut dyn Iterator<Item = &dyn DynTypedReduceRule>) -> R,
294 {
295 f(&mut self
296 .inner
297 .typed_reduce_rules
298 .get(id)
299 .iter()
300 .flat_map(|v| v.value())
301 .map(|arc| arc.as_ref()))
302 }
303
304 pub(crate) fn with_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
306 where
307 F: FnOnce(&mut dyn Iterator<Item = &dyn DynReduceRule>) -> R,
308 {
309 f(&mut self
310 .inner
311 .reduce_rules
312 .get(id)
313 .iter()
314 .flat_map(|v| v.value())
315 .map(|arc| arc.as_ref()))
316 }
317
318 pub(crate) fn with_parent_rules<F, R>(
322 &self,
323 child_id: &ExprId,
324 parent_id: Option<&ExprId>,
325 f: F,
326 ) -> R
327 where
328 F: FnOnce(&mut dyn Iterator<Item = &dyn DynParentReduceRule>) -> R,
329 {
330 let specific_entry = parent_id.and_then(|pid| {
331 self.inner
332 .parent_rules
333 .get(&(child_id.clone(), pid.clone()))
334 });
335 let wildcard_entry = self.inner.any_parent_rules.get(child_id);
336
337 f(&mut specific_entry
338 .iter()
339 .flat_map(|v| v.value())
340 .chain(wildcard_entry.iter().flat_map(|v| v.value()))
341 .map(|arc| arc.as_ref()))
342 }
343
344 pub(crate) fn with_typed_parent_rules<F, R>(
348 &self,
349 child_id: &ExprId,
350 parent_id: Option<&ExprId>,
351 f: F,
352 ) -> R
353 where
354 F: FnOnce(&mut dyn Iterator<Item = &dyn DynTypedParentReduceRule>) -> R,
355 {
356 let specific_entry = parent_id.and_then(|pid| {
357 self.inner
358 .typed_parent_rules
359 .get(&(child_id.clone(), pid.clone()))
360 });
361 let wildcard_entry = self.inner.typed_any_parent_rules.get(child_id);
362
363 f(&mut specific_entry
364 .iter()
365 .flat_map(|v| v.value())
366 .chain(wildcard_entry.iter().flat_map(|v| v.value()))
367 .map(|arc| arc.as_ref()))
368 }
369}