vortex_sequence/compute/
list_contains.rs1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::compute::ListContainsKernel;
8use vortex_array::compute::ListContainsKernelAdapter;
9use vortex_array::register_kernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12
13use crate::array::SequenceVTable;
14use crate::compute::compare::find_intersection_scalar;
15
16impl ListContainsKernel for SequenceVTable {
17 fn list_contains(
18 &self,
19 list: &dyn Array,
20 element: &Self::Array,
21 ) -> VortexResult<Option<ArrayRef>> {
22 let Some(list_scalar) = list.as_constant() else {
23 return Ok(None);
24 };
25
26 let list_elements = list_scalar
27 .as_list()
28 .elements()
29 .vortex_expect("non-null element (checked in entry)");
30
31 let set_indices = list_elements
32 .iter()
33 .flat_map(|elem| {
34 elem.as_primitive().pvalue().and_then(|intercept| {
35 find_intersection_scalar(
36 element.base(),
37 element.multiplier(),
38 element.len(),
39 intercept,
40 )
41 })
42 })
43 .collect::<Vec<_>>();
44
45 let nullability = list.dtype().nullability() | element.dtype().nullability();
46
47 Ok(Some(
48 BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
49 ))
50 }
51}
52
53register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
54
55#[cfg(test)]
56mod tests {
57 use std::sync::Arc;
58
59 use vortex_array::ToCanonical;
60 use vortex_array::arrays::ConstantArray;
61 use vortex_array::compute::list_contains;
62 use vortex_dtype::Nullability;
63 use vortex_dtype::PType::I32;
64 use vortex_scalar::Scalar;
65
66 use crate::SequenceArray;
67
68 #[test]
69 fn test_list_contains_seq() {
70 let elements = ConstantArray::new(
71 Scalar::list(
72 Arc::new(I32.into()),
73 vec![1.into(), 3.into()],
74 Nullability::Nullable,
75 ),
76 3,
77 );
78
79 {
80 let array = SequenceArray::typed_new(1, 1, Nullability::NonNullable, 3).unwrap();
84
85 let res = list_contains(elements.as_ref(), array.as_ref())
86 .unwrap()
87 .to_bool()
88 .bool_vec();
89
90 assert_eq!(res, vec![true, false, true]);
91 }
92
93 {
94 let array = SequenceArray::typed_new(1, 2, Nullability::NonNullable, 3).unwrap();
98
99 let res = list_contains(elements.as_ref(), array.as_ref())
100 .unwrap()
101 .to_bool()
102 .bool_vec();
103
104 assert_eq!(res, vec![true, true, false]);
105 }
106 }
107}