Skip to main content

vortex_array/aggregate_fn/fns/is_constant/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod decimal;
6mod extension;
7mod fixed_size_list;
8mod list;
9pub mod primitive;
10mod struct_;
11mod varbin;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_session::registry::CachedId;
17
18use self::bool::check_bool_constant;
19use self::decimal::check_decimal_constant;
20use self::extension::check_extension_constant;
21use self::fixed_size_list::check_fixed_size_list_constant;
22use self::list::check_listview_constant;
23use self::primitive::check_primitive_constant;
24use self::struct_::check_struct_constant;
25use self::varbin::check_varbinview_constant;
26use crate::ArrayRef;
27use crate::Canonical;
28use crate::Columnar;
29use crate::ExecutionCtx;
30use crate::IntoArray;
31use crate::aggregate_fn::Accumulator;
32use crate::aggregate_fn::AggregateFnId;
33use crate::aggregate_fn::AggregateFnVTable;
34use crate::aggregate_fn::DynAccumulator;
35use crate::aggregate_fn::EmptyOptions;
36use crate::arrays::Constant;
37use crate::arrays::Null;
38use crate::builtins::ArrayBuiltins;
39use crate::dtype::DType;
40use crate::dtype::FieldNames;
41use crate::dtype::Nullability;
42use crate::dtype::StructFields;
43use crate::expr::stats::Precision;
44use crate::expr::stats::Stat;
45use crate::expr::stats::StatsProvider;
46use crate::expr::stats::StatsProviderExt;
47use crate::scalar::Scalar;
48use crate::scalar_fn::fns::operators::Operator;
49
50/// Check if two arrays of the same length have equal values at every position (null-safe).
51///
52/// Two positions are considered equal if they are both null, or both non-null with the same value.
53// TODO(ngates): move this function out when we have any/all aggregate functions.
54fn arrays_value_equal(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
55    debug_assert_eq!(a.len(), b.len());
56    if a.is_empty() {
57        return Ok(true);
58    }
59
60    // Check validity masks match (null positions must be identical).
61    let a_mask = a.validity()?.execute_mask(a.len(), ctx)?;
62    let b_mask = b.validity()?.execute_mask(b.len(), ctx)?;
63    if a_mask != b_mask {
64        return Ok(false);
65    }
66
67    let valid_count = a_mask.true_count();
68    if valid_count == 0 {
69        // Both all-null → equal.
70        return Ok(true);
71    }
72
73    // Compare values element-wise. Result is null where both inputs are null,
74    // true/false where both are valid.
75    let eq_result = a.binary(b.clone(), Operator::Eq)?;
76    let eq_result = eq_result.null_as_false().execute(ctx)?;
77
78    Ok(eq_result.true_count() == valid_count)
79}
80
81/// Compute whether an array has constant values.
82///
83/// An array is constant IFF at least one of the following conditions apply:
84/// 1. It has at least one element (**Note** - an empty array isn't constant).
85/// 2. It's encoded as a [`ConstantArray`](crate::arrays::ConstantArray) or [`NullArray`](crate::arrays::NullArray)
86/// 3. Has an exact statistic attached to it, saying its constant.
87/// 4. Is all invalid.
88/// 5. Is all valid AND has minimum and maximum statistics that are equal.
89pub fn is_constant(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
90    // Short-circuit using cached array statistics.
91    if let Precision::Exact(value) = array.statistics().get_as::<bool>(Stat::IsConstant) {
92        return Ok(value);
93    }
94
95    // Empty arrays are not constant.
96    if array.is_empty() {
97        return Ok(false);
98    }
99
100    // Array of length 1 is always constant.
101    if array.len() == 1 {
102        array
103            .statistics()
104            .set(Stat::IsConstant, Precision::Exact(true.into()));
105        return Ok(true);
106    }
107
108    // Constant and null arrays are always constant.
109    if array.is::<Constant>() || array.is::<Null>() {
110        array
111            .statistics()
112            .set(Stat::IsConstant, Precision::Exact(true.into()));
113        return Ok(true);
114    }
115
116    let all_invalid = array.all_invalid(ctx)?;
117    if all_invalid {
118        array
119            .statistics()
120            .set(Stat::IsConstant, Precision::Exact(true.into()));
121        return Ok(true);
122    }
123
124    let all_valid = array.all_valid(ctx)?;
125
126    // If we have some nulls but not all nulls, array can't be constant.
127    if !all_valid && !all_invalid {
128        array
129            .statistics()
130            .set(Stat::IsConstant, Precision::Exact(false.into()));
131        return Ok(false);
132    }
133
134    // We already know here that the array is all valid, so we check for min/max stats.
135    let min_stat = array.statistics().get(Stat::Min);
136    let max_stat = array.statistics().get(Stat::Max);
137
138    if let Precision::Exact(min) = min_stat.as_ref()
139        && let Precision::Exact(max) = max_stat.as_ref()
140        && min == max
141        && (Stat::NaNCount.dtype(array.dtype()).is_none()
142            || array.statistics().get_as::<u64>(Stat::NaNCount) == Precision::exact(0u64))
143    {
144        array
145            .statistics()
146            .set(Stat::IsConstant, Precision::Exact(true.into()));
147        return Ok(true);
148    }
149
150    // Short-circuit for unsupported dtypes.
151    if IsConstant
152        .return_dtype(&EmptyOptions, array.dtype())
153        .is_none()
154    {
155        // Null dtype - vacuously false for empty
156        return Ok(false);
157    }
158
159    // Compute using Accumulator<IsConstant>.
160    let mut acc = Accumulator::try_new(IsConstant, EmptyOptions, array.dtype().clone())?;
161    acc.accumulate(array, ctx)?;
162    let result_scalar = acc.finish()?;
163
164    let result = result_scalar.as_bool().value().unwrap_or(false);
165
166    // Cache the computed is_constant as a statistic.
167    array
168        .statistics()
169        .set(Stat::IsConstant, Precision::Exact(result.into()));
170
171    Ok(result)
172}
173
174/// Compute whether an array is constant.
175///
176/// Returns a `Bool(NonNullable)` scalar.
177/// The partial state is a nullable struct `{is_constant: Bool(NN), value: input_dtype?}`.
178/// A null struct means the accumulator has seen no data yet (empty).
179#[derive(Clone, Debug)]
180pub struct IsConstant;
181
182impl IsConstant {
183    /// Build a partial scalar from a kernel's `is_constant` result.
184    ///
185    /// Kernels that compute `is_constant` by delegating to child arrays can call this
186    /// to package the boolean result into the partial struct format expected by the
187    /// accumulator, avoiding duplicated boilerplate.
188    pub fn make_partial(
189        batch: &ArrayRef,
190        is_constant: bool,
191        ctx: &mut ExecutionCtx,
192    ) -> VortexResult<Scalar> {
193        let partial_dtype = make_is_constant_partial_dtype(batch.dtype());
194        if is_constant {
195            if batch.is_empty() {
196                return Ok(Scalar::null(partial_dtype));
197            }
198            let first_value = batch.execute_scalar(0, ctx)?.into_nullable();
199            Ok(Scalar::struct_(
200                partial_dtype,
201                vec![Scalar::bool(true, Nullability::NonNullable), first_value],
202            ))
203        } else {
204            Ok(Scalar::struct_(
205                partial_dtype,
206                vec![
207                    Scalar::bool(false, Nullability::NonNullable),
208                    Scalar::null(batch.dtype().as_nullable()),
209                ],
210            ))
211        }
212    }
213}
214
215/// Partial accumulator state for is_constant.
216pub struct IsConstantPartial {
217    is_constant: bool,
218    /// None = empty (no values seen), Some(null) = all nulls, Some(v) = first value seen.
219    first_value: Option<Scalar>,
220    element_dtype: DType,
221}
222
223impl IsConstantPartial {
224    fn check_value(&mut self, value: Scalar) {
225        if !self.is_constant {
226            return;
227        }
228        match &self.first_value {
229            None => {
230                self.first_value = Some(value);
231            }
232            Some(first) => {
233                if *first != value {
234                    self.is_constant = false;
235                }
236            }
237        }
238    }
239}
240
241static NAMES: std::sync::LazyLock<FieldNames> =
242    std::sync::LazyLock::new(|| FieldNames::from(["is_constant", "value"]));
243
244pub fn make_is_constant_partial_dtype(element_dtype: &DType) -> DType {
245    DType::Struct(
246        StructFields::new(
247            NAMES.clone(),
248            vec![
249                DType::Bool(Nullability::NonNullable),
250                element_dtype.as_nullable(),
251            ],
252        ),
253        Nullability::Nullable,
254    )
255}
256
257impl AggregateFnVTable for IsConstant {
258    type Options = EmptyOptions;
259    type Partial = IsConstantPartial;
260
261    fn id(&self) -> AggregateFnId {
262        static ID: CachedId = CachedId::new("vortex.is_constant");
263        *ID
264    }
265
266    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
267        unimplemented!("IsConstant is not yet serializable");
268    }
269
270    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
271        match input_dtype {
272            DType::Null | DType::Variant(..) => None,
273            _ => Some(DType::Bool(Nullability::NonNullable)),
274        }
275    }
276
277    fn partial_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
278        match input_dtype {
279            DType::Null | DType::Variant(..) => None,
280            _ => Some(make_is_constant_partial_dtype(input_dtype)),
281        }
282    }
283
284    fn empty_partial(
285        &self,
286        _options: &Self::Options,
287        input_dtype: &DType,
288    ) -> VortexResult<Self::Partial> {
289        Ok(IsConstantPartial {
290            is_constant: true,
291            first_value: None,
292            element_dtype: input_dtype.clone(),
293        })
294    }
295
296    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
297        if !partial.is_constant {
298            return Ok(());
299        }
300
301        // Null struct means the other accumulator was empty, skip it.
302        if other.is_null() {
303            return Ok(());
304        }
305
306        let other_is_constant = other
307            .as_struct()
308            .field_by_idx(0)
309            .map(|s| s.as_bool().value().unwrap_or(false))
310            .unwrap_or(false);
311
312        if !other_is_constant {
313            partial.is_constant = false;
314            return Ok(());
315        }
316
317        let other_value = other.as_struct().field_by_idx(1);
318
319        if let Some(other_val) = other_value {
320            partial.check_value(other_val);
321        }
322
323        Ok(())
324    }
325
326    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
327        let dtype = make_is_constant_partial_dtype(&partial.element_dtype);
328        Ok(match &partial.first_value {
329            None => {
330                // Empty accumulator — return null struct.
331                Scalar::null(dtype)
332            }
333            Some(first_value) => Scalar::struct_(
334                dtype,
335                vec![
336                    Scalar::bool(partial.is_constant, Nullability::NonNullable),
337                    first_value
338                        .clone()
339                        .cast(&partial.element_dtype.as_nullable())?,
340                ],
341            ),
342        })
343    }
344
345    fn reset(&self, partial: &mut Self::Partial) {
346        partial.is_constant = true;
347        partial.first_value = None;
348    }
349
350    #[inline]
351    fn is_saturated(&self, partial: &Self::Partial) -> bool {
352        !partial.is_constant
353    }
354
355    fn accumulate(
356        &self,
357        partial: &mut Self::Partial,
358        batch: &Columnar,
359        ctx: &mut ExecutionCtx,
360    ) -> VortexResult<()> {
361        if !partial.is_constant {
362            return Ok(());
363        }
364
365        match batch {
366            Columnar::Constant(c) => {
367                partial.check_value(c.scalar().clone().into_nullable());
368                Ok(())
369            }
370            Columnar::Canonical(c) => {
371                if c.is_empty() {
372                    return Ok(());
373                }
374
375                // Convert to ArrayRef for DynArrayData methods.
376                let array_ref = c.clone().into_array();
377
378                let all_invalid = array_ref.all_invalid(ctx)?;
379                if all_invalid {
380                    partial.check_value(Scalar::null(partial.element_dtype.as_nullable()));
381                    return Ok(());
382                }
383
384                let all_valid = array_ref.all_valid(ctx)?;
385                // Mixed nulls → not constant.
386                if !all_valid && !all_invalid {
387                    partial.is_constant = false;
388                    return Ok(());
389                }
390
391                // All valid from here. Check batch-level constancy.
392                if c.len() == 1 {
393                    partial.check_value(array_ref.execute_scalar(0, ctx)?.into_nullable());
394                    return Ok(());
395                }
396
397                let batch_is_constant = match c {
398                    Canonical::Primitive(p) => check_primitive_constant(p),
399                    Canonical::Bool(b) => check_bool_constant(b),
400                    Canonical::VarBinView(v) => check_varbinview_constant(v),
401                    Canonical::Decimal(d) => check_decimal_constant(d),
402                    Canonical::Struct(s) => check_struct_constant(s, ctx)?,
403                    Canonical::Extension(e) => check_extension_constant(e, ctx)?,
404                    Canonical::List(l) => check_listview_constant(l, ctx)?,
405                    Canonical::FixedSizeList(f) => check_fixed_size_list_constant(f, ctx)?,
406                    Canonical::Null(_) => true,
407                    Canonical::Variant(_) => {
408                        vortex_bail!("Variant arrays don't support IsConstant")
409                    }
410                };
411
412                if !batch_is_constant {
413                    partial.is_constant = false;
414                    return Ok(());
415                }
416
417                partial.check_value(array_ref.execute_scalar(0, ctx)?.into_nullable());
418                Ok(())
419            }
420        }
421    }
422
423    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
424        partials.get_item(NAMES.get(0).vortex_expect("out of bounds").clone())
425    }
426
427    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
428        if partial.first_value.is_none() {
429            // Empty accumulator → return false.
430            return Ok(Scalar::bool(false, Nullability::NonNullable));
431        }
432        Ok(Scalar::bool(partial.is_constant, Nullability::NonNullable))
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use rstest::rstest;
439    use vortex_buffer::Buffer;
440    use vortex_buffer::buffer;
441    use vortex_error::VortexResult;
442
443    use crate::IntoArray as _;
444    use crate::VortexSessionExecute;
445    use crate::aggregate_fn::fns::is_constant::is_constant;
446    use crate::array_session;
447    use crate::arrays::BoolArray;
448    use crate::arrays::ChunkedArray;
449    use crate::arrays::DecimalArray;
450    use crate::arrays::ListArray;
451    use crate::arrays::PrimitiveArray;
452    use crate::arrays::StructArray;
453    use crate::dtype::DType;
454    use crate::dtype::DecimalDType;
455    use crate::dtype::FieldNames;
456    use crate::dtype::Nullability;
457    use crate::dtype::PType;
458    use crate::expr::stats::Stat;
459    use crate::validity::Validity;
460
461    // Tests migrated from compute/is_constant.rs
462    #[test]
463    fn is_constant_min_max_no_nan() -> VortexResult<()> {
464        let mut ctx = array_session().create_execution_ctx();
465
466        let arr = buffer![0, 1].into_array();
467        arr.statistics()
468            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
469        assert!(!is_constant(&arr, &mut ctx)?);
470
471        let arr = buffer![0, 0].into_array();
472        arr.statistics()
473            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
474        assert!(is_constant(&arr, &mut ctx)?);
475
476        let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]).into_array();
477        assert!(is_constant(&arr, &mut ctx)?);
478        Ok(())
479    }
480
481    #[test]
482    fn is_constant_min_max_with_nan() -> VortexResult<()> {
483        let mut ctx = array_session().create_execution_ctx();
484
485        let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]).into_array();
486        arr.statistics()
487            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
488        assert!(!is_constant(&arr, &mut ctx)?);
489
490        let arr =
491            PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)])
492                .into_array();
493        arr.statistics()
494            .compute_all(&[Stat::Min, Stat::Max], &mut ctx)?;
495        assert!(is_constant(&arr, &mut ctx)?);
496        Ok(())
497    }
498
499    // Tests migrated from arrays/bool/compute/is_constant.rs
500    #[rstest]
501    #[case(vec![true], true)]
502    #[case(vec![false; 65], true)]
503    #[case({
504        let mut v = vec![true; 64];
505        v.push(false);
506        v
507    }, false)]
508    fn test_bool_is_constant(#[case] input: Vec<bool>, #[case] expected: bool) -> VortexResult<()> {
509        let array = BoolArray::from_iter(input);
510        let mut ctx = array_session().create_execution_ctx();
511        assert_eq!(is_constant(&array.into_array(), &mut ctx)?, expected);
512        Ok(())
513    }
514
515    // Tests migrated from arrays/chunked/compute/is_constant.rs
516    #[test]
517    fn empty_chunk_is_constant() -> VortexResult<()> {
518        let chunked = ChunkedArray::try_new(
519            vec![
520                Buffer::<u8>::empty().into_array(),
521                Buffer::<u8>::empty().into_array(),
522                buffer![255u8, 255].into_array(),
523                Buffer::<u8>::empty().into_array(),
524                buffer![255u8, 255].into_array(),
525            ],
526            DType::Primitive(PType::U8, Nullability::NonNullable),
527        )?
528        .into_array();
529
530        let mut ctx = array_session().create_execution_ctx();
531        assert!(is_constant(&chunked, &mut ctx)?);
532        Ok(())
533    }
534
535    // Tests migrated from arrays/decimal/compute/is_constant.rs
536    #[test]
537    fn test_decimal_is_constant() -> VortexResult<()> {
538        let mut ctx = array_session().create_execution_ctx();
539
540        let array = DecimalArray::new(
541            buffer![0i128, 1i128, 2i128],
542            DecimalDType::new(19, 0),
543            Validity::NonNullable,
544        );
545        assert!(!is_constant(&array.into_array(), &mut ctx)?);
546
547        let array = DecimalArray::new(
548            buffer![100i128, 100i128, 100i128],
549            DecimalDType::new(19, 0),
550            Validity::NonNullable,
551        );
552        assert!(is_constant(&array.into_array(), &mut ctx)?);
553        Ok(())
554    }
555
556    // Tests migrated from arrays/list/compute/is_constant.rs
557    #[test]
558    fn test_is_constant_nested_list() -> VortexResult<()> {
559        let mut ctx = array_session().create_execution_ctx();
560
561        let xs = ListArray::try_new(
562            buffer![0i32, 1, 0, 1].into_array(),
563            buffer![0u32, 2, 4].into_array(),
564            Validity::NonNullable,
565        )?;
566
567        let struct_of_lists = StructArray::try_new(
568            FieldNames::from(["xs"]),
569            vec![xs.into_array()],
570            2,
571            Validity::NonNullable,
572        )?;
573        assert!(is_constant(
574            &struct_of_lists.clone().into_array(),
575            &mut ctx
576        )?);
577        assert!(is_constant(&struct_of_lists.into_array(), &mut ctx)?);
578        Ok(())
579    }
580
581    #[rstest]
582    #[case(
583        // [1,2], [1, 2], [1, 2]
584        vec![1i32, 2, 1, 2, 1, 2],
585        vec![0u32, 2, 4, 6],
586        true
587    )]
588    #[case(
589        // [1, 2], [3], [4, 5]
590        vec![1i32, 2, 3, 4, 5],
591        vec![0u32, 2, 3, 5],
592        false
593    )]
594    #[case(
595        // [1, 2], [3, 4]
596        vec![1i32, 2, 3, 4],
597        vec![0u32, 2, 4],
598        false
599    )]
600    #[case(
601        // [], [], []
602        vec![],
603        vec![0u32, 0, 0, 0],
604        true
605    )]
606    fn test_list_is_constant(
607        #[case] elements: Vec<i32>,
608        #[case] offsets: Vec<u32>,
609        #[case] expected: bool,
610    ) -> VortexResult<()> {
611        let list_array = ListArray::try_new(
612            PrimitiveArray::from_iter(elements).into_array(),
613            PrimitiveArray::from_iter(offsets).into_array(),
614            Validity::NonNullable,
615        )?;
616
617        let mut ctx = array_session().create_execution_ctx();
618        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
619        Ok(())
620    }
621
622    #[test]
623    fn test_list_is_constant_nested_lists() -> VortexResult<()> {
624        let inner_elements = buffer![1i32, 2, 1, 2].into_array();
625        let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
626        let inner_lists = ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable)?;
627
628        let outer_offsets = buffer![0u32, 2, 4].into_array();
629        let outer_list = ListArray::try_new(
630            inner_lists.into_array(),
631            outer_offsets,
632            Validity::NonNullable,
633        )?;
634
635        let mut ctx = array_session().create_execution_ctx();
636        // Both outer lists contain [[1], [2]], so should be constant
637        assert!(is_constant(&outer_list.into_array(), &mut ctx)?);
638        Ok(())
639    }
640
641    #[rstest]
642    #[case(
643        // 100 identical [1, 2] lists
644        [1i32, 2].repeat(100),
645        (0..101).map(|i| (i * 2) as u32).collect(),
646        true
647    )]
648    #[case(
649        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
650        {
651            let mut elements = [1i32, 2].repeat(64);
652            elements.extend_from_slice(&[3, 4]);
653            elements
654        },
655        (0..66).map(|i| (i * 2) as u32).collect(),
656        false
657    )]
658    #[case(
659        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
660        {
661            let mut elements = [1i32, 2].repeat(63);
662            elements.extend_from_slice(&[3, 4]);
663            elements.extend([1i32, 2].repeat(37));
664            elements
665        },
666        (0..101).map(|i| (i * 2) as u32).collect(),
667        false
668    )]
669    fn test_list_is_constant_with_threshold(
670        #[case] elements: Vec<i32>,
671        #[case] offsets: Vec<u32>,
672        #[case] expected: bool,
673    ) -> VortexResult<()> {
674        let list_array = ListArray::try_new(
675            PrimitiveArray::from_iter(elements).into_array(),
676            PrimitiveArray::from_iter(offsets).into_array(),
677            Validity::NonNullable,
678        )?;
679
680        let mut ctx = array_session().create_execution_ctx();
681        assert_eq!(is_constant(&list_array.into_array(), &mut ctx)?, expected);
682        Ok(())
683    }
684}