vortex_sequence/compute/
list_contains.rs1use vortex_array::arrays::BoolArray;
2use vortex_array::compute::{ListContainsKernel, ListContainsKernelAdapter};
3use vortex_array::{Array, ArrayRef, register_kernel};
4use vortex_error::{VortexExpect, VortexResult};
5
6use crate::array::SequenceVTable;
7use crate::compute::compare::find_intersection_scalar;
8
9impl ListContainsKernel for SequenceVTable {
10 fn list_contains(
11 &self,
12 list: &dyn Array,
13 element: &Self::Array,
14 ) -> VortexResult<Option<ArrayRef>> {
15 let Some(list_scalar) = list.as_constant() else {
16 return Ok(None);
17 };
18
19 let list_elements = list_scalar
20 .as_list()
21 .elements()
22 .vortex_expect("non-null element (checked in entry)");
23
24 let set_indices = list_elements
25 .iter()
26 .flat_map(|elem| {
27 elem.as_primitive().pvalue().and_then(|intercept| {
28 find_intersection_scalar(
29 element.base(),
30 element.multiplier(),
31 element.len(),
32 intercept,
33 )
34 })
35 })
36 .collect::<Vec<_>>();
37
38 let nullability = list.dtype().nullability() | element.dtype().nullability();
39
40 Ok(Some(
41 BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
42 ))
43 }
44}
45
46register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
47
48#[cfg(test)]
49mod tests {
50 use std::sync::Arc;
51
52 use vortex_array::ToCanonical;
53 use vortex_array::arrays::ConstantArray;
54 use vortex_array::compute::list_contains;
55 use vortex_dtype::Nullability;
56 use vortex_dtype::PType::I32;
57 use vortex_scalar::Scalar;
58
59 use crate::SequenceArray;
60
61 #[test]
62 fn test_list_contains_seq() {
63 let elements = ConstantArray::new(
64 Scalar::list(
65 Arc::new(I32.into()),
66 vec![1.into(), 3.into()],
67 Nullability::Nullable,
68 ),
69 3,
70 );
71
72 {
73 let array = SequenceArray::typed_new(1, 1, 3).unwrap();
77
78 let res = list_contains(elements.as_ref(), array.as_ref())
79 .unwrap()
80 .to_bool()
81 .unwrap()
82 .bool_vec()
83 .unwrap();
84
85 assert_eq!(res, vec![true, false, true]);
86 }
87
88 {
89 let array = SequenceArray::typed_new(1, 2, 3).unwrap();
93
94 let res = list_contains(elements.as_ref(), array.as_ref())
95 .unwrap()
96 .to_bool()
97 .unwrap()
98 .bool_vec()
99 .unwrap();
100
101 assert_eq!(res, vec![true, true, false]);
102 }
103 }
104}