vortex_sequence/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::compute::ListContainsKernel;
8use vortex_array::compute::ListContainsKernelAdapter;
9use vortex_array::register_kernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12
13use crate::array::SequenceVTable;
14use crate::compute::compare::find_intersection_scalar;
15
16impl ListContainsKernel for SequenceVTable {
17    fn list_contains(
18        &self,
19        list: &dyn Array,
20        element: &Self::Array,
21    ) -> VortexResult<Option<ArrayRef>> {
22        let Some(list_scalar) = list.as_constant() else {
23            return Ok(None);
24        };
25
26        let list_elements = list_scalar
27            .as_list()
28            .elements()
29            .vortex_expect("non-null element (checked in entry)");
30
31        let set_indices = list_elements
32            .iter()
33            .flat_map(|elem| {
34                elem.as_primitive().pvalue().and_then(|intercept| {
35                    find_intersection_scalar(
36                        element.base(),
37                        element.multiplier(),
38                        element.len(),
39                        intercept,
40                    )
41                })
42            })
43            .collect::<Vec<_>>();
44
45        let nullability = list.dtype().nullability() | element.dtype().nullability();
46
47        Ok(Some(
48            BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
49        ))
50    }
51}
52
53register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
54
55#[cfg(test)]
56mod tests {
57    use std::sync::Arc;
58
59    use vortex_array::ToCanonical;
60    use vortex_array::arrays::ConstantArray;
61    use vortex_array::compute::list_contains;
62    use vortex_dtype::Nullability;
63    use vortex_dtype::PType::I32;
64    use vortex_scalar::Scalar;
65
66    use crate::SequenceArray;
67
68    #[test]
69    fn test_list_contains_seq() {
70        let elements = ConstantArray::new(
71            Scalar::list(
72                Arc::new(I32.into()),
73                vec![1.into(), 3.into()],
74                Nullability::Nullable,
75            ),
76            3,
77        );
78
79        {
80            // [1, 3] in  1
81            //            2
82            //            3
83            let array = SequenceArray::typed_new(1, 1, Nullability::NonNullable, 3).unwrap();
84
85            let res = list_contains(elements.as_ref(), array.as_ref())
86                .unwrap()
87                .to_bool()
88                .bool_vec();
89
90            assert_eq!(res, vec![true, false, true]);
91        }
92
93        {
94            // [1, 3] in  1
95            //            3
96            //            5
97            let array = SequenceArray::typed_new(1, 2, Nullability::NonNullable, 3).unwrap();
98
99            let res = list_contains(elements.as_ref(), array.as_ref())
100                .unwrap()
101                .to_bool()
102                .bool_vec();
103
104            assert_eq!(res, vec![true, true, false]);
105        }
106    }
107}