vortex_array/optimizer/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use vortex_error::VortexResult;
9
10use crate::array::ArrayRef;
11use crate::vtable::ArrayId;
12use crate::vtable::VTable;
13
14/// Trait for matching array types in optimizer rules
15pub trait Matcher: Send + Sync + 'static {
16    type View<'a>;
17
18    /// Return the key for this matcher
19    fn key(&self) -> MatchKey;
20
21    /// Try to match the given array to this matcher type
22    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>>;
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, Hash)]
26pub enum MatchKey {
27    Any,
28    Array(ArrayId),
29}
30
31/// Matches any array type (wildcard matcher)
32#[derive(Debug)]
33pub struct AnyArray;
34impl Matcher for AnyArray {
35    type View<'a> = &'a ArrayRef;
36
37    fn key(&self) -> MatchKey {
38        MatchKey::Any
39    }
40
41    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
42        Some(array)
43    }
44}
45
46/// Matches a specific Array by its encoding ID.
47#[derive(Debug)]
48pub struct Exact<V: VTable> {
49    id: ArrayId,
50    _phantom: PhantomData<V>,
51}
52impl<V: VTable> Matcher for Exact<V> {
53    type View<'a> = &'a V::Array;
54
55    fn key(&self) -> MatchKey {
56        MatchKey::Array(self.id.clone())
57    }
58
59    fn try_match<'a>(&self, parent: &'a ArrayRef) -> Option<Self::View<'a>> {
60        parent.as_opt::<V>()
61    }
62}
63impl<V: VTable> Exact<V> {
64    /// Create a new Exact matcher for the given ArrayId.
65    ///
66    /// # Safety
67    ///
68    /// The optimizer will attempt to downcast the array to the type V when matching.
69    /// If an array with the given ID does not match type V, the rule will silently not be applied.
70    pub unsafe fn new_unchecked(id: ArrayId) -> Self {
71        Self {
72            id,
73            _phantom: PhantomData,
74        }
75    }
76}
77impl<V: VTable> From<&'static V> for Exact<V> {
78    fn from(vtable: &'static V) -> Self {
79        Self {
80            id: vtable.id(),
81            _phantom: PhantomData,
82        }
83    }
84}
85
86/// A rewrite rule that transforms arrays based on the array itself and its children
87pub trait ArrayReduceRule<M: Matcher>: Debug + Send + Sync + 'static {
88    /// Returns the matcher for this rule
89    fn matcher(&self) -> M;
90
91    /// Attempt to rewrite this array.
92    ///
93    /// Returns:
94    /// - `Ok(Some(new_array))` if the rule applied successfully
95    /// - `Ok(None)` if the rule doesn't apply
96    /// - `Err(e)` if an error occurred
97    fn reduce(&self, array: M::View<'_>) -> VortexResult<Option<ArrayRef>>;
98}
99
100/// A rewrite rule that transforms arrays based on parent context
101pub trait ArrayParentReduceRule<Child: Matcher, Parent: Matcher>:
102    Debug + Send + Sync + 'static
103{
104    /// Returns the matcher for the child array
105    fn child(&self) -> Child;
106
107    /// Returns the matcher for the parent array
108    fn parent(&self) -> Parent;
109
110    /// Attempt to rewrite this child array given information about its parent.
111    ///
112    /// Returns:
113    /// - `Ok(Some(new_array))` if the rule applied successfully
114    /// - `Ok(None)` if the rule doesn't apply
115    /// - `Err(e)` if an error occurred
116    fn reduce_parent(
117        &self,
118        child: Child::View<'_>,
119        parent: Parent::View<'_>,
120        child_idx: usize,
121    ) -> VortexResult<Option<ArrayRef>>;
122}
123
124/// Dynamic trait for array reduce rules
125pub trait DynArrayReduceRule: Debug + Send + Sync {
126    fn key(&self) -> MatchKey;
127
128    fn reduce(&self, array: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
129}
130
131/// Dynamic trait for array parent reduce rules
132pub trait DynArrayParentReduceRule: Debug + Send + Sync {
133    fn child_key(&self) -> MatchKey;
134
135    fn parent_key(&self) -> MatchKey;
136
137    fn reduce_parent(
138        &self,
139        array: &ArrayRef,
140        parent: &ArrayRef,
141        child_idx: usize,
142    ) -> VortexResult<Option<ArrayRef>>;
143}
144
145/// Adapter for ArrayReduceRule
146pub(crate) struct ReduceRuleAdapter<M, R> {
147    rule: R,
148    _phantom: PhantomData<M>,
149}
150
151impl<M, R> ReduceRuleAdapter<M, R> {
152    pub(crate) fn new(rule: R) -> Self {
153        Self {
154            rule,
155            _phantom: PhantomData,
156        }
157    }
158}
159
160impl<M: Matcher, R: ArrayReduceRule<M>> Debug for ReduceRuleAdapter<M, R> {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("ArrayReduceRuleAdapter")
163            .field("matcher", &type_name::<M>())
164            .field("rule", &self.rule)
165            .finish()
166    }
167}
168
169/// Adapter for ArrayParentReduceRule
170pub(crate) struct ParentReduceRuleAdapter<Child: Matcher, Parent: Matcher, R> {
171    rule: R,
172    _phantom: PhantomData<(Child, Parent)>,
173}
174
175impl<Child: Matcher, Parent: Matcher, R> ParentReduceRuleAdapter<Child, Parent, R> {
176    pub(crate) fn new(rule: R) -> Self {
177        Self {
178            rule,
179            _phantom: PhantomData,
180        }
181    }
182}
183
184impl<Child: Matcher, Parent: Matcher, R: Debug> Debug
185    for ParentReduceRuleAdapter<Child, Parent, R>
186{
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        f.debug_struct("ArrayParentReduceRuleAdapter")
189            .field("child", &type_name::<Child>())
190            .field("parent", &type_name::<Parent>())
191            .field("rule", &self.rule)
192            .finish()
193    }
194}
195
196impl<M: Matcher, R: ArrayReduceRule<M>> DynArrayReduceRule for ReduceRuleAdapter<M, R> {
197    fn key(&self) -> MatchKey {
198        self.rule.matcher().key()
199    }
200
201    fn reduce(&self, array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
202        let Some(view) = self.rule.matcher().try_match(array) else {
203            return Ok(None);
204        };
205        self.rule.reduce(view)
206    }
207}
208
209impl<Child, Parent, R> DynArrayParentReduceRule for ParentReduceRuleAdapter<Child, Parent, R>
210where
211    Child: Matcher,
212    Parent: Matcher,
213    R: ArrayParentReduceRule<Child, Parent>,
214{
215    fn child_key(&self) -> MatchKey {
216        self.rule.child().key()
217    }
218
219    fn parent_key(&self) -> MatchKey {
220        self.rule.parent().key()
221    }
222
223    fn reduce_parent(
224        &self,
225        child: &ArrayRef,
226        parent: &ArrayRef,
227        child_idx: usize,
228    ) -> VortexResult<Option<ArrayRef>> {
229        let Some(child_view) = self.rule.child().try_match(child) else {
230            return Ok(None);
231        };
232        let Some(parent_view) = self.rule.parent().try_match(parent) else {
233            return Ok(None);
234        };
235        self.rule.reduce_parent(child_view, parent_view, child_idx)
236    }
237}