Skip to main content

vortex_sequence/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::arrays::BoolArray;
8use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduce;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11
12use crate::array::Sequence;
13use crate::compute::compare::find_intersection_scalar;
14
15impl ListContainsElementReduce for Sequence {
16    fn list_contains(
17        list: &ArrayRef,
18        element: ArrayView<'_, Self>,
19    ) -> VortexResult<Option<ArrayRef>> {
20        let Some(list_scalar) = list.as_constant() else {
21            return Ok(None);
22        };
23
24        let list_elements = list_scalar
25            .as_list()
26            .elements()
27            .vortex_expect("non-null element (checked in entry)");
28
29        let mut set_indices: Vec<usize> = Vec::new();
30        for intercept in list_elements.iter() {
31            let Some(intercept) = intercept.as_primitive().pvalue() else {
32                continue;
33            };
34            if let Ok(intersection) = find_intersection_scalar(
35                element.base(),
36                element.multiplier(),
37                element.len(),
38                intercept,
39            ) {
40                set_indices.push(intersection)
41            }
42        }
43
44        let nullability = list.dtype().nullability() | element.dtype().nullability();
45
46        Ok(Some(
47            BoolArray::from_indices(element.len(), set_indices, nullability.into()).into_array(),
48        ))
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use std::sync::Arc;
55
56    use vortex_array::IntoArray;
57    use vortex_array::arrays::BoolArray;
58    use vortex_array::assert_arrays_eq;
59    use vortex_array::dtype::Nullability;
60    use vortex_array::dtype::PType::I32;
61    use vortex_array::expr::list_contains;
62    use vortex_array::expr::lit;
63    use vortex_array::expr::root;
64    use vortex_array::scalar::Scalar;
65
66    use crate::Sequence;
67
68    #[test]
69    fn test_list_contains_seq() {
70        let list_scalar = Scalar::list(
71            Arc::new(I32.into()),
72            vec![1.into(), 3.into()],
73            Nullability::Nullable,
74        );
75
76        {
77            // [1, 3] in  1
78            //            2
79            //            3
80            let array = Sequence::try_new_typed(1, 1, Nullability::NonNullable, 3)
81                .unwrap()
82                .into_array();
83
84            let expr = list_contains(lit(list_scalar.clone()), root());
85            let result = array.into_array().apply(&expr).unwrap();
86            let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
87            assert_arrays_eq!(result, expected);
88        }
89
90        {
91            // [1, 3] in  1
92            //            3
93            //            5
94            let array = Sequence::try_new_typed(1, 2, Nullability::NonNullable, 3)
95                .unwrap()
96                .into_array();
97
98            let expr = list_contains(lit(list_scalar), root());
99            let result = array.into_array().apply(&expr).unwrap();
100            let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
101            assert_arrays_eq!(result, expected);
102        }
103    }
104}