vortex_sequence/compute/
list_contains.rs1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::compute::ListContainsKernel;
8use vortex_array::compute::ListContainsKernelAdapter;
9use vortex_array::register_kernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12
13use crate::array::SequenceVTable;
14use crate::compute::compare::find_intersection_scalar;
15
16impl ListContainsKernel for SequenceVTable {
17 fn list_contains(
18 &self,
19 list: &dyn Array,
20 element: &Self::Array,
21 ) -> VortexResult<Option<ArrayRef>> {
22 let Some(list_scalar) = list.as_constant() else {
23 return Ok(None);
24 };
25
26 let list_elements = list_scalar
27 .as_list()
28 .elements()
29 .vortex_expect("non-null element (checked in entry)");
30
31 let mut set_indices: Vec<usize> = Vec::new();
32 for intercept in list_elements.iter() {
33 let Some(intercept) = intercept.as_primitive().pvalue() else {
34 continue;
35 };
36 if let Ok(intersection) = find_intersection_scalar(
37 element.base(),
38 element.multiplier(),
39 element.len(),
40 intercept,
41 ) {
42 set_indices.push(intersection)
43 }
44 }
45
46 let nullability = list.dtype().nullability() | element.dtype().nullability();
47
48 Ok(Some(
49 BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
50 ))
51 }
52}
53
54register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
55
56#[cfg(test)]
57mod tests {
58 use std::sync::Arc;
59
60 use vortex_array::arrays::BoolArray;
61 use vortex_array::arrays::ConstantArray;
62 use vortex_array::assert_arrays_eq;
63 use vortex_array::compute::list_contains;
64 use vortex_array::scalar::Scalar;
65 use vortex_dtype::Nullability;
66 use vortex_dtype::PType::I32;
67
68 use crate::SequenceArray;
69
70 #[test]
71 fn test_list_contains_seq() {
72 let elements = ConstantArray::new(
73 Scalar::list(
74 Arc::new(I32.into()),
75 vec![1.into(), 3.into()],
76 Nullability::Nullable,
77 ),
78 3,
79 );
80
81 {
82 let array = SequenceArray::typed_new(1, 1, Nullability::NonNullable, 3).unwrap();
86
87 let result = list_contains(elements.as_ref(), array.as_ref()).unwrap();
88 let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
89 assert_arrays_eq!(result, expected);
90 }
91
92 {
93 let array = SequenceArray::typed_new(1, 2, Nullability::NonNullable, 3).unwrap();
97
98 let result = list_contains(elements.as_ref(), array.as_ref()).unwrap();
99 let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
100 assert_arrays_eq!(result, expected);
101 }
102 }
103}