vortex_array/scalar_fn/fns/
stat.rs1use std::fmt::Display;
7use std::fmt::Formatter;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::aggregate_fn::AggregateFnRef;
16use crate::aggregate_fn::fns::all_nan::AllNan;
17use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
18use crate::aggregate_fn::fns::all_non_null::AllNonNull;
19use crate::aggregate_fn::fns::all_null::AllNull;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::stats::Precision;
24use crate::expr::stats::Stat;
25use crate::expr::stats::StatsProvider;
26use crate::expr::stats::StatsProviderExt;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34
35#[derive(Clone, Debug, PartialEq, Eq, Hash)]
37pub struct StatOptions {
38 aggregate_fn: AggregateFnRef,
39}
40
41impl StatOptions {
42 pub fn new(aggregate_fn: AggregateFnRef) -> Self {
44 Self { aggregate_fn }
45 }
46
47 pub fn aggregate_fn(&self) -> &AggregateFnRef {
49 &self.aggregate_fn
50 }
51}
52
53impl Display for StatOptions {
54 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55 Display::fmt(&self.aggregate_fn, f)
56 }
57}
58
59#[derive(Clone)]
79pub struct StatFn;
80
81impl ScalarFnVTable for StatFn {
82 type Options = StatOptions;
83
84 fn id(&self) -> ScalarFnId {
85 ScalarFnId::new("vortex.stat")
86 }
87
88 fn arity(&self, _options: &Self::Options) -> Arity {
89 Arity::Exact(1)
90 }
91
92 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
93 match child_idx {
94 0 => ChildName::from("input"),
95 _ => unreachable!("Invalid child index {} for Stat expression", child_idx),
96 }
97 }
98
99 fn fmt_sql(
100 &self,
101 options: &Self::Options,
102 expr: &Expression,
103 f: &mut Formatter<'_>,
104 ) -> std::fmt::Result {
105 write!(f, "stat(")?;
106 expr.child(0).fmt_sql(f)?;
107 write!(f, ", {})", options.aggregate_fn())
108 }
109
110 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
111 stat_dtype(options.aggregate_fn(), &arg_dtypes[0])
112 }
113
114 fn execute(
115 &self,
116 options: &Self::Options,
117 args: &dyn ExecutionArgs,
118 _ctx: &mut ExecutionCtx,
119 ) -> VortexResult<ArrayRef> {
120 let input = args.get(0)?;
121 let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?;
122 stat_array(&input, options.aggregate_fn(), dtype, args.row_count())
123 }
124}
125
126fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult<DType> {
127 let Some(dtype) = aggregate_fn.state_dtype(input_dtype) else {
128 vortex_bail!(
129 "Aggregate function {} does not support input dtype {}",
130 aggregate_fn,
131 input_dtype
132 );
133 };
134 Ok(dtype.as_nullable())
135}
136
137fn stat_array(
138 array: &ArrayRef,
139 aggregate_fn: &AggregateFnRef,
140 dtype: DType,
141 len: usize,
142) -> VortexResult<ArrayRef> {
143 let value = if aggregate_fn.is::<AllNull>() {
144 let len = u64::try_from(len)?;
145 match array.statistics().get_as::<u64>(Stat::NullCount) {
146 Precision::Exact(count) => Some(count == len),
147 Precision::Inexact(count) => (count < len).then_some(false),
148 Precision::Absent => None,
149 }
150 .map(ScalarValue::Bool)
151 } else if aggregate_fn.is::<AllNonNull>() {
152 match array.statistics().get_as::<u64>(Stat::NullCount) {
153 Precision::Exact(count) => Some(count == 0),
154 Precision::Inexact(0) => Some(true),
155 Precision::Inexact(_) | Precision::Absent => None,
156 }
157 .map(ScalarValue::Bool)
158 } else if aggregate_fn.is::<AllNan>() {
159 let len = u64::try_from(len)?;
160 match array.statistics().get_as::<u64>(Stat::NaNCount) {
161 Precision::Exact(count) => Some(count == len),
162 Precision::Inexact(count) => (count < len).then_some(false),
163 Precision::Absent => None,
164 }
165 .map(ScalarValue::Bool)
166 } else if aggregate_fn.is::<AllNonNan>() {
167 match array.statistics().get_as::<u64>(Stat::NaNCount) {
168 Precision::Exact(count) => Some(count == 0),
169 Precision::Inexact(0) => Some(true),
170 Precision::Inexact(_) | Precision::Absent => None,
171 }
172 .map(ScalarValue::Bool)
173 } else if let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) {
174 array
175 .statistics()
176 .with_typed_stats_set(|stats| stats.get(stat))
177 .into_inner()
179 .and_then(Scalar::into_value)
180 } else {
181 tracing::trace!(
182 "No legacy Stat slot for aggregate {}; stat expression will resolve to null",
183 aggregate_fn
184 );
185 None
186 };
187
188 let scalar = Scalar::try_new(dtype, value)?;
189 Ok(ConstantArray::new(scalar, len).into_array())
190}