vortex_array/optimizer/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use vortex_error::VortexResult;
9
10use crate::array::ArrayRef;
11use crate::matchers::MatchKey;
12use crate::matchers::Matcher;
13use crate::vtable::VTable;
14
15/// A rewrite rule that transforms arrays based on their own content
16pub trait ArrayReduceRule<V: VTable>: Debug + Send + Sync + 'static {
17    /// Attempt to rewrite this array.
18    ///
19    /// Returns:
20    /// - `Ok(Some(new_array))` if the rule applied successfully
21    /// - `Ok(None)` if the rule doesn't apply
22    /// - `Err(e)` if an error occurred
23    fn reduce(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>>;
24}
25
26/// A rewrite rule that transforms arrays based on parent context
27pub trait ArrayParentReduceRule<V: VTable>: Debug + Send + Sync + 'static {
28    type Parent: Matcher;
29
30    /// Returns the matcher for the parent array
31    fn parent(&self) -> Self::Parent;
32
33    /// Attempt to rewrite this child array given information about its parent.
34    ///
35    /// Returns:
36    /// - `Ok(Some(new_array))` if the rule applied successfully
37    /// - `Ok(None)` if the rule doesn't apply
38    /// - `Err(e)` if an error occurred
39    fn reduce_parent(
40        &self,
41        array: &V::Array,
42        parent: <Self::Parent as Matcher>::View<'_>,
43        child_idx: usize,
44    ) -> VortexResult<Option<ArrayRef>>;
45}
46
47/// Dynamic trait for array parent reduce rules
48pub trait DynArrayParentReduceRule<V: VTable>: Debug + Send + Sync {
49    fn parent_key(&self) -> MatchKey;
50
51    fn reduce_parent(
52        &self,
53        array: &V::Array,
54        parent: &ArrayRef,
55        child_idx: usize,
56    ) -> VortexResult<Option<ArrayRef>>;
57}
58
59/// Adapter for ArrayParentReduceRule
60pub struct ParentReduceRuleAdapter<V, R> {
61    rule: R,
62    _phantom: PhantomData<V>,
63}
64
65impl<V: VTable, R: ArrayParentReduceRule<V>> Debug for ParentReduceRuleAdapter<V, R> {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("ArrayParentReduceRuleAdapter")
68            .field("parent", &type_name::<R::Parent>())
69            .field("rule", &self.rule)
70            .finish()
71    }
72}
73
74impl<V: VTable, R: ArrayParentReduceRule<V>> DynArrayParentReduceRule<V>
75    for ParentReduceRuleAdapter<V, R>
76{
77    fn parent_key(&self) -> MatchKey {
78        self.rule.parent().key()
79    }
80
81    fn reduce_parent(
82        &self,
83        child: &V::Array,
84        parent: &ArrayRef,
85        child_idx: usize,
86    ) -> VortexResult<Option<ArrayRef>> {
87        let Some(parent_view) = self.rule.parent().try_match(parent) else {
88            return Ok(None);
89        };
90        self.rule.reduce_parent(child, parent_view, child_idx)
91    }
92}
93
94pub struct ReduceRuleSet<V: VTable> {
95    rules: &'static [&'static dyn ArrayReduceRule<V>],
96}
97
98impl<V: VTable> ReduceRuleSet<V> {
99    /// Create a new reduction rule set with the given rules.
100    pub const fn new(rules: &'static [&'static dyn ArrayReduceRule<V>]) -> Self {
101        Self { rules }
102    }
103
104    /// Evaluate the reduction rules on the given array.
105    pub fn evaluate(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>> {
106        for rule in self.rules.iter() {
107            if let Some(reduced) = rule.reduce(array)? {
108                return Ok(Some(reduced));
109            }
110        }
111        Ok(None)
112    }
113}
114
115/// A set of parent reduction rules for a specific child array encoding.
116pub struct ParentRuleSet<V: VTable> {
117    rules: &'static [&'static dyn DynArrayParentReduceRule<V>],
118}
119
120impl<V: VTable> ParentRuleSet<V> {
121    /// Create a new parent rule set with the given rules.
122    ///
123    /// Use [`ParentRuleSet::lift`] to lift static rules into dynamic trait objects.
124    pub const fn new(rules: &'static [&'static dyn DynArrayParentReduceRule<V>]) -> Self {
125        Self { rules }
126    }
127
128    /// Lift the given rule into a dynamic trait object.
129    pub const fn lift<R: ArrayParentReduceRule<V>>(
130        rule: &'static R,
131    ) -> &'static dyn DynArrayParentReduceRule<V> {
132        // Assert that self is zero-sized
133        const {
134            assert!(
135                !(size_of::<R>() != 0),
136                "Rule must be zero-sized to be lifted"
137            );
138        }
139        unsafe { &*(rule as *const R as *const ParentReduceRuleAdapter<V, R>) }
140    }
141
142    /// Evaluate the parent reduction rules on the given child and parent arrays.
143    pub fn evaluate(
144        &self,
145        child: &V::Array,
146        parent: &ArrayRef,
147        child_idx: usize,
148    ) -> VortexResult<Option<ArrayRef>> {
149        for rule in self.rules.iter() {
150            if let MatchKey::Array(id) = rule.parent_key()
151                && parent.encoding_id() != id
152            {
153                continue;
154            }
155            if let Some(reduced) = rule.reduce_parent(child, parent, child_idx)? {
156                // Debug assertions because these checks are already run elsewhere.
157                #[cfg(debug_assertions)]
158                {
159                    vortex_error::vortex_ensure!(
160                        reduced.len() == parent.len(),
161                        "Reduced array length mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
162                        rule,
163                        parent.display_tree(),
164                        reduced.display_tree()
165                    );
166                    vortex_error::vortex_ensure!(
167                        reduced.dtype() == parent.dtype(),
168                        "Reduced array dtype mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
169                        rule,
170                        parent.display_tree(),
171                        reduced.display_tree()
172                    );
173                }
174
175                return Ok(Some(reduced));
176            }
177        }
178        Ok(None)
179    }
180}