Skip to main content

vortex_array/stats/
bind.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Bind abstract `vortex.stat` expressions to a concrete stats representation.
5//!
6//! Stats rewrite rules describe pruning in terms of `vortex.stat(input, aggregate_fn)` placeholders
7//! so the rewrite is independent of where statistics are stored. These stat placeholders are
8//! abstract because they name the statistic needed for a proof, but not how that statistic is
9//! represented by a specific layout or reader.
10//!
11//! Binding is the later pass that replaces each abstract placeholder with the representation used
12//! by a caller: zone-map field references, file-level stat literals, or typed nulls for missing
13//! stats. This lets all callers share the same falsification rules while keeping layout-specific
14//! stat storage behind [`StatBinder`].
15
16use vortex_error::VortexResult;
17
18use crate::aggregate_fn::AggregateFnRef;
19use crate::dtype::DType;
20use crate::expr::Expression;
21use crate::expr::lit;
22use crate::expr::traversal::NodeExt;
23use crate::expr::traversal::Transformed;
24use crate::scalar::Scalar;
25use crate::scalar_fn::fns::stat::StatFn;
26
27/// A target that can bind abstract statistics to concrete expressions.
28///
29/// Implementations define how a pruning proof should read stats from a specific backing
30/// representation. For example, a zone-map binder can translate a `max(col)` placeholder into a
31/// field reference in the per-zone stats table, while a file-stats binder can translate the same
32/// placeholder into a literal value from the file footer.
33pub trait StatBinder {
34    /// The dtype scope used to type-check expressions before stats are bound.
35    fn scope(&self) -> &DType;
36
37    /// Bind `aggregate_fn(input)` to a concrete expression.
38    ///
39    /// Implementations should return `Ok(None)` when the requested aggregate
40    /// statistic is unavailable in their backing representation.
41    fn bind_aggregate(
42        &self,
43        input: &Expression,
44        aggregate_fn: &AggregateFnRef,
45        stat_dtype: &DType,
46    ) -> VortexResult<Option<Expression>>;
47
48    /// Expression to use when a stat is unavailable.
49    ///
50    /// The default is a nullable null literal, which preserves three-valued
51    /// pruning semantics for stats-table execution.
52    fn missing_stat(&self, dtype: DType) -> VortexResult<Expression> {
53        Ok(null_expr(dtype))
54    }
55}
56
57/// Bind all `vortex.stat` expressions in `predicate`.
58///
59/// The predicate is usually the output of a stats rewrite rule. Rewrite rules
60/// are responsible for expressing stat semantics; binding maps aggregate-backed
61/// stat requests to the concrete stats representation supported by the binder.
62pub fn bind_stats<B: StatBinder + ?Sized>(
63    predicate: Expression,
64    binder: &B,
65) -> VortexResult<Expression> {
66    let scope = binder.scope().clone();
67    Ok(predicate
68        .transform_down(|expr| {
69            if !expr.is::<StatFn>() {
70                return Ok(Transformed::no(expr));
71            }
72
73            match bind_stat_fn(&expr, &scope, binder)? {
74                Some(bound) => Ok(Transformed::yes(bound)),
75                None => {
76                    let dtype = expr.return_dtype(&scope)?;
77                    Ok(Transformed::yes(binder.missing_stat(dtype)?))
78                }
79            }
80        })?
81        .into_inner())
82}
83
84fn bind_stat_fn(
85    expr: &Expression,
86    scope: &DType,
87    binder: &(impl StatBinder + ?Sized),
88) -> VortexResult<Option<Expression>> {
89    let options = expr.as_::<StatFn>();
90    let aggregate_fn = options.aggregate_fn();
91    // `StatFn` has exactly one child: the expression the aggregate statistic is computed over.
92    let input = expr.child(0);
93
94    let stat_dtype = expr.return_dtype(scope)?;
95    binder.bind_aggregate(input, aggregate_fn, &stat_dtype)
96}
97
98fn null_expr(dtype: DType) -> Expression {
99    lit(Scalar::null(dtype.as_nullable()))
100}
101
102#[cfg(test)]
103mod tests {
104    use vortex_error::VortexResult;
105
106    use super::*;
107    use crate::dtype::Nullability;
108    use crate::dtype::PType;
109    use crate::dtype::StructFields;
110    use crate::expr::and;
111    use crate::expr::col;
112    use crate::expr::get_item;
113    use crate::expr::is_null;
114    use crate::expr::or;
115    use crate::expr::root;
116    use crate::expr::stats::Stat;
117    use crate::stats::all_non_nan;
118    use crate::stats::nan_count;
119
120    struct TestBinder {
121        input_scope: DType,
122        bind_nan_count: bool,
123    }
124
125    impl TestBinder {
126        fn new(bind_nan_count: bool) -> Self {
127            Self {
128                input_scope: DType::Struct(
129                    StructFields::from_iter([(
130                        "f",
131                        DType::Primitive(PType::F32, Nullability::NonNullable),
132                    )]),
133                    Nullability::NonNullable,
134                ),
135                bind_nan_count,
136            }
137        }
138    }
139
140    impl StatBinder for TestBinder {
141        fn scope(&self) -> &DType {
142            &self.input_scope
143        }
144
145        fn bind_aggregate(
146            &self,
147            _input: &Expression,
148            aggregate_fn: &AggregateFnRef,
149            _stat_dtype: &DType,
150        ) -> VortexResult<Option<Expression>> {
151            let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else {
152                return Ok(None);
153            };
154
155            if stat == Stat::NaNCount && self.bind_nan_count {
156                Ok(Some(get_item("f_nan_count", root())))
157            } else {
158                Ok(None)
159            }
160        }
161    }
162
163    #[test]
164    fn nan_count_binds_to_direct_stat_slot() -> VortexResult<()> {
165        let binder = TestBinder::new(true);
166
167        let bound = bind_stats(nan_count(col("f")), &binder)?;
168
169        assert_eq!(bound, col("f_nan_count"));
170        Ok(())
171    }
172
173    #[test]
174    fn all_non_nan_does_not_derive_from_nan_count() -> VortexResult<()> {
175        let binder = TestBinder::new(true);
176
177        let bound = bind_stats(all_non_nan(col("f")), &binder)?;
178
179        assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable))));
180        Ok(())
181    }
182
183    #[test]
184    fn missing_stats_bind_to_null_without_reducing() -> VortexResult<()> {
185        let binder = TestBinder::new(false);
186        let null_bool = lit(Scalar::null(DType::Bool(Nullability::Nullable)));
187
188        let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &binder)?;
189
190        assert_eq!(bound, and(lit(false), null_bool.clone()));
191
192        let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &binder)?;
193
194        assert_eq!(bound, or(lit(true), null_bool));
195        Ok(())
196    }
197
198    #[test]
199    fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> {
200        let binder = TestBinder::new(false);
201
202        let bound = bind_stats(is_null(col("f")), &binder)?;
203
204        assert_eq!(bound, is_null(col("f")));
205        Ok(())
206    }
207}