vortex_array/aggregate_fn/fns/all_non_distinct/
mod.rs1mod bool;
5mod decimal;
6mod extension;
7mod filter;
8mod fixed_size_list;
9mod list;
10mod primitive;
11mod struct_;
12#[cfg(test)]
13mod tests;
14mod varbin;
15
16use std::sync::LazyLock;
17
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_err;
21
22use self::bool::check_bool_identical;
23use self::decimal::check_decimal_identical;
24use self::extension::check_extension_identical;
25use self::filter::shared_validity_mask;
26use self::fixed_size_list::check_fixed_size_list_identical;
27use self::list::check_list_identical;
28use self::primitive::check_primitive_identical;
29use self::struct_::check_struct_identical;
30use self::varbin::check_varbinview_identical;
31use crate::ArrayRef;
32use crate::Canonical;
33use crate::Columnar;
34use crate::ExecutionCtx;
35use crate::IntoArray;
36use crate::aggregate_fn::Accumulator;
37use crate::aggregate_fn::AggregateFnId;
38use crate::aggregate_fn::AggregateFnVTable;
39use crate::aggregate_fn::DynAccumulator;
40use crate::aggregate_fn::EmptyOptions;
41use crate::arrays::StructArray;
42use crate::arrays::struct_::StructArrayExt;
43use crate::dtype::DType;
44use crate::dtype::FieldNames;
45use crate::dtype::Nullability;
46use crate::scalar::Scalar;
47use crate::validity::Validity;
48
49pub fn all_non_distinct(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
58 if a.dtype() != b.dtype() {
59 vortex_bail!(
60 "all_non_distinct: dtype mismatch: {} vs {}",
61 a.dtype(),
62 b.dtype()
63 );
64 }
65
66 if a.len() != b.len() {
67 vortex_bail!(
68 "all_non_distinct: length mismatch: {} vs {}",
69 a.len(),
70 b.len()
71 );
72 }
73
74 if a.is_empty() {
75 return Ok(true);
76 }
77
78 let Some(shared_validity) = shared_validity_mask(a, b, ctx)? else {
79 return Ok(false);
80 };
81 if shared_validity.true_count() == 0 {
82 return Ok(true);
83 }
84
85 let validity = Validity::from_mask(shared_validity, a.dtype().nullability());
86 let batch = StructArray::try_new(NAMES.clone(), vec![a.clone(), b.clone()], a.len(), validity)?
87 .into_array();
88
89 let mut acc = Accumulator::try_new(AllNonDistinct, EmptyOptions, batch.dtype().clone())?;
90 acc.accumulate(&batch, ctx)?;
91 let result = acc.finish()?;
92
93 Ok(result.as_bool().value().unwrap_or(false))
94}
95
96static NAMES: LazyLock<FieldNames> = LazyLock::new(|| FieldNames::from(["lhs", "rhs"]));
97
98#[derive(Clone, Debug)]
107pub struct AllNonDistinct;
108
109pub struct AllNonDistinctPartial {
111 all_non_distinct: bool,
112}
113
114impl AggregateFnVTable for AllNonDistinct {
115 type Options = EmptyOptions;
116 type Partial = AllNonDistinctPartial;
117
118 fn id(&self) -> AggregateFnId {
119 AggregateFnId::new("vortex.all_non_distinct")
120 }
121
122 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
123 unimplemented!("AllNonDistinct is not yet serializable");
124 }
125
126 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
127 match input_dtype {
128 DType::Struct(fields, _) if fields.nfields() == 2 => {
129 let lhs = fields.fields().next()?;
130 let rhs = fields.fields().nth(1)?;
131 (lhs == rhs).then(|| DType::Bool(Nullability::NonNullable))
132 }
133 _ => None,
134 }
135 }
136
137 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
138 self.return_dtype(options, input_dtype)
139 }
140
141 fn empty_partial(
142 &self,
143 _options: &Self::Options,
144 _input_dtype: &DType,
145 ) -> VortexResult<Self::Partial> {
146 Ok(AllNonDistinctPartial {
147 all_non_distinct: true,
148 })
149 }
150
151 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
152 if !partial.all_non_distinct {
153 return Ok(());
154 }
155
156 if !other.as_bool().value().unwrap_or(false) {
157 partial.all_non_distinct = false;
158 }
159 Ok(())
160 }
161
162 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
163 Ok(Scalar::bool(
164 partial.all_non_distinct,
165 Nullability::NonNullable,
166 ))
167 }
168
169 fn reset(&self, partial: &mut Self::Partial) {
170 partial.all_non_distinct = true;
171 }
172
173 #[inline]
174 fn is_saturated(&self, partial: &Self::Partial) -> bool {
175 !partial.all_non_distinct
176 }
177
178 fn accumulate(
179 &self,
180 partial: &mut Self::Partial,
181 batch: &Columnar,
182 ctx: &mut ExecutionCtx,
183 ) -> VortexResult<()> {
184 if !partial.all_non_distinct {
185 return Ok(());
186 }
187
188 match batch {
189 Columnar::Constant(c) => {
190 let _ = c;
191 Ok(())
192 }
193 Columnar::Canonical(c) => {
194 let Canonical::Struct(s) = c else {
195 vortex_bail!(
196 "AllNonDistinct expects a Struct canonical, got {:?}",
197 c.dtype()
198 );
199 };
200
201 let struct_mask = s.validity()?.execute_mask(s.len(), ctx)?;
204 if struct_mask.true_count() == 0 {
205 return Ok(());
206 }
207
208 let lhs = s.unmasked_field(0);
209 let rhs = s.unmasked_field(1);
210
211 let (lhs, rhs) = if struct_mask.true_count() == s.len() {
213 (lhs.clone(), rhs.clone())
214 } else {
215 (lhs.filter(struct_mask.clone())?, rhs.filter(struct_mask)?)
216 };
217
218 let lhs_canonical = lhs.execute::<Canonical>(ctx)?;
219 let rhs_canonical = rhs.execute::<Canonical>(ctx)?;
220
221 partial.all_non_distinct =
222 check_canonical_identical(&lhs_canonical, &rhs_canonical, ctx)?;
223
224 Ok(())
225 }
226 }
227 }
228
229 fn finalize(&self, _partials: ArrayRef) -> VortexResult<ArrayRef> {
230 vortex_bail!("AllNonDistinct does not support array finalization");
231 }
232
233 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
234 Ok(Scalar::bool(
235 partial.all_non_distinct,
236 Nullability::NonNullable,
237 ))
238 }
239}
240
241fn check_canonical_identical(
242 lhs: &Canonical,
243 rhs: &Canonical,
244 ctx: &mut ExecutionCtx,
245) -> VortexResult<bool> {
246 match (lhs, rhs) {
247 (Canonical::Null(_), Canonical::Null(_)) => Ok(true),
248 (Canonical::Bool(lhs), Canonical::Bool(rhs)) => check_bool_identical(lhs, rhs),
249 (Canonical::Primitive(lhs), Canonical::Primitive(rhs)) => {
250 check_primitive_identical(lhs, rhs)
251 }
252 (Canonical::Decimal(lhs), Canonical::Decimal(rhs)) => check_decimal_identical(lhs, rhs),
253 (Canonical::VarBinView(lhs), Canonical::VarBinView(rhs)) => {
254 check_varbinview_identical(lhs, rhs)
255 }
256 (Canonical::Struct(lhs), Canonical::Struct(rhs)) => check_struct_identical(lhs, rhs, ctx),
257 (Canonical::List(lhs), Canonical::List(rhs)) => check_list_identical(lhs, rhs, ctx),
258 (Canonical::FixedSizeList(lhs), Canonical::FixedSizeList(rhs)) => {
259 check_fixed_size_list_identical(lhs, rhs, ctx)
260 }
261 (Canonical::Extension(lhs), Canonical::Extension(rhs)) => {
262 check_extension_identical(lhs, rhs, ctx)
263 }
264 (Canonical::Variant(_), _) | (_, Canonical::Variant(_)) => {
265 vortex_bail!("Variant arrays don't support AllNonDistinct")
266 }
267 _ => Err(vortex_err!(
268 "Canonical type mismatch in AllNonDistinct: {:?} vs {:?}",
269 lhs.dtype(),
270 rhs.dtype()
271 )),
272 }
273}