Skip to main content

vortex_array/aggregate_fn/fns/is_constant/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod decimal;
6mod extension;
7mod fixed_size_list;
8mod list;
9pub mod primitive;
10mod struct_;
11mod varbin;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_mask::Mask;
17
18use self::bool::check_bool_constant;
19use self::decimal::check_decimal_constant;
20use self::extension::check_extension_constant;
21use self::fixed_size_list::check_fixed_size_list_constant;
22use self::list::check_listview_constant;
23use self::primitive::check_primitive_constant;
24use self::struct_::check_struct_constant;
25use self::varbin::check_varbinview_constant;
26use crate::ArrayRef;
27use crate::Canonical;
28use crate::Columnar;
29use crate::DynArray;
30use crate::ExecutionCtx;
31use crate::IntoArray;
32use crate::aggregate_fn::Accumulator;
33use crate::aggregate_fn::AggregateFnId;
34use crate::aggregate_fn::AggregateFnVTable;
35use crate::aggregate_fn::DynAccumulator;
36use crate::aggregate_fn::EmptyOptions;
37use crate::arrays::Constant;
38use crate::arrays::Null;
39use crate::builtins::ArrayBuiltins;
40use crate::dtype::DType;
41use crate::dtype::FieldNames;
42use crate::dtype::Nullability;
43use crate::dtype::StructFields;
44use crate::expr::stats::Precision;
45use crate::expr::stats::Stat;
46use crate::expr::stats::StatsProvider;
47use crate::expr::stats::StatsProviderExt;
48use crate::scalar::Scalar;
49use crate::scalar_fn::fns::operators::Operator;
50
51/// Check if two arrays of the same length have equal values at every position (null-safe).
52///
53/// Two positions are considered equal if they are both null, or both non-null with the same value.
54///
55// TODO(ngates): move this function out when we have any/all aggregate functions.
56fn arrays_value_equal(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
57    debug_assert_eq!(a.len(), b.len());
58    if a.is_empty() {
59        return Ok(true);
60    }
61
62    // Check validity masks match (null positions must be identical).
63    let a_mask = a.validity_mask()?;
64    let b_mask = b.validity_mask()?;
65    if a_mask != b_mask {
66        return Ok(false);
67    }
68
69    let valid_count = a_mask.true_count();
70    if valid_count == 0 {
71        // Both all-null → equal.
72        return Ok(true);
73    }
74
75    // Compare values element-wise. Result is null where both inputs are null,
76    // true/false where both are valid.
77    let eq_result = a.binary(b.clone(), Operator::Eq)?;
78    let eq_result = eq_result.execute::<Mask>(ctx)?;
79
80    Ok(eq_result.true_count() == valid_count)
81}
82
83/// Compute whether an array has constant values.
84///
85/// An array is constant IFF at least one of the following conditions apply:
86/// 1. It has at least one element (**Note** - an empty array isn't constant).
87/// 2. It's encoded as a [`ConstantArray`](crate::arrays::ConstantArray) or [`NullArray`](crate::arrays::NullArray)
88/// 3. Has an exact statistic attached to it, saying its constant.
89/// 4. Is all invalid.
90/// 5. Is all valid AND has minimum and maximum statistics that are equal.
91pub fn is_constant(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
92    // Short-circuit using cached array statistics.
93    if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
94        return Ok(value);
95    }
96
97    // Empty arrays are not constant.
98    if array.is_empty() {
99        return Ok(false);
100    }
101
102    // Array of length 1 is always constant.
103    if array.len() == 1 {
104        array
105            .statistics()
106            .set(Stat::IsConstant, Precision::Exact(true.into()));
107        return Ok(true);
108    }
109
110    // Constant and null arrays are always constant.
111    if array.is::<Constant>() || array.is::<Null>() {
112        array
113            .statistics()
114            .set(Stat::IsConstant, Precision::Exact(true.into()));
115        return Ok(true);
116    }
117
118    let all_invalid = array.all_invalid()?;
119    if all_invalid {
120        array
121            .statistics()
122            .set(Stat::IsConstant, Precision::Exact(true.into()));
123        return Ok(true);
124    }
125
126    let all_valid = array.all_valid()?;
127
128    // If we have some nulls but not all nulls, array can't be constant.
129    if !all_valid && !all_invalid {
130        array
131            .statistics()
132            .set(Stat::IsConstant, Precision::Exact(false.into()));
133        return Ok(false);
134    }
135
136    // We already know here that the array is all valid, so we check for min/max stats.
137    let min = array.statistics().get(Stat::Min);
138    let max = array.statistics().get(Stat::Max);
139
140    if let Some((min, max)) = min.zip(max)
141        && min.is_exact()
142        && min == max
143        && (Stat::NaNCount.dtype(array.dtype()).is_none()
144            || array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
145    {
146        array
147            .statistics()
148            .set(Stat::IsConstant, Precision::Exact(true.into()));
149        return Ok(true);
150    }
151
152    // Short-circuit for unsupported dtypes.
153    if IsConstant
154        .return_dtype(&EmptyOptions, array.dtype())
155        .is_none()
156    {
157        // Null dtype - vacuously false for empty
158        return Ok(false);
159    }
160
161    // Compute using Accumulator<IsConstant>.
162    let mut acc = Accumulator::try_new(IsConstant, EmptyOptions, array.dtype().clone())?;
163    acc.accumulate(array, ctx)?;
164    let result_scalar = acc.finish()?;
165
166    let result = result_scalar.as_bool().value().unwrap_or(false);
167
168    // Cache the computed is_constant as a statistic.
169    array
170        .statistics()
171        .set(Stat::IsConstant, Precision::Exact(result.into()));
172
173    Ok(result)
174}
175
176/// Compute whether an array is constant.
177///
178/// Returns a `Bool(NonNullable)` scalar.
179/// The partial state is a nullable struct `{is_constant: Bool(NN), value: input_dtype?}`.
180/// A null struct means the accumulator has seen no data yet (empty).
181#[derive(Clone, Debug)]
182pub struct IsConstant;
183
184impl IsConstant {
185    /// Build a partial scalar from a kernel's `is_constant` result.
186    ///
187    /// Kernels that compute `is_constant` by delegating to child arrays can call this
188    /// to package the boolean result into the partial struct format expected by the
189    /// accumulator, avoiding duplicated boilerplate.
190    pub fn make_partial(batch: &ArrayRef, is_constant: bool) -> VortexResult<Scalar> {
191        let partial_dtype = make_is_constant_partial_dtype(batch.dtype());
192        if is_constant {
193            if batch.is_empty() {
194                return Ok(Scalar::null(partial_dtype));
195            }
196            let first_value = batch.scalar_at(0)?.into_nullable();
197            Ok(Scalar::struct_(
198                partial_dtype,
199                vec![Scalar::bool(true, Nullability::NonNullable), first_value],
200            ))
201        } else {
202            Ok(Scalar::struct_(
203                partial_dtype,
204                vec![
205                    Scalar::bool(false, Nullability::NonNullable),
206                    Scalar::null(batch.dtype().as_nullable()),
207                ],
208            ))
209        }
210    }
211}
212
213/// Partial accumulator state for is_constant.
214pub struct IsConstantPartial {
215    is_constant: bool,
216    /// None = empty (no values seen), Some(null) = all nulls, Some(v) = first value seen.
217    first_value: Option<Scalar>,
218    element_dtype: DType,
219}
220
221impl IsConstantPartial {
222    fn check_value(&mut self, value: Scalar) {
223        if !self.is_constant {
224            return;
225        }
226        match &self.first_value {
227            None => {
228                self.first_value = Some(value);
229            }
230            Some(first) => {
231                if *first != value {
232                    self.is_constant = false;
233                }
234            }
235        }
236    }
237}
238
239static NAMES: std::sync::LazyLock<FieldNames> =
240    std::sync::LazyLock::new(|| FieldNames::from(["is_constant", "value"]));
241
242pub fn make_is_constant_partial_dtype(element_dtype: &DType) -> DType {
243    DType::Struct(
244        StructFields::new(
245            NAMES.clone(),
246            vec![
247                DType::Bool(Nullability::NonNullable),
248                element_dtype.as_nullable(),
249            ],
250        ),
251        Nullability::Nullable,
252    )
253}
254
255impl AggregateFnVTable for IsConstant {
256    type Options = EmptyOptions;
257    type Partial = IsConstantPartial;
258
259    fn id(&self) -> AggregateFnId {
260        AggregateFnId::new_ref("vortex.is_constant")
261    }
262
263    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
264        Ok(Some(vec![]))
265    }
266
267    fn deserialize(
268        &self,
269        _metadata: &[u8],
270        _session: &vortex_session::VortexSession,
271    ) -> VortexResult<Self::Options> {
272        Ok(EmptyOptions)
273    }
274
275    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
276        match input_dtype {
277            DType::Null | DType::Variant(..) => None,
278            _ => Some(DType::Bool(Nullability::NonNullable)),
279        }
280    }
281
282    fn partial_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
283        match input_dtype {
284            DType::Null | DType::Variant(..) => None,
285            _ => Some(make_is_constant_partial_dtype(input_dtype)),
286        }
287    }
288
289    fn empty_partial(
290        &self,
291        _options: &Self::Options,
292        input_dtype: &DType,
293    ) -> VortexResult<Self::Partial> {
294        Ok(IsConstantPartial {
295            is_constant: true,
296            first_value: None,
297            element_dtype: input_dtype.clone(),
298        })
299    }
300
301    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
302        if !partial.is_constant {
303            return Ok(());
304        }
305
306        // Null struct means the other accumulator was empty, skip it.
307        if other.is_null() {
308            return Ok(());
309        }
310
311        let other_is_constant = other
312            .as_struct()
313            .field_by_idx(0)
314            .map(|s| s.as_bool().value().unwrap_or(false))
315            .unwrap_or(false);
316
317        if !other_is_constant {
318            partial.is_constant = false;
319            return Ok(());
320        }
321
322        let other_value = other.as_struct().field_by_idx(1);
323
324        if let Some(other_val) = other_value {
325            partial.check_value(other_val);
326        }
327
328        Ok(())
329    }
330
331    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
332        let dtype = make_is_constant_partial_dtype(&partial.element_dtype);
333        Ok(match &partial.first_value {
334            None => {
335                // Empty accumulator — return null struct.
336                Scalar::null(dtype)
337            }
338            Some(first_value) => Scalar::struct_(
339                dtype,
340                vec![
341                    Scalar::bool(partial.is_constant, Nullability::NonNullable),
342                    first_value
343                        .clone()
344                        .cast(&partial.element_dtype.as_nullable())?,
345                ],
346            ),
347        })
348    }
349
350    fn reset(&self, partial: &mut Self::Partial) {
351        partial.is_constant = true;
352        partial.first_value = None;
353    }
354
355    #[inline]
356    fn is_saturated(&self, partial: &Self::Partial) -> bool {
357        !partial.is_constant
358    }
359
360    fn accumulate(
361        &self,
362        partial: &mut Self::Partial,
363        batch: &Columnar,
364        ctx: &mut ExecutionCtx,
365    ) -> VortexResult<()> {
366        if !partial.is_constant {
367            return Ok(());
368        }
369
370        match batch {
371            Columnar::Constant(c) => {
372                partial.check_value(c.scalar().clone().into_nullable());
373                Ok(())
374            }
375            Columnar::Canonical(c) => {
376                if c.is_empty() {
377                    return Ok(());
378                }
379
380                // Convert to ArrayRef for DynArray methods.
381                let array_ref = c.clone().into_array();
382
383                let all_invalid = array_ref.all_invalid()?;
384                if all_invalid {
385                    partial.check_value(Scalar::null(partial.element_dtype.as_nullable()));
386                    return Ok(());
387                }
388
389                let all_valid = array_ref.all_valid()?;
390                // Mixed nulls → not constant.
391                if !all_valid && !all_invalid {
392                    partial.is_constant = false;
393                    return Ok(());
394                }
395
396                // All valid from here. Check batch-level constancy.
397                if c.len() == 1 {
398                    partial.check_value(array_ref.scalar_at(0)?.into_nullable());
399                    return Ok(());
400                }
401
402                let batch_is_constant = match c {
403                    Canonical::Primitive(p) => check_primitive_constant(p),
404                    Canonical::Bool(b) => check_bool_constant(b),
405                    Canonical::VarBinView(v) => check_varbinview_constant(v),
406                    Canonical::Decimal(d) => check_decimal_constant(d),
407                    Canonical::Struct(s) => check_struct_constant(s, ctx)?,
408                    Canonical::Extension(e) => check_extension_constant(e, ctx)?,
409                    Canonical::List(l) => check_listview_constant(l, ctx)?,
410                    Canonical::FixedSizeList(f) => check_fixed_size_list_constant(f, ctx)?,
411                    Canonical::Null(_) => true,
412                    Canonical::Variant(_) => {
413                        vortex_bail!("Variant arrays don't support IsConstant")
414                    }
415                };
416
417                if !batch_is_constant {
418                    partial.is_constant = false;
419                    return Ok(());
420                }
421
422                partial.check_value(array_ref.scalar_at(0)?.into_nullable());
423                Ok(())
424            }
425        }
426    }
427
428    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
429        partials.get_item(NAMES.get(0).vortex_expect("out of bounds").clone())
430    }
431
432    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
433        if partial.first_value.is_none() {
434            // Empty accumulator → return false.
435            return Ok(Scalar::bool(false, Nullability::NonNullable));
436        }
437        Ok(Scalar::bool(partial.is_constant, Nullability::NonNullable))
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use rstest::rstest;
444    use vortex_buffer::Buffer;
445    use vortex_buffer::buffer;
446    use vortex_error::VortexResult;
447
448    use crate::IntoArray as _;
449    use crate::LEGACY_SESSION;
450    use crate::VortexSessionExecute;
451    use crate::aggregate_fn::fns::is_constant::is_constant;
452    use crate::arrays::BoolArray;
453    use crate::arrays::ChunkedArray;
454    use crate::arrays::DecimalArray;
455    use crate::arrays::ListArray;
456    use crate::arrays::PrimitiveArray;
457    use crate::arrays::StructArray;
458    use crate::dtype::DType;
459    use crate::dtype::DecimalDType;
460    use crate::dtype::FieldNames;
461    use crate::dtype::Nullability;
462    use crate::dtype::PType;
463    use crate::expr::stats::Stat;
464    use crate::validity::Validity;
465
466    // Tests migrated from compute/is_constant.rs
467    #[test]
468    fn is_constant_min_max_no_nan() -> VortexResult<()> {
469        let mut ctx = LEGACY_SESSION.create_execution_ctx();
470
471        let arr = buffer![0, 1].into_array();
472        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
473        assert!(!is_constant(&arr, &mut ctx)?);
474
475        let arr = buffer![0, 0].into_array();
476        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
477        assert!(is_constant(&arr, &mut ctx)?);
478
479        let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]).into_array();
480        assert!(is_constant(&arr, &mut ctx)?);
481        Ok(())
482    }
483
484    #[test]
485    fn is_constant_min_max_with_nan() -> VortexResult<()> {
486        let mut ctx = LEGACY_SESSION.create_execution_ctx();
487
488        let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]).into_array();
489        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
490        assert!(!is_constant(&arr, &mut ctx)?);
491
492        let arr =
493            PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)])
494                .into_array();
495        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
496        assert!(is_constant(&arr, &mut ctx)?);
497        Ok(())
498    }
499
500    // Tests migrated from arrays/bool/compute/is_constant.rs
501    #[rstest]
502    #[case(vec![true], true)]
503    #[case(vec![false; 65], true)]
504    #[case({
505        let mut v = vec![true; 64];
506        v.push(false);
507        v
508    }, false)]
509    fn test_bool_is_constant(#[case] input: Vec<bool>, #[case] expected: bool) -> VortexResult<()> {
510        let array = BoolArray::from_iter(input);
511        let mut ctx = LEGACY_SESSION.create_execution_ctx();
512        assert_eq!(is_constant(&array.into_array(), &mut ctx)?, expected);
513        Ok(())
514    }
515
516    // Tests migrated from arrays/chunked/compute/is_constant.rs
517    #[test]
518    fn empty_chunk_is_constant() -> VortexResult<()> {
519        let chunked = ChunkedArray::try_new(
520            vec![
521                Buffer::<u8>::empty().into_array(),
522                Buffer::<u8>::empty().into_array(),
523                buffer![255u8, 255].into_array(),
524                Buffer::<u8>::empty().into_array(),
525                buffer![255u8, 255].into_array(),
526            ],
527            DType::Primitive(PType::U8, Nullability::NonNullable),
528        )?
529        .into_array();
530
531        let mut ctx = LEGACY_SESSION.create_execution_ctx();
532        assert!(is_constant(&chunked, &mut ctx)?);
533        Ok(())
534    }
535
536    // Tests migrated from arrays/decimal/compute/is_constant.rs
537    #[test]
538    fn test_decimal_is_constant() -> VortexResult<()> {
539        let mut ctx = LEGACY_SESSION.create_execution_ctx();
540
541        let array = DecimalArray::new(
542            buffer![0i128, 1i128, 2i128],
543            DecimalDType::new(19, 0),
544            Validity::NonNullable,
545        );
546        assert!(!is_constant(&array.into_array(), &mut ctx)?);
547
548        let array = DecimalArray::new(
549            buffer![100i128, 100i128, 100i128],
550            DecimalDType::new(19, 0),
551            Validity::NonNullable,
552        );
553        assert!(is_constant(&array.into_array(), &mut ctx)?);
554        Ok(())
555    }
556
557    // Tests migrated from arrays/list/compute/is_constant.rs
558    #[test]
559    fn test_is_constant_nested_list() -> VortexResult<()> {
560        let mut ctx = LEGACY_SESSION.create_execution_ctx();
561
562        let xs = ListArray::try_new(
563            buffer![0i32, 1, 0, 1].into_array(),
564            buffer![0u32, 2, 4].into_array(),
565            Validity::NonNullable,
566        )?;
567
568        let struct_of_lists = StructArray::try_new(
569            FieldNames::from(["xs"]),
570            vec![xs.into_array()],
571            2,
572            Validity::NonNullable,
573        )?;
574        assert!(is_constant(
575            &struct_of_lists.clone().into_array(),
576            &mut ctx
577        )?);
578        assert!(is_constant(&struct_of_lists.into_array(), &mut ctx)?);
579        Ok(())
580    }
581
582    #[rstest]
583    #[case(
584        // [1,2], [1, 2], [1, 2]
585        vec![1i32, 2, 1, 2, 1, 2],
586        vec![0u32, 2, 4, 6],
587        true
588    )]
589    #[case(
590        // [1, 2], [3], [4, 5]
591        vec![1i32, 2, 3, 4, 5],
592        vec![0u32, 2, 3, 5],
593        false
594    )]
595    #[case(
596        // [1, 2], [3, 4]
597        vec![1i32, 2, 3, 4],
598        vec![0u32, 2, 4],
599        false
600    )]
601    #[case(
602        // [], [], []
603        vec![],
604        vec![0u32, 0, 0, 0],
605        true
606    )]
607    fn test_list_is_constant(
608        #[case] elements: Vec<i32>,
609        #[case] offsets: Vec<u32>,
610        #[case] expected: bool,
611    ) -> VortexResult<()> {
612        let list_array = ListArray::try_new(
613            PrimitiveArray::from_iter(elements).into_array(),
614            PrimitiveArray::from_iter(offsets).into_array(),
615            Validity::NonNullable,
616        )?;
617
618        let mut ctx = LEGACY_SESSION.create_execution_ctx();
619        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
620        Ok(())
621    }
622
623    #[test]
624    fn test_list_is_constant_nested_lists() -> VortexResult<()> {
625        let inner_elements = buffer![1i32, 2, 1, 2].into_array();
626        let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
627        let inner_lists = ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable)?;
628
629        let outer_offsets = buffer![0u32, 2, 4].into_array();
630        let outer_list = ListArray::try_new(
631            inner_lists.into_array(),
632            outer_offsets,
633            Validity::NonNullable,
634        )?;
635
636        let mut ctx = LEGACY_SESSION.create_execution_ctx();
637        // Both outer lists contain [[1], [2]], so should be constant
638        assert!(is_constant(&outer_list.into_array(), &mut ctx)?);
639        Ok(())
640    }
641
642    #[rstest]
643    #[case(
644        // 100 identical [1, 2] lists
645        [1i32, 2].repeat(100),
646        (0..101).map(|i| (i * 2) as u32).collect(),
647        true
648    )]
649    #[case(
650        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
651        {
652            let mut elements = [1i32, 2].repeat(64);
653            elements.extend_from_slice(&[3, 4]);
654            elements
655        },
656        (0..66).map(|i| (i * 2) as u32).collect(),
657        false
658    )]
659    #[case(
660        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
661        {
662            let mut elements = [1i32, 2].repeat(63);
663            elements.extend_from_slice(&[3, 4]);
664            elements.extend([1i32, 2].repeat(37));
665            elements
666        },
667        (0..101).map(|i| (i * 2) as u32).collect(),
668        false
669    )]
670    fn test_list_is_constant_with_threshold(
671        #[case] elements: Vec<i32>,
672        #[case] offsets: Vec<u32>,
673        #[case] expected: bool,
674    ) -> VortexResult<()> {
675        let list_array = ListArray::try_new(
676            PrimitiveArray::from_iter(elements).into_array(),
677            PrimitiveArray::from_iter(offsets).into_array(),
678            Validity::NonNullable,
679        )?;
680
681        let mut ctx = LEGACY_SESSION.create_execution_ctx();
682        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
683        Ok(())
684    }
685}