vortex_sequence/compute/
list_contains.rs

1use vortex_array::arrays::BoolArray;
2use vortex_array::compute::{ListContainsKernel, ListContainsKernelAdapter};
3use vortex_array::{Array, ArrayRef, register_kernel};
4use vortex_error::{VortexExpect, VortexResult};
5
6use crate::array::SequenceVTable;
7use crate::compute::compare::find_intersection_scalar;
8
9impl ListContainsKernel for SequenceVTable {
10    fn list_contains(
11        &self,
12        list: &dyn Array,
13        element: &Self::Array,
14    ) -> VortexResult<Option<ArrayRef>> {
15        let Some(list_scalar) = list.as_constant() else {
16            return Ok(None);
17        };
18
19        let list_elements = list_scalar
20            .as_list()
21            .elements()
22            .vortex_expect("non-null element (checked in entry)");
23
24        let set_indices = list_elements
25            .iter()
26            .flat_map(|elem| {
27                elem.as_primitive().pvalue().and_then(|intercept| {
28                    find_intersection_scalar(
29                        element.base(),
30                        element.multiplier(),
31                        element.len(),
32                        intercept,
33                    )
34                })
35            })
36            .collect::<Vec<_>>();
37
38        let nullability = list.dtype().nullability() | element.dtype().nullability();
39
40        Ok(Some(
41            BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
42        ))
43    }
44}
45
46register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
47
48#[cfg(test)]
49mod tests {
50    use std::sync::Arc;
51
52    use vortex_array::ToCanonical;
53    use vortex_array::arrays::ConstantArray;
54    use vortex_array::compute::list_contains;
55    use vortex_dtype::Nullability;
56    use vortex_dtype::PType::I32;
57    use vortex_scalar::Scalar;
58
59    use crate::SequenceArray;
60
61    #[test]
62    fn test_list_contains_seq() {
63        let elements = ConstantArray::new(
64            Scalar::list(
65                Arc::new(I32.into()),
66                vec![1.into(), 3.into()],
67                Nullability::Nullable,
68            ),
69            3,
70        );
71
72        {
73            // [1, 3] in  1
74            //            2
75            //            3
76            let array = SequenceArray::typed_new(1, 1, 3).unwrap();
77
78            let res = list_contains(elements.as_ref(), array.as_ref())
79                .unwrap()
80                .to_bool()
81                .unwrap()
82                .bool_vec()
83                .unwrap();
84
85            assert_eq!(res, vec![true, false, true]);
86        }
87
88        {
89            // [1, 3] in  1
90            //            3
91            //            5
92            let array = SequenceArray::typed_new(1, 2, 3).unwrap();
93
94            let res = list_contains(elements.as_ref(), array.as_ref())
95                .unwrap()
96                .to_bool()
97                .unwrap()
98                .bool_vec()
99                .unwrap();
100
101            assert_eq!(res, vec![true, true, false]);
102        }
103    }
104}