1use arrow_buffer::BooleanBuffer;
4use arrow_buffer::bit_iterator::BitIndexIterator;
5use num_traits::AsPrimitive;
6use vortex_buffer::Buffer;
7use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_bail};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray, ListArray};
13use crate::compute::{Operator, compare, invert};
14use crate::validity::Validity;
15use crate::vtable::ValidityHelper;
16use crate::{Array, ArrayRef, IntoArray, ToCanonical};
17
18pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef> {
49 let DType::List(elem_dtype, _nullability) = array.dtype() else {
50 vortex_bail!("Array must be of List type");
51 };
52 if &**elem_dtype != value.dtype() {
53 vortex_bail!("Element type of ListArray does not match search value");
54 }
55
56 if array.is_constant() && array.len() > 1 {
58 let contains = list_contains(&array.slice(0, 1)?, value)?;
59 return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
60 }
61
62 let list_array = array.to_list()?;
65
66 if value.is_null() {
67 return list_contains_null(&list_array);
68 }
69
70 let elems = list_array.elements();
71 let ends = list_array.offsets().to_primitive()?;
72
73 let rhs = ConstantArray::new(value, elems.len());
74 let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
75 let matches = matching_elements.to_bool()?;
76
77 if let Some(pred) = matches.as_constant() {
79 if matches!(pred.as_bool().value(), None | Some(false)) {
80 return Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array());
82 }
83 }
84
85 match_each_integer_ptype!(ends.ptype(), |T| {
86 Ok(reduce_with_ends(
87 ends.as_slice::<T>(),
88 matches.boolean_buffer(),
89 list_array.validity().clone(),
90 ))
91 })
92}
93
94fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
97 let elems = list_array.elements();
98
99 match elems.validity_mask()? {
101 Mask::AllTrue(_) => match list_array.validity() {
103 Validity::NonNullable => {
104 Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
105 }
106 Validity::AllValid => Ok(ConstantArray::new(
107 Scalar::bool(true, Nullability::Nullable),
108 list_array.len(),
109 )
110 .into_array()),
111 Validity::AllInvalid => Ok(ConstantArray::new(
112 Scalar::null(DType::Bool(Nullability::Nullable)),
113 list_array.len(),
114 )
115 .into_array()),
116 Validity::Array(list_mask) => {
117 let buffer = BooleanBuffer::new_unset(list_array.len());
119 Ok(BoolArray::new(buffer, Validity::Array(list_mask.clone())).into_array())
120 }
121 },
122 Mask::AllFalse(_) => Ok(ConstantArray::new::<bool>(true, list_array.len()).into_array()),
124 Mask::Values(mask) => {
125 let nulls = invert(&mask.into_array())?.to_bool()?;
126 let ends = list_array.offsets().to_primitive()?;
127 match_each_integer_ptype!(ends.ptype(), |T| {
128 Ok(reduce_with_ends(
129 list_array.offsets().to_primitive()?.as_slice::<T>(),
130 nulls.boolean_buffer(),
131 list_array.validity().clone(),
132 ))
133 })
134 }
135 }
136}
137
138fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
141 ends: &[T],
142 matches: &BooleanBuffer,
143 validity: Validity,
144) -> ArrayRef {
145 let mask: BooleanBuffer = ends
146 .windows(2)
147 .map(|window| {
148 let len = window[1].as_() - window[0].as_();
149 let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
150 set_bits.next().is_some()
151 })
152 .collect();
153
154 BoolArray::new(mask, validity).into_array()
155}
156
157pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
180 if !matches!(array.dtype(), DType::List(..)) {
181 vortex_bail!("Array must be of list type");
182 }
183
184 if array.is_constant() && array.len() > 1 {
186 let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
187 return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
188 }
189
190 let list_array = array.to_list()?;
191 let offsets = list_array.offsets().to_primitive()?;
192 let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
193 element_lens(offsets.as_slice::<T>()).into_array()
194 });
195
196 Ok(lens_array)
197}
198
199fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
200 values
201 .windows(2)
202 .map(|window| window[1] - window[0])
203 .collect()
204}
205
206#[cfg(test)]
207mod tests {
208 use std::sync::Arc;
209
210 use itertools::Itertools;
211 use rstest::rstest;
212 use vortex_buffer::Buffer;
213 use vortex_dtype::{DType, Nullability, PType};
214 use vortex_scalar::Scalar;
215
216 use crate::arrays::{BoolArray, ConstantArray, ConstantVTable, ListArray, VarBinArray};
217 use crate::canonical::ToCanonical;
218 use crate::compute::list_contains;
219 use crate::validity::Validity;
220 use crate::vtable::ValidityHelper;
221 use crate::{ArrayRef, IntoArray};
222
223 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
224 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
225 .unwrap()
226 }
227
228 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
229 let elements = values.iter().flatten().cloned().collect_vec();
230 let mut offsets = values
231 .iter()
232 .scan(0u64, |st, v| {
233 *st += v.len() as u64;
234 Some(*st)
235 })
236 .collect_vec();
237 offsets.insert(0, 0u64);
238 let offsets = Buffer::from_iter(offsets).into_array();
239
240 let elements =
241 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
242
243 ListArray::try_new(elements, offsets, Validity::NonNullable)
244 .unwrap()
245 .into_array()
246 }
247
248 fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray {
249 let validity = match validity {
250 None => Validity::NonNullable,
251 Some(v) => Validity::from_iter(v),
252 };
253
254 BoolArray::new(values.into_iter().collect(), validity)
255 }
256
257 #[rstest]
258 #[case(
260 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
261 Some("a"),
262 bool_array(vec![false, true, true], None)
263 )]
264 #[case(
266 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
267 None,
268 bool_array(vec![false, true, true], None)
269 )]
270 #[case(
272 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
273 Some("a"),
274 bool_array(vec![false, true, true], None)
275 )]
276 #[case(
278 nonnull_strings(vec![vec![], vec![], vec![]]),
279 Some("a"),
280 bool_array(vec![false, false, false], None)
281 )]
282 #[case(
284 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
285 Some("a"),
286 bool_array(vec![false, false, false], None)
287 )]
288 fn test_contains_nullable(
289 #[case] list_array: ArrayRef,
290 #[case] value: Option<&str>,
291 #[case] expected: BoolArray,
292 ) {
293 let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
294 let scalar = match value {
295 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
296 Some(v) => Scalar::utf8(v, element_nullability),
297 };
298 let result = list_contains(&list_array, scalar).expect("list_contains failed");
299 let bool_result = result.to_bool().expect("to_bool failed");
300 assert_eq!(
301 bool_result.boolean_buffer().iter().collect_vec(),
302 expected.boolean_buffer().iter().collect_vec()
303 );
304 assert_eq!(bool_result.validity(), expected.validity());
305 }
306
307 #[test]
308 fn test_constant_list() {
309 let list_array = ConstantArray::new(
310 Scalar::list(
311 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
312 vec![1i32.into(), 2i32.into(), 3i32.into()],
313 Nullability::NonNullable,
314 ),
315 2,
316 )
317 .into_array();
318
319 let contains = list_contains(&list_array, 2i32.into()).unwrap();
320 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
321 assert_eq!(
322 contains
323 .to_bool()
324 .unwrap()
325 .boolean_buffer()
326 .iter()
327 .collect_vec(),
328 vec![true, true],
329 );
330 }
331}