vortex_array/stats/
bind.rs1use vortex_error::VortexResult;
17
18use crate::aggregate_fn::AggregateFnRef;
19use crate::dtype::DType;
20use crate::expr::Expression;
21use crate::expr::lit;
22use crate::expr::traversal::NodeExt;
23use crate::expr::traversal::Transformed;
24use crate::scalar::Scalar;
25use crate::scalar_fn::fns::stat::StatFn;
26
27pub trait StatBinder {
34 fn scope(&self) -> &DType;
36
37 fn bind_aggregate(
42 &self,
43 input: &Expression,
44 aggregate_fn: &AggregateFnRef,
45 stat_dtype: &DType,
46 ) -> VortexResult<Option<Expression>>;
47
48 fn missing_stat(&self, dtype: DType) -> VortexResult<Expression> {
53 Ok(null_expr(dtype))
54 }
55}
56
57pub fn bind_stats<B: StatBinder + ?Sized>(
63 predicate: Expression,
64 binder: &B,
65) -> VortexResult<Expression> {
66 let scope = binder.scope().clone();
67 Ok(predicate
68 .transform_down(|expr| {
69 if !expr.is::<StatFn>() {
70 return Ok(Transformed::no(expr));
71 }
72
73 match bind_stat_fn(&expr, &scope, binder)? {
74 Some(bound) => Ok(Transformed::yes(bound)),
75 None => {
76 let dtype = expr.return_dtype(&scope)?;
77 Ok(Transformed::yes(binder.missing_stat(dtype)?))
78 }
79 }
80 })?
81 .into_inner())
82}
83
84fn bind_stat_fn(
85 expr: &Expression,
86 scope: &DType,
87 binder: &(impl StatBinder + ?Sized),
88) -> VortexResult<Option<Expression>> {
89 let options = expr.as_::<StatFn>();
90 let aggregate_fn = options.aggregate_fn();
91 let input = expr.child(0);
93
94 let stat_dtype = expr.return_dtype(scope)?;
95 binder.bind_aggregate(input, aggregate_fn, &stat_dtype)
96}
97
98fn null_expr(dtype: DType) -> Expression {
99 lit(Scalar::null(dtype.as_nullable()))
100}
101
102#[cfg(test)]
103mod tests {
104 use vortex_error::VortexResult;
105
106 use super::*;
107 use crate::dtype::Nullability;
108 use crate::dtype::PType;
109 use crate::dtype::StructFields;
110 use crate::expr::and;
111 use crate::expr::col;
112 use crate::expr::get_item;
113 use crate::expr::is_null;
114 use crate::expr::or;
115 use crate::expr::root;
116 use crate::expr::stats::Stat;
117 use crate::stats::all_non_nan;
118 use crate::stats::nan_count;
119
120 struct TestBinder {
121 input_scope: DType,
122 bind_nan_count: bool,
123 }
124
125 impl TestBinder {
126 fn new(bind_nan_count: bool) -> Self {
127 Self {
128 input_scope: DType::Struct(
129 StructFields::from_iter([(
130 "f",
131 DType::Primitive(PType::F32, Nullability::NonNullable),
132 )]),
133 Nullability::NonNullable,
134 ),
135 bind_nan_count,
136 }
137 }
138 }
139
140 impl StatBinder for TestBinder {
141 fn scope(&self) -> &DType {
142 &self.input_scope
143 }
144
145 fn bind_aggregate(
146 &self,
147 _input: &Expression,
148 aggregate_fn: &AggregateFnRef,
149 _stat_dtype: &DType,
150 ) -> VortexResult<Option<Expression>> {
151 let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else {
152 return Ok(None);
153 };
154
155 if stat == Stat::NaNCount && self.bind_nan_count {
156 Ok(Some(get_item("f_nan_count", root())))
157 } else {
158 Ok(None)
159 }
160 }
161 }
162
163 #[test]
164 fn nan_count_binds_to_direct_stat_slot() -> VortexResult<()> {
165 let binder = TestBinder::new(true);
166
167 let bound = bind_stats(nan_count(col("f")), &binder)?;
168
169 assert_eq!(bound, col("f_nan_count"));
170 Ok(())
171 }
172
173 #[test]
174 fn all_non_nan_does_not_derive_from_nan_count() -> VortexResult<()> {
175 let binder = TestBinder::new(true);
176
177 let bound = bind_stats(all_non_nan(col("f")), &binder)?;
178
179 assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable))));
180 Ok(())
181 }
182
183 #[test]
184 fn missing_stats_bind_to_null_without_reducing() -> VortexResult<()> {
185 let binder = TestBinder::new(false);
186 let null_bool = lit(Scalar::null(DType::Bool(Nullability::Nullable)));
187
188 let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &binder)?;
189
190 assert_eq!(bound, and(lit(false), null_bool.clone()));
191
192 let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &binder)?;
193
194 assert_eq!(bound, or(lit(true), null_bool));
195 Ok(())
196 }
197
198 #[test]
199 fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> {
200 let binder = TestBinder::new(false);
201
202 let bound = bind_stats(is_null(col("f")), &binder)?;
203
204 assert_eq!(bound, is_null(col("f")));
205 Ok(())
206 }
207}