vortex_sequence/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::BoolArray;
5use vortex_array::compute::{ListContainsKernel, ListContainsKernelAdapter};
6use vortex_array::{Array, ArrayRef, register_kernel};
7use vortex_error::{VortexExpect, VortexResult};
8
9use crate::array::SequenceVTable;
10use crate::compute::compare::find_intersection_scalar;
11
12impl ListContainsKernel for SequenceVTable {
13    fn list_contains(
14        &self,
15        list: &dyn Array,
16        element: &Self::Array,
17    ) -> VortexResult<Option<ArrayRef>> {
18        let Some(list_scalar) = list.as_constant() else {
19            return Ok(None);
20        };
21
22        let list_elements = list_scalar
23            .as_list()
24            .elements()
25            .vortex_expect("non-null element (checked in entry)");
26
27        let set_indices = list_elements
28            .iter()
29            .flat_map(|elem| {
30                elem.as_primitive().pvalue().and_then(|intercept| {
31                    find_intersection_scalar(
32                        element.base(),
33                        element.multiplier(),
34                        element.len(),
35                        intercept,
36                    )
37                })
38            })
39            .collect::<Vec<_>>();
40
41        let nullability = list.dtype().nullability() | element.dtype().nullability();
42
43        Ok(Some(
44            BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
45        ))
46    }
47}
48
49register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
50
51#[cfg(test)]
52mod tests {
53    use std::sync::Arc;
54
55    use vortex_array::ToCanonical;
56    use vortex_array::arrays::ConstantArray;
57    use vortex_array::compute::list_contains;
58    use vortex_dtype::Nullability;
59    use vortex_dtype::PType::I32;
60    use vortex_scalar::Scalar;
61
62    use crate::SequenceArray;
63
64    #[test]
65    fn test_list_contains_seq() {
66        let elements = ConstantArray::new(
67            Scalar::list(
68                Arc::new(I32.into()),
69                vec![1.into(), 3.into()],
70                Nullability::Nullable,
71            ),
72            3,
73        );
74
75        {
76            // [1, 3] in  1
77            //            2
78            //            3
79            let array = SequenceArray::typed_new(1, 1, Nullability::NonNullable, 3).unwrap();
80
81            let res = list_contains(elements.as_ref(), array.as_ref())
82                .unwrap()
83                .to_bool()
84                .unwrap()
85                .bool_vec()
86                .unwrap();
87
88            assert_eq!(res, vec![true, false, true]);
89        }
90
91        {
92            // [1, 3] in  1
93            //            3
94            //            5
95            let array = SequenceArray::typed_new(1, 2, Nullability::NonNullable, 3).unwrap();
96
97            let res = list_contains(elements.as_ref(), array.as_ref())
98                .unwrap()
99                .to_bool()
100                .unwrap()
101                .bool_vec()
102                .unwrap();
103
104            assert_eq!(res, vec![true, true, false]);
105        }
106    }
107}