vortex_array/optimizer/
rules.rs1use std::any::type_name;
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use vortex_error::VortexResult;
9
10use crate::array::ArrayRef;
11use crate::matchers::MatchKey;
12use crate::matchers::Matcher;
13use crate::vtable::VTable;
14
15pub trait ArrayReduceRule<V: VTable>: Debug + Send + Sync + 'static {
17 fn reduce(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>>;
24}
25
26pub trait ArrayParentReduceRule<V: VTable>: Debug + Send + Sync + 'static {
28 type Parent: Matcher;
29
30 fn parent(&self) -> Self::Parent;
32
33 fn reduce_parent(
40 &self,
41 array: &V::Array,
42 parent: <Self::Parent as Matcher>::View<'_>,
43 child_idx: usize,
44 ) -> VortexResult<Option<ArrayRef>>;
45}
46
47pub trait DynArrayParentReduceRule<V: VTable>: Debug + Send + Sync {
49 fn parent_key(&self) -> MatchKey;
50
51 fn reduce_parent(
52 &self,
53 array: &V::Array,
54 parent: &ArrayRef,
55 child_idx: usize,
56 ) -> VortexResult<Option<ArrayRef>>;
57}
58
59pub struct ParentReduceRuleAdapter<V, R> {
61 rule: R,
62 _phantom: PhantomData<V>,
63}
64
65impl<V: VTable, R: ArrayParentReduceRule<V>> Debug for ParentReduceRuleAdapter<V, R> {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("ArrayParentReduceRuleAdapter")
68 .field("parent", &type_name::<R::Parent>())
69 .field("rule", &self.rule)
70 .finish()
71 }
72}
73
74impl<V: VTable, R: ArrayParentReduceRule<V>> DynArrayParentReduceRule<V>
75 for ParentReduceRuleAdapter<V, R>
76{
77 fn parent_key(&self) -> MatchKey {
78 self.rule.parent().key()
79 }
80
81 fn reduce_parent(
82 &self,
83 child: &V::Array,
84 parent: &ArrayRef,
85 child_idx: usize,
86 ) -> VortexResult<Option<ArrayRef>> {
87 let Some(parent_view) = self.rule.parent().try_match(parent) else {
88 return Ok(None);
89 };
90 self.rule.reduce_parent(child, parent_view, child_idx)
91 }
92}
93
94pub struct ReduceRuleSet<V: VTable> {
95 rules: &'static [&'static dyn ArrayReduceRule<V>],
96}
97
98impl<V: VTable> ReduceRuleSet<V> {
99 pub const fn new(rules: &'static [&'static dyn ArrayReduceRule<V>]) -> Self {
101 Self { rules }
102 }
103
104 pub fn evaluate(&self, array: &V::Array) -> VortexResult<Option<ArrayRef>> {
106 for rule in self.rules.iter() {
107 if let Some(reduced) = rule.reduce(array)? {
108 return Ok(Some(reduced));
109 }
110 }
111 Ok(None)
112 }
113}
114
115pub struct ParentRuleSet<V: VTable> {
117 rules: &'static [&'static dyn DynArrayParentReduceRule<V>],
118}
119
120impl<V: VTable> ParentRuleSet<V> {
121 pub const fn new(rules: &'static [&'static dyn DynArrayParentReduceRule<V>]) -> Self {
125 Self { rules }
126 }
127
128 pub const fn lift<R: ArrayParentReduceRule<V>>(
130 rule: &'static R,
131 ) -> &'static dyn DynArrayParentReduceRule<V> {
132 const {
134 assert!(
135 !(size_of::<R>() != 0),
136 "Rule must be zero-sized to be lifted"
137 );
138 }
139 unsafe { &*(rule as *const R as *const ParentReduceRuleAdapter<V, R>) }
140 }
141
142 pub fn evaluate(
144 &self,
145 child: &V::Array,
146 parent: &ArrayRef,
147 child_idx: usize,
148 ) -> VortexResult<Option<ArrayRef>> {
149 for rule in self.rules.iter() {
150 if let MatchKey::Array(id) = rule.parent_key()
151 && parent.encoding_id() != id
152 {
153 continue;
154 }
155 if let Some(reduced) = rule.reduce_parent(child, parent, child_idx)? {
156 #[cfg(debug_assertions)]
158 {
159 vortex_error::vortex_ensure!(
160 reduced.len() == parent.len(),
161 "Reduced array length mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
162 rule,
163 parent.display_tree(),
164 reduced.display_tree()
165 );
166 vortex_error::vortex_ensure!(
167 reduced.dtype() == parent.dtype(),
168 "Reduced array dtype mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
169 rule,
170 parent.display_tree(),
171 reduced.display_tree()
172 );
173 }
174
175 return Ok(Some(reduced));
176 }
177 }
178 Ok(None)
179 }
180}