Skip to main content

vortex_sequence/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::DynArray;
6use vortex_array::IntoArray;
7use vortex_array::arrays::BoolArray;
8use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduce;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11
12use crate::array::SequenceVTable;
13use crate::compute::compare::find_intersection_scalar;
14
15impl ListContainsElementReduce for SequenceVTable {
16    fn list_contains(list: &ArrayRef, element: &Self::Array) -> VortexResult<Option<ArrayRef>> {
17        let Some(list_scalar) = list.as_constant() else {
18            return Ok(None);
19        };
20
21        let list_elements = list_scalar
22            .as_list()
23            .elements()
24            .vortex_expect("non-null element (checked in entry)");
25
26        let mut set_indices: Vec<usize> = Vec::new();
27        for intercept in list_elements.iter() {
28            let Some(intercept) = intercept.as_primitive().pvalue() else {
29                continue;
30            };
31            if let Ok(intersection) = find_intersection_scalar(
32                element.base(),
33                element.multiplier(),
34                element.len(),
35                intercept,
36            ) {
37                set_indices.push(intersection)
38            }
39        }
40
41        let nullability = list.dtype().nullability() | element.dtype().nullability();
42
43        Ok(Some(
44            BoolArray::from_indices(element.len(), set_indices, nullability.into()).into_array(),
45        ))
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use std::sync::Arc;
52
53    use vortex_array::DynArray;
54    use vortex_array::arrays::BoolArray;
55    use vortex_array::assert_arrays_eq;
56    use vortex_array::dtype::Nullability;
57    use vortex_array::dtype::PType::I32;
58    use vortex_array::expr::list_contains;
59    use vortex_array::expr::lit;
60    use vortex_array::expr::root;
61    use vortex_array::scalar::Scalar;
62
63    use crate::SequenceArray;
64
65    #[test]
66    fn test_list_contains_seq() {
67        let list_scalar = Scalar::list(
68            Arc::new(I32.into()),
69            vec![1.into(), 3.into()],
70            Nullability::Nullable,
71        );
72
73        {
74            // [1, 3] in  1
75            //            2
76            //            3
77            let array = SequenceArray::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap();
78
79            let expr = list_contains(lit(list_scalar.clone()), root());
80            let result = array.apply(&expr).unwrap();
81            let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
82            assert_arrays_eq!(result, expected);
83        }
84
85        {
86            // [1, 3] in  1
87            //            3
88            //            5
89            let array = SequenceArray::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap();
90
91            let expr = list_contains(lit(list_scalar), root());
92            let result = array.apply(&expr).unwrap();
93            let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
94            assert_arrays_eq!(result, expected);
95        }
96    }
97}