Skip to main content

vortex_array/aggregate_fn/fns/is_constant/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod decimal;
6mod extension;
7mod fixed_size_list;
8mod list;
9pub mod primitive;
10mod struct_;
11mod varbin;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_mask::Mask;
17
18use self::bool::check_bool_constant;
19use self::decimal::check_decimal_constant;
20use self::extension::check_extension_constant;
21use self::fixed_size_list::check_fixed_size_list_constant;
22use self::list::check_listview_constant;
23use self::primitive::check_primitive_constant;
24use self::struct_::check_struct_constant;
25use self::varbin::check_varbinview_constant;
26use crate::ArrayRef;
27use crate::Canonical;
28use crate::Columnar;
29use crate::ExecutionCtx;
30use crate::IntoArray;
31use crate::aggregate_fn::Accumulator;
32use crate::aggregate_fn::AggregateFnId;
33use crate::aggregate_fn::AggregateFnVTable;
34use crate::aggregate_fn::DynAccumulator;
35use crate::aggregate_fn::EmptyOptions;
36use crate::arrays::Constant;
37use crate::arrays::Null;
38use crate::builtins::ArrayBuiltins;
39use crate::dtype::DType;
40use crate::dtype::FieldNames;
41use crate::dtype::Nullability;
42use crate::dtype::StructFields;
43use crate::expr::stats::Precision;
44use crate::expr::stats::Stat;
45use crate::expr::stats::StatsProvider;
46use crate::expr::stats::StatsProviderExt;
47use crate::scalar::Scalar;
48use crate::scalar_fn::fns::operators::Operator;
49
50/// Check if two arrays of the same length have equal values at every position (null-safe).
51///
52/// Two positions are considered equal if they are both null, or both non-null with the same value.
53///
54// TODO(ngates): move this function out when we have any/all aggregate functions.
55fn arrays_value_equal(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
56    debug_assert_eq!(a.len(), b.len());
57    if a.is_empty() {
58        return Ok(true);
59    }
60
61    // Check validity masks match (null positions must be identical).
62    let a_mask = a.validity_mask()?;
63    let b_mask = b.validity_mask()?;
64    if a_mask != b_mask {
65        return Ok(false);
66    }
67
68    let valid_count = a_mask.true_count();
69    if valid_count == 0 {
70        // Both all-null → equal.
71        return Ok(true);
72    }
73
74    // Compare values element-wise. Result is null where both inputs are null,
75    // true/false where both are valid.
76    let eq_result = a.binary(b.clone(), Operator::Eq)?;
77    let eq_result = eq_result.execute::<Mask>(ctx)?;
78
79    Ok(eq_result.true_count() == valid_count)
80}
81
82/// Compute whether an array has constant values.
83///
84/// An array is constant IFF at least one of the following conditions apply:
85/// 1. It has at least one element (**Note** - an empty array isn't constant).
86/// 2. It's encoded as a [`ConstantArray`](crate::arrays::ConstantArray) or [`NullArray`](crate::arrays::NullArray)
87/// 3. Has an exact statistic attached to it, saying its constant.
88/// 4. Is all invalid.
89/// 5. Is all valid AND has minimum and maximum statistics that are equal.
90pub fn is_constant(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
91    // Short-circuit using cached array statistics.
92    if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
93        return Ok(value);
94    }
95
96    // Empty arrays are not constant.
97    if array.is_empty() {
98        return Ok(false);
99    }
100
101    // Array of length 1 is always constant.
102    if array.len() == 1 {
103        array
104            .statistics()
105            .set(Stat::IsConstant, Precision::Exact(true.into()));
106        return Ok(true);
107    }
108
109    // Constant and null arrays are always constant.
110    if array.is::<Constant>() || array.is::<Null>() {
111        array
112            .statistics()
113            .set(Stat::IsConstant, Precision::Exact(true.into()));
114        return Ok(true);
115    }
116
117    let all_invalid = array.all_invalid()?;
118    if all_invalid {
119        array
120            .statistics()
121            .set(Stat::IsConstant, Precision::Exact(true.into()));
122        return Ok(true);
123    }
124
125    let all_valid = array.all_valid()?;
126
127    // If we have some nulls but not all nulls, array can't be constant.
128    if !all_valid && !all_invalid {
129        array
130            .statistics()
131            .set(Stat::IsConstant, Precision::Exact(false.into()));
132        return Ok(false);
133    }
134
135    // We already know here that the array is all valid, so we check for min/max stats.
136    let min = array.statistics().get(Stat::Min);
137    let max = array.statistics().get(Stat::Max);
138
139    if let Some((min, max)) = min.zip(max)
140        && min.is_exact()
141        && min == max
142        && (Stat::NaNCount.dtype(array.dtype()).is_none()
143            || array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
144    {
145        array
146            .statistics()
147            .set(Stat::IsConstant, Precision::Exact(true.into()));
148        return Ok(true);
149    }
150
151    // Short-circuit for unsupported dtypes.
152    if IsConstant
153        .return_dtype(&EmptyOptions, array.dtype())
154        .is_none()
155    {
156        // Null dtype - vacuously false for empty
157        return Ok(false);
158    }
159
160    // Compute using Accumulator<IsConstant>.
161    let mut acc = Accumulator::try_new(IsConstant, EmptyOptions, array.dtype().clone())?;
162    acc.accumulate(array, ctx)?;
163    let result_scalar = acc.finish()?;
164
165    let result = result_scalar.as_bool().value().unwrap_or(false);
166
167    // Cache the computed is_constant as a statistic.
168    array
169        .statistics()
170        .set(Stat::IsConstant, Precision::Exact(result.into()));
171
172    Ok(result)
173}
174
175/// Compute whether an array is constant.
176///
177/// Returns a `Bool(NonNullable)` scalar.
178/// The partial state is a nullable struct `{is_constant: Bool(NN), value: input_dtype?}`.
179/// A null struct means the accumulator has seen no data yet (empty).
180#[derive(Clone, Debug)]
181pub struct IsConstant;
182
183impl IsConstant {
184    /// Build a partial scalar from a kernel's `is_constant` result.
185    ///
186    /// Kernels that compute `is_constant` by delegating to child arrays can call this
187    /// to package the boolean result into the partial struct format expected by the
188    /// accumulator, avoiding duplicated boilerplate.
189    pub fn make_partial(batch: &ArrayRef, is_constant: bool) -> VortexResult<Scalar> {
190        let partial_dtype = make_is_constant_partial_dtype(batch.dtype());
191        if is_constant {
192            if batch.is_empty() {
193                return Ok(Scalar::null(partial_dtype));
194            }
195            let first_value = batch.scalar_at(0)?.into_nullable();
196            Ok(Scalar::struct_(
197                partial_dtype,
198                vec![Scalar::bool(true, Nullability::NonNullable), first_value],
199            ))
200        } else {
201            Ok(Scalar::struct_(
202                partial_dtype,
203                vec![
204                    Scalar::bool(false, Nullability::NonNullable),
205                    Scalar::null(batch.dtype().as_nullable()),
206                ],
207            ))
208        }
209    }
210}
211
212/// Partial accumulator state for is_constant.
213pub struct IsConstantPartial {
214    is_constant: bool,
215    /// None = empty (no values seen), Some(null) = all nulls, Some(v) = first value seen.
216    first_value: Option<Scalar>,
217    element_dtype: DType,
218}
219
220impl IsConstantPartial {
221    fn check_value(&mut self, value: Scalar) {
222        if !self.is_constant {
223            return;
224        }
225        match &self.first_value {
226            None => {
227                self.first_value = Some(value);
228            }
229            Some(first) => {
230                if *first != value {
231                    self.is_constant = false;
232                }
233            }
234        }
235    }
236}
237
238static NAMES: std::sync::LazyLock<FieldNames> =
239    std::sync::LazyLock::new(|| FieldNames::from(["is_constant", "value"]));
240
241pub fn make_is_constant_partial_dtype(element_dtype: &DType) -> DType {
242    DType::Struct(
243        StructFields::new(
244            NAMES.clone(),
245            vec![
246                DType::Bool(Nullability::NonNullable),
247                element_dtype.as_nullable(),
248            ],
249        ),
250        Nullability::Nullable,
251    )
252}
253
254impl AggregateFnVTable for IsConstant {
255    type Options = EmptyOptions;
256    type Partial = IsConstantPartial;
257
258    fn id(&self) -> AggregateFnId {
259        AggregateFnId::new_ref("vortex.is_constant")
260    }
261
262    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
263        unimplemented!("IsConstant is not yet serializable");
264    }
265
266    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
267        match input_dtype {
268            DType::Null | DType::Variant(..) => None,
269            _ => Some(DType::Bool(Nullability::NonNullable)),
270        }
271    }
272
273    fn partial_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
274        match input_dtype {
275            DType::Null | DType::Variant(..) => None,
276            _ => Some(make_is_constant_partial_dtype(input_dtype)),
277        }
278    }
279
280    fn empty_partial(
281        &self,
282        _options: &Self::Options,
283        input_dtype: &DType,
284    ) -> VortexResult<Self::Partial> {
285        Ok(IsConstantPartial {
286            is_constant: true,
287            first_value: None,
288            element_dtype: input_dtype.clone(),
289        })
290    }
291
292    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
293        if !partial.is_constant {
294            return Ok(());
295        }
296
297        // Null struct means the other accumulator was empty, skip it.
298        if other.is_null() {
299            return Ok(());
300        }
301
302        let other_is_constant = other
303            .as_struct()
304            .field_by_idx(0)
305            .map(|s| s.as_bool().value().unwrap_or(false))
306            .unwrap_or(false);
307
308        if !other_is_constant {
309            partial.is_constant = false;
310            return Ok(());
311        }
312
313        let other_value = other.as_struct().field_by_idx(1);
314
315        if let Some(other_val) = other_value {
316            partial.check_value(other_val);
317        }
318
319        Ok(())
320    }
321
322    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
323        let dtype = make_is_constant_partial_dtype(&partial.element_dtype);
324        Ok(match &partial.first_value {
325            None => {
326                // Empty accumulator — return null struct.
327                Scalar::null(dtype)
328            }
329            Some(first_value) => Scalar::struct_(
330                dtype,
331                vec![
332                    Scalar::bool(partial.is_constant, Nullability::NonNullable),
333                    first_value
334                        .clone()
335                        .cast(&partial.element_dtype.as_nullable())?,
336                ],
337            ),
338        })
339    }
340
341    fn reset(&self, partial: &mut Self::Partial) {
342        partial.is_constant = true;
343        partial.first_value = None;
344    }
345
346    #[inline]
347    fn is_saturated(&self, partial: &Self::Partial) -> bool {
348        !partial.is_constant
349    }
350
351    fn accumulate(
352        &self,
353        partial: &mut Self::Partial,
354        batch: &Columnar,
355        ctx: &mut ExecutionCtx,
356    ) -> VortexResult<()> {
357        if !partial.is_constant {
358            return Ok(());
359        }
360
361        match batch {
362            Columnar::Constant(c) => {
363                partial.check_value(c.scalar().clone().into_nullable());
364                Ok(())
365            }
366            Columnar::Canonical(c) => {
367                if c.is_empty() {
368                    return Ok(());
369                }
370
371                // Convert to ArrayRef for DynArray methods.
372                let array_ref = c.clone().into_array();
373
374                let all_invalid = array_ref.all_invalid()?;
375                if all_invalid {
376                    partial.check_value(Scalar::null(partial.element_dtype.as_nullable()));
377                    return Ok(());
378                }
379
380                let all_valid = array_ref.all_valid()?;
381                // Mixed nulls → not constant.
382                if !all_valid && !all_invalid {
383                    partial.is_constant = false;
384                    return Ok(());
385                }
386
387                // All valid from here. Check batch-level constancy.
388                if c.len() == 1 {
389                    partial.check_value(array_ref.scalar_at(0)?.into_nullable());
390                    return Ok(());
391                }
392
393                let batch_is_constant = match c {
394                    Canonical::Primitive(p) => check_primitive_constant(p),
395                    Canonical::Bool(b) => check_bool_constant(b),
396                    Canonical::VarBinView(v) => check_varbinview_constant(v),
397                    Canonical::Decimal(d) => check_decimal_constant(d),
398                    Canonical::Struct(s) => check_struct_constant(s, ctx)?,
399                    Canonical::Extension(e) => check_extension_constant(e, ctx)?,
400                    Canonical::List(l) => check_listview_constant(l, ctx)?,
401                    Canonical::FixedSizeList(f) => check_fixed_size_list_constant(f, ctx)?,
402                    Canonical::Null(_) => true,
403                    Canonical::Variant(_) => {
404                        vortex_bail!("Variant arrays don't support IsConstant")
405                    }
406                };
407
408                if !batch_is_constant {
409                    partial.is_constant = false;
410                    return Ok(());
411                }
412
413                partial.check_value(array_ref.scalar_at(0)?.into_nullable());
414                Ok(())
415            }
416        }
417    }
418
419    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
420        partials.get_item(NAMES.get(0).vortex_expect("out of bounds").clone())
421    }
422
423    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
424        if partial.first_value.is_none() {
425            // Empty accumulator → return false.
426            return Ok(Scalar::bool(false, Nullability::NonNullable));
427        }
428        Ok(Scalar::bool(partial.is_constant, Nullability::NonNullable))
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use rstest::rstest;
435    use vortex_buffer::Buffer;
436    use vortex_buffer::buffer;
437    use vortex_error::VortexResult;
438
439    use crate::IntoArray as _;
440    use crate::LEGACY_SESSION;
441    use crate::VortexSessionExecute;
442    use crate::aggregate_fn::fns::is_constant::is_constant;
443    use crate::arrays::BoolArray;
444    use crate::arrays::ChunkedArray;
445    use crate::arrays::DecimalArray;
446    use crate::arrays::ListArray;
447    use crate::arrays::PrimitiveArray;
448    use crate::arrays::StructArray;
449    use crate::dtype::DType;
450    use crate::dtype::DecimalDType;
451    use crate::dtype::FieldNames;
452    use crate::dtype::Nullability;
453    use crate::dtype::PType;
454    use crate::expr::stats::Stat;
455    use crate::validity::Validity;
456
457    // Tests migrated from compute/is_constant.rs
458    #[test]
459    fn is_constant_min_max_no_nan() -> VortexResult<()> {
460        let mut ctx = LEGACY_SESSION.create_execution_ctx();
461
462        let arr = buffer![0, 1].into_array();
463        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
464        assert!(!is_constant(&arr, &mut ctx)?);
465
466        let arr = buffer![0, 0].into_array();
467        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
468        assert!(is_constant(&arr, &mut ctx)?);
469
470        let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]).into_array();
471        assert!(is_constant(&arr, &mut ctx)?);
472        Ok(())
473    }
474
475    #[test]
476    fn is_constant_min_max_with_nan() -> VortexResult<()> {
477        let mut ctx = LEGACY_SESSION.create_execution_ctx();
478
479        let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]).into_array();
480        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
481        assert!(!is_constant(&arr, &mut ctx)?);
482
483        let arr =
484            PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)])
485                .into_array();
486        arr.statistics().compute_all(&[Stat::Min, Stat::Max])?;
487        assert!(is_constant(&arr, &mut ctx)?);
488        Ok(())
489    }
490
491    // Tests migrated from arrays/bool/compute/is_constant.rs
492    #[rstest]
493    #[case(vec![true], true)]
494    #[case(vec![false; 65], true)]
495    #[case({
496        let mut v = vec![true; 64];
497        v.push(false);
498        v
499    }, false)]
500    fn test_bool_is_constant(#[case] input: Vec<bool>, #[case] expected: bool) -> VortexResult<()> {
501        let array = BoolArray::from_iter(input);
502        let mut ctx = LEGACY_SESSION.create_execution_ctx();
503        assert_eq!(is_constant(&array.into_array(), &mut ctx)?, expected);
504        Ok(())
505    }
506
507    // Tests migrated from arrays/chunked/compute/is_constant.rs
508    #[test]
509    fn empty_chunk_is_constant() -> VortexResult<()> {
510        let chunked = ChunkedArray::try_new(
511            vec![
512                Buffer::<u8>::empty().into_array(),
513                Buffer::<u8>::empty().into_array(),
514                buffer![255u8, 255].into_array(),
515                Buffer::<u8>::empty().into_array(),
516                buffer![255u8, 255].into_array(),
517            ],
518            DType::Primitive(PType::U8, Nullability::NonNullable),
519        )?
520        .into_array();
521
522        let mut ctx = LEGACY_SESSION.create_execution_ctx();
523        assert!(is_constant(&chunked, &mut ctx)?);
524        Ok(())
525    }
526
527    // Tests migrated from arrays/decimal/compute/is_constant.rs
528    #[test]
529    fn test_decimal_is_constant() -> VortexResult<()> {
530        let mut ctx = LEGACY_SESSION.create_execution_ctx();
531
532        let array = DecimalArray::new(
533            buffer![0i128, 1i128, 2i128],
534            DecimalDType::new(19, 0),
535            Validity::NonNullable,
536        );
537        assert!(!is_constant(&array.into_array(), &mut ctx)?);
538
539        let array = DecimalArray::new(
540            buffer![100i128, 100i128, 100i128],
541            DecimalDType::new(19, 0),
542            Validity::NonNullable,
543        );
544        assert!(is_constant(&array.into_array(), &mut ctx)?);
545        Ok(())
546    }
547
548    // Tests migrated from arrays/list/compute/is_constant.rs
549    #[test]
550    fn test_is_constant_nested_list() -> VortexResult<()> {
551        let mut ctx = LEGACY_SESSION.create_execution_ctx();
552
553        let xs = ListArray::try_new(
554            buffer![0i32, 1, 0, 1].into_array(),
555            buffer![0u32, 2, 4].into_array(),
556            Validity::NonNullable,
557        )?;
558
559        let struct_of_lists = StructArray::try_new(
560            FieldNames::from(["xs"]),
561            vec![xs.into_array()],
562            2,
563            Validity::NonNullable,
564        )?;
565        assert!(is_constant(
566            &struct_of_lists.clone().into_array(),
567            &mut ctx
568        )?);
569        assert!(is_constant(&struct_of_lists.into_array(), &mut ctx)?);
570        Ok(())
571    }
572
573    #[rstest]
574    #[case(
575        // [1,2], [1, 2], [1, 2]
576        vec![1i32, 2, 1, 2, 1, 2],
577        vec![0u32, 2, 4, 6],
578        true
579    )]
580    #[case(
581        // [1, 2], [3], [4, 5]
582        vec![1i32, 2, 3, 4, 5],
583        vec![0u32, 2, 3, 5],
584        false
585    )]
586    #[case(
587        // [1, 2], [3, 4]
588        vec![1i32, 2, 3, 4],
589        vec![0u32, 2, 4],
590        false
591    )]
592    #[case(
593        // [], [], []
594        vec![],
595        vec![0u32, 0, 0, 0],
596        true
597    )]
598    fn test_list_is_constant(
599        #[case] elements: Vec<i32>,
600        #[case] offsets: Vec<u32>,
601        #[case] expected: bool,
602    ) -> VortexResult<()> {
603        let list_array = ListArray::try_new(
604            PrimitiveArray::from_iter(elements).into_array(),
605            PrimitiveArray::from_iter(offsets).into_array(),
606            Validity::NonNullable,
607        )?;
608
609        let mut ctx = LEGACY_SESSION.create_execution_ctx();
610        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
611        Ok(())
612    }
613
614    #[test]
615    fn test_list_is_constant_nested_lists() -> VortexResult<()> {
616        let inner_elements = buffer![1i32, 2, 1, 2].into_array();
617        let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
618        let inner_lists = ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable)?;
619
620        let outer_offsets = buffer![0u32, 2, 4].into_array();
621        let outer_list = ListArray::try_new(
622            inner_lists.into_array(),
623            outer_offsets,
624            Validity::NonNullable,
625        )?;
626
627        let mut ctx = LEGACY_SESSION.create_execution_ctx();
628        // Both outer lists contain [[1], [2]], so should be constant
629        assert!(is_constant(&outer_list.into_array(), &mut ctx)?);
630        Ok(())
631    }
632
633    #[rstest]
634    #[case(
635        // 100 identical [1, 2] lists
636        [1i32, 2].repeat(100),
637        (0..101).map(|i| (i * 2) as u32).collect(),
638        true
639    )]
640    #[case(
641        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
642        {
643            let mut elements = [1i32, 2].repeat(64);
644            elements.extend_from_slice(&[3, 4]);
645            elements
646        },
647        (0..66).map(|i| (i * 2) as u32).collect(),
648        false
649    )]
650    #[case(
651        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
652        {
653            let mut elements = [1i32, 2].repeat(63);
654            elements.extend_from_slice(&[3, 4]);
655            elements.extend([1i32, 2].repeat(37));
656            elements
657        },
658        (0..101).map(|i| (i * 2) as u32).collect(),
659        false
660    )]
661    fn test_list_is_constant_with_threshold(
662        #[case] elements: Vec<i32>,
663        #[case] offsets: Vec<u32>,
664        #[case] expected: bool,
665    ) -> VortexResult<()> {
666        let list_array = ListArray::try_new(
667            PrimitiveArray::from_iter(elements).into_array(),
668            PrimitiveArray::from_iter(offsets).into_array(),
669            Validity::NonNullable,
670        )?;
671
672        let mut ctx = LEGACY_SESSION.create_execution_ctx();
673        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
674        Ok(())
675    }
676}