Skip to main content

vortex_sequence/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::compute::ListContainsKernel;
8use vortex_array::compute::ListContainsKernelAdapter;
9use vortex_array::register_kernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12
13use crate::array::SequenceVTable;
14use crate::compute::compare::find_intersection_scalar;
15
16impl ListContainsKernel for SequenceVTable {
17    fn list_contains(
18        &self,
19        list: &dyn Array,
20        element: &Self::Array,
21    ) -> VortexResult<Option<ArrayRef>> {
22        let Some(list_scalar) = list.as_constant() else {
23            return Ok(None);
24        };
25
26        let list_elements = list_scalar
27            .as_list()
28            .elements()
29            .vortex_expect("non-null element (checked in entry)");
30
31        let mut set_indices: Vec<usize> = Vec::new();
32        for intercept in list_elements.iter() {
33            let Some(intercept) = intercept.as_primitive().pvalue() else {
34                continue;
35            };
36            if let Ok(intersection) = find_intersection_scalar(
37                element.base(),
38                element.multiplier(),
39                element.len(),
40                intercept,
41            ) {
42                set_indices.push(intersection)
43            }
44        }
45
46        let nullability = list.dtype().nullability() | element.dtype().nullability();
47
48        Ok(Some(
49            BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
50        ))
51    }
52}
53
54register_kernel!(ListContainsKernelAdapter(SequenceVTable).lift());
55
56#[cfg(test)]
57mod tests {
58    use std::sync::Arc;
59
60    use vortex_array::arrays::BoolArray;
61    use vortex_array::arrays::ConstantArray;
62    use vortex_array::assert_arrays_eq;
63    use vortex_array::compute::list_contains;
64    use vortex_array::scalar::Scalar;
65    use vortex_dtype::Nullability;
66    use vortex_dtype::PType::I32;
67
68    use crate::SequenceArray;
69
70    #[test]
71    fn test_list_contains_seq() {
72        let elements = ConstantArray::new(
73            Scalar::list(
74                Arc::new(I32.into()),
75                vec![1.into(), 3.into()],
76                Nullability::Nullable,
77            ),
78            3,
79        );
80
81        {
82            // [1, 3] in  1
83            //            2
84            //            3
85            let array = SequenceArray::typed_new(1, 1, Nullability::NonNullable, 3).unwrap();
86
87            let result = list_contains(elements.as_ref(), array.as_ref()).unwrap();
88            let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
89            assert_arrays_eq!(result, expected);
90        }
91
92        {
93            // [1, 3] in  1
94            //            3
95            //            5
96            let array = SequenceArray::typed_new(1, 2, Nullability::NonNullable, 3).unwrap();
97
98            let result = list_contains(elements.as_ref(), array.as_ref()).unwrap();
99            let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
100            assert_arrays_eq!(result, expected);
101        }
102    }
103}