vortex_array/aggregate_fn/
combined.rs1use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::fmt::{self};
11use std::hash::Hash;
12
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::Columnar;
20use crate::ExecutionCtx;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AccumulatorRef;
23use crate::aggregate_fn::AggregateFnId;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::FieldName;
28use crate::dtype::FieldNames;
29use crate::dtype::Nullability;
30use crate::dtype::StructFields;
31use crate::scalar::Scalar;
32
33#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct PairOptions<L, R>(pub L, pub R);
39
40impl<L: Display, R: Display> Display for PairOptions<L, R> {
41 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
42 write!(f, "({}, {})", self.0, self.1)
43 }
44}
45
46type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
48type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
49pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
51
52pub trait BinaryCombined: 'static + Send + Sync + Clone {
54 type Left: AggregateFnVTable;
56 type Right: AggregateFnVTable;
58
59 fn id(&self) -> AggregateFnId;
61
62 fn left(&self) -> Self::Left;
64
65 fn right(&self) -> Self::Right;
67
68 fn left_name(&self) -> &'static str {
70 "left"
71 }
72
73 fn right_name(&self) -> &'static str {
75 "right"
76 }
77
78 fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
80
81 fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
83
84 fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar>;
85
86 fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
88 let _ = options;
89 Ok(None)
90 }
91
92 fn deserialize(
94 &self,
95 metadata: &[u8],
96 session: &VortexSession,
97 ) -> VortexResult<CombinedOptions<Self>> {
98 let _ = (metadata, session);
99 vortex_bail!(
100 "Combined aggregate function {} is not deserializable",
101 BinaryCombined::id(self)
102 );
103 }
104
105 fn coerce_args(
107 &self,
108 options: &CombinedOptions<Self>,
109 input_dtype: &DType,
110 ) -> VortexResult<DType> {
111 let left_coerced = self.left().coerce_args(&options.0, input_dtype)?;
112 self.right().coerce_args(&options.1, &left_coerced)
113 }
114
115 fn partial_struct_dtype(&self, left: DType, right: DType) -> DType {
117 DType::Struct(
118 StructFields::new(
119 FieldNames::from_iter([
120 FieldName::from(self.left_name()),
121 FieldName::from(self.right_name()),
122 ]),
123 vec![left, right],
124 ),
125 Nullability::NonNullable,
126 )
127 }
128}
129
130#[derive(Clone, Debug)]
132pub struct Combined<T: BinaryCombined>(pub T);
133
134impl<T: BinaryCombined> Combined<T> {
135 pub fn new(inner: T) -> Self {
137 Self(inner)
138 }
139}
140
141impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
142 type Options = CombinedOptions<T>;
143 type Partial = (AccumulatorRef, AccumulatorRef);
147
148 fn id(&self) -> AggregateFnId {
149 self.0.id()
150 }
151
152 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
153 BinaryCombined::serialize(&self.0, options)
154 }
155
156 fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
157 BinaryCombined::deserialize(&self.0, metadata, session)
158 }
159
160 fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
161 BinaryCombined::coerce_args(&self.0, options, input_dtype)
162 }
163
164 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
165 BinaryCombined::return_dtype(&self.0, input_dtype)
166 }
167
168 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
169 let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
170 let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
171 Some(self.0.partial_struct_dtype(l, r))
172 }
173
174 fn empty_partial(
175 &self,
176 options: &Self::Options,
177 input_dtype: &DType,
178 ) -> VortexResult<Self::Partial> {
179 let left = Accumulator::try_new(self.0.left(), options.0.clone(), input_dtype.clone())?;
180 let right = Accumulator::try_new(self.0.right(), options.1.clone(), input_dtype.clone())?;
181 Ok((
182 Box::new(left) as AccumulatorRef,
183 Box::new(right) as AccumulatorRef,
184 ))
185 }
186
187 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
188 if other.is_null() {
189 return Ok(());
190 }
191 let s = other.as_struct();
192 let lname = self.0.left_name();
193 let rname = self.0.right_name();
194 let l_field = s
195 .field(lname)
196 .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
197 let r_field = s
198 .field(rname)
199 .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
200 partial.0.combine_partials(l_field)?;
201 partial.1.combine_partials(r_field)?;
202 Ok(())
203 }
204
205 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
206 let l_scalar = partial.0.partial_scalar()?;
207 let r_scalar = partial.1.partial_scalar()?;
208 let dtype = self
209 .0
210 .partial_struct_dtype(l_scalar.dtype().clone(), r_scalar.dtype().clone());
211 Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
212 }
213
214 fn reset(&self, partial: &mut Self::Partial) {
215 partial.0.reset();
216 partial.1.reset();
217 }
218
219 fn is_saturated(&self, partial: &Self::Partial) -> bool {
220 partial.0.is_saturated() && partial.1.is_saturated()
221 }
222
223 fn try_accumulate(
229 &self,
230 state: &mut Self::Partial,
231 batch: &ArrayRef,
232 ctx: &mut ExecutionCtx,
233 ) -> VortexResult<bool> {
234 state.0.accumulate(batch, ctx)?;
235 state.1.accumulate(batch, ctx)?;
236 Ok(true)
237 }
238
239 fn accumulate(
240 &self,
241 _state: &mut Self::Partial,
242 _batch: &Columnar,
243 _ctx: &mut ExecutionCtx,
244 ) -> VortexResult<()> {
245 unreachable!("Combined::try_accumulate handles all batches")
246 }
247
248 fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
249 let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
250 let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
251 let l_finalized = self.0.left().finalize(l_field)?;
252 let r_finalized = self.0.right().finalize(r_field)?;
253 BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
254 }
255
256 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
257 let l_scalar = partial.0.final_scalar()?;
258 let r_scalar = partial.1.final_scalar()?;
259 BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar)
260 }
261}