Skip to main content

vortex_array/aggregate_fn/
combined.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Generic adapter for aggregates whose result is computed from two child
5//! aggregate functions, e.g. `Mean = Sum / Count`.
6
7use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::fmt::{self};
11use std::hash::Hash;
12
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::Columnar;
20use crate::ExecutionCtx;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AccumulatorRef;
23use crate::aggregate_fn::AggregateFnId;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::FieldName;
28use crate::dtype::FieldNames;
29use crate::dtype::Nullability;
30use crate::dtype::StructFields;
31use crate::scalar::Scalar;
32
33/// Pair of options for the two children of a [`BinaryCombined`] aggregate.
34///
35/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound
36/// requires `Display`, which tuples don't implement.
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct PairOptions<L, R>(pub L, pub R);
39
40impl<L: Display, R: Display> Display for PairOptions<L, R> {
41    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
42        write!(f, "({}, {})", self.0, self.1)
43    }
44}
45
46// Convenience aliases so signatures stay readable.
47type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
48type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
49/// Combined options for a [`BinaryCombined`] aggregate.
50pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
51
52/// Declare an aggregate function in terms of two child aggregates.
53pub trait BinaryCombined: 'static + Send + Sync + Clone {
54    /// The left child aggregate vtable.
55    type Left: AggregateFnVTable;
56    /// The right child aggregate vtable.
57    type Right: AggregateFnVTable;
58
59    /// Stable identifier for the combined aggregate.
60    fn id(&self) -> AggregateFnId;
61
62    /// Construct the left child vtable.
63    fn left(&self) -> Self::Left;
64
65    /// Construct the right child vtable.
66    fn right(&self) -> Self::Right;
67
68    /// Field name for the left child in the partial struct dtype.
69    fn left_name(&self) -> &'static str {
70        "left"
71    }
72
73    /// Field name for the right child in the partial struct dtype.
74    fn right_name(&self) -> &'static str {
75        "right"
76    }
77
78    /// Return type of the combined aggregate.
79    fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
80
81    /// Combine the finalized left and right results into the final aggregate.
82    fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
83
84    fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar>;
85
86    /// Serialize the options for this combined aggregate. Default: not serializable.
87    fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
88        let _ = options;
89        Ok(None)
90    }
91
92    /// Deserialize the options for this combined aggregate. Default: bails.
93    fn deserialize(
94        &self,
95        metadata: &[u8],
96        session: &VortexSession,
97    ) -> VortexResult<CombinedOptions<Self>> {
98        let _ = (metadata, session);
99        vortex_bail!(
100            "Combined aggregate function {} is not deserializable",
101            BinaryCombined::id(self)
102        );
103    }
104
105    /// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`.
106    fn coerce_args(
107        &self,
108        options: &CombinedOptions<Self>,
109        input_dtype: &DType,
110    ) -> VortexResult<DType> {
111        let left_coerced = self.left().coerce_args(&options.0, input_dtype)?;
112        self.right().coerce_args(&options.1, &left_coerced)
113    }
114
115    /// Build the partial struct dtype that wraps the two child partials.
116    fn partial_struct_dtype(&self, left: DType, right: DType) -> DType {
117        DType::Struct(
118            StructFields::new(
119                FieldNames::from_iter([
120                    FieldName::from(self.left_name()),
121                    FieldName::from(self.right_name()),
122                ]),
123                vec![left, right],
124            ),
125            Nullability::NonNullable,
126        )
127    }
128}
129
130/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`].
131#[derive(Clone, Debug)]
132pub struct Combined<T: BinaryCombined>(pub T);
133
134impl<T: BinaryCombined> Combined<T> {
135    /// Construct a new combined aggregate vtable.
136    pub fn new(inner: T) -> Self {
137        Self(inner)
138    }
139}
140
141impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
142    type Options = CombinedOptions<T>;
143    // Each child is held as a fully-fledged `AccumulatorRef` so that batches dispatched through
144    // `try_accumulate` consult the kernel registry per-child (e.g. a `(Dict, Sum)` kernel fires
145    // for the inner `Sum` child of `Combined<Mean>`).
146    type Partial = (AccumulatorRef, AccumulatorRef);
147
148    fn id(&self) -> AggregateFnId {
149        self.0.id()
150    }
151
152    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
153        BinaryCombined::serialize(&self.0, options)
154    }
155
156    fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
157        BinaryCombined::deserialize(&self.0, metadata, session)
158    }
159
160    fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
161        BinaryCombined::coerce_args(&self.0, options, input_dtype)
162    }
163
164    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
165        BinaryCombined::return_dtype(&self.0, input_dtype)
166    }
167
168    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
169        let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
170        let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
171        Some(self.0.partial_struct_dtype(l, r))
172    }
173
174    fn empty_partial(
175        &self,
176        options: &Self::Options,
177        input_dtype: &DType,
178    ) -> VortexResult<Self::Partial> {
179        let left = Accumulator::try_new(self.0.left(), options.0.clone(), input_dtype.clone())?;
180        let right = Accumulator::try_new(self.0.right(), options.1.clone(), input_dtype.clone())?;
181        Ok((
182            Box::new(left) as AccumulatorRef,
183            Box::new(right) as AccumulatorRef,
184        ))
185    }
186
187    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
188        if other.is_null() {
189            return Ok(());
190        }
191        let s = other.as_struct();
192        let lname = self.0.left_name();
193        let rname = self.0.right_name();
194        let l_field = s
195            .field(lname)
196            .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
197        let r_field = s
198            .field(rname)
199            .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
200        partial.0.combine_partials(l_field)?;
201        partial.1.combine_partials(r_field)?;
202        Ok(())
203    }
204
205    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
206        let l_scalar = partial.0.partial_scalar()?;
207        let r_scalar = partial.1.partial_scalar()?;
208        let dtype = self
209            .0
210            .partial_struct_dtype(l_scalar.dtype().clone(), r_scalar.dtype().clone());
211        Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
212    }
213
214    fn reset(&self, partial: &mut Self::Partial) {
215        partial.0.reset();
216        partial.1.reset();
217    }
218
219    fn is_saturated(&self, partial: &Self::Partial) -> bool {
220        partial.0.is_saturated() && partial.1.is_saturated()
221    }
222
223    /// Delegate the batch to each child's `Accumulator::accumulate`, which consults the
224    /// kernel registry against the child's `aggregate_fn` id. This is what makes
225    /// `(encoding, Child)` kernels reachable through `Combined<Parent>` — without it, a
226    /// `(Dict, Sum)` kernel would be dead code for `Combined<Mean>`. We always return
227    /// `true` so [`Self::accumulate`] is unreachable.
228    fn try_accumulate(
229        &self,
230        state: &mut Self::Partial,
231        batch: &ArrayRef,
232        ctx: &mut ExecutionCtx,
233    ) -> VortexResult<bool> {
234        state.0.accumulate(batch, ctx)?;
235        state.1.accumulate(batch, ctx)?;
236        Ok(true)
237    }
238
239    fn accumulate(
240        &self,
241        _state: &mut Self::Partial,
242        _batch: &Columnar,
243        _ctx: &mut ExecutionCtx,
244    ) -> VortexResult<()> {
245        unreachable!("Combined::try_accumulate handles all batches")
246    }
247
248    fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
249        let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
250        let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
251        let l_finalized = self.0.left().finalize(l_field)?;
252        let r_finalized = self.0.right().finalize(r_field)?;
253        BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
254    }
255
256    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
257        let l_scalar = partial.0.final_scalar()?;
258        let r_scalar = partial.1.final_scalar()?;
259        BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar)
260    }
261}