Skip to main content

vortex_array/aggregate_fn/fns/all_non_distinct/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod decimal;
6mod extension;
7mod filter;
8mod fixed_size_list;
9mod list;
10mod primitive;
11mod struct_;
12#[cfg(test)]
13mod tests;
14mod varbin;
15
16use std::sync::LazyLock;
17
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_err;
21
22use self::bool::check_bool_identical;
23use self::decimal::check_decimal_identical;
24use self::extension::check_extension_identical;
25use self::filter::shared_validity_mask;
26use self::fixed_size_list::check_fixed_size_list_identical;
27use self::list::check_list_identical;
28use self::primitive::check_primitive_identical;
29use self::struct_::check_struct_identical;
30use self::varbin::check_varbinview_identical;
31use crate::ArrayRef;
32use crate::Canonical;
33use crate::Columnar;
34use crate::ExecutionCtx;
35use crate::IntoArray;
36use crate::aggregate_fn::Accumulator;
37use crate::aggregate_fn::AggregateFnId;
38use crate::aggregate_fn::AggregateFnVTable;
39use crate::aggregate_fn::DynAccumulator;
40use crate::aggregate_fn::EmptyOptions;
41use crate::arrays::StructArray;
42use crate::arrays::struct_::StructArrayExt;
43use crate::dtype::DType;
44use crate::dtype::FieldNames;
45use crate::dtype::Nullability;
46use crate::scalar::Scalar;
47use crate::validity::Validity;
48
49/// Check if two arrays are element-wise non-distinct, treating null == null as true.
50///
51/// Returns `true` if and only if:
52/// - Both arrays have the same dtype and length
53/// - At every position, both are null or both are non-null with the same value
54/// - The arrays are empty, vacuously
55///
56/// This is a fused `bool_all(non_distinct(lhs, rhs))` aggregate that allows early
57/// termination via accumulator saturation as soon as a mismatch is found.
58pub fn all_non_distinct(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
59    if a.dtype() != b.dtype() {
60        vortex_bail!(
61            "all_non_distinct: dtype mismatch: {} vs {}",
62            a.dtype(),
63            b.dtype()
64        );
65    }
66
67    if a.len() != b.len() {
68        vortex_bail!(
69            "all_non_distinct: length mismatch: {} vs {}",
70            a.len(),
71            b.len()
72        );
73    }
74
75    if a.is_empty() {
76        return Ok(true);
77    }
78
79    let Some(shared_validity) = shared_validity_mask(a, b, ctx)? else {
80        return Ok(false);
81    };
82    if shared_validity.true_count() == 0 {
83        return Ok(true);
84    }
85
86    let validity = Validity::from_mask(shared_validity, a.dtype().nullability());
87    let batch = StructArray::try_new(NAMES.clone(), vec![a.clone(), b.clone()], a.len(), validity)?
88        .into_array();
89
90    let mut acc = Accumulator::try_new(AllNonDistinct, EmptyOptions, batch.dtype().clone())?;
91    acc.accumulate(&batch, ctx)?;
92    let result = acc.finish()?;
93
94    Ok(result.as_bool().value().unwrap_or(false))
95}
96
97static NAMES: LazyLock<FieldNames> = LazyLock::new(|| FieldNames::from(["lhs", "rhs"]));
98
99/// Fused `bool_all(non_distinct(lhs, rhs))` aggregate function.
100///
101/// This combines a pairwise non-distinct scalar comparison with a boolean-all reduction
102/// into a single aggregate, enabling early termination via accumulator saturation: as soon
103/// as the first distinct pair is found, the accumulator is saturated and remaining batches
104/// are skipped.
105///
106/// Like other `all` aggregates, this is vacuously true for empty input.
107///
108/// The input is a `Struct{lhs: T, rhs: T}` and the result is `Bool(NonNullable)`.
109#[derive(Clone, Debug)]
110pub struct AllNonDistinct;
111
112/// Partial accumulator state: just a bool tracking "all non-distinct so far".
113pub struct AllNonDistinctPartial {
114    all_non_distinct: bool,
115}
116
117impl AggregateFnVTable for AllNonDistinct {
118    type Options = EmptyOptions;
119    type Partial = AllNonDistinctPartial;
120
121    fn id(&self) -> AggregateFnId {
122        AggregateFnId::new("vortex.all_non_distinct")
123    }
124
125    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
126        unimplemented!("AllNonDistinct is not yet serializable");
127    }
128
129    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
130        match input_dtype {
131            DType::Struct(fields, _) if fields.nfields() == 2 => {
132                let lhs = fields.fields().next()?;
133                let rhs = fields.fields().nth(1)?;
134                (lhs == rhs).then(|| DType::Bool(Nullability::NonNullable))
135            }
136            _ => None,
137        }
138    }
139
140    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
141        self.return_dtype(options, input_dtype)
142    }
143
144    fn empty_partial(
145        &self,
146        _options: &Self::Options,
147        _input_dtype: &DType,
148    ) -> VortexResult<Self::Partial> {
149        Ok(AllNonDistinctPartial {
150            all_non_distinct: true,
151        })
152    }
153
154    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
155        if !partial.all_non_distinct {
156            return Ok(());
157        }
158
159        if !other.as_bool().value().unwrap_or(false) {
160            partial.all_non_distinct = false;
161        }
162        Ok(())
163    }
164
165    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
166        Ok(Scalar::bool(
167            partial.all_non_distinct,
168            Nullability::NonNullable,
169        ))
170    }
171
172    fn reset(&self, partial: &mut Self::Partial) {
173        partial.all_non_distinct = true;
174    }
175
176    #[inline]
177    fn is_saturated(&self, partial: &Self::Partial) -> bool {
178        !partial.all_non_distinct
179    }
180
181    fn accumulate(
182        &self,
183        partial: &mut Self::Partial,
184        batch: &Columnar,
185        ctx: &mut ExecutionCtx,
186    ) -> VortexResult<()> {
187        if !partial.all_non_distinct {
188            return Ok(());
189        }
190
191        match batch {
192            Columnar::Constant(c) => {
193                let _ = c;
194                Ok(())
195            }
196            Columnar::Canonical(c) => {
197                let Canonical::Struct(s) = c else {
198                    vortex_bail!(
199                        "AllNonDistinct expects a Struct canonical, got {:?}",
200                        c.dtype()
201                    );
202                };
203
204                // The struct-level validity represents the shared validity mask
205                // (positions where both lhs and rhs are non-null).
206                let struct_mask = s.validity()?.execute_mask(s.len(), ctx)?;
207                if struct_mask.true_count() == 0 {
208                    return Ok(());
209                }
210
211                let lhs = s.unmasked_field(0);
212                let rhs = s.unmasked_field(1);
213
214                // Filter to only valid rows if the struct has nulls.
215                let (lhs, rhs) = if struct_mask.true_count() == s.len() {
216                    (lhs.clone(), rhs.clone())
217                } else {
218                    (lhs.filter(struct_mask.clone())?, rhs.filter(struct_mask)?)
219                };
220
221                let lhs_canonical = lhs.execute::<Canonical>(ctx)?;
222                let rhs_canonical = rhs.execute::<Canonical>(ctx)?;
223
224                partial.all_non_distinct =
225                    check_canonical_identical(&lhs_canonical, &rhs_canonical, ctx)?;
226
227                Ok(())
228            }
229        }
230    }
231
232    fn finalize(&self, _partials: ArrayRef) -> VortexResult<ArrayRef> {
233        vortex_bail!("AllNonDistinct does not support array finalization");
234    }
235
236    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
237        Ok(Scalar::bool(
238            partial.all_non_distinct,
239            Nullability::NonNullable,
240        ))
241    }
242}
243
244fn check_canonical_identical(
245    lhs: &Canonical,
246    rhs: &Canonical,
247    ctx: &mut ExecutionCtx,
248) -> VortexResult<bool> {
249    match (lhs, rhs) {
250        (Canonical::Null(_), Canonical::Null(_)) => Ok(true),
251        (Canonical::Bool(lhs), Canonical::Bool(rhs)) => check_bool_identical(lhs, rhs),
252        (Canonical::Primitive(lhs), Canonical::Primitive(rhs)) => {
253            check_primitive_identical(lhs, rhs)
254        }
255        (Canonical::Decimal(lhs), Canonical::Decimal(rhs)) => check_decimal_identical(lhs, rhs),
256        (Canonical::VarBinView(lhs), Canonical::VarBinView(rhs)) => {
257            check_varbinview_identical(lhs, rhs)
258        }
259        (Canonical::Struct(lhs), Canonical::Struct(rhs)) => check_struct_identical(lhs, rhs, ctx),
260        (Canonical::List(lhs), Canonical::List(rhs)) => check_list_identical(lhs, rhs, ctx),
261        (Canonical::FixedSizeList(lhs), Canonical::FixedSizeList(rhs)) => {
262            check_fixed_size_list_identical(lhs, rhs, ctx)
263        }
264        (Canonical::Extension(lhs), Canonical::Extension(rhs)) => {
265            check_extension_identical(lhs, rhs, ctx)
266        }
267        (Canonical::Variant(_), _) | (_, Canonical::Variant(_)) => {
268            vortex_bail!("Variant arrays don't support AllNonDistinct")
269        }
270        _ => Err(vortex_err!(
271            "Canonical type mismatch in AllNonDistinct: {:?} vs {:?}",
272            lhs.dtype(),
273            rhs.dtype()
274        )),
275    }
276}