vortex_array/arrays/list/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_scalar::NumericOperator;
6
7use crate::arrays::{ListArray, ListVTable};
8use crate::compute::{IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, numeric};
9use crate::register_kernel;
10
11const SMALL_ARRAY_THRESHOLD: usize = 64;
12
13impl IsConstantKernel for ListVTable {
14    fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
15        // At this point, we're guaranteed:
16        // - Array has at least 2 elements
17        // - All elements are valid (no nulls)
18
19        let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
20
21        let first_list_len = array.offset_at(1) - array.offset_at(0);
22        for i in 1..manual_check_until {
23            let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
24            if current_list_len != first_list_len {
25                return Ok(Some(false));
26            }
27        }
28
29        if opts.is_negligible_cost() {
30            return Ok(None);
31        }
32
33        if array.len() > SMALL_ARRAY_THRESHOLD {
34            // check the rest of the element lengths
35            let start_offsets = array.offsets.slice(SMALL_ARRAY_THRESHOLD, array.len())?;
36            let end_offsets = array
37                .offsets
38                .slice(SMALL_ARRAY_THRESHOLD + 1, array.len() + 1)?;
39            let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
40
41            if !list_lengths.is_constant() {
42                return Ok(Some(false));
43            }
44        }
45
46        // If all lists have the same length, compare the actual list contents
47        let first_scalar = array.scalar_at(0)?;
48        for i in 1..array.len() {
49            let current_scalar = array.scalar_at(i)?;
50            if current_scalar != first_scalar {
51                return Ok(Some(false));
52            }
53        }
54
55        Ok(Some(true))
56    }
57}
58
59register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
60
61#[cfg(test)]
62mod tests {
63
64    use rstest::rstest;
65    use vortex_dtype::FieldNames;
66
67    use crate::IntoArray;
68    use crate::arrays::{ListArray, PrimitiveArray, StructArray};
69    use crate::compute::is_constant;
70    use crate::validity::Validity;
71
72    #[test]
73    fn test_is_constant_nested_list() {
74        let xs = ListArray::try_new(
75            PrimitiveArray::from_iter([0i32, 1, 0, 1]).into_array(),
76            PrimitiveArray::from_iter([0u32, 2, 4]).into_array(),
77            Validity::NonNullable,
78        )
79        .unwrap();
80
81        let struct_of_lists = StructArray::try_new(
82            FieldNames::from(["xs"]),
83            vec![xs.into_array()],
84            2,
85            Validity::NonNullable,
86        )
87        .unwrap();
88        assert!(
89            is_constant(&struct_of_lists.clone().into_array())
90                .unwrap()
91                .unwrap()
92        );
93        assert!(struct_of_lists.is_constant());
94    }
95
96    #[rstest]
97    #[case(
98        // [1,2], [1, 2], [1, 2]
99        vec![1i32, 2, 1, 2, 1, 2],
100        vec![0u32, 2, 4, 6],
101        true
102    )]
103    #[case(
104        // [1, 2], [3], [4, 5]
105        vec![1i32, 2, 3, 4, 5],
106        vec![0u32, 2, 3, 5],
107        false
108    )]
109    #[case(
110        // [1, 2], [3, 4]
111        vec![1i32, 2, 3, 4],
112        vec![0u32, 2, 4],
113        false
114    )]
115    #[case(
116        // [], [], []
117        vec![],
118        vec![0u32, 0, 0, 0],
119        true
120    )]
121    fn test_list_is_constant(
122        #[case] elements: Vec<i32>,
123        #[case] offsets: Vec<u32>,
124        #[case] expected: bool,
125    ) {
126        let list_array = ListArray::try_new(
127            PrimitiveArray::from_iter(elements).into_array(),
128            PrimitiveArray::from_iter(offsets).into_array(),
129            Validity::NonNullable,
130        )
131        .unwrap();
132
133        let result = is_constant(&list_array.into_array()).unwrap();
134        assert_eq!(result.unwrap(), expected);
135    }
136
137    #[test]
138    fn test_list_is_constant_nested_lists() {
139        let inner_elements = PrimitiveArray::from_iter([1i32, 2, 1, 2]).into_array();
140        let inner_offsets = PrimitiveArray::from_iter([0u32, 1, 2, 3, 4]).into_array();
141        let inner_lists =
142            ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
143
144        let outer_offsets = PrimitiveArray::from_iter([0u32, 2, 4]).into_array();
145        let outer_list = ListArray::try_new(
146            inner_lists.into_array(),
147            outer_offsets,
148            Validity::NonNullable,
149        )
150        .unwrap();
151
152        // Both outer lists contain [[1], [2]], so should be constant
153        assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
154    }
155
156    #[rstest]
157    #[case(
158        // 100 identical [1, 2] lists
159        [1i32, 2].repeat(100),
160        (0..101).map(|i| (i * 2) as u32).collect(),
161        true
162    )]
163    #[case(
164        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
165        {
166            let mut elements = [1i32, 2].repeat(64);
167            elements.extend_from_slice(&[3, 4]);
168            elements
169        },
170        (0..66).map(|i| (i * 2) as u32).collect(),
171        false
172    )]
173    #[case(
174        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
175        {
176            let mut elements = [1i32, 2].repeat(63);
177            elements.extend_from_slice(&[3, 4]);
178            elements.extend([1i32, 2].repeat(37));
179            elements
180        },
181        (0..101).map(|i| (i * 2) as u32).collect(),
182        false
183    )]
184    fn test_large_list_is_constant(
185        #[case] elements: Vec<i32>,
186        #[case] offsets: Vec<u32>,
187        #[case] expected: bool,
188    ) {
189        let list_array = ListArray::try_new(
190            PrimitiveArray::from_iter(elements).into_array(),
191            PrimitiveArray::from_iter(offsets).into_array(),
192            Validity::NonNullable,
193        )
194        .unwrap();
195
196        let result = is_constant(&list_array.into_array()).unwrap();
197        assert_eq!(result.unwrap(), expected);
198    }
199}