vortex_array/optimizer/
rules.rs1use std::any::type_name;
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use vortex_error::VortexResult;
9
10use crate::array::ArrayRef;
11use crate::matcher::Matcher;
12use crate::vtable::VTable;
13
14pub trait ArrayReduceRule<V: VTable>: Debug + Send + Sync + 'static {
16 fn reduce(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>>;
23}
24
25pub trait ArrayParentReduceRule<V: VTable>: Debug + Send + Sync + 'static {
27 type Parent: Matcher;
28
29 fn reduce_parent(
36 &self,
37 array: &V::Array,
38 parent: <Self::Parent as Matcher>::Match<'_>,
39 child_idx: usize,
40 ) -> VortexResult<Option<ArrayRef>>;
41}
42
43pub trait DynArrayParentReduceRule<V: VTable>: Debug + Send + Sync {
45 fn matches(&self, parent: &ArrayRef) -> bool;
46
47 fn reduce_parent(
48 &self,
49 array: &V::Array,
50 parent: &ArrayRef,
51 child_idx: usize,
52 ) -> VortexResult<Option<ArrayRef>>;
53}
54
55pub struct ParentReduceRuleAdapter<V, R> {
57 rule: R,
58 _phantom: PhantomData<V>,
59}
60
61impl<V: VTable, R: ArrayParentReduceRule<V>> Debug for ParentReduceRuleAdapter<V, R> {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("ArrayParentReduceRuleAdapter")
64 .field("parent", &type_name::<R::Parent>())
65 .field("rule", &self.rule)
66 .finish()
67 }
68}
69
70impl<V: VTable, K: ArrayParentReduceRule<V>> DynArrayParentReduceRule<V>
71 for ParentReduceRuleAdapter<V, K>
72{
73 fn matches(&self, parent: &ArrayRef) -> bool {
74 K::Parent::matches(parent)
75 }
76
77 fn reduce_parent(
78 &self,
79 child: &V::Array,
80 parent: &ArrayRef,
81 child_idx: usize,
82 ) -> VortexResult<Option<ArrayRef>> {
83 let Some(parent_view) = K::Parent::try_match(parent) else {
84 return Ok(None);
85 };
86 self.rule.reduce_parent(child, parent_view, child_idx)
87 }
88}
89
90pub struct ReduceRuleSet<V: VTable> {
91 rules: &'static [&'static dyn ArrayReduceRule<V>],
92}
93
94impl<V: VTable> ReduceRuleSet<V> {
95 pub const fn new(rules: &'static [&'static dyn ArrayReduceRule<V>]) -> Self {
97 Self { rules }
98 }
99
100 pub fn evaluate(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>> {
102 for rule in self.rules.iter() {
103 if let Some(reduced) = rule.reduce(array)? {
104 return Ok(Some(reduced));
105 }
106 }
107 Ok(None)
108 }
109}
110
111pub struct ParentRuleSet<V: VTable> {
113 rules: &'static [&'static dyn DynArrayParentReduceRule<V>],
114}
115
116impl<V: VTable> ParentRuleSet<V> {
117 pub const fn new(rules: &'static [&'static dyn DynArrayParentReduceRule<V>]) -> Self {
121 Self { rules }
122 }
123
124 pub const fn lift<R: ArrayParentReduceRule<V>>(
126 rule: &'static R,
127 ) -> &'static dyn DynArrayParentReduceRule<V> {
128 const {
130 assert!(
131 !(size_of::<R>() != 0),
132 "Rule must be zero-sized to be lifted"
133 );
134 }
135 unsafe { &*(rule as *const R as *const ParentReduceRuleAdapter<V, R>) }
136 }
137
138 pub fn evaluate(
140 &self,
141 child: &V::Array,
142 parent: &ArrayRef,
143 child_idx: usize,
144 ) -> VortexResult<Option<ArrayRef>> {
145 for rule in self.rules.iter() {
146 if !rule.matches(parent) {
147 continue;
148 }
149 if let Some(reduced) = rule.reduce_parent(child, parent, child_idx)? {
150 #[cfg(debug_assertions)]
152 {
153 vortex_error::vortex_ensure!(
154 reduced.len() == parent.len(),
155 "Reduced array length mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
156 rule,
157 parent.encoding_id(),
158 reduced.encoding_id()
159 );
160 vortex_error::vortex_ensure!(
161 reduced.dtype() == parent.dtype(),
162 "Reduced array dtype mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
163 rule,
164 parent.encoding_id(),
165 reduced.encoding_id()
166 );
167 }
168
169 return Ok(Some(reduced));
170 }
171 }
172 Ok(None)
173 }
174}