vortex_sequence/compute/
list_contains.rs1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduce;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use crate::array::SequenceVTable;
12use crate::compute::compare::find_intersection_scalar;
13
14impl ListContainsElementReduce for SequenceVTable {
15 fn list_contains(list: &ArrayRef, element: &Self::Array) -> VortexResult<Option<ArrayRef>> {
16 let Some(list_scalar) = list.as_constant() else {
17 return Ok(None);
18 };
19
20 let list_elements = list_scalar
21 .as_list()
22 .elements()
23 .vortex_expect("non-null element (checked in entry)");
24
25 let mut set_indices: Vec<usize> = Vec::new();
26 for intercept in list_elements.iter() {
27 let Some(intercept) = intercept.as_primitive().pvalue() else {
28 continue;
29 };
30 if let Ok(intersection) = find_intersection_scalar(
31 element.base(),
32 element.multiplier(),
33 element.len(),
34 intercept,
35 ) {
36 set_indices.push(intersection)
37 }
38 }
39
40 let nullability = list.dtype().nullability() | element.dtype().nullability();
41
42 Ok(Some(
43 BoolArray::from_indices(element.len(), set_indices, nullability.into()).to_array(),
44 ))
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use std::sync::Arc;
51
52 use vortex_array::Array;
53 use vortex_array::arrays::BoolArray;
54 use vortex_array::assert_arrays_eq;
55 use vortex_array::dtype::Nullability;
56 use vortex_array::dtype::PType::I32;
57 use vortex_array::expr::list_contains;
58 use vortex_array::expr::lit;
59 use vortex_array::expr::root;
60 use vortex_array::scalar::Scalar;
61
62 use crate::SequenceArray;
63
64 #[test]
65 fn test_list_contains_seq() {
66 let list_scalar = Scalar::list(
67 Arc::new(I32.into()),
68 vec![1.into(), 3.into()],
69 Nullability::Nullable,
70 );
71
72 {
73 let array = SequenceArray::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap();
77
78 let expr = list_contains(lit(list_scalar.clone()), root());
79 let result = array.apply(&expr).unwrap();
80 let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
81 assert_arrays_eq!(result, expected);
82 }
83
84 {
85 let array = SequenceArray::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap();
89
90 let expr = list_contains(lit(list_scalar), root());
91 let result = array.apply(&expr).unwrap();
92 let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
93 assert_arrays_eq!(result, expected);
94 }
95 }
96}