vortex_sequence/compute/
list_contains.rs1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::arrays::BoolArray;
8use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduce;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11
12use crate::array::Sequence;
13use crate::compute::compare::find_intersection_scalar;
14
15impl ListContainsElementReduce for Sequence {
16 fn list_contains(
17 list: &ArrayRef,
18 element: ArrayView<'_, Self>,
19 ) -> VortexResult<Option<ArrayRef>> {
20 let Some(list_scalar) = list.as_constant() else {
21 return Ok(None);
22 };
23
24 let list_elements = list_scalar
25 .as_list()
26 .elements()
27 .vortex_expect("non-null element (checked in entry)");
28
29 let mut set_indices: Vec<usize> = Vec::new();
30 for intercept in list_elements.iter() {
31 let Some(intercept) = intercept.as_primitive().pvalue() else {
32 continue;
33 };
34 if let Ok(intersection) = find_intersection_scalar(
35 element.base(),
36 element.multiplier(),
37 element.len(),
38 intercept,
39 ) {
40 set_indices.push(intersection)
41 }
42 }
43
44 let nullability = list.dtype().nullability() | element.dtype().nullability();
45
46 Ok(Some(
47 BoolArray::from_indices(element.len(), set_indices, nullability.into()).into_array(),
48 ))
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use std::sync::Arc;
55
56 use vortex_array::IntoArray;
57 use vortex_array::arrays::BoolArray;
58 use vortex_array::assert_arrays_eq;
59 use vortex_array::dtype::Nullability;
60 use vortex_array::dtype::PType::I32;
61 use vortex_array::expr::list_contains;
62 use vortex_array::expr::lit;
63 use vortex_array::expr::root;
64 use vortex_array::scalar::Scalar;
65
66 use crate::Sequence;
67
68 #[test]
69 fn test_list_contains_seq() {
70 let list_scalar = Scalar::list(
71 Arc::new(I32.into()),
72 vec![1.into(), 3.into()],
73 Nullability::Nullable,
74 );
75
76 {
77 let array = Sequence::try_new_typed(1, 1, Nullability::NonNullable, 3)
81 .unwrap()
82 .into_array();
83
84 let expr = list_contains(lit(list_scalar.clone()), root());
85 let result = array.into_array().apply(&expr).unwrap();
86 let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
87 assert_arrays_eq!(result, expected);
88 }
89
90 {
91 let array = Sequence::try_new_typed(1, 2, Nullability::NonNullable, 3)
95 .unwrap()
96 .into_array();
97
98 let expr = list_contains(lit(list_scalar), root());
99 let result = array.into_array().apply(&expr).unwrap();
100 let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
101 assert_arrays_eq!(result, expected);
102 }
103 }
104}