Skip to main content

vortex_array/aggregate_fn/fns/all_non_distinct/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod decimal;
6mod extension;
7mod filter;
8mod fixed_size_list;
9mod list;
10mod primitive;
11mod struct_;
12#[cfg(test)]
13mod tests;
14mod varbin;
15mod variant;
16
17use std::sync::LazyLock;
18
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_err;
22use vortex_session::registry::CachedId;
23
24use self::bool::check_bool_identical;
25use self::decimal::check_decimal_identical;
26use self::extension::check_extension_identical;
27use self::filter::shared_validity_mask;
28use self::fixed_size_list::check_fixed_size_list_identical;
29use self::list::check_list_identical;
30use self::primitive::check_primitive_identical;
31use self::struct_::check_struct_identical;
32use self::varbin::check_varbinview_identical;
33use crate::ArrayRef;
34use crate::Canonical;
35use crate::Columnar;
36use crate::ExecutionCtx;
37use crate::IntoArray;
38use crate::aggregate_fn::Accumulator;
39use crate::aggregate_fn::AggregateFnId;
40use crate::aggregate_fn::AggregateFnVTable;
41use crate::aggregate_fn::DynAccumulator;
42use crate::aggregate_fn::EmptyOptions;
43use crate::aggregate_fn::fns::all_non_distinct::variant::check_variant_identical;
44use crate::arrays::StructArray;
45use crate::arrays::struct_::StructArrayExt;
46use crate::dtype::DType;
47use crate::dtype::FieldNames;
48use crate::dtype::Nullability;
49use crate::scalar::Scalar;
50use crate::validity::Validity;
51
52/// Check if two arrays are element-wise non-distinct, treating null == null as true.
53///
54/// Returns `true` if and only if:
55/// - Both arrays have the same dtype and length
56/// - At every position, both are null or both are non-null with the same value
57/// - The arrays are empty, vacuously
58///
59/// This is a fused `bool_all(non_distinct(lhs, rhs))` aggregate that allows early
60/// termination via accumulator saturation as soon as a mismatch is found.
61pub fn all_non_distinct(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
62    if a.dtype() != b.dtype() {
63        vortex_bail!(
64            "all_non_distinct: dtype mismatch: {} vs {}",
65            a.dtype(),
66            b.dtype()
67        );
68    }
69
70    if a.len() != b.len() {
71        vortex_bail!(
72            "all_non_distinct: length mismatch: {} vs {}",
73            a.len(),
74            b.len()
75        );
76    }
77
78    if a.is_empty() {
79        return Ok(true);
80    }
81
82    let Some(shared_validity) = shared_validity_mask(a, b, ctx)? else {
83        return Ok(false);
84    };
85    if shared_validity.true_count() == 0 {
86        return Ok(true);
87    }
88
89    let validity = Validity::from_mask(shared_validity, a.dtype().nullability());
90    let batch = StructArray::try_new(NAMES.clone(), vec![a.clone(), b.clone()], a.len(), validity)?
91        .into_array();
92
93    let mut acc = Accumulator::try_new(AllNonDistinct, EmptyOptions, batch.dtype().clone())?;
94    acc.accumulate(&batch, ctx)?;
95    let result = acc.finish()?;
96
97    Ok(result.as_bool().value().unwrap_or(false))
98}
99
100static NAMES: LazyLock<FieldNames> = LazyLock::new(|| FieldNames::from(["lhs", "rhs"]));
101
102/// Fused `bool_all(non_distinct(lhs, rhs))` aggregate function.
103///
104/// This combines a pairwise non-distinct scalar comparison with a boolean-all reduction
105/// into a single aggregate, enabling early termination via accumulator saturation: as soon
106/// as the first distinct pair is found, the accumulator is saturated and remaining batches
107/// are skipped.
108///
109/// Like other `all` aggregates, this is vacuously true for empty input.
110///
111/// The input is a `Struct{lhs: T, rhs: T}` and the result is `Bool(NonNullable)`.
112#[derive(Clone, Debug)]
113pub struct AllNonDistinct;
114
115/// Partial accumulator state: just a bool tracking "all non-distinct so far".
116pub struct AllNonDistinctPartial {
117    all_non_distinct: bool,
118}
119
120impl AggregateFnVTable for AllNonDistinct {
121    type Options = EmptyOptions;
122    type Partial = AllNonDistinctPartial;
123
124    fn id(&self) -> AggregateFnId {
125        static ID: CachedId = CachedId::new("vortex.all_non_distinct");
126        *ID
127    }
128
129    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
130        unimplemented!("AllNonDistinct is not yet serializable");
131    }
132
133    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
134        match input_dtype {
135            DType::Struct(fields, _) if fields.nfields() == 2 => {
136                let lhs = fields.fields().next()?;
137                let rhs = fields.fields().nth(1)?;
138                (lhs == rhs).then(|| DType::Bool(Nullability::NonNullable))
139            }
140            _ => None,
141        }
142    }
143
144    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
145        self.return_dtype(options, input_dtype)
146    }
147
148    fn empty_partial(
149        &self,
150        _options: &Self::Options,
151        _input_dtype: &DType,
152    ) -> VortexResult<Self::Partial> {
153        Ok(AllNonDistinctPartial {
154            all_non_distinct: true,
155        })
156    }
157
158    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
159        if !partial.all_non_distinct {
160            return Ok(());
161        }
162
163        if !other.as_bool().value().unwrap_or(false) {
164            partial.all_non_distinct = false;
165        }
166        Ok(())
167    }
168
169    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
170        Ok(Scalar::bool(
171            partial.all_non_distinct,
172            Nullability::NonNullable,
173        ))
174    }
175
176    fn reset(&self, partial: &mut Self::Partial) {
177        partial.all_non_distinct = true;
178    }
179
180    #[inline]
181    fn is_saturated(&self, partial: &Self::Partial) -> bool {
182        !partial.all_non_distinct
183    }
184
185    fn accumulate(
186        &self,
187        partial: &mut Self::Partial,
188        batch: &Columnar,
189        ctx: &mut ExecutionCtx,
190    ) -> VortexResult<()> {
191        if !partial.all_non_distinct {
192            return Ok(());
193        }
194
195        match batch {
196            Columnar::Constant(c) => {
197                let _ = c;
198                Ok(())
199            }
200            Columnar::Canonical(c) => {
201                let Canonical::Struct(s) = c else {
202                    vortex_bail!(
203                        "AllNonDistinct expects a Struct canonical, got {:?}",
204                        c.dtype()
205                    );
206                };
207
208                // The struct-level validity represents the shared validity mask
209                // (positions where both lhs and rhs are non-null).
210                let struct_mask = s.validity()?.execute_mask(s.len(), ctx)?;
211                if struct_mask.true_count() == 0 {
212                    return Ok(());
213                }
214
215                let lhs = s.unmasked_field(0);
216                let rhs = s.unmasked_field(1);
217
218                // Filter to only valid rows if the struct has nulls.
219                let (lhs, rhs) = if struct_mask.true_count() == s.len() {
220                    (lhs.clone(), rhs.clone())
221                } else {
222                    (lhs.filter(struct_mask.clone())?, rhs.filter(struct_mask)?)
223                };
224
225                let lhs_canonical = lhs.execute::<Canonical>(ctx)?;
226                let rhs_canonical = rhs.execute::<Canonical>(ctx)?;
227
228                partial.all_non_distinct =
229                    check_canonical_identical(&lhs_canonical, &rhs_canonical, ctx)?;
230
231                Ok(())
232            }
233        }
234    }
235
236    fn finalize(&self, _partials: ArrayRef) -> VortexResult<ArrayRef> {
237        vortex_bail!("AllNonDistinct does not support array finalization");
238    }
239
240    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
241        Ok(Scalar::bool(
242            partial.all_non_distinct,
243            Nullability::NonNullable,
244        ))
245    }
246}
247
248fn check_canonical_identical(
249    lhs: &Canonical,
250    rhs: &Canonical,
251    ctx: &mut ExecutionCtx,
252) -> VortexResult<bool> {
253    match (lhs, rhs) {
254        (Canonical::Null(_), Canonical::Null(_)) => Ok(true),
255        (Canonical::Bool(lhs), Canonical::Bool(rhs)) => check_bool_identical(lhs, rhs),
256        (Canonical::Primitive(lhs), Canonical::Primitive(rhs)) => {
257            check_primitive_identical(lhs, rhs)
258        }
259        (Canonical::Decimal(lhs), Canonical::Decimal(rhs)) => check_decimal_identical(lhs, rhs),
260        (Canonical::VarBinView(lhs), Canonical::VarBinView(rhs)) => {
261            check_varbinview_identical(lhs, rhs)
262        }
263        (Canonical::Struct(lhs), Canonical::Struct(rhs)) => check_struct_identical(lhs, rhs, ctx),
264        (Canonical::List(lhs), Canonical::List(rhs)) => check_list_identical(lhs, rhs, ctx),
265        (Canonical::FixedSizeList(lhs), Canonical::FixedSizeList(rhs)) => {
266            check_fixed_size_list_identical(lhs, rhs, ctx)
267        }
268        (Canonical::Extension(lhs), Canonical::Extension(rhs)) => {
269            check_extension_identical(lhs, rhs, ctx)
270        }
271        (Canonical::Variant(lhs), Canonical::Variant(rhs)) => {
272            check_variant_identical(lhs, rhs, ctx)
273        }
274        _ => Err(vortex_err!(
275            "Canonical type mismatch in AllNonDistinct: {:?} vs {:?}",
276            lhs.dtype(),
277            rhs.dtype()
278        )),
279    }
280}