vortex_array/expr/session/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod rewrite;
5
6pub use rewrite::RewriteRuleRegistry;
7use vortex_session::Ref;
8use vortex_session::SessionExt;
9use vortex_session::registry::Registry;
10
11use crate::expr::ExprVTable;
12use crate::expr::VTable;
13use crate::expr::exprs::between::Between;
14use crate::expr::exprs::binary::Binary;
15use crate::expr::exprs::cast::Cast;
16use crate::expr::exprs::get_item::GetItem;
17use crate::expr::exprs::get_item::transform::PackGetItemRule;
18use crate::expr::exprs::is_null::IsNull;
19use crate::expr::exprs::like::Like;
20use crate::expr::exprs::list_contains::ListContains;
21use crate::expr::exprs::literal::Literal;
22use crate::expr::exprs::merge::Merge;
23use crate::expr::exprs::merge::transform::RemoveMergeRule;
24use crate::expr::exprs::not::Not;
25use crate::expr::exprs::pack::Pack;
26use crate::expr::exprs::root::Root;
27use crate::expr::exprs::select::Select;
28use crate::expr::exprs::select::transform::RemoveSelectRule;
29use crate::expr::transform::rules::Any;
30use crate::expr::transform::rules::Exact;
31use crate::expr::transform::rules::ParentReduceRule;
32use crate::expr::transform::rules::ReduceRule;
33use crate::expr::transform::rules::RuleContext;
34use crate::expr::transform::rules::TypedRuleContext;
35
36/// Registry of expression vtables.
37pub type ExprRegistry = Registry<ExprVTable>;
38
39/// Session state for expression vtables and rewrite rules.
40#[derive(Debug)]
41pub struct ExprSession {
42    registry: ExprRegistry,
43    rewrite_rules: RewriteRuleRegistry,
44}
45
46impl ExprSession {
47    pub fn registry(&self) -> &ExprRegistry {
48        &self.registry
49    }
50
51    /// Get the rewrite rule registry.
52    pub fn rewrite_rules(&self) -> &RewriteRuleRegistry {
53        &self.rewrite_rules
54    }
55
56    /// Register an expression vtable in the session, replacing any existing vtable with the same ID.
57    pub fn register(&self, expr: ExprVTable) {
58        self.registry.register(expr)
59    }
60
61    /// Register expression vtables in the session, replacing any existing vtables with the same IDs.
62    pub fn register_many(&self, exprs: impl IntoIterator<Item = ExprVTable>) {
63        self.registry.register_many(exprs);
64    }
65
66    /// Register a generic reduce rule that uses Typed context.
67    /// Use this for rules that need access to dtype information.
68    pub fn register_typed_reduce_rule<V, R>(&mut self, vtable: &'static V, rule: R)
69    where
70        V: VTable,
71        R: 'static,
72        R: ReduceRule<V, TypedRuleContext>,
73    {
74        self.rewrite_rules
75            .register_typed_reduce_rule::<V, R>(vtable, rule);
76    }
77
78    /// Register a reduce rule that uses Untyped context.
79    /// Use this for rules that don't need access to dtype information.
80    pub fn register_reduce_rule<V, R>(&mut self, vtable: &'static V, rule: R)
81    where
82        V: VTable,
83        R: 'static,
84        R: ReduceRule<V, RuleContext>,
85    {
86        self.rewrite_rules
87            .register_reduce_rule::<V, R>(vtable, rule);
88    }
89
90    /// Register a parent reduce rule for a specific parent type.
91    pub fn register_parent_rule<Child, Parent, R>(
92        &mut self,
93        child_vtable: &'static Child,
94        parent_vtable: &'static Parent,
95        rule: R,
96    ) where
97        Child: VTable,
98        Parent: VTable,
99        R: 'static,
100        R: ParentReduceRule<Child, Exact<Parent>, RuleContext>,
101    {
102        self.rewrite_rules
103            .register_parent_rule_specific::<Child, Parent, R>(child_vtable, parent_vtable, rule);
104    }
105
106    /// Register a parent rule that matches ANY parent type (wildcard).
107    pub fn register_any_parent_rule<Child, R>(&mut self, child_vtable: &'static Child, rule: R)
108    where
109        Child: VTable,
110        R: 'static,
111        R: ParentReduceRule<Child, Any, RuleContext>,
112    {
113        self.rewrite_rules
114            .register_parent_rule_any::<Child, R>(child_vtable, rule);
115    }
116
117    /// Register a typed parent reduce rule for a specific parent type.
118    pub fn register_typed_parent_rule<Child, Parent, R>(
119        &mut self,
120        child_vtable: &'static Child,
121        parent_vtable: &'static Parent,
122        rule: R,
123    ) where
124        Child: VTable,
125        Parent: VTable,
126        R: 'static,
127        R: ParentReduceRule<Child, Exact<Parent>, TypedRuleContext>,
128    {
129        self.rewrite_rules
130            .register_typed_parent_rule_specific::<Child, Parent, R>(
131                child_vtable,
132                parent_vtable,
133                rule,
134            );
135    }
136
137    /// Register a typed parent rule that matches ANY parent type (wildcard).
138    pub fn register_typed_any_parent_rule<Child, R>(
139        &mut self,
140        child_vtable: &'static Child,
141        rule: R,
142    ) where
143        Child: VTable,
144        R: 'static,
145        R: ParentReduceRule<Child, Any, TypedRuleContext>,
146    {
147        self.rewrite_rules
148            .register_typed_parent_rule_any::<Child, R>(child_vtable, rule);
149    }
150}
151
152impl Default for ExprSession {
153    fn default() -> Self {
154        let expressions = ExprRegistry::default();
155
156        // Register built-in expressions here if needed.
157        expressions.register_many([
158            ExprVTable::new_static(&Between),
159            ExprVTable::new_static(&Binary),
160            ExprVTable::new_static(&Cast),
161            ExprVTable::new_static(&GetItem),
162            ExprVTable::new_static(&IsNull),
163            ExprVTable::new_static(&Like),
164            ExprVTable::new_static(&ListContains),
165            ExprVTable::new_static(&Literal),
166            ExprVTable::new_static(&Merge),
167            ExprVTable::new_static(&Not),
168            ExprVTable::new_static(&Pack),
169            ExprVTable::new_static(&Root),
170            ExprVTable::new_static(&Select),
171        ]);
172
173        // Register built-in rewrite rules
174        let mut rewrite_rules = RewriteRuleRegistry::new();
175        rewrite_rules
176            .register_typed_reduce_rule::<Select, RemoveSelectRule>(&Select, RemoveSelectRule);
177        rewrite_rules.register_typed_reduce_rule::<Merge, RemoveMergeRule>(&Merge, RemoveMergeRule);
178        rewrite_rules.register_reduce_rule::<GetItem, PackGetItemRule>(&GetItem, PackGetItemRule);
179
180        Self {
181            registry: expressions,
182            rewrite_rules,
183        }
184    }
185}
186
187/// Extension trait for accessing expression session data.
188pub trait ExprSessionExt: SessionExt {
189    /// Returns the expression vtable registry.
190    fn expressions(&self) -> Ref<'_, ExprSession> {
191        self.get::<ExprSession>()
192    }
193}
194impl<S: SessionExt> ExprSessionExt for S {}