Skip to main content

vortex_sequence/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduce;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::array::SequenceVTable;
12use crate::compute::compare::find_intersection_scalar;
13
14impl ListContainsElementReduce for SequenceVTable {
15    fn list_contains(list: &ArrayRef, element: &Self::Array) -> VortexResult<Option<ArrayRef>> {
16        let Some(list_scalar) = list.as_constant() else {
17            return Ok(None);
18        };
19
20        let list_elements = list_scalar
21            .as_list()
22            .elements()
23            .vortex_expect("non-null element (checked in entry)");
24
25        let mut set_indices: Vec<usize> = Vec::new();
26        for intercept in list_elements.iter() {
27            let Some(intercept) = intercept.as_primitive().pvalue() else {
28                continue;
29            };
30            if let Ok(intersection) = find_intersection_scalar(
31                element.base(),
32                element.multiplier(),
33                element.len(),
34                intercept,
35            ) {
36                set_indices.push(intersection)
37            }
38        }
39
40        let nullability = list.dtype().nullability() | element.dtype().nullability();
41
42        Ok(Some(
43            BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
44        ))
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use std::sync::Arc;
51
52    use vortex_array::Array;
53    use vortex_array::arrays::BoolArray;
54    use vortex_array::assert_arrays_eq;
55    use vortex_array::dtype::Nullability;
56    use vortex_array::dtype::PType::I32;
57    use vortex_array::expr::list_contains;
58    use vortex_array::expr::lit;
59    use vortex_array::expr::root;
60    use vortex_array::scalar::Scalar;
61
62    use crate::SequenceArray;
63
64    #[test]
65    fn test_list_contains_seq() {
66        let list_scalar = Scalar::list(
67            Arc::new(I32.into()),
68            vec![1.into(), 3.into()],
69            Nullability::Nullable,
70        );
71
72        {
73            // [1, 3] in  1
74            //            2
75            //            3
76            let array = SequenceArray::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap();
77
78            let expr = list_contains(lit(list_scalar.clone()), root());
79            let result = array.apply(&expr).unwrap();
80            let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
81            assert_arrays_eq!(result, expected);
82        }
83
84        {
85            // [1, 3] in  1
86            //            3
87            //            5
88            let array = SequenceArray::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap();
89
90            let expr = list_contains(lit(list_scalar), root());
91            let result = array.apply(&expr).unwrap();
92            let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
93            assert_arrays_eq!(result, expected);
94        }
95    }
96}