vortex_array/aggregate_fn/vtable.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Columnar;
18use crate::ExecutionCtx;
19use crate::aggregate_fn::AggregateFn;
20use crate::aggregate_fn::AggregateFnId;
21use crate::aggregate_fn::AggregateFnRef;
22use crate::aggregate_fn::AggregateFnSatisfaction;
23use crate::dtype::DType;
24use crate::scalar::Scalar;
25
26/// Defines the interface for aggregate function vtables.
27///
28/// This trait is non-object-safe and allows the implementer to make use of associated types
29/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
30/// outputs of each function.
31///
32/// The [`AggregateFnVTable`] trait should be implemented for a struct that holds global data across
33/// all instances of the aggregate. In almost all cases, this struct will be an empty unit
34/// struct, since most aggregates do not require any global state.
35pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync {
36 /// Options for this aggregate function.
37 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
38
39 /// The partial accumulator state for a single group.
40 type Partial: 'static + Send;
41
42 /// Returns the ID of the aggregate function vtable.
43 fn id(&self) -> AggregateFnId;
44
45 /// Serialize the options for this aggregate function.
46 ///
47 /// Should return `Ok(None)` if the function is not serializable, and `Ok(vec![])` if it is
48 /// serializable but has no metadata.
49 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50 _ = options;
51 Ok(None)
52 }
53
54 /// Deserialize the options of this aggregate function.
55 fn deserialize(
56 &self,
57 _metadata: &[u8],
58 _session: &VortexSession,
59 ) -> VortexResult<Self::Options> {
60 vortex_bail!("Aggregate function {} is not deserializable", self.id());
61 }
62
63 /// Coerce the input type for this aggregate function.
64 ///
65 /// This is optionally used by Vortex users when performing type coercion over a Vortex
66 /// expression. The default implementation returns the input type unchanged.
67 fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
68 let _ = options;
69 Ok(input_dtype.clone())
70 }
71
72 /// Return whether this stored aggregate can satisfy `requested`.
73 ///
74 /// The default implementation only treats exactly equal aggregate functions as satisfying the
75 /// request. Approximate pruning aggregates can override this to expose looser-but-sound bounds.
76 fn can_satisfy(
77 &self,
78 options: &Self::Options,
79 requested: &AggregateFnRef,
80 ) -> AggregateFnSatisfaction {
81 if requested
82 .as_opt::<Self>()
83 .is_some_and(|other| other == options)
84 {
85 AggregateFnSatisfaction::Exact
86 } else {
87 AggregateFnSatisfaction::No
88 }
89 }
90
91 /// The return [`DType`] of the aggregate.
92 ///
93 /// Returns `None` if the aggregate function cannot be applied to the input dtype.
94 fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
95
96 /// DType of the intermediate partial accumulator state.
97 ///
98 /// Use a struct dtype when multiple fields are needed
99 /// (e.g., Mean: `Struct { sum: f64, count: u64 }`).
100 ///
101 /// Returns `None` if the aggregate function cannot be applied to the input dtype.
102 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
103
104 /// Return the partial accumulator state for an empty group.
105 fn empty_partial(
106 &self,
107 options: &Self::Options,
108 input_dtype: &DType,
109 ) -> VortexResult<Self::Partial>;
110
111 /// Combine partial scalar state into the accumulator.
112 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()>;
113
114 /// Convert the partial state into a partial scalar.
115 ///
116 /// The returned scalar must have the same DType as specified by `partial_dtype` for the
117 /// options and input dtype used to construct the state.
118 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar>;
119
120 /// Reset the state of the accumulator to an empty group.
121 fn reset(&self, partial: &mut Self::Partial);
122
123 /// Is the partial accumulator state is "saturated", i.e. has it reached a state where the
124 /// final result is fully determined.
125 fn is_saturated(&self, state: &Self::Partial) -> bool;
126
127 /// Try to accumulate the raw array before decompression.
128 ///
129 /// Returns `true` if the array was handled, `false` to fall through to
130 /// the default kernel dispatch and canonicalization path.
131 ///
132 /// This is useful for aggregates that only depend on array metadata (e.g., validity)
133 /// rather than the encoded data, avoiding unnecessary decompression.
134 fn try_accumulate(
135 &self,
136 _state: &mut Self::Partial,
137 _batch: &ArrayRef,
138 _ctx: &mut ExecutionCtx,
139 ) -> VortexResult<bool> {
140 Ok(false)
141 }
142
143 /// Accumulate a new canonical array into the accumulator state.
144 fn accumulate(
145 &self,
146 state: &mut Self::Partial,
147 batch: &Columnar,
148 ctx: &mut ExecutionCtx,
149 ) -> VortexResult<()>;
150
151 /// Finalize an array of accumulator states into an array of aggregate results.
152 ///
153 /// The provides `states` array has dtype as specified by `state_dtype`, the result array
154 /// must have dtype as specified by `return_dtype`.
155 fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef>;
156
157 /// Finalize a scalar accumulator state into an aggregate result.
158 ///
159 /// The provided `state` has dtype as specified by `state_dtype`, the result scalar must have
160 /// dtype as specified by `return_dtype`.
161 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar>;
162}
163
164#[derive(Clone, Debug, PartialEq, Eq, Hash)]
165pub struct EmptyOptions;
166impl Display for EmptyOptions {
167 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
168 write!(f, "")
169 }
170}
171
172/// Options for aggregate functions over primitive numeric inputs, controlling how NaN values in
173/// floating-point arrays are handled.
174///
175/// When `skip_nans` is `true` (the default), NaN values are treated as missing: they contribute
176/// nothing to `sum`/`min`/`max`/`mean` and are excluded from `count`.
177///
178/// When `skip_nans` is `false`, NaN values participate in the aggregate: `count` includes them,
179/// while any NaN value poisons the result of `sum`/`min`/`max`/`mean` to NaN.
180///
181/// The option has no effect on non-float inputs.
182#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
183pub struct NumericalAggregateOpts {
184 /// Whether NaN values are skipped (treated as missing) during aggregation.
185 pub skip_nans: bool,
186}
187
188impl NumericalAggregateOpts {
189 /// Options that skip NaN values, treating them as missing during aggregation.
190 ///
191 /// This is the default configuration; see [`NumericalAggregateOpts::include_nans`] for the
192 /// NaN-including variant.
193 pub const fn skip_nans() -> Self {
194 Self { skip_nans: true }
195 }
196
197 /// Options that include NaN values in the aggregate: `count` counts them, while any NaN
198 /// poisons the result of `sum`/`min`/`max`/`mean` to NaN.
199 ///
200 /// See [`NumericalAggregateOpts::skip_nans`] for the default NaN-skipping variant.
201 pub const fn include_nans() -> Self {
202 Self { skip_nans: false }
203 }
204
205 /// Serialize these options to protobuf-encoded metadata bytes.
206 pub fn serialize(&self) -> Vec<u8> {
207 pb::NumericalAggregateOpts {
208 skip_nans: self.skip_nans,
209 }
210 .encode_to_vec()
211 }
212
213 /// Deserialize these options from protobuf-encoded metadata bytes.
214 pub fn deserialize(metadata: &[u8]) -> VortexResult<Self> {
215 let opts = pb::NumericalAggregateOpts::decode(metadata)?;
216 Ok(Self {
217 skip_nans: opts.skip_nans,
218 })
219 }
220}
221
222impl Default for NumericalAggregateOpts {
223 fn default() -> Self {
224 Self::skip_nans()
225 }
226}
227
228impl Display for NumericalAggregateOpts {
229 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
230 // Only the non-default configuration is displayed, so that aggregates with default
231 // options render identically to their pre-options form, e.g. `vortex.sum()`.
232 if !self.skip_nans {
233 write!(f, "skip_nans=false")?;
234 }
235 Ok(())
236 }
237}
238
239/// Factory functions for aggregate vtables.
240pub trait AggregateFnVTableExt: AggregateFnVTable {
241 /// Bind this vtable with the given options into an [`AggregateFnRef`].
242 fn bind(&self, options: Self::Options) -> AggregateFnRef {
243 AggregateFn::new(self.clone(), options).erased()
244 }
245}
246impl<V: AggregateFnVTable> AggregateFnVTableExt for V {}