Skip to main content

vortex_array/optimizer/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use vortex_error::VortexResult;
9
10use crate::array::ArrayRef;
11use crate::matcher::Matcher;
12use crate::vtable::VTable;
13
14/// A rewrite rule that transforms arrays based on their own content
15pub trait ArrayReduceRule<V: VTable>: Debug + Send + Sync + 'static {
16    /// Attempt to rewrite this array.
17    ///
18    /// Returns:
19    /// - `Ok(Some(new_array))` if the rule applied successfully
20    /// - `Ok(None)` if the rule doesn't apply
21    /// - `Err(e)` if an error occurred
22    fn reduce(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>>;
23}
24
25/// A rewrite rule that transforms arrays based on parent context
26pub trait ArrayParentReduceRule<V: VTable>: Debug + Send + Sync + 'static {
27    type Parent: Matcher;
28
29    /// Attempt to rewrite this child array given information about its parent.
30    ///
31    /// Returns:
32    /// - `Ok(Some(new_array))` if the rule applied successfully
33    /// - `Ok(None)` if the rule doesn't apply
34    /// - `Err(e)` if an error occurred
35    fn reduce_parent(
36        &self,
37        array: &V::Array,
38        parent: <Self::Parent as Matcher>::Match<'_>,
39        child_idx: usize,
40    ) -> VortexResult<Option<ArrayRef>>;
41}
42
43/// Dynamic trait for array parent reduce rules
44pub trait DynArrayParentReduceRule<V: VTable>: Debug + Send + Sync {
45    fn matches(&self, parent: &ArrayRef) -> bool;
46
47    fn reduce_parent(
48        &self,
49        array: &V::Array,
50        parent: &ArrayRef,
51        child_idx: usize,
52    ) -> VortexResult<Option<ArrayRef>>;
53}
54
55/// Adapter for ArrayParentReduceRule
56pub struct ParentReduceRuleAdapter<V, R> {
57    rule: R,
58    _phantom: PhantomData<V>,
59}
60
61impl<V: VTable, R: ArrayParentReduceRule<V>> Debug for ParentReduceRuleAdapter<V, R> {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("ArrayParentReduceRuleAdapter")
64            .field("parent", &type_name::<R::Parent>())
65            .field("rule", &self.rule)
66            .finish()
67    }
68}
69
70impl<V: VTable, K: ArrayParentReduceRule<V>> DynArrayParentReduceRule<V>
71    for ParentReduceRuleAdapter<V, K>
72{
73    fn matches(&self, parent: &ArrayRef) -> bool {
74        K::Parent::matches(parent)
75    }
76
77    fn reduce_parent(
78        &self,
79        child: &V::Array,
80        parent: &ArrayRef,
81        child_idx: usize,
82    ) -> VortexResult<Option<ArrayRef>> {
83        let Some(parent_view) = K::Parent::try_match(parent) else {
84            return Ok(None);
85        };
86        self.rule.reduce_parent(child, parent_view, child_idx)
87    }
88}
89
90pub struct ReduceRuleSet<V: VTable> {
91    rules: &'static [&'static dyn ArrayReduceRule<V>],
92}
93
94impl<V: VTable> ReduceRuleSet<V> {
95    /// Create a new reduction rule set with the given rules.
96    pub const fn new(rules: &'static [&'static dyn ArrayReduceRule<V>]) -> Self {
97        Self { rules }
98    }
99
100    /// Evaluate the reduction rules on the given array.
101    pub fn evaluate(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>> {
102        for rule in self.rules.iter() {
103            if let Some(reduced) = rule.reduce(array)? {
104                return Ok(Some(reduced));
105            }
106        }
107        Ok(None)
108    }
109}
110
111/// A set of parent reduction rules for a specific child array encoding.
112pub struct ParentRuleSet<V: VTable> {
113    rules: &'static [&'static dyn DynArrayParentReduceRule<V>],
114}
115
116impl<V: VTable> ParentRuleSet<V> {
117    /// Create a new parent rule set with the given rules.
118    ///
119    /// Use [`ParentRuleSet::lift`] to lift static rules into dynamic trait objects.
120    pub const fn new(rules: &'static [&'static dyn DynArrayParentReduceRule<V>]) -> Self {
121        Self { rules }
122    }
123
124    /// Lift the given rule into a dynamic trait object.
125    pub const fn lift<R: ArrayParentReduceRule<V>>(
126        rule: &'static R,
127    ) -> &'static dyn DynArrayParentReduceRule<V> {
128        // Assert that self is zero-sized
129        const {
130            assert!(
131                !(size_of::<R>() != 0),
132                "Rule must be zero-sized to be lifted"
133            );
134        }
135        unsafe { &*(rule as *const R as *const ParentReduceRuleAdapter<V, R>) }
136    }
137
138    /// Evaluate the parent reduction rules on the given child and parent arrays.
139    pub fn evaluate(
140        &self,
141        child: &V::Array,
142        parent: &ArrayRef,
143        child_idx: usize,
144    ) -> VortexResult<Option<ArrayRef>> {
145        for rule in self.rules.iter() {
146            if !rule.matches(parent) {
147                continue;
148            }
149            if let Some(reduced) = rule.reduce_parent(child, parent, child_idx)? {
150                // Debug assertions because these checks are already run elsewhere.
151                #[cfg(debug_assertions)]
152                {
153                    vortex_error::vortex_ensure!(
154                        reduced.len() == parent.len(),
155                        "Reduced array length mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
156                        rule,
157                        parent.encoding_id(),
158                        reduced.encoding_id()
159                    );
160                    vortex_error::vortex_ensure!(
161                        reduced.dtype() == parent.dtype(),
162                        "Reduced array dtype mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
163                        rule,
164                        parent.encoding_id(),
165                        reduced.encoding_id()
166                    );
167                }
168
169                return Ok(Some(reduced));
170            }
171        }
172        Ok(None)
173    }
174}