vortex_array/arrays/list/compute/
is_constant.rs1use vortex_error::VortexResult;
5use vortex_scalar::NumericOperator;
6
7use crate::arrays::{ListArray, ListVTable};
8use crate::compute::{IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, numeric};
9use crate::register_kernel;
10
11const SMALL_ARRAY_THRESHOLD: usize = 64;
12
13impl IsConstantKernel for ListVTable {
14 fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
15 let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
20
21 let first_list_len = array.offset_at(1) - array.offset_at(0);
24 for i in 1..manual_check_until {
25 let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
26 if current_list_len != first_list_len {
27 return Ok(Some(false));
28 }
29 }
30
31 if opts.is_negligible_cost() {
34 return Ok(None);
35 }
36
37 if array.len() > SMALL_ARRAY_THRESHOLD {
39 let start_offsets = array.offsets.slice(SMALL_ARRAY_THRESHOLD..array.len());
41 let end_offsets = array
42 .offsets
43 .slice(SMALL_ARRAY_THRESHOLD + 1..array.len() + 1);
44 let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
45
46 if !list_lengths.is_constant() {
47 return Ok(Some(false));
48 }
49 }
50
51 debug_assert!(
52 array.len() > 1,
53 "precondition for `is_constant` is incorrect"
54 );
55 let first_scalar = array.scalar_at(0); for i in 1..array.len() {
59 let current_scalar = array.scalar_at(i);
60 if current_scalar != first_scalar {
61 return Ok(Some(false));
62 }
63 }
64
65 Ok(Some(true))
66 }
67}
68
69register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
70
71#[cfg(test)]
72mod tests {
73
74 use rstest::rstest;
75 use vortex_buffer::buffer;
76 use vortex_dtype::FieldNames;
77
78 use crate::IntoArray;
79 use crate::arrays::{ListArray, PrimitiveArray, StructArray};
80 use crate::compute::is_constant;
81 use crate::validity::Validity;
82
83 #[test]
84 fn test_is_constant_nested_list() {
85 let xs = ListArray::try_new(
86 buffer![0i32, 1, 0, 1].into_array(),
87 buffer![0u32, 2, 4].into_array(),
88 Validity::NonNullable,
89 )
90 .unwrap();
91
92 let struct_of_lists = StructArray::try_new(
93 FieldNames::from(["xs"]),
94 vec![xs.into_array()],
95 2,
96 Validity::NonNullable,
97 )
98 .unwrap();
99 assert!(
100 is_constant(&struct_of_lists.clone().into_array())
101 .unwrap()
102 .unwrap()
103 );
104 assert!(struct_of_lists.is_constant());
105 }
106
107 #[rstest]
108 #[case(
109 vec![1i32, 2, 1, 2, 1, 2],
111 vec![0u32, 2, 4, 6],
112 true
113 )]
114 #[case(
115 vec![1i32, 2, 3, 4, 5],
117 vec![0u32, 2, 3, 5],
118 false
119 )]
120 #[case(
121 vec![1i32, 2, 3, 4],
123 vec![0u32, 2, 4],
124 false
125 )]
126 #[case(
127 vec![],
129 vec![0u32, 0, 0, 0],
130 true
131 )]
132 fn test_list_is_constant(
133 #[case] elements: Vec<i32>,
134 #[case] offsets: Vec<u32>,
135 #[case] expected: bool,
136 ) {
137 let list_array = ListArray::try_new(
138 PrimitiveArray::from_iter(elements).into_array(),
139 PrimitiveArray::from_iter(offsets).into_array(),
140 Validity::NonNullable,
141 )
142 .unwrap();
143
144 let result = is_constant(&list_array.into_array()).unwrap();
145 assert_eq!(result.unwrap(), expected);
146 }
147
148 #[test]
149 fn test_list_is_constant_nested_lists() {
150 let inner_elements = buffer![1i32, 2, 1, 2].into_array();
151 let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
152 let inner_lists =
153 ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
154
155 let outer_offsets = buffer![0u32, 2, 4].into_array();
156 let outer_list = ListArray::try_new(
157 inner_lists.into_array(),
158 outer_offsets,
159 Validity::NonNullable,
160 )
161 .unwrap();
162
163 assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
165 }
166
167 #[rstest]
168 #[case(
169 [1i32, 2].repeat(100),
171 (0..101).map(|i| (i * 2) as u32).collect(),
172 true
173 )]
174 #[case(
175 {
177 let mut elements = [1i32, 2].repeat(64);
178 elements.extend_from_slice(&[3, 4]);
179 elements
180 },
181 (0..66).map(|i| (i * 2) as u32).collect(),
182 false
183 )]
184 #[case(
185 {
187 let mut elements = [1i32, 2].repeat(63);
188 elements.extend_from_slice(&[3, 4]);
189 elements.extend([1i32, 2].repeat(37));
190 elements
191 },
192 (0..101).map(|i| (i * 2) as u32).collect(),
193 false
194 )]
195 fn test_list_is_constant_with_threshold(
196 #[case] elements: Vec<i32>,
197 #[case] offsets: Vec<u32>,
198 #[case] expected: bool,
199 ) {
200 let list_array = ListArray::try_new(
201 PrimitiveArray::from_iter(elements).into_array(),
202 PrimitiveArray::from_iter(offsets).into_array(),
203 Validity::NonNullable,
204 )
205 .unwrap();
206
207 let result = is_constant(&list_array.into_array()).unwrap();
208 assert_eq!(result.unwrap(), expected);
209 }
210}