1use arrow_buffer::BooleanBuffer;
4use arrow_buffer::bit_iterator::BitIndexIterator;
5use num_traits::AsPrimitive;
6use vortex_buffer::Buffer;
7use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_bail};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray, ListArray};
13use crate::compute::{Operator, compare, invert};
14use crate::validity::Validity;
15use crate::vtable::ValidityHelper;
16use crate::{Array, ArrayRef, IntoArray, ToCanonical};
17
18pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef> {
49 let DType::List(elem_dtype, _nullability) = array.dtype() else {
50 vortex_bail!("Array must be of List type");
51 };
52 if &**elem_dtype != value.dtype() {
53 vortex_bail!("Element type of ListArray does not match search value");
54 }
55
56 if array.is_constant() && array.len() > 1 {
58 let contains = list_contains(&array.slice(0, 1)?, value)?;
59 return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
60 }
61
62 let list_array = array.to_list()?;
65
66 if value.is_null() {
67 return list_contains_null(&list_array);
68 }
69
70 let elems = list_array.elements();
71 let ends = list_array.offsets().to_primitive()?;
72
73 let rhs = ConstantArray::new(value, elems.len());
74 let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
75 let matches = matching_elements.to_bool()?;
76
77 if let Some(pred) = matches.as_constant() {
79 if matches!(pred.as_bool().value(), None | Some(false)) {
80 return Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array());
82 }
83 }
84
85 match_each_integer_ptype!(ends.ptype(), |$T| {
86 Ok(reduce_with_ends(ends.as_slice::<$T>(), &matches.boolean_buffer(), list_array.validity().clone()))
87 })
88}
89
90fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
93 let elems = list_array.elements();
94
95 match elems.validity_mask()? {
97 Mask::AllTrue(_) => match list_array.validity() {
99 Validity::NonNullable => {
100 Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
101 }
102 Validity::AllValid => Ok(ConstantArray::new(
103 Scalar::bool(true, Nullability::Nullable),
104 list_array.len(),
105 )
106 .into_array()),
107 Validity::AllInvalid => Ok(ConstantArray::new(
108 Scalar::null(DType::Bool(Nullability::Nullable)),
109 list_array.len(),
110 )
111 .into_array()),
112 Validity::Array(list_mask) => {
113 let buffer = BooleanBuffer::new_unset(list_array.len());
115 Ok(BoolArray::new(buffer, Validity::Array(list_mask.clone())).into_array())
116 }
117 },
118 Mask::AllFalse(_) => Ok(ConstantArray::new::<bool>(true, list_array.len()).into_array()),
120 Mask::Values(mask) => {
121 let nulls = invert(&mask.into_array())?.to_bool()?;
122 let ends = list_array.offsets().to_primitive()?;
123 match_each_integer_ptype!(ends.ptype(), |$T| {
124 Ok(reduce_with_ends(
125 list_array.offsets().to_primitive()?.as_slice::<$T>(),
126 &nulls.boolean_buffer(),
127 list_array.validity().clone(),
128 ))
129 })
130 }
131 }
132}
133
134fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
137 ends: &[T],
138 matches: &BooleanBuffer,
139 validity: Validity,
140) -> ArrayRef {
141 let mask: BooleanBuffer = ends
142 .windows(2)
143 .map(|window| {
144 let len = window[1].as_() - window[0].as_();
145 let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
146 set_bits.next().is_some()
147 })
148 .collect();
149
150 BoolArray::new(mask, validity).into_array()
151}
152
153pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
176 if !matches!(array.dtype(), DType::List(..)) {
177 vortex_bail!("Array must be of list type");
178 }
179
180 if array.is_constant() && array.len() > 1 {
182 let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
183 return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
184 }
185
186 let list_array = array.to_list()?;
187 let offsets = list_array.offsets().to_primitive()?;
188 let lens_array = match_each_integer_ptype!(offsets.ptype(), |$T| {
189 element_lens(offsets.as_slice::<$T>()).into_array()
190 });
191
192 Ok(lens_array)
193}
194
195fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
196 values
197 .windows(2)
198 .map(|window| window[1] - window[0])
199 .collect()
200}
201
202#[cfg(test)]
203mod tests {
204 use std::sync::Arc;
205
206 use itertools::Itertools;
207 use rstest::rstest;
208 use vortex_buffer::Buffer;
209 use vortex_dtype::{DType, Nullability, PType};
210 use vortex_scalar::Scalar;
211
212 use crate::arrays::{BoolArray, ConstantArray, ConstantVTable, ListArray, VarBinArray};
213 use crate::canonical::ToCanonical;
214 use crate::compute::list_contains;
215 use crate::validity::Validity;
216 use crate::vtable::ValidityHelper;
217 use crate::{ArrayRef, IntoArray};
218
219 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
220 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
221 .unwrap()
222 }
223
224 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
225 let elements = values.iter().flatten().cloned().collect_vec();
226 let mut offsets = values
227 .iter()
228 .scan(0u64, |st, v| {
229 *st += v.len() as u64;
230 Some(*st)
231 })
232 .collect_vec();
233 offsets.insert(0, 0u64);
234 let offsets = Buffer::from_iter(offsets).into_array();
235
236 let elements =
237 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
238
239 ListArray::try_new(elements, offsets, Validity::NonNullable)
240 .unwrap()
241 .into_array()
242 }
243
244 fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray {
245 let validity = match validity {
246 None => Validity::NonNullable,
247 Some(v) => Validity::from_iter(v),
248 };
249
250 BoolArray::new(values.into_iter().collect(), validity)
251 }
252
253 #[rstest]
254 #[case(
256 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
257 Some("a"),
258 bool_array(vec![false, true, true], None)
259 )]
260 #[case(
262 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
263 None,
264 bool_array(vec![false, true, true], None)
265 )]
266 #[case(
268 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
269 Some("a"),
270 bool_array(vec![false, true, true], None)
271 )]
272 #[case(
274 nonnull_strings(vec![vec![], vec![], vec![]]),
275 Some("a"),
276 bool_array(vec![false, false, false], None)
277 )]
278 #[case(
280 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
281 Some("a"),
282 bool_array(vec![false, false, false], None)
283 )]
284 fn test_contains_nullable(
285 #[case] list_array: ArrayRef,
286 #[case] value: Option<&str>,
287 #[case] expected: BoolArray,
288 ) {
289 let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
290 let scalar = match value {
291 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
292 Some(v) => Scalar::utf8(v, element_nullability),
293 };
294 let result = list_contains(&list_array, scalar).expect("list_contains failed");
295 let bool_result = result.to_bool().expect("to_bool failed");
296 assert_eq!(
297 bool_result.boolean_buffer().iter().collect_vec(),
298 expected.boolean_buffer().iter().collect_vec()
299 );
300 assert_eq!(bool_result.validity(), expected.validity());
301 }
302
303 #[test]
304 fn test_constant_list() {
305 let list_array = ConstantArray::new(
306 Scalar::list(
307 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
308 vec![1i32.into(), 2i32.into(), 3i32.into()],
309 Nullability::NonNullable,
310 ),
311 2,
312 )
313 .into_array();
314
315 let contains = list_contains(&list_array, 2i32.into()).unwrap();
316 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
317 assert_eq!(
318 contains
319 .to_bool()
320 .unwrap()
321 .boolean_buffer()
322 .iter()
323 .collect_vec(),
324 vec![true, true],
325 );
326 }
327}