vortex_sequence/compute/
list_contains.rs1use vortex_array::ArrayRef;
5use vortex_array::DynArray;
6use vortex_array::IntoArray;
7use vortex_array::arrays::BoolArray;
8use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduce;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11
12use crate::array::Sequence;
13use crate::compute::compare::find_intersection_scalar;
14
15impl ListContainsElementReduce for Sequence {
16 fn list_contains(list: &ArrayRef, element: &Self::Array) -> VortexResult<Option<ArrayRef>> {
17 let Some(list_scalar) = list.as_constant() else {
18 return Ok(None);
19 };
20
21 let list_elements = list_scalar
22 .as_list()
23 .elements()
24 .vortex_expect("non-null element (checked in entry)");
25
26 let mut set_indices: Vec<usize> = Vec::new();
27 for intercept in list_elements.iter() {
28 let Some(intercept) = intercept.as_primitive().pvalue() else {
29 continue;
30 };
31 if let Ok(intersection) = find_intersection_scalar(
32 element.base(),
33 element.multiplier(),
34 element.len(),
35 intercept,
36 ) {
37 set_indices.push(intersection)
38 }
39 }
40
41 let nullability = list.dtype().nullability() | element.dtype().nullability();
42
43 Ok(Some(
44 BoolArray::from_indices(element.len(), set_indices, nullability.into()).into_array(),
45 ))
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use std::sync::Arc;
52
53 use vortex_array::DynArray;
54 use vortex_array::arrays::BoolArray;
55 use vortex_array::assert_arrays_eq;
56 use vortex_array::dtype::Nullability;
57 use vortex_array::dtype::PType::I32;
58 use vortex_array::expr::list_contains;
59 use vortex_array::expr::lit;
60 use vortex_array::expr::root;
61 use vortex_array::scalar::Scalar;
62
63 use crate::SequenceArray;
64
65 #[test]
66 fn test_list_contains_seq() {
67 let list_scalar = Scalar::list(
68 Arc::new(I32.into()),
69 vec![1.into(), 3.into()],
70 Nullability::Nullable,
71 );
72
73 {
74 let array = SequenceArray::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap();
78
79 let expr = list_contains(lit(list_scalar.clone()), root());
80 let result = array.apply(&expr).unwrap();
81 let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
82 assert_arrays_eq!(result, expected);
83 }
84
85 {
86 let array = SequenceArray::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap();
90
91 let expr = list_contains(lit(list_scalar), root());
92 let result = array.apply(&expr).unwrap();
93 let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
94 assert_arrays_eq!(result, expected);
95 }
96 }
97}