vortex_array/stats/
rewrite.rs1use std::fmt::Debug;
7use std::sync::Arc;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_session::VortexSession;
12
13use crate::dtype::DType;
14use crate::expr::Expression;
15use crate::expr::or_collect;
16use crate::scalar_fn::ScalarFnId;
17use crate::stats::session::StatsSessionExt;
18
19mod builtins;
20
21pub(crate) use builtins::register_builtins;
22
23pub type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;
25
26pub trait StatsRewriteRule: Debug + Send + Sync + 'static {
44 fn scalar_fn_id(&self) -> ScalarFnId;
46
47 fn falsify(
55 &self,
56 expr: &Expression,
57 ctx: &StatsRewriteCtx<'_>,
58 ) -> VortexResult<Option<Expression>> {
59 _ = expr;
60 _ = ctx;
61 Ok(None)
62 }
63
64 fn satisfy(
75 &self,
76 expr: &Expression,
77 ctx: &StatsRewriteCtx<'_>,
78 ) -> VortexResult<Option<Expression>> {
79 _ = expr;
80 _ = ctx;
81 Ok(None)
82 }
83}
84
85pub struct StatsRewriteCtx<'a> {
87 session: &'a VortexSession,
88 scope: &'a DType,
89}
90
91impl<'a> StatsRewriteCtx<'a> {
92 pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self {
94 Self { session, scope }
95 }
96
97 pub fn session(&self) -> &'a VortexSession {
99 self.session
100 }
101
102 pub fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
104 expr.return_dtype(self.scope)
105 }
106
107 pub fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
109 self.ensure_predicate(expr)?;
110 rewrite(expr, self, StatsRewriteRule::falsify)
111 }
112
113 pub fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
115 self.ensure_predicate(expr)?;
116 rewrite(expr, self, StatsRewriteRule::satisfy)
117 }
118
119 fn ensure_predicate(&self, expr: &Expression) -> VortexResult<()> {
120 let dtype = self.return_dtype(expr)?;
121 vortex_ensure!(
122 matches!(dtype, DType::Bool(_)),
123 "Stats rewrites require a boolean predicate, got {dtype}",
124 );
125 Ok(())
126 }
127}
128
129fn rewrite(
130 expr: &Expression,
131 ctx: &StatsRewriteCtx<'_>,
132 apply: fn(
133 &dyn StatsRewriteRule,
134 &Expression,
135 &StatsRewriteCtx<'_>,
136 ) -> VortexResult<Option<Expression>>,
137) -> VortexResult<Option<Expression>> {
138 let rules = ctx
139 .session()
140 .stats()
141 .rewrite_rules_for(expr.scalar_fn().id());
142 let Some(rules) = rules else {
143 return Ok(None);
144 };
145
146 let mut rewrites = Vec::new();
147 for rule in rules.iter() {
148 if let Some(rewrite) = apply(rule.as_ref(), expr, ctx)? {
149 rewrites.push(rewrite);
150 }
151 }
152
153 Ok(or_collect(rewrites))
154}
155
156#[cfg(test)]
157mod tests {
158 use vortex_error::VortexResult;
159
160 use super::StatsRewriteCtx;
161 use super::StatsRewriteRule;
162 use crate::dtype::DType;
163 use crate::dtype::Nullability;
164 use crate::dtype::PType;
165 use crate::expr::Expression;
166 use crate::expr::lit;
167 use crate::expr::or;
168 use crate::scalar_fn::ScalarFnId;
169 use crate::scalar_fn::ScalarFnVTable;
170 use crate::scalar_fn::fns::literal::Literal;
171 use crate::stats::session::StatsSessionExt;
172
173 #[derive(Debug)]
174 struct StaticLiteralRule {
175 falsifier: Option<Expression>,
176 satisfier: Option<Expression>,
177 }
178
179 impl StatsRewriteRule for StaticLiteralRule {
180 fn scalar_fn_id(&self) -> ScalarFnId {
181 Literal.id()
182 }
183
184 fn falsify(
185 &self,
186 _expr: &Expression,
187 _ctx: &StatsRewriteCtx<'_>,
188 ) -> VortexResult<Option<Expression>> {
189 Ok(self.falsifier.clone())
190 }
191
192 fn satisfy(
193 &self,
194 _expr: &Expression,
195 _ctx: &StatsRewriteCtx<'_>,
196 ) -> VortexResult<Option<Expression>> {
197 Ok(self.satisfier.clone())
198 }
199 }
200
201 #[test]
202 fn combines_multiple_falsifiers_with_or() -> VortexResult<()> {
203 let session = crate::array_session();
204 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
205 session.stats().register_rewrite(StaticLiteralRule {
206 falsifier: Some(lit(false)),
207 satisfier: None,
208 });
209 session.stats().register_rewrite(StaticLiteralRule {
210 falsifier: Some(lit(true)),
211 satisfier: None,
212 });
213
214 assert_eq!(
215 lit(true).falsify(&dtype, &session)?,
216 Some(or(lit(false), lit(true)))
217 );
218 Ok(())
219 }
220
221 #[test]
222 fn combines_multiple_satisfiers_with_or() -> VortexResult<()> {
223 let session = crate::array_session();
224 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
225 session.stats().register_rewrite(StaticLiteralRule {
226 falsifier: None,
227 satisfier: Some(lit(false)),
228 });
229 session.stats().register_rewrite(StaticLiteralRule {
230 falsifier: None,
231 satisfier: Some(lit(true)),
232 });
233
234 assert_eq!(
235 lit(true).satisfy(&dtype, &session)?,
236 Some(or(lit(false), lit(true)))
237 );
238 Ok(())
239 }
240
241 #[test]
242 fn unregistered_expression_has_no_rewrite() -> VortexResult<()> {
243 let session = crate::array_session();
244 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
245
246 assert_eq!(lit(true).falsify(&dtype, &session)?, None);
247 assert_eq!(lit(true).satisfy(&dtype, &session)?, None);
248 Ok(())
249 }
250
251 #[test]
252 fn non_predicate_expression_errors() {
253 let session = crate::array_session();
254 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
255
256 assert!(lit(7).falsify(&dtype, &session).is_err());
257 assert!(lit(7).satisfy(&dtype, &session).is_err());
258 }
259}