Skip to main content

vortex_array/stats/
rewrite.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-registered rewrite rules for aggregate-backed stats expressions.
5
6use std::fmt::Debug;
7use std::sync::Arc;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_session::VortexSession;
12
13use crate::dtype::DType;
14use crate::expr::Expression;
15use crate::expr::or_collect;
16use crate::scalar_fn::ScalarFnId;
17use crate::stats::session::StatsSessionExt;
18
19mod builtins;
20
21pub(crate) use builtins::register_builtins;
22
23/// Shared reference to a stats rewrite rule.
24pub type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;
25
26/// A plugin-provided rule for predicates whose root scalar function matches this rule.
27///
28/// Rules do not produce expressions equivalent to `expr`. They produce optional sufficient
29/// conditions over stats for the current scope:
30///
31/// - a falsifier evaluating to `true` proves that `expr` is false for every row in the scope;
32/// - a satisfier evaluating to `true` proves that `expr` is true for every row in the scope.
33///
34/// Returning `None` means this rule cannot prove anything for the expression. A returned proof
35/// expression that evaluates to `false` or `null` is also inconclusive.
36///
37/// Multiple rules may be registered for the same scalar function. Their proofs are combined with
38/// `OR`, so every proof returned by an individual rule must be sound on its own.
39///
40/// `expr` is the full predicate expression whose root scalar function id is
41/// [`Self::scalar_fn_id`]. Use [`StatsRewriteCtx`] to resolve dtypes and recursively rewrite child
42/// predicates.
43pub trait StatsRewriteRule: Debug + Send + Sync + 'static {
44    /// Returns the scalar function id handled by this rule.
45    fn scalar_fn_id(&self) -> ScalarFnId;
46
47    /// Returns a stats-backed proof that `expr` is false for the current scope.
48    ///
49    /// If the returned expression evaluates to `true` against the scope's stats, then `expr` is
50    /// guaranteed to be false for every row in that scope. A returned proof expression that
51    /// evaluates to `false` or `null` is inconclusive.
52    ///
53    /// Returns `Ok(None)` when this rule cannot construct a sound falsity proof for `expr`.
54    fn falsify(
55        &self,
56        expr: &Expression,
57        ctx: &StatsRewriteCtx<'_>,
58    ) -> VortexResult<Option<Expression>> {
59        _ = expr;
60        _ = ctx;
61        Ok(None)
62    }
63
64    /// Returns a stats-backed proof that `expr` is true for the current scope.
65    ///
66    /// If the returned expression evaluates to `true` against the scope's stats, then `expr` is
67    /// guaranteed to be true for every row in that scope. A returned proof expression that
68    /// evaluates to `false` or `null` is inconclusive.
69    ///
70    /// This is not the complement of [`Self::falsify`]; both methods are one-way proofs and may be
71    /// implemented independently.
72    ///
73    /// Returns `Ok(None)` when this rule cannot construct a sound truth proof for `expr`.
74    fn satisfy(
75        &self,
76        expr: &Expression,
77        ctx: &StatsRewriteCtx<'_>,
78    ) -> VortexResult<Option<Expression>> {
79        _ = expr;
80        _ = ctx;
81        Ok(None)
82    }
83}
84
85/// Context passed to stats rewrite rules.
86pub struct StatsRewriteCtx<'a> {
87    session: &'a VortexSession,
88    scope: &'a DType,
89}
90
91impl<'a> StatsRewriteCtx<'a> {
92    /// Create a rewrite context for `session`.
93    pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self {
94        Self { session, scope }
95    }
96
97    /// Returns the session that owns the rewrite registry.
98    pub fn session(&self) -> &'a VortexSession {
99        self.session
100    }
101
102    /// Return the dtype of `expr` within this rewrite scope.
103    pub fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
104        expr.return_dtype(self.scope)
105    }
106
107    /// Rewrite `expr` into a stats-backed falsifier.
108    pub fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
109        self.ensure_predicate(expr)?;
110        rewrite(expr, self, StatsRewriteRule::falsify)
111    }
112
113    /// Rewrite `expr` into a stats-backed satisfier.
114    pub fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
115        self.ensure_predicate(expr)?;
116        rewrite(expr, self, StatsRewriteRule::satisfy)
117    }
118
119    fn ensure_predicate(&self, expr: &Expression) -> VortexResult<()> {
120        let dtype = self.return_dtype(expr)?;
121        vortex_ensure!(
122            matches!(dtype, DType::Bool(_)),
123            "Stats rewrites require a boolean predicate, got {dtype}",
124        );
125        Ok(())
126    }
127}
128
129fn rewrite(
130    expr: &Expression,
131    ctx: &StatsRewriteCtx<'_>,
132    apply: fn(
133        &dyn StatsRewriteRule,
134        &Expression,
135        &StatsRewriteCtx<'_>,
136    ) -> VortexResult<Option<Expression>>,
137) -> VortexResult<Option<Expression>> {
138    let rules = ctx
139        .session()
140        .stats()
141        .rewrite_rules_for(expr.scalar_fn().id());
142    let Some(rules) = rules else {
143        return Ok(None);
144    };
145
146    let mut rewrites = Vec::new();
147    for rule in rules.iter() {
148        if let Some(rewrite) = apply(rule.as_ref(), expr, ctx)? {
149            rewrites.push(rewrite);
150        }
151    }
152
153    Ok(or_collect(rewrites))
154}
155
156#[cfg(test)]
157mod tests {
158    use vortex_error::VortexResult;
159
160    use super::StatsRewriteCtx;
161    use super::StatsRewriteRule;
162    use crate::dtype::DType;
163    use crate::dtype::Nullability;
164    use crate::dtype::PType;
165    use crate::expr::Expression;
166    use crate::expr::lit;
167    use crate::expr::or;
168    use crate::scalar_fn::ScalarFnId;
169    use crate::scalar_fn::ScalarFnVTable;
170    use crate::scalar_fn::fns::literal::Literal;
171    use crate::stats::session::StatsSessionExt;
172
173    #[derive(Debug)]
174    struct StaticLiteralRule {
175        falsifier: Option<Expression>,
176        satisfier: Option<Expression>,
177    }
178
179    impl StatsRewriteRule for StaticLiteralRule {
180        fn scalar_fn_id(&self) -> ScalarFnId {
181            Literal.id()
182        }
183
184        fn falsify(
185            &self,
186            _expr: &Expression,
187            _ctx: &StatsRewriteCtx<'_>,
188        ) -> VortexResult<Option<Expression>> {
189            Ok(self.falsifier.clone())
190        }
191
192        fn satisfy(
193            &self,
194            _expr: &Expression,
195            _ctx: &StatsRewriteCtx<'_>,
196        ) -> VortexResult<Option<Expression>> {
197            Ok(self.satisfier.clone())
198        }
199    }
200
201    #[test]
202    fn combines_multiple_falsifiers_with_or() -> VortexResult<()> {
203        let session = crate::array_session();
204        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
205        session.stats().register_rewrite(StaticLiteralRule {
206            falsifier: Some(lit(false)),
207            satisfier: None,
208        });
209        session.stats().register_rewrite(StaticLiteralRule {
210            falsifier: Some(lit(true)),
211            satisfier: None,
212        });
213
214        assert_eq!(
215            lit(true).falsify(&dtype, &session)?,
216            Some(or(lit(false), lit(true)))
217        );
218        Ok(())
219    }
220
221    #[test]
222    fn combines_multiple_satisfiers_with_or() -> VortexResult<()> {
223        let session = crate::array_session();
224        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
225        session.stats().register_rewrite(StaticLiteralRule {
226            falsifier: None,
227            satisfier: Some(lit(false)),
228        });
229        session.stats().register_rewrite(StaticLiteralRule {
230            falsifier: None,
231            satisfier: Some(lit(true)),
232        });
233
234        assert_eq!(
235            lit(true).satisfy(&dtype, &session)?,
236            Some(or(lit(false), lit(true)))
237        );
238        Ok(())
239    }
240
241    #[test]
242    fn unregistered_expression_has_no_rewrite() -> VortexResult<()> {
243        let session = crate::array_session();
244        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
245
246        assert_eq!(lit(true).falsify(&dtype, &session)?, None);
247        assert_eq!(lit(true).satisfy(&dtype, &session)?, None);
248        Ok(())
249    }
250
251    #[test]
252    fn non_predicate_expression_errors() {
253        let session = crate::array_session();
254        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
255
256        assert!(lit(7).falsify(&dtype, &session).is_err());
257        assert!(lit(7).satisfy(&dtype, &session).is_err());
258    }
259}