vortex_array/arrays/list/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_scalar::NumericOperator;
6
7use crate::arrays::ListArray;
8use crate::arrays::ListVTable;
9use crate::compute::IsConstantKernel;
10use crate::compute::IsConstantKernelAdapter;
11use crate::compute::IsConstantOpts;
12use crate::compute::numeric;
13use crate::register_kernel;
14
15const SMALL_ARRAY_THRESHOLD: usize = 64;
16
17impl IsConstantKernel for ListVTable {
18    fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
19        // At this point, we're guaranteed:
20        // - Array has at least 2 elements
21        // - All elements are valid (no nulls)
22
23        let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
24
25        // We can first quickly check if all of the list lengths are equal. If not, then we know the
26        // array cannot be constant.
27        let first_list_len = array.offset_at(1) - array.offset_at(0);
28        for i in 1..manual_check_until {
29            let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
30            if current_list_len != first_list_len {
31                return Ok(Some(false));
32            }
33        }
34
35        // Since we were not able to determine that this list array was **not** constant, and the
36        // cost is negligible, then don't bother doing the rest of this expensive check.
37        if opts.is_negligible_cost() {
38            return Ok(None);
39        }
40
41        // If the array is long, do an optimistic check on the remainder of the list lengths.
42        if array.len() > SMALL_ARRAY_THRESHOLD {
43            // check the rest of the element lengths
44            let start_offsets = array.offsets().slice(SMALL_ARRAY_THRESHOLD..array.len());
45            let end_offsets = array
46                .offsets()
47                .slice(SMALL_ARRAY_THRESHOLD + 1..array.len() + 1);
48            let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
49
50            if !list_lengths.is_constant() {
51                return Ok(Some(false));
52            }
53        }
54
55        debug_assert!(
56            array.len() > 1,
57            "precondition for `is_constant` is incorrect"
58        );
59        let first_scalar = array.scalar_at(0); // We checked the array length above.
60
61        // All lists have the same length, so compare the actual list contents.
62        for i in 1..array.len() {
63            let current_scalar = array.scalar_at(i);
64            if current_scalar != first_scalar {
65                return Ok(Some(false));
66            }
67        }
68
69        Ok(Some(true))
70    }
71}
72
73register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
74
75#[cfg(test)]
76mod tests {
77
78    use rstest::rstest;
79    use vortex_buffer::buffer;
80    use vortex_dtype::FieldNames;
81
82    use crate::IntoArray;
83    use crate::arrays::ListArray;
84    use crate::arrays::PrimitiveArray;
85    use crate::arrays::StructArray;
86    use crate::compute::is_constant;
87    use crate::validity::Validity;
88
89    #[test]
90    fn test_is_constant_nested_list() {
91        let xs = ListArray::try_new(
92            buffer![0i32, 1, 0, 1].into_array(),
93            buffer![0u32, 2, 4].into_array(),
94            Validity::NonNullable,
95        )
96        .unwrap();
97
98        let struct_of_lists = StructArray::try_new(
99            FieldNames::from(["xs"]),
100            vec![xs.into_array()],
101            2,
102            Validity::NonNullable,
103        )
104        .unwrap();
105        assert!(
106            is_constant(&struct_of_lists.clone().into_array())
107                .unwrap()
108                .unwrap()
109        );
110        assert!(struct_of_lists.is_constant());
111    }
112
113    #[rstest]
114    #[case(
115        // [1,2], [1, 2], [1, 2]
116        vec![1i32, 2, 1, 2, 1, 2],
117        vec![0u32, 2, 4, 6],
118        true
119    )]
120    #[case(
121        // [1, 2], [3], [4, 5]
122        vec![1i32, 2, 3, 4, 5],
123        vec![0u32, 2, 3, 5],
124        false
125    )]
126    #[case(
127        // [1, 2], [3, 4]
128        vec![1i32, 2, 3, 4],
129        vec![0u32, 2, 4],
130        false
131    )]
132    #[case(
133        // [], [], []
134        vec![],
135        vec![0u32, 0, 0, 0],
136        true
137    )]
138    fn test_list_is_constant(
139        #[case] elements: Vec<i32>,
140        #[case] offsets: Vec<u32>,
141        #[case] expected: bool,
142    ) {
143        let list_array = ListArray::try_new(
144            PrimitiveArray::from_iter(elements).into_array(),
145            PrimitiveArray::from_iter(offsets).into_array(),
146            Validity::NonNullable,
147        )
148        .unwrap();
149
150        let result = is_constant(&list_array.into_array()).unwrap();
151        assert_eq!(result.unwrap(), expected);
152    }
153
154    #[test]
155    fn test_list_is_constant_nested_lists() {
156        let inner_elements = buffer![1i32, 2, 1, 2].into_array();
157        let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
158        let inner_lists =
159            ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
160
161        let outer_offsets = buffer![0u32, 2, 4].into_array();
162        let outer_list = ListArray::try_new(
163            inner_lists.into_array(),
164            outer_offsets,
165            Validity::NonNullable,
166        )
167        .unwrap();
168
169        // Both outer lists contain [[1], [2]], so should be constant
170        assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
171    }
172
173    #[rstest]
174    #[case(
175        // 100 identical [1, 2] lists
176        [1i32, 2].repeat(100),
177        (0..101).map(|i| (i * 2) as u32).collect(),
178        true
179    )]
180    #[case(
181        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
182        {
183            let mut elements = [1i32, 2].repeat(64);
184            elements.extend_from_slice(&[3, 4]);
185            elements
186        },
187        (0..66).map(|i| (i * 2) as u32).collect(),
188        false
189    )]
190    #[case(
191        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
192        {
193            let mut elements = [1i32, 2].repeat(63);
194            elements.extend_from_slice(&[3, 4]);
195            elements.extend([1i32, 2].repeat(37));
196            elements
197        },
198        (0..101).map(|i| (i * 2) as u32).collect(),
199        false
200    )]
201    fn test_list_is_constant_with_threshold(
202        #[case] elements: Vec<i32>,
203        #[case] offsets: Vec<u32>,
204        #[case] expected: bool,
205    ) {
206        let list_array = ListArray::try_new(
207            PrimitiveArray::from_iter(elements).into_array(),
208            PrimitiveArray::from_iter(offsets).into_array(),
209            Validity::NonNullable,
210        )
211        .unwrap();
212
213        let result = is_constant(&list_array.into_array()).unwrap();
214        assert_eq!(result.unwrap(), expected);
215    }
216}