Skip to main content

vortex_array/aggregate_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Columnar;
18use crate::ExecutionCtx;
19use crate::aggregate_fn::AggregateFn;
20use crate::aggregate_fn::AggregateFnId;
21use crate::aggregate_fn::AggregateFnRef;
22use crate::aggregate_fn::AggregateFnSatisfaction;
23use crate::dtype::DType;
24use crate::scalar::Scalar;
25
26/// Defines the interface for aggregate function vtables.
27///
28/// This trait is non-object-safe and allows the implementer to make use of associated types
29/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
30/// outputs of each function.
31///
32/// The [`AggregateFnVTable`] trait should be implemented for a struct that holds global data across
33/// all instances of the aggregate. In almost all cases, this struct will be an empty unit
34/// struct, since most aggregates do not require any global state.
35pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync {
36    /// Options for this aggregate function.
37    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
38
39    /// The partial accumulator state for a single group.
40    type Partial: 'static + Send;
41
42    /// Returns the ID of the aggregate function vtable.
43    fn id(&self) -> AggregateFnId;
44
45    /// Serialize the options for this aggregate function.
46    ///
47    /// Should return `Ok(None)` if the function is not serializable, and `Ok(vec![])` if it is
48    /// serializable but has no metadata.
49    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50        _ = options;
51        Ok(None)
52    }
53
54    /// Deserialize the options of this aggregate function.
55    fn deserialize(
56        &self,
57        _metadata: &[u8],
58        _session: &VortexSession,
59    ) -> VortexResult<Self::Options> {
60        vortex_bail!("Aggregate function {} is not deserializable", self.id());
61    }
62
63    /// Coerce the input type for this aggregate function.
64    ///
65    /// This is optionally used by Vortex users when performing type coercion over a Vortex
66    /// expression. The default implementation returns the input type unchanged.
67    fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
68        let _ = options;
69        Ok(input_dtype.clone())
70    }
71
72    /// Return whether this stored aggregate can satisfy `requested`.
73    ///
74    /// The default implementation only treats exactly equal aggregate functions as satisfying the
75    /// request. Approximate pruning aggregates can override this to expose looser-but-sound bounds.
76    fn can_satisfy(
77        &self,
78        options: &Self::Options,
79        requested: &AggregateFnRef,
80    ) -> AggregateFnSatisfaction {
81        if requested
82            .as_opt::<Self>()
83            .is_some_and(|other| other == options)
84        {
85            AggregateFnSatisfaction::Exact
86        } else {
87            AggregateFnSatisfaction::No
88        }
89    }
90
91    /// The return [`DType`] of the aggregate.
92    ///
93    /// Returns `None` if the aggregate function cannot be applied to the input dtype.
94    fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
95
96    /// DType of the intermediate partial accumulator state.
97    ///
98    /// Use a struct dtype when multiple fields are needed
99    /// (e.g., Mean: `Struct { sum: f64, count: u64 }`).
100    ///
101    /// Returns `None` if the aggregate function cannot be applied to the input dtype.
102    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
103
104    /// Return the partial accumulator state for an empty group.
105    fn empty_partial(
106        &self,
107        options: &Self::Options,
108        input_dtype: &DType,
109    ) -> VortexResult<Self::Partial>;
110
111    /// Combine partial scalar state into the accumulator.
112    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()>;
113
114    /// Convert the partial state into a partial scalar.
115    ///
116    /// The returned scalar must have the same DType as specified by `partial_dtype` for the
117    /// options and input dtype used to construct the state.
118    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar>;
119
120    /// Reset the state of the accumulator to an empty group.
121    fn reset(&self, partial: &mut Self::Partial);
122
123    /// Is the partial accumulator state is "saturated", i.e. has it reached a state where the
124    /// final result is fully determined.
125    fn is_saturated(&self, state: &Self::Partial) -> bool;
126
127    /// Try to accumulate the raw array before decompression.
128    ///
129    /// Returns `true` if the array was handled, `false` to fall through to
130    /// the default kernel dispatch and canonicalization path.
131    ///
132    /// This is useful for aggregates that only depend on array metadata (e.g., validity)
133    /// rather than the encoded data, avoiding unnecessary decompression.
134    fn try_accumulate(
135        &self,
136        _state: &mut Self::Partial,
137        _batch: &ArrayRef,
138        _ctx: &mut ExecutionCtx,
139    ) -> VortexResult<bool> {
140        Ok(false)
141    }
142
143    /// Accumulate a new canonical array into the accumulator state.
144    fn accumulate(
145        &self,
146        state: &mut Self::Partial,
147        batch: &Columnar,
148        ctx: &mut ExecutionCtx,
149    ) -> VortexResult<()>;
150
151    /// Finalize an array of accumulator states into an array of aggregate results.
152    ///
153    /// The provides `states` array has dtype as specified by `state_dtype`, the result array
154    /// must have dtype as specified by `return_dtype`.
155    fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef>;
156
157    /// Finalize a scalar accumulator state into an aggregate result.
158    ///
159    /// The provided `state` has dtype as specified by `state_dtype`, the result scalar must have
160    /// dtype as specified by `return_dtype`.
161    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar>;
162}
163
164#[derive(Clone, Debug, PartialEq, Eq, Hash)]
165pub struct EmptyOptions;
166impl Display for EmptyOptions {
167    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
168        write!(f, "")
169    }
170}
171
172/// Options for aggregate functions over primitive numeric inputs, controlling how NaN values in
173/// floating-point arrays are handled.
174///
175/// When `skip_nans` is `true` (the default), NaN values are treated as missing: they contribute
176/// nothing to `sum`/`min`/`max`/`mean` and are excluded from `count`.
177///
178/// When `skip_nans` is `false`, NaN values participate in the aggregate: `count` includes them,
179/// while any NaN value poisons the result of `sum`/`min`/`max`/`mean` to NaN.
180///
181/// The option has no effect on non-float inputs.
182#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
183pub struct NumericalAggregateOpts {
184    /// Whether NaN values are skipped (treated as missing) during aggregation.
185    pub skip_nans: bool,
186}
187
188impl NumericalAggregateOpts {
189    /// Options that skip NaN values, treating them as missing during aggregation.
190    ///
191    /// This is the default configuration; see [`NumericalAggregateOpts::include_nans`] for the
192    /// NaN-including variant.
193    pub const fn skip_nans() -> Self {
194        Self { skip_nans: true }
195    }
196
197    /// Options that include NaN values in the aggregate: `count` counts them, while any NaN
198    /// poisons the result of `sum`/`min`/`max`/`mean` to NaN.
199    ///
200    /// See [`NumericalAggregateOpts::skip_nans`] for the default NaN-skipping variant.
201    pub const fn include_nans() -> Self {
202        Self { skip_nans: false }
203    }
204
205    /// Serialize these options to protobuf-encoded metadata bytes.
206    pub fn serialize(&self) -> Vec<u8> {
207        pb::NumericalAggregateOpts {
208            skip_nans: self.skip_nans,
209        }
210        .encode_to_vec()
211    }
212
213    /// Deserialize these options from protobuf-encoded metadata bytes.
214    pub fn deserialize(metadata: &[u8]) -> VortexResult<Self> {
215        let opts = pb::NumericalAggregateOpts::decode(metadata)?;
216        Ok(Self {
217            skip_nans: opts.skip_nans,
218        })
219    }
220}
221
222impl Default for NumericalAggregateOpts {
223    fn default() -> Self {
224        Self::skip_nans()
225    }
226}
227
228impl Display for NumericalAggregateOpts {
229    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
230        // Only the non-default configuration is displayed, so that aggregates with default
231        // options render identically to their pre-options form, e.g. `vortex.sum()`.
232        if !self.skip_nans {
233            write!(f, "skip_nans=false")?;
234        }
235        Ok(())
236    }
237}
238
239/// Factory functions for aggregate vtables.
240pub trait AggregateFnVTableExt: AggregateFnVTable {
241    /// Bind this vtable with the given options into an [`AggregateFnRef`].
242    fn bind(&self, options: Self::Options) -> AggregateFnRef {
243        AggregateFn::new(self.clone(), options).erased()
244    }
245}
246impl<V: AggregateFnVTable> AggregateFnVTableExt for V {}