vortex_sequence/compute/
list_contains.rs1use vortex_array::arrays::BoolArray;
5use vortex_array::compute::{ListContainsKernel, ListContainsKernelAdapter};
6use vortex_array::{Array, ArrayRef, register_kernel};
7use vortex_error::{VortexExpect, VortexResult};
8
9use crate::array::SequenceVTable;
10use crate::compute::compare::find_intersection_scalar;
11
12impl ListContainsKernel for SequenceVTable {
13 fn list_contains(
14 &self,
15 list: &dyn Array,
16 element: &Self::Array,
17 ) -> VortexResult<Option<ArrayRef>> {
18 let Some(list_scalar) = list.as_constant() else {
19 return Ok(None);
20 };
21
22 let list_elements = list_scalar
23 .as_list()
24 .elements()
25 .vortex_expect("non-null element (checked in entry)");
26
27 let set_indices = list_elements
28 .iter()
29 .flat_map(|elem| {
30 elem.as_primitive().pvalue().and_then(|intercept| {
31 find_intersection_scalar(
32 element.base(),
33 element.multiplier(),
34 element.len(),
35 intercept,
36 )
37 })
38 })
39 .collect::<Vec<_>>();
40
41 let nullability = list.dtype().nullability() | element.dtype().nullability();
42
43 Ok(Some(
44 BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
45 ))
46 }
47}
48
49register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
50
51#[cfg(test)]
52mod tests {
53 use std::sync::Arc;
54
55 use vortex_array::ToCanonical;
56 use vortex_array::arrays::ConstantArray;
57 use vortex_array::compute::list_contains;
58 use vortex_dtype::Nullability;
59 use vortex_dtype::PType::I32;
60 use vortex_scalar::Scalar;
61
62 use crate::SequenceArray;
63
64 #[test]
65 fn test_list_contains_seq() {
66 let elements = ConstantArray::new(
67 Scalar::list(
68 Arc::new(I32.into()),
69 vec![1.into(), 3.into()],
70 Nullability::Nullable,
71 ),
72 3,
73 );
74
75 {
76 let array = SequenceArray::typed_new(1, 1, Nullability::NonNullable, 3).unwrap();
80
81 let res = list_contains(elements.as_ref(), array.as_ref())
82 .unwrap()
83 .to_bool()
84 .unwrap()
85 .bool_vec()
86 .unwrap();
87
88 assert_eq!(res, vec![true, false, true]);
89 }
90
91 {
92 let array = SequenceArray::typed_new(1, 2, Nullability::NonNullable, 3).unwrap();
96
97 let res = list_contains(elements.as_ref(), array.as_ref())
98 .unwrap()
99 .to_bool()
100 .unwrap()
101 .bool_vec()
102 .unwrap();
103
104 assert_eq!(res, vec![true, true, false]);
105 }
106 }
107}