vortex_array/scalar_fn/fns/
stat.rs1use std::fmt::Display;
7use std::fmt::Formatter;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::aggregate_fn::AggregateFnRef;
17use crate::aggregate_fn::fns::all_nan::AllNan;
18use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
19use crate::aggregate_fn::fns::all_non_null::AllNonNull;
20use crate::aggregate_fn::fns::all_null::AllNull;
21use crate::arrays::ConstantArray;
22use crate::dtype::DType;
23use crate::expr::Expression;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::expr::stats::StatsProviderExt;
28use crate::scalar::Scalar;
29use crate::scalar::ScalarValue;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35
36#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct StatOptions {
39 aggregate_fn: AggregateFnRef,
40}
41
42impl StatOptions {
43 pub fn new(aggregate_fn: AggregateFnRef) -> Self {
45 Self { aggregate_fn }
46 }
47
48 pub fn aggregate_fn(&self) -> &AggregateFnRef {
50 &self.aggregate_fn
51 }
52}
53
54impl Display for StatOptions {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 Display::fmt(&self.aggregate_fn, f)
57 }
58}
59
60#[derive(Clone)]
80pub struct StatFn;
81
82impl ScalarFnVTable for StatFn {
83 type Options = StatOptions;
84
85 fn id(&self) -> ScalarFnId {
86 static ID: CachedId = CachedId::new("vortex.stat");
87 *ID
88 }
89
90 fn arity(&self, _options: &Self::Options) -> Arity {
91 Arity::Exact(1)
92 }
93
94 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
95 match child_idx {
96 0 => ChildName::from("input"),
97 _ => unreachable!("Invalid child index {} for Stat expression", child_idx),
98 }
99 }
100
101 fn fmt_sql(
102 &self,
103 options: &Self::Options,
104 expr: &Expression,
105 f: &mut Formatter<'_>,
106 ) -> std::fmt::Result {
107 write!(f, "stat(")?;
108 expr.child(0).fmt_sql(f)?;
109 write!(f, ", {})", options.aggregate_fn())
110 }
111
112 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
113 stat_dtype(options.aggregate_fn(), &arg_dtypes[0])
114 }
115
116 fn execute(
117 &self,
118 options: &Self::Options,
119 args: &dyn ExecutionArgs,
120 _ctx: &mut ExecutionCtx,
121 ) -> VortexResult<ArrayRef> {
122 let input = args.get(0)?;
123 let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?;
124 stat_array(&input, options.aggregate_fn(), dtype, args.row_count())
125 }
126}
127
128fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult<DType> {
129 let Some(dtype) = aggregate_fn.state_dtype(input_dtype) else {
130 vortex_bail!(
131 "Aggregate function {} does not support input dtype {}",
132 aggregate_fn,
133 input_dtype
134 );
135 };
136 Ok(dtype.as_nullable())
137}
138
139fn stat_array(
140 array: &ArrayRef,
141 aggregate_fn: &AggregateFnRef,
142 dtype: DType,
143 len: usize,
144) -> VortexResult<ArrayRef> {
145 let value = if aggregate_fn.is::<AllNull>() {
146 let len = u64::try_from(len)?;
147 match array.statistics().get_as::<u64>(Stat::NullCount) {
148 Precision::Exact(count) => Some(count == len),
149 Precision::Inexact(count) => (count < len).then_some(false),
150 Precision::Absent => None,
151 }
152 .map(ScalarValue::Bool)
153 } else if aggregate_fn.is::<AllNonNull>() {
154 match array.statistics().get_as::<u64>(Stat::NullCount) {
155 Precision::Exact(count) => Some(count == 0),
156 Precision::Inexact(0) => Some(true),
157 Precision::Inexact(_) | Precision::Absent => None,
158 }
159 .map(ScalarValue::Bool)
160 } else if aggregate_fn.is::<AllNan>() {
161 let len = u64::try_from(len)?;
162 match array.statistics().get_as::<u64>(Stat::NaNCount) {
163 Precision::Exact(count) => Some(count == len),
164 Precision::Inexact(count) => (count < len).then_some(false),
165 Precision::Absent => None,
166 }
167 .map(ScalarValue::Bool)
168 } else if aggregate_fn.is::<AllNonNan>() {
169 match array.statistics().get_as::<u64>(Stat::NaNCount) {
170 Precision::Exact(count) => Some(count == 0),
171 Precision::Inexact(0) => Some(true),
172 Precision::Inexact(_) | Precision::Absent => None,
173 }
174 .map(ScalarValue::Bool)
175 } else if let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) {
176 array
177 .statistics()
178 .with_typed_stats_set(|stats| stats.get(stat))
179 .into_inner()
181 .and_then(Scalar::into_value)
182 } else {
183 tracing::trace!(
184 "No legacy Stat slot for aggregate {}; stat expression will resolve to null",
185 aggregate_fn
186 );
187 None
188 };
189
190 let scalar = Scalar::try_new(dtype, value)?;
191 Ok(ConstantArray::new(scalar, len).into_array())
192}