Skip to main content

vortex_array/aggregate_fn/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::Columnar;
16use crate::DynArray;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::aggregate_fn::AggregateFn;
20use crate::aggregate_fn::AggregateFnId;
21use crate::aggregate_fn::AggregateFnRef;
22use crate::arrays::ConstantArray;
23use crate::dtype::DType;
24use crate::scalar::Scalar;
25
26/// Defines the interface for aggregate function vtables.
27///
28/// This trait is non-object-safe and allows the implementer to make use of associated types
29/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
30/// outputs of each function.
31///
32/// The [`AggregateFnVTable`] trait should be implemented for a struct that holds global data across
33/// all instances of the aggregate. In almost all cases, this struct will be an empty unit
34/// struct, since most aggregates do not require any global state.
35pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync {
36    /// Options for this aggregate function.
37    type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
38
39    /// The partial accumulator state for a single group.
40    type Partial: 'static + Send;
41
42    /// Returns the ID of the aggregate function vtable.
43    fn id(&self) -> AggregateFnId;
44
45    /// Serialize the options for this aggregate function.
46    ///
47    /// Should return `Ok(None)` if the function is not serializable, and `Ok(vec![])` if it is
48    /// serializable but has no metadata.
49    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50        _ = options;
51        Ok(None)
52    }
53
54    /// Deserialize the options of this aggregate function.
55    fn deserialize(
56        &self,
57        _metadata: &[u8],
58        _session: &VortexSession,
59    ) -> VortexResult<Self::Options> {
60        vortex_bail!("Aggregate function {} is not deserializable", self.id());
61    }
62
63    /// Coerce the input type for this aggregate function.
64    ///
65    /// This is optionally used by Vortex users when performing type coercion over a Vortex
66    /// expression. The default implementation returns the input type unchanged.
67    fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
68        let _ = options;
69        Ok(input_dtype.clone())
70    }
71
72    /// The return [`DType`] of the aggregate.
73    fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType>;
74
75    /// DType of the intermediate partial accumulator state.
76    ///
77    /// Use a struct dtype when multiple fields are needed
78    /// (e.g., Mean: `Struct { sum: f64, count: u64 }`).
79    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType>;
80
81    /// Return the partial accumulator state for an empty group.
82    fn empty_partial(
83        &self,
84        options: &Self::Options,
85        input_dtype: &DType,
86    ) -> VortexResult<Self::Partial>;
87
88    /// Combine partial scalar state into the accumulator.
89    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()>;
90
91    /// Flush the partial aggregate for the given accumulator state.
92    ///
93    /// The returned scalar must have the same DType as specified by `state_dtype` for the
94    /// options and input dtype used to construct the state.
95    ///
96    /// The internal state of the accumulator is reset to the empty state after flushing.
97    fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar>;
98
99    /// Is the partial accumulator state is "saturated", i.e. has it reached a state where the
100    /// final result is fully determined.
101    fn is_saturated(&self, state: &Self::Partial) -> bool;
102
103    /// Accumulate a new canonical array into the accumulator state.
104    fn accumulate(
105        &self,
106        state: &mut Self::Partial,
107        batch: &Columnar,
108        ctx: &mut ExecutionCtx,
109    ) -> VortexResult<()>;
110
111    /// Finalize an array of accumulator states into an array of aggregate results.
112    ///
113    /// The provides `states` array has dtype as specified by `state_dtype`, the result array
114    /// must have dtype as specified by `return_dtype`.
115    fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef>;
116
117    /// Finalize a scalar accumulator state into an aggregate result.
118    ///
119    /// The provided `state` has dtype as specified by `state_dtype`, the result scalar must have
120    /// dtype as specified by `return_dtype`.
121    fn finalize_scalar(&self, state: Scalar) -> VortexResult<Scalar> {
122        let array = ConstantArray::new(state, 1).into_array();
123        let result = self.finalize(array)?;
124        result.scalar_at(0)
125    }
126}
127
128#[derive(Clone, Debug, PartialEq, Eq, Hash)]
129pub struct EmptyOptions;
130impl Display for EmptyOptions {
131    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
132        write!(f, "")
133    }
134}
135
136/// Factory functions for aggregate vtables.
137pub trait AggregateFnVTableExt: AggregateFnVTable {
138    /// Bind this vtable with the given options into an [`AggregateFnRef`].
139    fn bind(&self, options: Self::Options) -> AggregateFnRef {
140        AggregateFn::new(self.clone(), options).erased()
141    }
142}
143impl<V: AggregateFnVTable> AggregateFnVTableExt for V {}