vortex_array/scalar_fn/fns/
stat.rs1use std::fmt::Display;
7use std::fmt::Formatter;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::aggregate_fn::AggregateFnRef;
16use crate::arrays::ConstantArray;
17use crate::dtype::DType;
18use crate::expr::Expression;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22use crate::scalar_fn::Arity;
23use crate::scalar_fn::ChildName;
24use crate::scalar_fn::ExecutionArgs;
25use crate::scalar_fn::ScalarFnId;
26use crate::scalar_fn::ScalarFnVTable;
27
28#[derive(Clone, Debug, PartialEq, Eq, Hash)]
30pub struct StatOptions {
31 aggregate_fn: AggregateFnRef,
32}
33
34impl StatOptions {
35 pub fn new(aggregate_fn: AggregateFnRef) -> Self {
37 Self { aggregate_fn }
38 }
39
40 pub fn aggregate_fn(&self) -> &AggregateFnRef {
42 &self.aggregate_fn
43 }
44}
45
46impl Display for StatOptions {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 Display::fmt(&self.aggregate_fn, f)
49 }
50}
51
52#[derive(Clone)]
72pub struct StatFn;
73
74impl ScalarFnVTable for StatFn {
75 type Options = StatOptions;
76
77 fn id(&self) -> ScalarFnId {
78 ScalarFnId::new("vortex.stat")
79 }
80
81 fn arity(&self, _options: &Self::Options) -> Arity {
82 Arity::Exact(1)
83 }
84
85 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
86 match child_idx {
87 0 => ChildName::from("input"),
88 _ => unreachable!("Invalid child index {} for Stat expression", child_idx),
89 }
90 }
91
92 fn fmt_sql(
93 &self,
94 options: &Self::Options,
95 expr: &Expression,
96 f: &mut Formatter<'_>,
97 ) -> std::fmt::Result {
98 write!(f, "stat(")?;
99 expr.child(0).fmt_sql(f)?;
100 write!(f, ", {})", options.aggregate_fn())
101 }
102
103 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
104 stat_dtype(options.aggregate_fn(), &arg_dtypes[0])
105 }
106
107 fn execute(
108 &self,
109 options: &Self::Options,
110 args: &dyn ExecutionArgs,
111 _ctx: &mut ExecutionCtx,
112 ) -> VortexResult<ArrayRef> {
113 let input = args.get(0)?;
114 let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?;
115 stat_array(&input, options.aggregate_fn(), dtype, args.row_count())
116 }
117}
118
119fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult<DType> {
120 let Some(dtype) = aggregate_fn.return_dtype(input_dtype) else {
121 vortex_bail!(
122 "Aggregate function {} does not support input dtype {}",
123 aggregate_fn,
124 input_dtype
125 );
126 };
127 Ok(dtype.as_nullable())
128}
129
130fn stat_array(
131 array: &ArrayRef,
132 aggregate_fn: &AggregateFnRef,
133 dtype: DType,
134 len: usize,
135) -> VortexResult<ArrayRef> {
136 let value = if let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) {
137 array
138 .statistics()
139 .with_typed_stats_set(|stats| stats.get(stat))
140 .map(|stat| stat.into_inner())
142 .and_then(Scalar::into_value)
143 } else {
144 tracing::trace!(
145 "No legacy Stat slot for aggregate {}; stat expression will resolve to null",
146 aggregate_fn
147 );
148 None
149 };
150
151 let scalar = Scalar::try_new(dtype, value)?;
152 Ok(ConstantArray::new(scalar, len).into_array())
153}