vortex_array/arrays/list/compute/
is_constant.rs1use vortex_error::VortexResult;
5use vortex_scalar::NumericOperator;
6
7use crate::arrays::ListArray;
8use crate::arrays::ListVTable;
9use crate::compute::IsConstantKernel;
10use crate::compute::IsConstantKernelAdapter;
11use crate::compute::IsConstantOpts;
12use crate::compute::numeric;
13use crate::register_kernel;
14
15const SMALL_ARRAY_THRESHOLD: usize = 64;
16
17impl IsConstantKernel for ListVTable {
18 fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
19 let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
24
25 let first_list_len = array.offset_at(1) - array.offset_at(0);
28 for i in 1..manual_check_until {
29 let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
30 if current_list_len != first_list_len {
31 return Ok(Some(false));
32 }
33 }
34
35 if opts.is_negligible_cost() {
38 return Ok(None);
39 }
40
41 if array.len() > SMALL_ARRAY_THRESHOLD {
43 let start_offsets = array.offsets().slice(SMALL_ARRAY_THRESHOLD..array.len());
45 let end_offsets = array
46 .offsets()
47 .slice(SMALL_ARRAY_THRESHOLD + 1..array.len() + 1);
48 let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
49
50 if !list_lengths.is_constant() {
51 return Ok(Some(false));
52 }
53 }
54
55 debug_assert!(
56 array.len() > 1,
57 "precondition for `is_constant` is incorrect"
58 );
59 let first_scalar = array.scalar_at(0); for i in 1..array.len() {
63 let current_scalar = array.scalar_at(i);
64 if current_scalar != first_scalar {
65 return Ok(Some(false));
66 }
67 }
68
69 Ok(Some(true))
70 }
71}
72
73register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
74
75#[cfg(test)]
76mod tests {
77
78 use rstest::rstest;
79 use vortex_buffer::buffer;
80 use vortex_dtype::FieldNames;
81
82 use crate::IntoArray;
83 use crate::arrays::ListArray;
84 use crate::arrays::PrimitiveArray;
85 use crate::arrays::StructArray;
86 use crate::compute::is_constant;
87 use crate::validity::Validity;
88
89 #[test]
90 fn test_is_constant_nested_list() {
91 let xs = ListArray::try_new(
92 buffer![0i32, 1, 0, 1].into_array(),
93 buffer![0u32, 2, 4].into_array(),
94 Validity::NonNullable,
95 )
96 .unwrap();
97
98 let struct_of_lists = StructArray::try_new(
99 FieldNames::from(["xs"]),
100 vec![xs.into_array()],
101 2,
102 Validity::NonNullable,
103 )
104 .unwrap();
105 assert!(
106 is_constant(&struct_of_lists.clone().into_array())
107 .unwrap()
108 .unwrap()
109 );
110 assert!(struct_of_lists.is_constant());
111 }
112
113 #[rstest]
114 #[case(
115 vec![1i32, 2, 1, 2, 1, 2],
117 vec![0u32, 2, 4, 6],
118 true
119 )]
120 #[case(
121 vec![1i32, 2, 3, 4, 5],
123 vec![0u32, 2, 3, 5],
124 false
125 )]
126 #[case(
127 vec![1i32, 2, 3, 4],
129 vec![0u32, 2, 4],
130 false
131 )]
132 #[case(
133 vec![],
135 vec![0u32, 0, 0, 0],
136 true
137 )]
138 fn test_list_is_constant(
139 #[case] elements: Vec<i32>,
140 #[case] offsets: Vec<u32>,
141 #[case] expected: bool,
142 ) {
143 let list_array = ListArray::try_new(
144 PrimitiveArray::from_iter(elements).into_array(),
145 PrimitiveArray::from_iter(offsets).into_array(),
146 Validity::NonNullable,
147 )
148 .unwrap();
149
150 let result = is_constant(&list_array.into_array()).unwrap();
151 assert_eq!(result.unwrap(), expected);
152 }
153
154 #[test]
155 fn test_list_is_constant_nested_lists() {
156 let inner_elements = buffer![1i32, 2, 1, 2].into_array();
157 let inner_offsets = buffer![0u32, 1, 2, 3, 4].into_array();
158 let inner_lists =
159 ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
160
161 let outer_offsets = buffer![0u32, 2, 4].into_array();
162 let outer_list = ListArray::try_new(
163 inner_lists.into_array(),
164 outer_offsets,
165 Validity::NonNullable,
166 )
167 .unwrap();
168
169 assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
171 }
172
173 #[rstest]
174 #[case(
175 [1i32, 2].repeat(100),
177 (0..101).map(|i| (i * 2) as u32).collect(),
178 true
179 )]
180 #[case(
181 {
183 let mut elements = [1i32, 2].repeat(64);
184 elements.extend_from_slice(&[3, 4]);
185 elements
186 },
187 (0..66).map(|i| (i * 2) as u32).collect(),
188 false
189 )]
190 #[case(
191 {
193 let mut elements = [1i32, 2].repeat(63);
194 elements.extend_from_slice(&[3, 4]);
195 elements.extend([1i32, 2].repeat(37));
196 elements
197 },
198 (0..101).map(|i| (i * 2) as u32).collect(),
199 false
200 )]
201 fn test_list_is_constant_with_threshold(
202 #[case] elements: Vec<i32>,
203 #[case] offsets: Vec<u32>,
204 #[case] expected: bool,
205 ) {
206 let list_array = ListArray::try_new(
207 PrimitiveArray::from_iter(elements).into_array(),
208 PrimitiveArray::from_iter(offsets).into_array(),
209 Validity::NonNullable,
210 )
211 .unwrap();
212
213 let result = is_constant(&list_array.into_array()).unwrap();
214 assert_eq!(result.unwrap(), expected);
215 }
216}