vortex_array/arrays/list/compute/
is_constant.rs1use vortex_error::VortexResult;
2use vortex_scalar::NumericOperator;
3
4use crate::arrays::{ListArray, ListVTable};
5use crate::compute::{IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, numeric};
6use crate::register_kernel;
7
8const SMALL_ARRAY_THRESHOLD: usize = 64;
9
10impl IsConstantKernel for ListVTable {
11 fn is_constant(&self, array: &ListArray, opts: &IsConstantOpts) -> VortexResult<Option<bool>> {
12 let manual_check_until = std::cmp::min(SMALL_ARRAY_THRESHOLD, array.len());
17
18 let first_list_len = array.offset_at(1) - array.offset_at(0);
19 for i in 1..manual_check_until {
20 let current_list_len = array.offset_at(i + 1) - array.offset_at(i);
21 if current_list_len != first_list_len {
22 return Ok(Some(false));
23 }
24 }
25
26 if opts.is_negligible_cost() {
27 return Ok(None);
28 }
29
30 if array.len() > SMALL_ARRAY_THRESHOLD {
31 let start_offsets = array.offsets.slice(SMALL_ARRAY_THRESHOLD, array.len())?;
33 let end_offsets = array
34 .offsets
35 .slice(SMALL_ARRAY_THRESHOLD + 1, array.len() + 1)?;
36 let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?;
37
38 if !list_lengths.is_constant() {
39 return Ok(Some(false));
40 }
41 }
42
43 let first_scalar = array.scalar_at(0)?;
45 for i in 1..array.len() {
46 let current_scalar = array.scalar_at(i)?;
47 if current_scalar != first_scalar {
48 return Ok(Some(false));
49 }
50 }
51
52 Ok(Some(true))
53 }
54}
55
56register_kernel!(IsConstantKernelAdapter(ListVTable).lift());
57
58#[cfg(test)]
59mod tests {
60
61 use rstest::rstest;
62 use vortex_dtype::FieldNames;
63
64 use crate::IntoArray;
65 use crate::arrays::{ListArray, PrimitiveArray, StructArray};
66 use crate::compute::is_constant;
67 use crate::validity::Validity;
68
69 #[test]
70 fn test_is_constant_nested_list() {
71 let xs = ListArray::try_new(
72 PrimitiveArray::from_iter([0i32, 1, 0, 1]).into_array(),
73 PrimitiveArray::from_iter([0u32, 2, 4]).into_array(),
74 Validity::NonNullable,
75 )
76 .unwrap();
77
78 let struct_of_lists = StructArray::try_new(
79 FieldNames::from(["xs"]),
80 vec![xs.into_array()],
81 2,
82 Validity::NonNullable,
83 )
84 .unwrap();
85 assert!(
86 is_constant(&struct_of_lists.clone().into_array())
87 .unwrap()
88 .unwrap()
89 );
90 assert!(struct_of_lists.is_constant());
91 }
92
93 #[rstest]
94 #[case(
95 vec![1i32, 2, 1, 2, 1, 2],
97 vec![0u32, 2, 4, 6],
98 true
99 )]
100 #[case(
101 vec![1i32, 2, 3, 4, 5],
103 vec![0u32, 2, 3, 5],
104 false
105 )]
106 #[case(
107 vec![1i32, 2, 3, 4],
109 vec![0u32, 2, 4],
110 false
111 )]
112 #[case(
113 vec![],
115 vec![0u32, 0, 0, 0],
116 true
117 )]
118 fn test_list_is_constant(
119 #[case] elements: Vec<i32>,
120 #[case] offsets: Vec<u32>,
121 #[case] expected: bool,
122 ) {
123 let list_array = ListArray::try_new(
124 PrimitiveArray::from_iter(elements).into_array(),
125 PrimitiveArray::from_iter(offsets).into_array(),
126 Validity::NonNullable,
127 )
128 .unwrap();
129
130 let result = is_constant(&list_array.into_array()).unwrap();
131 assert_eq!(result.unwrap(), expected);
132 }
133
134 #[test]
135 fn test_list_is_constant_nested_lists() {
136 let inner_elements = PrimitiveArray::from_iter([1i32, 2, 1, 2]).into_array();
137 let inner_offsets = PrimitiveArray::from_iter([0u32, 1, 2, 3, 4]).into_array();
138 let inner_lists =
139 ListArray::try_new(inner_elements, inner_offsets, Validity::NonNullable).unwrap();
140
141 let outer_offsets = PrimitiveArray::from_iter([0u32, 2, 4]).into_array();
142 let outer_list = ListArray::try_new(
143 inner_lists.into_array(),
144 outer_offsets,
145 Validity::NonNullable,
146 )
147 .unwrap();
148
149 assert!(is_constant(&outer_list.into_array()).unwrap().unwrap());
151 }
152
153 #[rstest]
154 #[case(
155 [1i32, 2].repeat(100),
157 (0..101).map(|i| (i * 2) as u32).collect(),
158 true
159 )]
160 #[case(
161 {
163 let mut elements = [1i32, 2].repeat(64);
164 elements.extend_from_slice(&[3, 4]);
165 elements
166 },
167 (0..66).map(|i| (i * 2) as u32).collect(),
168 false
169 )]
170 #[case(
171 {
173 let mut elements = [1i32, 2].repeat(63);
174 elements.extend_from_slice(&[3, 4]);
175 elements.extend([1i32, 2].repeat(37));
176 elements
177 },
178 (0..101).map(|i| (i * 2) as u32).collect(),
179 false
180 )]
181 fn test_large_list_is_constant(
182 #[case] elements: Vec<i32>,
183 #[case] offsets: Vec<u32>,
184 #[case] expected: bool,
185 ) {
186 let list_array = ListArray::try_new(
187 PrimitiveArray::from_iter(elements).into_array(),
188 PrimitiveArray::from_iter(offsets).into_array(),
189 Validity::NonNullable,
190 )
191 .unwrap();
192
193 let result = is_constant(&list_array.into_array()).unwrap();
194 assert_eq!(result.unwrap(), expected);
195 }
196}