Skip to main content

vortex_array/arrays/list/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::arrays::ListArray;
7use crate::arrays::ListVTable;
8use crate::compute::IsConstantKernel;
9use crate::compute::IsConstantKernelAdapter;
10use crate::compute::IsConstantOpts;
11use crate::compute::is_constant;
12use crate::compute::numeric;
13use crate::register_kernel;
14use crate::scalar::NumericOperator;
15
16const SMALL_ARRAY_THRESHOLD: usize = 64;
17
18impl IsConstantKernel for ListVTable {
19    fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
20        // At this point, we're guaranteed:
21        // - Array has at least 2 elements
22        // - All elements are valid (no nulls)
23
24        let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
25
26        // We can first quickly check if all of the list lengths are equal. If not, then we know the
27        // array cannot be constant.
28        let first_list_len = array.offset_at(1)? - array.offset_at(0)?;
29        for i in 1..manual_check_until {
30            let current_list_len = array.offset_at(i + 1)? - array.offset_at(i)?;
31            if current_list_len != first_list_len {
32                return Ok(Some(false));
33            }
34        }
35
36        // Since we were not able to determine that this list array was **not** constant, and the
37        // cost is negligible, then don't bother doing the rest of this expensive check.
38        if opts.is_negligible_cost() {
39            return Ok(None);
40        }
41
42        // If the array is long, do an optimistic check on the remainder of the list lengths.
43        if array.len() > SMALL_ARRAY_THRESHOLD {
44            // check the rest of the element lengths
45            let start_offsets = array.offsets().slice(SMALL_ARRAY_THRESHOLD..array.len())?;
46            let end_offsets = array
47                .offsets()
48                .slice(SMALL_ARRAY_THRESHOLD + 1..array.len() + 1)?;
49            let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
50
51            if !is_constant(&list_lengths)?.unwrap_or_default() {
52                return Ok(Some(false));
53            }
54        }
55
56        debug_assert!(
57            array.len() > 1,
58            "precondition for `is_constant` is incorrect"
59        );
60        let first_scalar = array.scalar_at(0)?; // We checked the array length above.
61
62        // All lists have the same length, so compare the actual list contents.
63        for i in 1..array.len() {
64            let current_scalar = array.scalar_at(i)?;
65            if current_scalar != first_scalar {
66                return Ok(Some(false));
67            }
68        }
69
70        Ok(Some(true))
71    }
72}
73
74register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
75
76#[cfg(test)]
77mod tests {
78
79    use rstest::rstest;
80    use vortex_buffer::buffer;
81    use vortex_dtype::FieldNames;
82
83    use crate::IntoArray;
84    use crate::arrays::ListArray;
85    use crate::arrays::PrimitiveArray;
86    use crate::arrays::StructArray;
87    use crate::compute::is_constant;
88    use crate::validity::Validity;
89
90    #[test]
91    fn test_is_constant_nested_list() {
92        let xs = ListArray::try_new(
93            buffer![0i32, 1, 0, 1].into_array(),
94            buffer![0u32, 2, 4].into_array(),
95            Validity::NonNullable,
96        )
97        .unwrap();
98
99        let struct_of_lists = StructArray::try_new(
100            FieldNames::from(["xs"]),
101            vec![xs.into_array()],
102            2,
103            Validity::NonNullable,
104        )
105        .unwrap();
106        assert!(
107            is_constant(&struct_of_lists.clone().into_array())
108                .unwrap()
109                .unwrap()
110        );
111        assert!(
112            is_constant(struct_of_lists.as_ref())
113                .unwrap()
114                .unwrap_or_default()
115        );
116    }
117
118    #[rstest]
119    #[case(
120        // [1,2], [1, 2], [1, 2]
121        vec![1i32, 2, 1, 2, 1, 2],
122        vec![0u32, 2, 4, 6],
123        true
124    )]
125    #[case(
126        // [1, 2], [3], [4, 5]
127        vec![1i32, 2, 3, 4, 5],
128        vec![0u32, 2, 3, 5],
129        false
130    )]
131    #[case(
132        // [1, 2], [3, 4]
133        vec![1i32, 2, 3, 4],
134        vec![0u32, 2, 4],
135        false
136    )]
137    #[case(
138        // [], [], []
139        vec![],
140        vec![0u32, 0, 0, 0],
141        true
142    )]
143    fn test_list_is_constant(
144        #[case] elements: Vec<i32>,
145        #[case] offsets: Vec<u32>,
146        #[case] expected: bool,
147    ) {
148        let list_array = ListArray::try_new(
149            PrimitiveArray::from_iter(elements).into_array(),
150            PrimitiveArray::from_iter(offsets).into_array(),
151            Validity::NonNullable,
152        )
153        .unwrap();
154
155        let result = is_constant(&list_array.into_array()).unwrap();
156        assert_eq!(result.unwrap(), expected);
157    }
158
159    #[test]
160    fn test_list_is_constant_nested_lists() {
161        let inner_elements = buffer![1i32, 2, 1, 2].into_array();
162        let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
163        let inner_lists =
164            ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
165
166        let outer_offsets = buffer![0u32, 2, 4].into_array();
167        let outer_list = ListArray::try_new(
168            inner_lists.into_array(),
169            outer_offsets,
170            Validity::NonNullable,
171        )
172        .unwrap();
173
174        // Both outer lists contain [[1], [2]], so should be constant
175        assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
176    }
177
178    #[rstest]
179    #[case(
180        // 100 identical [1, 2] lists
181        [1i32, 2].repeat(100),
182        (0..101).map(|i| (i * 2) as u32).collect(),
183        true
184    )]
185    #[case(
186        // Difference after threshold: 64 identical [1, 2] + one [3, 4]
187        {
188            let mut elements = [1i32, 2].repeat(64);
189            elements.extend_from_slice(&[3, 4]);
190            elements
191        },
192        (0..66).map(|i| (i * 2) as u32).collect(),
193        false
194    )]
195    #[case(
196        // Difference in first 64: first 63 identical [1, 2] + one [3, 4] + rest identical [1, 2]
197        {
198            let mut elements = [1i32, 2].repeat(63);
199            elements.extend_from_slice(&[3, 4]);
200            elements.extend([1i32, 2].repeat(37));
201            elements
202        },
203        (0..101).map(|i| (i * 2) as u32).collect(),
204        false
205    )]
206    fn test_list_is_constant_with_threshold(
207        #[case] elements: Vec<i32>,
208        #[case] offsets: Vec<u32>,
209        #[case] expected: bool,
210    ) {
211        let list_array = ListArray::try_new(
212            PrimitiveArray::from_iter(elements).into_array(),
213            PrimitiveArray::from_iter(offsets).into_array(),
214            Validity::NonNullable,
215        )
216        .unwrap();
217
218        let result = is_constant(&list_array.into_array()).unwrap();
219        assert_eq!(result.unwrap(), expected);
220    }
221}