Skip to main content

vortex_array/aggregate_fn/
combined.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Generic adapter for aggregates whose result is computed from two child
5//! aggregate functions, e.g. `Mean = Sum / Count`.
6
7use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::fmt::{self};
11use std::hash::Hash;
12
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::Columnar;
20use crate::ExecutionCtx;
21use crate::aggregate_fn::AggregateFnId;
22use crate::aggregate_fn::AggregateFnVTable;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::FieldName;
26use crate::dtype::FieldNames;
27use crate::dtype::Nullability;
28use crate::dtype::StructFields;
29use crate::scalar::Scalar;
30
31/// Pair of options for the two children of a [`BinaryCombined`] aggregate.
32///
33/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound
34/// requires `Display`, which tuples don't implement.
35#[derive(Clone, Debug, PartialEq, Eq, Hash)]
36pub struct PairOptions<L, R>(pub L, pub R);
37
38impl<L: Display, R: Display> Display for PairOptions<L, R> {
39    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40        write!(f, "({}, {})", self.0, self.1)
41    }
42}
43
44// Convenience aliases so signatures stay readable.
45type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
46type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
47type LeftPartial<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Partial;
48type RightPartial<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Partial;
49/// Combined options for a [`BinaryCombined`] aggregate.
50pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
51
52/// Declare an aggregate function in terms of two child aggregates.
53pub trait BinaryCombined: 'static + Send + Sync + Clone {
54    /// The left child aggregate vtable.
55    type Left: AggregateFnVTable;
56    /// The right child aggregate vtable.
57    type Right: AggregateFnVTable;
58
59    /// Stable identifier for the combined aggregate.
60    fn id(&self) -> AggregateFnId;
61
62    /// Construct the left child vtable.
63    fn left(&self) -> Self::Left;
64
65    /// Construct the right child vtable.
66    fn right(&self) -> Self::Right;
67
68    /// Field name for the left child in the partial struct dtype.
69    fn left_name(&self) -> &'static str {
70        "left"
71    }
72
73    /// Field name for the right child in the partial struct dtype.
74    fn right_name(&self) -> &'static str {
75        "right"
76    }
77
78    /// Return type of the combined aggregate.
79    fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
80
81    /// Combine the finalized left and right results into the final aggregate.
82    fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
83
84    fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar>;
85
86    /// Serialize the options for this combined aggregate. Default: not serializable.
87    fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
88        let _ = options;
89        Ok(None)
90    }
91
92    /// Deserialize the options for this combined aggregate. Default: bails.
93    fn deserialize(
94        &self,
95        metadata: &[u8],
96        session: &VortexSession,
97    ) -> VortexResult<CombinedOptions<Self>> {
98        let _ = (metadata, session);
99        vortex_bail!(
100            "Combined aggregate function {} is not deserializable",
101            BinaryCombined::id(self)
102        );
103    }
104
105    /// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`.
106    fn coerce_args(
107        &self,
108        options: &CombinedOptions<Self>,
109        input_dtype: &DType,
110    ) -> VortexResult<DType> {
111        let left_coerced = self.left().coerce_args(&options.0, input_dtype)?;
112        self.right().coerce_args(&options.1, &left_coerced)
113    }
114
115    /// Build the partial struct dtype that wraps the two child partials.
116    fn partial_struct_dtype(&self, left: DType, right: DType) -> DType {
117        DType::Struct(
118            StructFields::new(
119                FieldNames::from_iter([
120                    FieldName::from(self.left_name()),
121                    FieldName::from(self.right_name()),
122                ]),
123                vec![left, right],
124            ),
125            Nullability::NonNullable,
126        )
127    }
128}
129
130/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`].
131#[derive(Clone, Debug)]
132pub struct Combined<T: BinaryCombined>(pub T);
133
134impl<T: BinaryCombined> Combined<T> {
135    /// Construct a new combined aggregate vtable.
136    pub fn new(inner: T) -> Self {
137        Self(inner)
138    }
139}
140
141impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
142    type Options = CombinedOptions<T>;
143    type Partial = (LeftPartial<T>, RightPartial<T>);
144
145    fn id(&self) -> AggregateFnId {
146        self.0.id()
147    }
148
149    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
150        BinaryCombined::serialize(&self.0, options)
151    }
152
153    fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
154        BinaryCombined::deserialize(&self.0, metadata, session)
155    }
156
157    fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
158        BinaryCombined::coerce_args(&self.0, options, input_dtype)
159    }
160
161    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
162        BinaryCombined::return_dtype(&self.0, input_dtype)
163    }
164
165    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
166        let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
167        let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
168        Some(self.0.partial_struct_dtype(l, r))
169    }
170
171    fn empty_partial(
172        &self,
173        options: &Self::Options,
174        input_dtype: &DType,
175    ) -> VortexResult<Self::Partial> {
176        Ok((
177            self.0.left().empty_partial(&options.0, input_dtype)?,
178            self.0.right().empty_partial(&options.1, input_dtype)?,
179        ))
180    }
181
182    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
183        if other.is_null() {
184            return Ok(());
185        }
186        let s = other.as_struct();
187        let lname = self.0.left_name();
188        let rname = self.0.right_name();
189        let l_field = s
190            .field(lname)
191            .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
192        let r_field = s
193            .field(rname)
194            .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
195        self.0.left().combine_partials(&mut partial.0, l_field)?;
196        self.0.right().combine_partials(&mut partial.1, r_field)?;
197        Ok(())
198    }
199
200    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
201        let l_scalar = self.0.left().to_scalar(&partial.0)?;
202        let r_scalar = self.0.right().to_scalar(&partial.1)?;
203        let dtype = self
204            .0
205            .partial_struct_dtype(l_scalar.dtype().clone(), r_scalar.dtype().clone());
206        Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
207    }
208
209    fn reset(&self, partial: &mut Self::Partial) {
210        self.0.left().reset(&mut partial.0);
211        self.0.right().reset(&mut partial.1);
212    }
213
214    fn is_saturated(&self, partial: &Self::Partial) -> bool {
215        self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1)
216    }
217
218    /// Fans out to each child's `try_accumulate`, falling back to `accumulate`
219    /// against a lazily-canonicalized batch. We always claim to handle the
220    /// batch ourselves so [`Self::accumulate`] is unreachable — this is the
221    /// same trick `Count` uses to opt out of the canonicalization path.
222    fn try_accumulate(
223        &self,
224        state: &mut Self::Partial,
225        batch: &ArrayRef,
226        ctx: &mut ExecutionCtx,
227    ) -> VortexResult<bool> {
228        let mut canonical: Option<Columnar> = None;
229        if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? {
230            let c = canonical.insert(batch.clone().execute::<Columnar>(ctx)?);
231            self.0.left().accumulate(&mut state.0, c, ctx)?;
232        }
233        if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? {
234            let c = match canonical.as_ref() {
235                Some(c) => c,
236                None => canonical.insert(batch.clone().execute::<Columnar>(ctx)?),
237            };
238            self.0.right().accumulate(&mut state.1, c, ctx)?;
239        }
240        Ok(true)
241    }
242
243    fn accumulate(
244        &self,
245        _state: &mut Self::Partial,
246        _batch: &Columnar,
247        _ctx: &mut ExecutionCtx,
248    ) -> VortexResult<()> {
249        unreachable!("Combined::try_accumulate handles all batches")
250    }
251
252    fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
253        let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
254        let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
255        let l_finalized = self.0.left().finalize(l_field)?;
256        let r_finalized = self.0.right().finalize(r_field)?;
257        BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
258    }
259
260    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
261        let l_scalar = self.0.left().finalize_scalar(&partial.0)?;
262        let r_scalar = self.0.right().finalize_scalar(&partial.1)?;
263        BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar)
264    }
265}