Skip to main content

vortex_array/aggregate_fn/fns/is_constant/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod decimal;
6mod extension;
7mod fixed_size_list;
8mod list;
9pub mod primitive;
10mod struct_;
11mod varbin;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_mask::Mask;
17
18use self::bool::check_bool_constant;
19use self::decimal::check_decimal_constant;
20use self::extension::check_extension_constant;
21use self::fixed_size_list::check_fixed_size_list_constant;
22use self::list::check_listview_constant;
23use self::primitive::check_primitive_constant;
24use self::struct_::check_struct_constant;
25use self::varbin::check_varbinview_constant;
26use crate::ArrayRef;
27use crate::Canonical;
28use crate::Columnar;
29use crate::ExecutionCtx;
30use crate::IntoArray;
31use crate::aggregate_fn::Accumulator;
32use crate::aggregate_fn::AggregateFnId;
33use crate::aggregate_fn::AggregateFnVTable;
34use crate::aggregate_fn::DynAccumulator;
35use crate::aggregate_fn::EmptyOptions;
36use crate::arrays::Constant;
37use crate::arrays::Null;
38use crate::builtins::ArrayBuiltins;
39use crate::dtype::DType;
40use crate::dtype::FieldNames;
41use crate::dtype::Nullability;
42use crate::dtype::StructFields;
43use crate::expr::stats::Precision;
44use crate::expr::stats::Stat;
45use crate::expr::stats::StatsProvider;
46use crate::expr::stats::StatsProviderExt;
47use crate::scalar::Scalar;
48use crate::scalar_fn::fns::operators::Operator;
49
50/// Check if two arrays of the same length have equal values at every position (null-safe).
51///
52/// Two positions are considered equal if they are both null, or both non-null with the same value.
53///
54// TODO(ngates): move this function out when we have any/all aggregate functions.
55fn arrays_value_equal(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
56    debug_assert_eq!(a.len(), b.len());
57    if a.is_empty() {
58        return Ok(true);
59    }
60
61    // Check validity masks match (null positions must be identical).
62    let a_mask = a.validity()?.to_mask(a.len(), ctx)?;
63    let b_mask = b.validity()?.to_mask(b.len(), ctx)?;
64    if a_mask != b_mask {
65        return Ok(false);
66    }
67
68    let valid_count = a_mask.true_count();
69    if valid_count == 0 {
70        // Both all-null → equal.
71        return Ok(true);
72    }
73
74    // Compare values element-wise. Result is null where both inputs are null,
75    // true/false where both are valid.
76    let eq_result = a.binary(b.clone(), Operator::Eq)?;
77    let eq_result = eq_result.execute::<Mask>(ctx)?;
78
79    Ok(eq_result.true_count() == valid_count)
80}
81
82/// Compute whether an array has constant values.
83///
84/// An array is constant IFF at least one of the following conditions apply:
85/// 1. It has at least one element (**Note** - an empty array isn't constant).
86/// 2. It's encoded as a [`ConstantArray`](crate::arrays::ConstantArray) or [`NullArray`](crate::arrays::NullArray)
87/// 3. Has an exact statistic attached to it, saying its constant.
88/// 4. Is all invalid.
89/// 5. Is all valid AND has minimum and maximum statistics that are equal.
90pub fn is_constant(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
91    // Short-circuit using cached array statistics.
92    if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
93        return Ok(value);
94    }
95
96    // Empty arrays are not constant.
97    if array.is_empty() {
98        return Ok(false);
99    }
100
101    // Array of length 1 is always constant.
102    if array.len() == 1 {
103        array
104            .statistics()
105            .set(Stat::IsConstant, Precision::Exact(true.into()));
106        return Ok(true);
107    }
108
109    // Constant and null arrays are always constant.
110    if array.is::<Constant>() || array.is::<Null>() {
111        array
112            .statistics()
113            .set(Stat::IsConstant, Precision::Exact(true.into()));
114        return Ok(true);
115    }
116
117    let all_invalid = array.all_invalid(ctx)?;
118    if all_invalid {
119        array
120            .statistics()
121            .set(Stat::IsConstant, Precision::Exact(true.into()));
122        return Ok(true);
123    }
124
125    let all_valid = array.all_valid(ctx)?;
126
127    // If we have some nulls but not all nulls, array can't be constant.
128    if !all_valid && !all_invalid {
129        array
130            .statistics()
131            .set(Stat::IsConstant, Precision::Exact(false.into()));
132        return Ok(false);
133    }
134
135    // We already know here that the array is all valid, so we check for min/max stats.
136    let min = array.statistics().get(Stat::Min);
137    let max = array.statistics().get(Stat::Max);
138
139    if let Some((min, max)) = min.zip(max)
140        && min.is_exact()
141        && min == max
142        && (Stat::NaNCount.dtype(array.dtype()).is_none()
143            || array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
144    {
145        array
146            .statistics()
147            .set(Stat::IsConstant, Precision::Exact(true.into()));
148        return Ok(true);
149    }
150
151    // Short-circuit for unsupported dtypes.
152    if IsConstant
153        .return_dtype(&EmptyOptions, array.dtype())
154        .is_none()
155    {
156        // Null dtype - vacuously false for empty
157        return Ok(false);
158    }
159
160    // Compute using Accumulator<IsConstant>.
161    let mut acc = Accumulator::try_new(IsConstant, EmptyOptions, array.dtype().clone())?;
162    acc.accumulate(array, ctx)?;
163    let result_scalar = acc.finish()?;
164
165    let result = result_scalar.as_bool().value().unwrap_or(false);
166
167    // Cache the computed is_constant as a statistic.
168    array
169        .statistics()
170        .set(Stat::IsConstant, Precision::Exact(result.into()));
171
172    Ok(result)
173}
174
175/// Compute whether an array is constant.
176///
177/// Returns a `Bool(NonNullable)` scalar.
178/// The partial state is a nullable struct `{is_constant: Bool(NN), value: input_dtype?}`.
179/// A null struct means the accumulator has seen no data yet (empty).
180#[derive(Clone, Debug)]
181pub struct IsConstant;
182
183impl IsConstant {
184    /// Build a partial scalar from a kernel's `is_constant` result.
185    ///
186    /// Kernels that compute `is_constant` by delegating to child arrays can call this
187    /// to package the boolean result into the partial struct format expected by the
188    /// accumulator, avoiding duplicated boilerplate.
189    pub fn make_partial(
190        batch: &ArrayRef,
191        is_constant: bool,
192        ctx: &mut ExecutionCtx,
193    ) -> VortexResult<Scalar> {
194        let partial_dtype = make_is_constant_partial_dtype(batch.dtype());
195        if is_constant {
196            if batch.is_empty() {
197                return Ok(Scalar::null(partial_dtype));
198            }
199            let first_value = batch.execute_scalar(0, ctx)?.into_nullable();
200            Ok(Scalar::struct_(
201                partial_dtype,
202                vec![Scalar::bool(true, Nullability::NonNullable), first_value],
203            ))
204        } else {
205            Ok(Scalar::struct_(
206                partial_dtype,
207                vec![
208                    Scalar::bool(false, Nullability::NonNullable),
209                    Scalar::null(batch.dtype().as_nullable()),
210                ],
211            ))
212        }
213    }
214}
215
216/// Partial accumulator state for is_constant.
217pub struct IsConstantPartial {
218    is_constant: bool,
219    /// None = empty (no values seen), Some(null) = all nulls, Some(v) = first value seen.
220    first_value: Option<Scalar>,
221    element_dtype: DType,
222}
223
224impl IsConstantPartial {
225    fn check_value(&mut self, value: Scalar) {
226        if !self.is_constant {
227            return;
228        }
229        match &self.first_value {
230            None => {
231                self.first_value = Some(value);
232            }
233            Some(first) => {
234                if *first != value {
235                    self.is_constant = false;
236                }
237            }
238        }
239    }
240}
241
242static NAMES: std::sync::LazyLock<FieldNames> =
243    std::sync::LazyLock::new(|| FieldNames::from(["is_constant", "value"]));
244
245pub fn make_is_constant_partial_dtype(element_dtype: &DType) -> DType {
246    DType::Struct(
247        StructFields::new(
248            NAMES.clone(),
249            vec![
250                DType::Bool(Nullability::NonNullable),
251                element_dtype.as_nullable(),
252            ],
253        ),
254        Nullability::Nullable,
255    )
256}
257
258impl AggregateFnVTable for IsConstant {
259    type Options = EmptyOptions;
260    type Partial = IsConstantPartial;
261
262    fn id(&self) -> AggregateFnId {
263        AggregateFnId::new("vortex.is_constant")
264    }
265
266    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
267        unimplemented!("IsConstant is not yet serializable");
268    }
269
270    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
271        match input_dtype {
272            DType::Null | DType::Variant(..) => None,
273            _ => Some(DType::Bool(Nullability::NonNullable)),
274        }
275    }
276
277    fn partial_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
278        match input_dtype {
279            DType::Null | DType::Variant(..) => None,
280            _ => Some(make_is_constant_partial_dtype(input_dtype)),
281        }
282    }
283
284    fn empty_partial(
285        &self,
286        _options: &Self::Options,
287        input_dtype: &DType,
288    ) -> VortexResult<Self::Partial> {
289        Ok(IsConstantPartial {
290            is_constant: true,
291            first_value: None,
292            element_dtype: input_dtype.clone(),
293        })
294    }
295
296    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
297        if !partial.is_constant {
298            return Ok(());
299        }
300
301        // Null struct means the other accumulator was empty, skip it.
302        if other.is_null() {
303            return Ok(());
304        }
305
306        let other_is_constant = other
307            .as_struct()
308            .field_by_idx(0)
309            .map(|s| s.as_bool().value().unwrap_or(false))
310            .unwrap_or(false);
311
312        if !other_is_constant {
313            partial.is_constant = false;
314            return Ok(());
315        }
316
317        let other_value = other.as_struct().field_by_idx(1);
318
319        if let Some(other_val) = other_value {
320            partial.check_value(other_val);
321        }
322
323        Ok(())
324    }
325
326    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
327        let dtype = make_is_constant_partial_dtype(&partial.element_dtype);
328        Ok(match &partial.first_value {
329            None => {
330                // Empty accumulator — return null struct.
331                Scalar::null(dtype)
332            }
333            Some(first_value) => Scalar::struct_(
334                dtype,
335                vec![
336                    Scalar::bool(partial.is_constant, Nullability::NonNullable),
337                    first_value
338                        .clone()
339                        .cast(&partial.element_dtype.as_nullable())?,
340                ],
341            ),
342        })
343    }
344
345    fn reset(&self, partial: &mut Self::Partial) {
346        partial.is_constant = true;
347        partial.first_value = None;
348    }
349
350    #[inline]
351    fn is_saturated(&self, partial: &Self::Partial) -> bool {
352        !partial.is_constant
353    }
354
355    fn accumulate(
356        &self,
357        partial: &mut Self::Partial,
358        batch: &Columnar,
359        ctx: &mut ExecutionCtx,
360    ) -> VortexResult<()> {
361        if !partial.is_constant {
362            return Ok(());
363        }
364
365        match batch {
366            Columnar::Constant(c) => {
367                partial.check_value(c.scalar().clone().into_nullable());
368                Ok(())
369            }
370            Columnar::Canonical(c) => {
371                if c.is_empty() {
372                    return Ok(());
373                }
374
375                // Convert to ArrayRef for DynArray methods.
376                let array_ref = c.clone().into_array();
377
378                let all_invalid = array_ref.all_invalid(ctx)?;
379                if all_invalid {
380                    partial.check_value(Scalar::null(partial.element_dtype.as_nullable()));
381                    return Ok(());
382                }
383
384                let all_valid = array_ref.all_valid(ctx)?;
385                // Mixed nulls → not constant.
386                if !all_valid && !all_invalid {
387                    partial.is_constant = false;
388                    return Ok(());
389                }
390
391                // All valid from here. Check batch-level constancy.
392                if c.len() == 1 {
393                    partial.check_value(array_ref.execute_scalar(0, ctx)?.into_nullable());
394                    return Ok(());
395                }
396
397                let batch_is_constant = match c {
398                    Canonical::Primitive(p) => check_primitive_constant(p),
399                    Canonical::Bool(b) => check_bool_constant(b),
400                    Canonical::VarBinView(v) => check_varbinview_constant(v),
401                    Canonical::Decimal(d) => check_decimal_constant(d),
402                    Canonical::Struct(s) => check_struct_constant(s, ctx)?,
403                    Canonical::Extension(e) => check_extension_constant(e, ctx)?,
404                    Canonical::List(l) => check_listview_constant(l, ctx)?,
405                    Canonical::FixedSizeList(f) => check_fixed_size_list_constant(f, ctx)?,
406                    Canonical::Null(_) => true,
407                    Canonical::Variant(_) => {
408                        vortex_bail!("Variant arrays don't support IsConstant")
409                    }
410                };
411
412                if !batch_is_constant {
413                    partial.is_constant = false;
414                    return Ok(());
415                }
416
417                partial.check_value(array_ref.execute_scalar(0, ctx)?.into_nullable());
418                Ok(())
419            }
420        }
421    }
422
423    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
424        partials.get_item(NAMES.get(0).vortex_expect("out of bounds").clone())
425    }
426
427    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
428        if partial.first_value.is_none() {
429            // Empty accumulator → return false.
430            return Ok(Scalar::bool(false, Nullability::NonNullable));
431        }
432        Ok(Scalar::bool(partial.is_constant, Nullability::NonNullable))
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use rstest::rstest;
439    use vortex_buffer::Buffer;
440    use vortex_buffer::buffer;
441    use vortex_error::VortexResult;
442
443    use crate::IntoArray as _;
444    use crate::LEGACY_SESSION;
445    use crate::VortexSessionExecute;
446    use crate::aggregate_fn::fns::is_constant::is_constant;
447    use crate::arrays::BoolArray;
448    use crate::arrays::ChunkedArray;
449    use crate::arrays::DecimalArray;
450    use crate::arrays::ListArray;
451    use crate::arrays::PrimitiveArray;
452    use crate::arrays::StructArray;
453    use crate::dtype::DType;
454    use crate::dtype::DecimalDType;
455    use crate::dtype::FieldNames;
456    use crate::dtype::Nullability;
457    use crate::dtype::PType;
458    use crate::expr::stats::Stat;
459    use crate::validity::Validity;
460
461    // Tests migrated from compute/is_constant.rs
462    #[test]
463    fn is_constant_min_max_no_nan() -> VortexResult<()> {
464        let mut ctx = LEGACY_SESSION.create_execution_ctx();
465
466        let arr = buffer![0, 1].into_array();
467        arr.statistics()
468            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
469        assert!(!is_constant(&arr, &mut ctx)?);
470
471        let arr = buffer![0, 0].into_array();
472        arr.statistics()
473            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
474        assert!(is_constant(&arr, &mut ctx)?);
475
476        let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]).into_array();
477        assert!(is_constant(&arr, &mut ctx)?);
478        Ok(())
479    }
480
481    #[test]
482    fn is_constant_min_max_with_nan() -> VortexResult<()> {
483        let mut ctx = LEGACY_SESSION.create_execution_ctx();
484
485        let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]).into_array();
486        arr.statistics()
487            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
488        assert!(!is_constant(&arr, &mut ctx)?);
489
490        let arr =
491            PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)])
492                .into_array();
493        arr.statistics()
494            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
495        assert!(is_constant(&arr, &mut ctx)?);
496        Ok(())
497    }
498
499    // Tests migrated from arrays/bool/compute/is_constant.rs
500    #[rstest]
501    #[case(vec![true], true)]
502    #[case(vec![false; 65], true)]
503    #[case({
504        let mut v = vec![true; 64];
505        v.push(false);
506        v
507    }, false)]
508    fn test_bool_is_constant(#[case] input: Vec<bool>, #[case] expected: bool) -> VortexResult<()> {
509        let array = BoolArray::from_iter(input);
510        let mut ctx = LEGACY_SESSION.create_execution_ctx();
511        assert_eq!(is_constant(&array.into_array(), &mut ctx)?, expected);
512        Ok(())
513    }
514
515    // Tests migrated from arrays/chunked/compute/is_constant.rs
516    #[test]
517    fn empty_chunk_is_constant() -> VortexResult<()> {
518        let chunked = ChunkedArray::try_new(
519            vec![
520                Buffer::<u8>::empty().into_array(),
521                Buffer::<u8>::empty().into_array(),
522                buffer![255u8, 255].into_array(),
523                Buffer::<u8>::empty().into_array(),
524                buffer![255u8, 255].into_array(),
525            ],
526            DType::Primitive(PType::U8, Nullability::NonNullable),
527        )?
528        .into_array();
529
530        let mut ctx = LEGACY_SESSION.create_execution_ctx();
531        assert!(is_constant(&chunked, &mut ctx)?);
532        Ok(())
533    }
534
535    // Tests migrated from arrays/decimal/compute/is_constant.rs
536    #[test]
537    fn test_decimal_is_constant() -> VortexResult<()> {
538        let mut ctx = LEGACY_SESSION.create_execution_ctx();
539
540        let array = DecimalArray::new(
541            buffer![0i128, 1i128, 2i128],
542            DecimalDType::new(19, 0),
543            Validity::NonNullable,
544        );
545        assert!(!is_constant(&array.into_array(), &mut ctx)?);
546
547        let array = DecimalArray::new(
548            buffer![100i128, 100i128, 100i128],
549            DecimalDType::new(19, 0),
550            Validity::NonNullable,
551        );
552        assert!(is_constant(&array.into_array(), &mut ctx)?);
553        Ok(())
554    }
555
556    // Tests migrated from arrays/list/compute/is_constant.rs
557    #[test]
558    fn test_is_constant_nested_list() -> VortexResult<()> {
559        let mut ctx = LEGACY_SESSION.create_execution_ctx();
560
561        let xs = ListArray::try_new(
562            buffer![0i32, 1, 0, 1].into_array(),
563            buffer![0u32, 2, 4].into_array(),
564            Validity::NonNullable,
565        )?;
566
567        let struct_of_lists = StructArray::try_new(
568            FieldNames::from(["xs"]),
569            vec![xs.into_array()],
570            2,
571            Validity::NonNullable,
572        )?;
573        assert!(is_constant(
574            &struct_of_lists.clone().into_array(),
575            &mut ctx
576        )?);
577        assert!(is_constant(&struct_of_lists.into_array(), &mut ctx)?);
578        Ok(())
579    }
580
581    #[rstest]
582    #[case(
583        // [1,2], [1, 2], [1, 2]
584        vec![1i32, 2, 1, 2, 1, 2],
585        vec![0u32, 2, 4, 6],
586        true
587    )]
588    #[case(
589        // [1, 2], [3], [4, 5]
590        vec![1i32, 2, 3, 4, 5],
591        vec![0u32, 2, 3, 5],
592        false
593    )]
594    #[case(
595        // [1, 2], [3, 4]
596        vec![1i32, 2, 3, 4],
597        vec![0u32, 2, 4],
598        false
599    )]
600    #[case(
601        // [], [], []
602        vec![],
603        vec![0u32, 0, 0, 0],
604        true
605    )]
606    fn test_list_is_constant(
607        #[case] elements: Vec<i32>,
608        #[case] offsets: Vec<u32>,
609        #[case] expected: bool,
610    ) -> VortexResult<()> {
611        let list_array = ListArray::try_new(
612            PrimitiveArray::from_iter(elements).into_array(),
613            PrimitiveArray::from_iter(offsets).into_array(),
614            Validity::NonNullable,
615        )?;
616
617        let mut ctx = LEGACY_SESSION.create_execution_ctx();
618        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
619        Ok(())
620    }
621
622    #[test]
623    fn test_list_is_constant_nested_lists() -> VortexResult<()> {
624        let inner_elements = buffer![1i32, 2, 1, 2].into_array();
625        let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
626        let inner_lists = ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable)?;
627
628        let outer_offsets = buffer![0u32, 2, 4].into_array();
629        let outer_list = ListArray::try_new(
630            inner_lists.into_array(),
631            outer_offsets,
632            Validity::NonNullable,
633        )?;
634
635        let mut ctx = LEGACY_SESSION.create_execution_ctx();
636        // Both outer lists contain [[1], [2]], so should be constant
637        assert!(is_constant(&outer_list.into_array(), &mut ctx)?);
638        Ok(())
639    }
640
641    #[rstest]
642    #[case(
643        // 100 identical [1, 2] lists
644        [1i32, 2].repeat(100),
645        (0..101).map(|i| (i * 2) as u32).collect(),
646        true
647    )]
648    #[case(
649        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
650        {
651            let mut elements = [1i32, 2].repeat(64);
652            elements.extend_from_slice(&[3, 4]);
653            elements
654        },
655        (0..66).map(|i| (i * 2) as u32).collect(),
656        false
657    )]
658    #[case(
659        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
660        {
661            let mut elements = [1i32, 2].repeat(63);
662            elements.extend_from_slice(&[3, 4]);
663            elements.extend([1i32, 2].repeat(37));
664            elements
665        },
666        (0..101).map(|i| (i * 2) as u32).collect(),
667        false
668    )]
669    fn test_list_is_constant_with_threshold(
670        #[case] elements: Vec<i32>,
671        #[case] offsets: Vec<u32>,
672        #[case] expected: bool,
673    ) -> VortexResult<()> {
674        let list_array = ListArray::try_new(
675            PrimitiveArray::from_iter(elements).into_array(),
676            PrimitiveArray::from_iter(offsets).into_array(),
677            Validity::NonNullable,
678        )?;
679
680        let mut ctx = LEGACY_SESSION.create_execution_ctx();
681        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
682        Ok(())
683    }
684}