Skip to main content

vortex_array/aggregate_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::Columnar;
16use crate::ExecutionCtx;
17use crate::aggregate_fn::AggregateFn;
18use crate::aggregate_fn::AggregateFnId;
19use crate::aggregate_fn::AggregateFnRef;
20use crate::aggregate_fn::AggregateFnSatisfaction;
21use crate::dtype::DType;
22use crate::scalar::Scalar;
23
24/// Defines the interface for aggregate function vtables.
25///
26/// This trait is non-object-safe and allows the implementer to make use of associated types
27/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
28/// outputs of each function.
29///
30/// The [`AggregateFnVTable`] trait should be implemented for a struct that holds global data across
31/// all instances of the aggregate. In almost all cases, this struct will be an empty unit
32/// struct, since most aggregates do not require any global state.
33pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync {
34    /// Options for this aggregate function.
35    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
36
37    /// The partial accumulator state for a single group.
38    type Partial: 'static + Send;
39
40    /// Returns the ID of the aggregate function vtable.
41    fn id(&self) -> AggregateFnId;
42
43    /// Serialize the options for this aggregate function.
44    ///
45    /// Should return `Ok(None)` if the function is not serializable, and `Ok(vec![])` if it is
46    /// serializable but has no metadata.
47    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
48        _ = options;
49        Ok(None)
50    }
51
52    /// Deserialize the options of this aggregate function.
53    fn deserialize(
54        &self,
55        _metadata: &[u8],
56        _session: &VortexSession,
57    ) -> VortexResult<Self::Options> {
58        vortex_bail!("Aggregate function {} is not deserializable", self.id());
59    }
60
61    /// Coerce the input type for this aggregate function.
62    ///
63    /// This is optionally used by Vortex users when performing type coercion over a Vortex
64    /// expression. The default implementation returns the input type unchanged.
65    fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
66        let _ = options;
67        Ok(input_dtype.clone())
68    }
69
70    /// Return whether this stored aggregate can satisfy `requested`.
71    ///
72    /// The default implementation only treats exactly equal aggregate functions as satisfying the
73    /// request. Approximate pruning aggregates can override this to expose looser-but-sound bounds.
74    fn can_satisfy(
75        &self,
76        options: &Self::Options,
77        requested: &AggregateFnRef,
78    ) -> AggregateFnSatisfaction {
79        if requested
80            .as_opt::<Self>()
81            .is_some_and(|other| other == options)
82        {
83            AggregateFnSatisfaction::Exact
84        } else {
85            AggregateFnSatisfaction::No
86        }
87    }
88
89    /// The return [`DType`] of the aggregate.
90    ///
91    /// Returns `None` if the aggregate function cannot be applied to the input dtype.
92    fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
93
94    /// DType of the intermediate partial accumulator state.
95    ///
96    /// Use a struct dtype when multiple fields are needed
97    /// (e.g., Mean: `Struct { sum: f64, count: u64 }`).
98    ///
99    /// Returns `None` if the aggregate function cannot be applied to the input dtype.
100    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
101
102    /// Return the partial accumulator state for an empty group.
103    fn empty_partial(
104        &self,
105        options: &Self::Options,
106        input_dtype: &DType,
107    ) -> VortexResult<Self::Partial>;
108
109    /// Combine partial scalar state into the accumulator.
110    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()>;
111
112    /// Convert the partial state into a partial scalar.
113    ///
114    /// The returned scalar must have the same DType as specified by `partial_dtype` for the
115    /// options and input dtype used to construct the state.
116    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar>;
117
118    /// Reset the state of the accumulator to an empty group.
119    fn reset(&self, partial: &mut Self::Partial);
120
121    /// Is the partial accumulator state is "saturated", i.e. has it reached a state where the
122    /// final result is fully determined.
123    fn is_saturated(&self, state: &Self::Partial) -> bool;
124
125    /// Try to accumulate the raw array before decompression.
126    ///
127    /// Returns `true` if the array was handled, `false` to fall through to
128    /// the default kernel dispatch and canonicalization path.
129    ///
130    /// This is useful for aggregates that only depend on array metadata (e.g., validity)
131    /// rather than the encoded data, avoiding unnecessary decompression.
132    fn try_accumulate(
133        &self,
134        _state: &mut Self::Partial,
135        _batch: &ArrayRef,
136        _ctx: &mut ExecutionCtx,
137    ) -> VortexResult<bool> {
138        Ok(false)
139    }
140
141    /// Accumulate a new canonical array into the accumulator state.
142    fn accumulate(
143        &self,
144        state: &mut Self::Partial,
145        batch: &Columnar,
146        ctx: &mut ExecutionCtx,
147    ) -> VortexResult<()>;
148
149    /// Finalize an array of accumulator states into an array of aggregate results.
150    ///
151    /// The provides `states` array has dtype as specified by `state_dtype`, the result array
152    /// must have dtype as specified by `return_dtype`.
153    fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef>;
154
155    /// Finalize a scalar accumulator state into an aggregate result.
156    ///
157    /// The provided `state` has dtype as specified by `state_dtype`, the result scalar must have
158    /// dtype as specified by `return_dtype`.
159    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar>;
160}
161
162#[derive(Clone, Debug, PartialEq, Eq, Hash)]
163pub struct EmptyOptions;
164impl Display for EmptyOptions {
165    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
166        write!(f, "")
167    }
168}
169
170/// Factory functions for aggregate vtables.
171pub trait AggregateFnVTableExt: AggregateFnVTable {
172    /// Bind this vtable with the given options into an [`AggregateFnRef`].
173    fn bind(&self, options: Self::Options) -> AggregateFnRef {
174        AggregateFn::new(self.clone(), options).erased()
175    }
176}
177impl<V: AggregateFnVTable> AggregateFnVTableExt for V {}