vortex_array/aggregate_fn/
combined.rs1use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::fmt::{self};
11use std::hash::Hash;
12
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::Columnar;
20use crate::ExecutionCtx;
21use crate::aggregate_fn::AggregateFnId;
22use crate::aggregate_fn::AggregateFnVTable;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::FieldName;
26use crate::dtype::FieldNames;
27use crate::dtype::Nullability;
28use crate::dtype::StructFields;
29use crate::scalar::Scalar;
30
31#[derive(Clone, Debug, PartialEq, Eq, Hash)]
36pub struct PairOptions<L, R>(pub L, pub R);
37
38impl<L: Display, R: Display> Display for PairOptions<L, R> {
39 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40 write!(f, "({}, {})", self.0, self.1)
41 }
42}
43
44type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
46type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
47type LeftPartial<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Partial;
48type RightPartial<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Partial;
49pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
51
52pub trait BinaryCombined: 'static + Send + Sync + Clone {
54 type Left: AggregateFnVTable;
56 type Right: AggregateFnVTable;
58
59 fn id(&self) -> AggregateFnId;
61
62 fn left(&self) -> Self::Left;
64
65 fn right(&self) -> Self::Right;
67
68 fn left_name(&self) -> &'static str {
70 "left"
71 }
72
73 fn right_name(&self) -> &'static str {
75 "right"
76 }
77
78 fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
80
81 fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
83
84 fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar>;
85
86 fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
88 let _ = options;
89 Ok(None)
90 }
91
92 fn deserialize(
94 &self,
95 metadata: &[u8],
96 session: &VortexSession,
97 ) -> VortexResult<CombinedOptions<Self>> {
98 let _ = (metadata, session);
99 vortex_bail!(
100 "Combined aggregate function {} is not deserializable",
101 BinaryCombined::id(self)
102 );
103 }
104
105 fn coerce_args(
107 &self,
108 options: &CombinedOptions<Self>,
109 input_dtype: &DType,
110 ) -> VortexResult<DType> {
111 let left_coerced = self.left().coerce_args(&options.0, input_dtype)?;
112 self.right().coerce_args(&options.1, &left_coerced)
113 }
114
115 fn partial_struct_dtype(&self, left: DType, right: DType) -> DType {
117 DType::Struct(
118 StructFields::new(
119 FieldNames::from_iter([
120 FieldName::from(self.left_name()),
121 FieldName::from(self.right_name()),
122 ]),
123 vec![left, right],
124 ),
125 Nullability::NonNullable,
126 )
127 }
128}
129
130#[derive(Clone, Debug)]
132pub struct Combined<T: BinaryCombined>(pub T);
133
134impl<T: BinaryCombined> Combined<T> {
135 pub fn new(inner: T) -> Self {
137 Self(inner)
138 }
139}
140
141impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
142 type Options = CombinedOptions<T>;
143 type Partial = (LeftPartial<T>, RightPartial<T>);
144
145 fn id(&self) -> AggregateFnId {
146 self.0.id()
147 }
148
149 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
150 BinaryCombined::serialize(&self.0, options)
151 }
152
153 fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
154 BinaryCombined::deserialize(&self.0, metadata, session)
155 }
156
157 fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
158 BinaryCombined::coerce_args(&self.0, options, input_dtype)
159 }
160
161 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
162 BinaryCombined::return_dtype(&self.0, input_dtype)
163 }
164
165 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
166 let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
167 let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
168 Some(self.0.partial_struct_dtype(l, r))
169 }
170
171 fn empty_partial(
172 &self,
173 options: &Self::Options,
174 input_dtype: &DType,
175 ) -> VortexResult<Self::Partial> {
176 Ok((
177 self.0.left().empty_partial(&options.0, input_dtype)?,
178 self.0.right().empty_partial(&options.1, input_dtype)?,
179 ))
180 }
181
182 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
183 if other.is_null() {
184 return Ok(());
185 }
186 let s = other.as_struct();
187 let lname = self.0.left_name();
188 let rname = self.0.right_name();
189 let l_field = s
190 .field(lname)
191 .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
192 let r_field = s
193 .field(rname)
194 .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
195 self.0.left().combine_partials(&mut partial.0, l_field)?;
196 self.0.right().combine_partials(&mut partial.1, r_field)?;
197 Ok(())
198 }
199
200 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
201 let l_scalar = self.0.left().to_scalar(&partial.0)?;
202 let r_scalar = self.0.right().to_scalar(&partial.1)?;
203 let dtype = self
204 .0
205 .partial_struct_dtype(l_scalar.dtype().clone(), r_scalar.dtype().clone());
206 Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
207 }
208
209 fn reset(&self, partial: &mut Self::Partial) {
210 self.0.left().reset(&mut partial.0);
211 self.0.right().reset(&mut partial.1);
212 }
213
214 fn is_saturated(&self, partial: &Self::Partial) -> bool {
215 self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1)
216 }
217
218 fn try_accumulate(
223 &self,
224 state: &mut Self::Partial,
225 batch: &ArrayRef,
226 ctx: &mut ExecutionCtx,
227 ) -> VortexResult<bool> {
228 let mut canonical: Option<Columnar> = None;
229 if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? {
230 let c = canonical.insert(batch.clone().execute::<Columnar>(ctx)?);
231 self.0.left().accumulate(&mut state.0, c, ctx)?;
232 }
233 if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? {
234 let c = match canonical.as_ref() {
235 Some(c) => c,
236 None => canonical.insert(batch.clone().execute::<Columnar>(ctx)?),
237 };
238 self.0.right().accumulate(&mut state.1, c, ctx)?;
239 }
240 Ok(true)
241 }
242
243 fn accumulate(
244 &self,
245 _state: &mut Self::Partial,
246 _batch: &Columnar,
247 _ctx: &mut ExecutionCtx,
248 ) -> VortexResult<()> {
249 unreachable!("Combined::try_accumulate handles all batches")
250 }
251
252 fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
253 let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
254 let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
255 let l_finalized = self.0.left().finalize(l_field)?;
256 let r_finalized = self.0.right().finalize(r_field)?;
257 BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
258 }
259
260 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
261 let l_scalar = self.0.left().finalize_scalar(&partial.0)?;
262 let r_scalar = self.0.right().finalize_scalar(&partial.1)?;
263 BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar)
264 }
265}