vortex_array/aggregate_fn/
combined.rs1use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::fmt::{self};
11use std::hash::Hash;
12
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::Columnar;
20use crate::ExecutionCtx;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AccumulatorRef;
23use crate::aggregate_fn::AggregateFnId;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::FieldName;
28use crate::dtype::FieldNames;
29use crate::dtype::Nullability;
30use crate::dtype::StructFields;
31use crate::scalar::Scalar;
32
33#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct PairOptions<L, R>(pub L, pub R);
39
40impl<L: Display, R: Display> Display for PairOptions<L, R> {
41 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
42 write!(f, "({}, {})", self.0, self.1)
43 }
44}
45
46type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
48type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
49pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
51
52pub trait BinaryCombined: 'static + Send + Sync + Clone {
54 type Left: AggregateFnVTable;
56 type Right: AggregateFnVTable;
58
59 fn id(&self) -> AggregateFnId;
61
62 fn left(&self) -> Self::Left;
64
65 fn right(&self) -> Self::Right;
67
68 fn left_name(&self) -> &'static str {
70 "left"
71 }
72
73 fn right_name(&self) -> &'static str {
75 "right"
76 }
77
78 fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
80
81 fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
83
84 fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar>;
85
86 fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
88 let _ = options;
89 Ok(None)
90 }
91
92 fn deserialize(
94 &self,
95 metadata: &[u8],
96 session: &VortexSession,
97 ) -> VortexResult<CombinedOptions<Self>> {
98 let _ = (metadata, session);
99 vortex_bail!(
100 "Combined aggregate function {} is not deserializable",
101 BinaryCombined::id(self)
102 );
103 }
104
105 fn partial_struct_dtype(&self, left: DType, right: DType) -> DType {
107 DType::Struct(
108 StructFields::new(
109 FieldNames::from_iter([
110 FieldName::from(self.left_name()),
111 FieldName::from(self.right_name()),
112 ]),
113 vec![left, right],
114 ),
115 Nullability::NonNullable,
116 )
117 }
118}
119
120#[derive(Clone, Debug)]
122pub struct Combined<T: BinaryCombined>(pub T);
123
124impl<T: BinaryCombined> Combined<T> {
125 pub fn new(inner: T) -> Self {
127 Self(inner)
128 }
129}
130
131impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
132 type Options = CombinedOptions<T>;
133 type Partial = (AccumulatorRef, AccumulatorRef);
137
138 fn id(&self) -> AggregateFnId {
139 self.0.id()
140 }
141
142 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
143 BinaryCombined::serialize(&self.0, options)
144 }
145
146 fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
147 BinaryCombined::deserialize(&self.0, metadata, session)
148 }
149
150 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
151 BinaryCombined::return_dtype(&self.0, input_dtype)
152 }
153
154 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
155 let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
156 let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
157 Some(self.0.partial_struct_dtype(l, r))
158 }
159
160 fn empty_partial(
161 &self,
162 options: &Self::Options,
163 input_dtype: &DType,
164 ) -> VortexResult<Self::Partial> {
165 let left = Accumulator::try_new(self.0.left(), options.0.clone(), input_dtype.clone())?;
166 let right = Accumulator::try_new(self.0.right(), options.1.clone(), input_dtype.clone())?;
167 Ok((
168 Box::new(left) as AccumulatorRef,
169 Box::new(right) as AccumulatorRef,
170 ))
171 }
172
173 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
174 if other.is_null() {
175 return Ok(());
176 }
177 let s = other.as_struct();
178 let lname = self.0.left_name();
179 let rname = self.0.right_name();
180 let l_field = s
181 .field(lname)
182 .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
183 let r_field = s
184 .field(rname)
185 .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
186 partial.0.combine_partials(l_field)?;
187 partial.1.combine_partials(r_field)?;
188 Ok(())
189 }
190
191 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
192 let l_scalar = partial.0.partial_scalar()?;
193 let r_scalar = partial.1.partial_scalar()?;
194 let dtype = self
195 .0
196 .partial_struct_dtype(l_scalar.dtype().clone(), r_scalar.dtype().clone());
197 Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
198 }
199
200 fn reset(&self, partial: &mut Self::Partial) {
201 partial.0.reset();
202 partial.1.reset();
203 }
204
205 fn is_saturated(&self, partial: &Self::Partial) -> bool {
206 partial.0.is_saturated() && partial.1.is_saturated()
207 }
208
209 fn try_accumulate(
215 &self,
216 state: &mut Self::Partial,
217 batch: &ArrayRef,
218 ctx: &mut ExecutionCtx,
219 ) -> VortexResult<bool> {
220 state.0.accumulate(batch, ctx)?;
221 state.1.accumulate(batch, ctx)?;
222 Ok(true)
223 }
224
225 fn accumulate(
226 &self,
227 _state: &mut Self::Partial,
228 _batch: &Columnar,
229 _ctx: &mut ExecutionCtx,
230 ) -> VortexResult<()> {
231 unreachable!("Combined::try_accumulate handles all batches")
232 }
233
234 fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
235 let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
236 let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
237 let l_finalized = self.0.left().finalize(l_field)?;
238 let r_finalized = self.0.right().finalize(r_field)?;
239 BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
240 }
241
242 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
243 let l_scalar = partial.0.final_scalar()?;
244 let r_scalar = partial.1.final_scalar()?;
245 BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar)
246 }
247}