Skip to main content

vortex_array/aggregate_fn/
combined.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Generic adapter for aggregates whose result is computed from two child
5//! aggregate functions, e.g. `Mean = Sum / Count`.
6
7use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::fmt::{self};
11use std::hash::Hash;
12
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::Columnar;
20use crate::ExecutionCtx;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AccumulatorRef;
23use crate::aggregate_fn::AggregateFnId;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::FieldName;
28use crate::dtype::FieldNames;
29use crate::dtype::Nullability;
30use crate::dtype::StructFields;
31use crate::scalar::Scalar;
32
33/// Pair of options for the two children of a [`BinaryCombined`] aggregate.
34///
35/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound
36/// requires `Display`, which tuples don't implement.
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct PairOptions<L, R>(pub L, pub R);
39
40impl<L: Display, R: Display> Display for PairOptions<L, R> {
41    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
42        write!(f, "({}, {})", self.0, self.1)
43    }
44}
45
46// Convenience aliases so signatures stay readable.
47type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
48type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
49/// Combined options for a [`BinaryCombined`] aggregate.
50pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
51
52/// Declare an aggregate function in terms of two child aggregates.
53pub trait BinaryCombined: 'static + Send + Sync + Clone {
54    /// The left child aggregate vtable.
55    type Left: AggregateFnVTable;
56    /// The right child aggregate vtable.
57    type Right: AggregateFnVTable;
58
59    /// Stable identifier for the combined aggregate.
60    fn id(&self) -> AggregateFnId;
61
62    /// Construct the left child vtable.
63    fn left(&self) -> Self::Left;
64
65    /// Construct the right child vtable.
66    fn right(&self) -> Self::Right;
67
68    /// Field name for the left child in the partial struct dtype.
69    fn left_name(&self) -> &'static str {
70        "left"
71    }
72
73    /// Field name for the right child in the partial struct dtype.
74    fn right_name(&self) -> &'static str {
75        "right"
76    }
77
78    /// Return type of the combined aggregate.
79    fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
80
81    /// Combine the finalized left and right results into the final aggregate.
82    fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
83
84    fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar>;
85
86    /// Serialize the options for this combined aggregate. Default: not serializable.
87    fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
88        let _ = options;
89        Ok(None)
90    }
91
92    /// Deserialize the options for this combined aggregate. Default: bails.
93    fn deserialize(
94        &self,
95        metadata: &[u8],
96        session: &VortexSession,
97    ) -> VortexResult<CombinedOptions<Self>> {
98        let _ = (metadata, session);
99        vortex_bail!(
100            "Combined aggregate function {} is not deserializable",
101            BinaryCombined::id(self)
102        );
103    }
104
105    /// Build the partial struct dtype that wraps the two child partials.
106    fn partial_struct_dtype(&self, left: DType, right: DType) -> DType {
107        DType::Struct(
108            StructFields::new(
109                FieldNames::from_iter([
110                    FieldName::from(self.left_name()),
111                    FieldName::from(self.right_name()),
112                ]),
113                vec![left, right],
114            ),
115            Nullability::NonNullable,
116        )
117    }
118}
119
120/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`].
121#[derive(Clone, Debug)]
122pub struct Combined<T: BinaryCombined>(pub T);
123
124impl<T: BinaryCombined> Combined<T> {
125    /// Construct a new combined aggregate vtable.
126    pub fn new(inner: T) -> Self {
127        Self(inner)
128    }
129}
130
131impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
132    type Options = CombinedOptions<T>;
133    // Each child is held as a fully-fledged `AccumulatorRef` so that batches dispatched through
134    // `try_accumulate` consult the kernel registry per-child (e.g. a `(Dict, Sum)` kernel fires
135    // for the inner `Sum` child of `Combined<Mean>`).
136    type Partial = (AccumulatorRef, AccumulatorRef);
137
138    fn id(&self) -> AggregateFnId {
139        self.0.id()
140    }
141
142    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
143        BinaryCombined::serialize(&self.0, options)
144    }
145
146    fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
147        BinaryCombined::deserialize(&self.0, metadata, session)
148    }
149
150    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
151        BinaryCombined::return_dtype(&self.0, input_dtype)
152    }
153
154    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
155        let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
156        let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
157        Some(self.0.partial_struct_dtype(l, r))
158    }
159
160    fn empty_partial(
161        &self,
162        options: &Self::Options,
163        input_dtype: &DType,
164    ) -> VortexResult<Self::Partial> {
165        let left = Accumulator::try_new(self.0.left(), options.0.clone(), input_dtype.clone())?;
166        let right = Accumulator::try_new(self.0.right(), options.1.clone(), input_dtype.clone())?;
167        Ok((
168            Box::new(left) as AccumulatorRef,
169            Box::new(right) as AccumulatorRef,
170        ))
171    }
172
173    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
174        if other.is_null() {
175            return Ok(());
176        }
177        let s = other.as_struct();
178        let lname = self.0.left_name();
179        let rname = self.0.right_name();
180        let l_field = s
181            .field(lname)
182            .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
183        let r_field = s
184            .field(rname)
185            .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
186        partial.0.combine_partials(l_field)?;
187        partial.1.combine_partials(r_field)?;
188        Ok(())
189    }
190
191    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
192        let l_scalar = partial.0.partial_scalar()?;
193        let r_scalar = partial.1.partial_scalar()?;
194        let dtype = self
195            .0
196            .partial_struct_dtype(l_scalar.dtype().clone(), r_scalar.dtype().clone());
197        Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
198    }
199
200    fn reset(&self, partial: &mut Self::Partial) {
201        partial.0.reset();
202        partial.1.reset();
203    }
204
205    fn is_saturated(&self, partial: &Self::Partial) -> bool {
206        partial.0.is_saturated() && partial.1.is_saturated()
207    }
208
209    /// Delegate the batch to each child's `Accumulator::accumulate`, which consults the
210    /// kernel registry against the child's `aggregate_fn` id. This is what makes
211    /// `(encoding, Child)` kernels reachable through `Combined<Parent>` — without it, a
212    /// `(Dict, Sum)` kernel would be dead code for `Combined<Mean>`. We always return
213    /// `true` so [`Self::accumulate`] is unreachable.
214    fn try_accumulate(
215        &self,
216        state: &mut Self::Partial,
217        batch: &ArrayRef,
218        ctx: &mut ExecutionCtx,
219    ) -> VortexResult<bool> {
220        state.0.accumulate(batch, ctx)?;
221        state.1.accumulate(batch, ctx)?;
222        Ok(true)
223    }
224
225    fn accumulate(
226        &self,
227        _state: &mut Self::Partial,
228        _batch: &Columnar,
229        _ctx: &mut ExecutionCtx,
230    ) -> VortexResult<()> {
231        unreachable!("Combined::try_accumulate handles all batches")
232    }
233
234    fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
235        let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
236        let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
237        let l_finalized = self.0.left().finalize(l_field)?;
238        let r_finalized = self.0.right().finalize(r_field)?;
239        BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
240    }
241
242    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
243        let l_scalar = partial.0.final_scalar()?;
244        let r_scalar = partial.1.final_scalar()?;
245        BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar)
246    }
247}