vortex_array/expr/transform/rules.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Traits for extensible expression rewrite rules.
5//!
6//! These traits allow external crates to define custom expression optimization rules
7//! that can be registered with the expression session.
8
9use std::fmt::Debug;
10use std::marker::PhantomData;
11
12use vortex_dtype::DType;
13use vortex_error::VortexResult;
14
15use crate::expr::Expression;
16use crate::expr::ExpressionView;
17use crate::expr::VTable;
18
19/// Trait that abstracts over matching on expression types.
20pub trait Matcher: Send + Sync + 'static {
21 /// The view type returned when matching succeeds.
22 type View<'a>;
23
24 /// Try to match/downcast the parent expression.
25 /// Returns Some if the parent matches this matcher's criteria, None otherwise.
26 fn try_match(parent: &Expression) -> Option<Self::View<'_>>;
27}
28
29/// Marker type representing "any" - matches all expressions.
30#[derive(Debug)]
31pub struct Any;
32impl Matcher for Any {
33 type View<'a> = &'a Expression;
34
35 fn try_match(parent: &Expression) -> Option<Self::View<'_>> {
36 Some(parent)
37 }
38}
39
40/// Marker type representing a specific VTable type as a matcher.
41#[derive(Debug)]
42pub struct Exact<V: VTable>(PhantomData<V>);
43impl<V: VTable> Matcher for Exact<V> {
44 type View<'a> = ExpressionView<'a, V>;
45
46 fn try_match(parent: &Expression) -> Option<Self::View<'_>> {
47 parent.as_opt::<V>()
48 }
49}
50
51/// A rewrite rule that transforms expressions without needing context.
52///
53/// Called during bottom-up traversal after children have been processed.
54/// This is useful for self-contained rewrites like: `select(...) -> pack(get_item(...))`
55///
56/// # Type Parameters
57/// * `V` - The VTable type this rule applies to. The rule will only be invoked for expressions
58/// with this vtable type, providing compile-time type safety.
59pub trait ReduceRule<V: VTable, C: RewriteContext>: Debug + Send + Sync + 'static {
60 /// Try to rewrite an expression.
61 ///
62 /// # Arguments
63 /// * `expr` - The expression to potentially rewrite (already downcast to type V)
64 /// * `ctx` - Context for the rewrite (dtype, etc.)
65 ///
66 /// # Returns
67 /// * `Some(new_expr)` if the rule applies and produces a rewritten expression
68 /// * `None` if the rule does not apply
69 fn reduce(&self, expr: &ExpressionView<V>, ctx: &C) -> VortexResult<Option<Expression>>;
70}
71
72/// A rewrite rule that can transform expressions based on parent context.
73///
74/// Called during top-down traversal from the root.
75/// This is useful for rules that need to know about the parent expression.
76///
77/// Note: This rule is only called for non-root expressions (i.e., when there is a parent).
78///
79/// # Type Parameters
80/// * `Child` - The VTable type this rule applies to (the child expression type). The rule will only
81/// be invoked for expressions with this vtable type, providing compile-time type safety.
82/// * `Parent` - The parent matcher. Can be a specific VTable type (e.g., `Binary`) for typed parent
83/// access, or `AnyParent` to match any parent type with untyped access.
84/// * `C` - The rewrite context type (RuleContext or TypedRuleContext)
85pub trait ParentReduceRule<Child: VTable, Parent: Matcher, C: RewriteContext>:
86 Debug + Send + Sync + 'static
87{
88 /// Try to rewrite an expression based on its parent.
89 ///
90 /// # Arguments
91 /// * `expr` - The expression to potentially rewrite (already downcast to type Child)
92 /// * `parent` - The parent view (type depends on Parent matcher - typed for specific VTables,
93 /// untyped `&Expression` for `AnyParent`)
94 /// * `child_idx` - The index of the child expression within the parent.
95 /// * `ctx` - Context for the rewrite (dtype, etc.)
96 ///
97 /// # Returns
98 /// * `Some(new_expr)` if the rule applies and produces a rewritten expression
99 /// * `None` if the rule does not apply
100 fn reduce_parent(
101 &self,
102 expr: &ExpressionView<Child>,
103 parent: Parent::View<'_>,
104 child_idx: usize,
105 ctx: &C,
106 ) -> VortexResult<Option<Expression>>;
107}
108
109/// Sealed trait for rewrite rule contexts.
110///
111/// This trait cannot be implemented outside this module. Only `Typed` and `Untyped`
112/// implement this trait.
113pub trait RewriteContext: private::Sealed {}
114
115mod private {
116 /// Sealing trait to prevent external implementations of `RewriteContext`.
117 pub trait Sealed {}
118}
119
120/// Typed context for rewrite rules that need access to dtype information.
121#[derive(Debug, Clone)]
122pub struct TypedRuleContext {
123 /// This is the root dtype of the expression
124 dtype: DType,
125}
126
127impl TypedRuleContext {
128 pub fn new(dtype: DType) -> Self {
129 Self { dtype }
130 }
131
132 pub fn dtype(&self) -> &DType {
133 &self.dtype
134 }
135}
136
137impl private::Sealed for TypedRuleContext {}
138impl RewriteContext for TypedRuleContext {}
139
140/// A context for rewrite rules that don't need dtype information.
141#[derive(Debug, Clone, Copy, Default)]
142pub struct RuleContext;
143
144impl private::Sealed for RuleContext {}
145impl RewriteContext for RuleContext {}
146
147impl From<&TypedRuleContext> for RuleContext {
148 fn from(_value: &TypedRuleContext) -> Self {
149 RuleContext
150 }
151}
152
153/// Type-erased wrappers that allows dynamic dispatch.
154pub(crate) trait DynReduceRule: Debug + Send + Sync {
155 fn reduce(&self, expr: &Expression, ctx: &RuleContext) -> VortexResult<Option<Expression>>;
156}
157
158pub(crate) trait DynTypedReduceRule: Debug + Send + Sync {
159 fn reduce(&self, expr: &Expression, ctx: &TypedRuleContext)
160 -> VortexResult<Option<Expression>>;
161}
162
163pub(crate) trait DynParentReduceRule: Debug + Send + Sync {
164 fn reduce_parent(
165 &self,
166 expr: &Expression,
167 parent: &Expression,
168 child_idx: usize,
169 ctx: &RuleContext,
170 ) -> VortexResult<Option<Expression>>;
171}
172
173pub(crate) trait DynTypedParentReduceRule: Debug + Send + Sync {
174 fn reduce_parent(
175 &self,
176 expr: &Expression,
177 parent: &Expression,
178 child_idx: usize,
179 ctx: &TypedRuleContext,
180 ) -> VortexResult<Option<Expression>>;
181}