vortex_array/arrays/list/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_scalar::NumericOperator;
6
7use crate::arrays::{ListArray, ListVTable};
8use crate::compute::{IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, numeric};
9use crate::register_kernel;
10
11const SMALL_ARRAY_THRESHOLD: usize = 64;
12
13impl IsConstantKernel for ListVTable {
14    fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
15        // At this point, we're guaranteed:
16        // - Array has at least 2 elements
17        // - All elements are valid (no nulls)
18
19        let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
20
21        // We can first quickly check if all of the list lengths are equal. If not, then we know the
22        // array cannot be constant.
23        let first_list_len = array.offset_at(1) - array.offset_at(0);
24        for i in 1..manual_check_until {
25            let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
26            if current_list_len != first_list_len {
27                return Ok(Some(false));
28            }
29        }
30
31        // Since we were not able to determine that this list array was **not** constant, and the
32        // cost is negligible, then don't bother doing the rest of this expensive check.
33        if opts.is_negligible_cost() {
34            return Ok(None);
35        }
36
37        // If the array is long, do an optimistic check on the remainder of the list lengths.
38        if array.len() > SMALL_ARRAY_THRESHOLD {
39            // check the rest of the element lengths
40            let start_offsets = array.offsets.slice(SMALL_ARRAY_THRESHOLD..array.len());
41            let end_offsets = array
42                .offsets
43                .slice(SMALL_ARRAY_THRESHOLD + 1..array.len() + 1);
44            let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
45
46            if !list_lengths.is_constant() {
47                return Ok(Some(false));
48            }
49        }
50
51        debug_assert!(
52            array.len() > 1,
53            "precondition for `is_constant` is incorrect"
54        );
55        let first_scalar = array.scalar_at(0); // We checked the array length above.
56
57        // All lists have the same length, so compare the actual list contents.
58        for i in 1..array.len() {
59            let current_scalar = array.scalar_at(i);
60            if current_scalar != first_scalar {
61                return Ok(Some(false));
62            }
63        }
64
65        Ok(Some(true))
66    }
67}
68
69register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
70
71#[cfg(test)]
72mod tests {
73
74    use rstest::rstest;
75    use vortex_buffer::buffer;
76    use vortex_dtype::FieldNames;
77
78    use crate::IntoArray;
79    use crate::arrays::{ListArray, PrimitiveArray, StructArray};
80    use crate::compute::is_constant;
81    use crate::validity::Validity;
82
83    #[test]
84    fn test_is_constant_nested_list() {
85        let xs = ListArray::try_new(
86            buffer![0i32, 1, 0, 1].into_array(),
87            buffer![0u32, 2, 4].into_array(),
88            Validity::NonNullable,
89        )
90        .unwrap();
91
92        let struct_of_lists = StructArray::try_new(
93            FieldNames::from(["xs"]),
94            vec![xs.into_array()],
95            2,
96            Validity::NonNullable,
97        )
98        .unwrap();
99        assert!(
100            is_constant(&struct_of_lists.clone().into_array())
101                .unwrap()
102                .unwrap()
103        );
104        assert!(struct_of_lists.is_constant());
105    }
106
107    #[rstest]
108    #[case(
109        // [1,2], [1, 2], [1, 2]
110        vec![1i32, 2, 1, 2, 1, 2],
111        vec![0u32, 2, 4, 6],
112        true
113    )]
114    #[case(
115        // [1, 2], [3], [4, 5]
116        vec![1i32, 2, 3, 4, 5],
117        vec![0u32, 2, 3, 5],
118        false
119    )]
120    #[case(
121        // [1, 2], [3, 4]
122        vec![1i32, 2, 3, 4],
123        vec![0u32, 2, 4],
124        false
125    )]
126    #[case(
127        // [], [], []
128        vec![],
129        vec![0u32, 0, 0, 0],
130        true
131    )]
132    fn test_list_is_constant(
133        #[case] elements: Vec<i32>,
134        #[case] offsets: Vec<u32>,
135        #[case] expected: bool,
136    ) {
137        let list_array = ListArray::try_new(
138            PrimitiveArray::from_iter(elements).into_array(),
139            PrimitiveArray::from_iter(offsets).into_array(),
140            Validity::NonNullable,
141        )
142        .unwrap();
143
144        let result = is_constant(&list_array.into_array()).unwrap();
145        assert_eq!(result.unwrap(), expected);
146    }
147
148    #[test]
149    fn test_list_is_constant_nested_lists() {
150        let inner_elements = buffer![1i32, 2, 1, 2].into_array();
151        let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
152        let inner_lists =
153            ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
154
155        let outer_offsets = buffer![0u32, 2, 4].into_array();
156        let outer_list = ListArray::try_new(
157            inner_lists.into_array(),
158            outer_offsets,
159            Validity::NonNullable,
160        )
161        .unwrap();
162
163        // Both outer lists contain [[1], [2]], so should be constant
164        assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
165    }
166
167    #[rstest]
168    #[case(
169        // 100 identical [1, 2] lists
170        [1i32, 2].repeat(100),
171        (0..101).map(|i| (i * 2) as u32).collect(),
172        true
173    )]
174    #[case(
175        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
176        {
177            let mut elements = [1i32, 2].repeat(64);
178            elements.extend_from_slice(&[3, 4]);
179            elements
180        },
181        (0..66).map(|i| (i * 2) as u32).collect(),
182        false
183    )]
184    #[case(
185        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
186        {
187            let mut elements = [1i32, 2].repeat(63);
188            elements.extend_from_slice(&[3, 4]);
189            elements.extend([1i32, 2].repeat(37));
190            elements
191        },
192        (0..101).map(|i| (i * 2) as u32).collect(),
193        false
194    )]
195    fn test_list_is_constant_with_threshold(
196        #[case] elements: Vec<i32>,
197        #[case] offsets: Vec<u32>,
198        #[case] expected: bool,
199    ) {
200        let list_array = ListArray::try_new(
201            PrimitiveArray::from_iter(elements).into_array(),
202            PrimitiveArray::from_iter(offsets).into_array(),
203            Validity::NonNullable,
204        )
205        .unwrap();
206
207        let result = is_constant(&list_array.into_array()).unwrap();
208        assert_eq!(result.unwrap(), expected);
209    }
210}