vortex_array/aggregate_fn/vtable.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::Columnar;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::aggregate_fn::AggregateFn;
19use crate::aggregate_fn::AggregateFnId;
20use crate::aggregate_fn::AggregateFnRef;
21use crate::arrays::ConstantArray;
22use crate::dtype::DType;
23use crate::scalar::Scalar;
24
25/// Defines the interface for aggregate function vtables.
26///
27/// This trait is non-object-safe and allows the implementer to make use of associated types
28/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and
29/// outputs of each function.
30///
31/// The [`AggregateFnVTable`] trait should be implemented for a struct that holds global data across
32/// all instances of the aggregate. In almost all cases, this struct will be an empty unit
33/// struct, since most aggregates do not require any global state.
34pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync {
35 /// Options for this aggregate function.
36 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
37
38 /// The partial accumulator state for a single group.
39 type Partial: 'static + Send;
40
41 /// Returns the ID of the aggregate function vtable.
42 fn id(&self) -> AggregateFnId;
43
44 /// Serialize the options for this aggregate function.
45 ///
46 /// Should return `Ok(None)` if the function is not serializable, and `Ok(vec![])` if it is
47 /// serializable but has no metadata.
48 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
49 _ = options;
50 Ok(None)
51 }
52
53 /// Deserialize the options of this aggregate function.
54 fn deserialize(
55 &self,
56 _metadata: &[u8],
57 _session: &VortexSession,
58 ) -> VortexResult<Self::Options> {
59 vortex_bail!("Aggregate function {} is not deserializable", self.id());
60 }
61
62 /// Coerce the input type for this aggregate function.
63 ///
64 /// This is optionally used by Vortex users when performing type coercion over a Vortex
65 /// expression. The default implementation returns the input type unchanged.
66 fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
67 let _ = options;
68 Ok(input_dtype.clone())
69 }
70
71 /// The return [`DType`] of the aggregate.
72 ///
73 /// Returns `None` if the aggregate function cannot be applied to the input dtype.
74 fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
75
76 /// DType of the intermediate partial accumulator state.
77 ///
78 /// Use a struct dtype when multiple fields are needed
79 /// (e.g., Mean: `Struct { sum: f64, count: u64 }`).
80 ///
81 /// Returns `None` if the aggregate function cannot be applied to the input dtype.
82 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType>;
83
84 /// Return the partial accumulator state for an empty group.
85 fn empty_partial(
86 &self,
87 options: &Self::Options,
88 input_dtype: &DType,
89 ) -> VortexResult<Self::Partial>;
90
91 /// Combine partial scalar state into the accumulator.
92 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()>;
93
94 /// Convert the partial state into a partial scalar.
95 ///
96 /// The returned scalar must have the same DType as specified by `partial_dtype` for the
97 /// options and input dtype used to construct the state.
98 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar>;
99
100 /// Reset the state of the accumulator to an empty group.
101 fn reset(&self, partial: &mut Self::Partial);
102
103 /// Is the partial accumulator state is "saturated", i.e. has it reached a state where the
104 /// final result is fully determined.
105 fn is_saturated(&self, state: &Self::Partial) -> bool;
106
107 /// Try to accumulate the raw array before decompression.
108 ///
109 /// Returns `true` if the array was handled, `false` to fall through to
110 /// the default kernel dispatch and canonicalization path.
111 ///
112 /// This is useful for aggregates that only depend on array metadata (e.g., validity)
113 /// rather than the encoded data, avoiding unnecessary decompression.
114 fn try_accumulate(
115 &self,
116 _state: &mut Self::Partial,
117 _batch: &ArrayRef,
118 _ctx: &mut ExecutionCtx,
119 ) -> VortexResult<bool> {
120 Ok(false)
121 }
122
123 /// Accumulate a new canonical array into the accumulator state.
124 fn accumulate(
125 &self,
126 state: &mut Self::Partial,
127 batch: &Columnar,
128 ctx: &mut ExecutionCtx,
129 ) -> VortexResult<()>;
130
131 /// Finalize an array of accumulator states into an array of aggregate results.
132 ///
133 /// The provides `states` array has dtype as specified by `state_dtype`, the result array
134 /// must have dtype as specified by `return_dtype`.
135 fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef>;
136
137 /// Finalize a scalar accumulator state into an aggregate result.
138 ///
139 /// The provided `state` has dtype as specified by `state_dtype`, the result scalar must have
140 /// dtype as specified by `return_dtype`.
141 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
142 let scalar = self.to_scalar(partial)?;
143 let array = ConstantArray::new(scalar, 1).into_array();
144 let result = self.finalize(array)?;
145 result.scalar_at(0)
146 }
147}
148
149#[derive(Clone, Debug, PartialEq, Eq, Hash)]
150pub struct EmptyOptions;
151impl Display for EmptyOptions {
152 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
153 write!(f, "")
154 }
155}
156
157/// Factory functions for aggregate vtables.
158pub trait AggregateFnVTableExt: AggregateFnVTable {
159 /// Bind this vtable with the given options into an [`AggregateFnRef`].
160 fn bind(&self, options: Self::Options) -> AggregateFnRef {
161 AggregateFn::new(self.clone(), options).erased()
162 }
163}
164impl<V: AggregateFnVTable> AggregateFnVTableExt for V {}