vortex_array/expr/exprs/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use vortex_dtype::DType;
7use vortex_error::{VortexResult, vortex_bail};
8
9use crate::ArrayRef;
10use crate::compute::list_contains as compute_list_contains;
11use crate::expr::exprs::binary::{and, gt, lt, or};
12use crate::expr::exprs::literal::{Literal, lit};
13use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
14
15pub struct ListContains;
16
17impl VTable for ListContains {
18    type Instance = ();
19
20    fn id(&self) -> ExprId {
21        ExprId::from("vortex.list.contains")
22    }
23
24    fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
25        Ok(Some(vec![]))
26    }
27
28    fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
29        Ok(Some(()))
30    }
31
32    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
33        if expr.children().len() != 2 {
34            vortex_bail!(
35                "ListContains expression requires exactly 2 children, got {}",
36                expr.children().len()
37            );
38        }
39        Ok(())
40    }
41
42    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
43        match child_idx {
44            0 => ChildName::from("list"),
45            1 => ChildName::from("needle"),
46            _ => unreachable!(
47                "Invalid child index {} for ListContains expression",
48                child_idx
49            ),
50        }
51    }
52
53    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
54        write!(f, "contains(")?;
55        expr.child(0).fmt_sql(f)?;
56        write!(f, ", ")?;
57        expr.child(1).fmt_sql(f)?;
58        write!(f, ")")
59    }
60
61    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
62        let list_dtype = expr.child(0).return_dtype(scope)?;
63        let value_dtype = expr.child(1).return_dtype(scope)?;
64
65        let nullability = match list_dtype {
66            DType::List(_, list_nullability) => list_nullability,
67            _ => {
68                vortex_bail!(
69                    "First argument to ListContains must be a List, got {:?}",
70                    list_dtype
71                );
72            }
73        } | value_dtype.nullability();
74
75        Ok(DType::Bool(nullability))
76    }
77
78    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
79        let list_array = expr.child(0).evaluate(scope)?;
80        let value_array = expr.child(1).evaluate(scope)?;
81        compute_list_contains(list_array.as_ref(), value_array.as_ref())
82    }
83
84    fn stat_falsification(
85        &self,
86        expr: &ExpressionView<Self>,
87        catalog: &mut dyn StatsCatalog,
88    ) -> Option<Expression> {
89        // falsification(contains([1,2,5], x)) =>
90        //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
91        let min = expr.list().stat_min(catalog)?;
92        let max = expr.list().stat_max(catalog)?;
93        // If the list is constant when we can compare each element to the value
94        if min == max {
95            let list_ = min
96                .as_opt::<Literal>()
97                .and_then(|l| l.data().as_list_opt())
98                .and_then(|l| l.elements())?;
99            if list_.is_empty() {
100                // contains([], x) is always false.
101                return Some(lit(true));
102            }
103            let value_max = expr.needle().stat_max(catalog)?;
104            let value_min = expr.needle().stat_min(catalog)?;
105
106            return list_
107                .iter()
108                .map(move |v| {
109                    or(
110                        lt(value_max.clone(), lit(v.clone())),
111                        gt(value_min.clone(), lit(v.clone())),
112                    )
113                })
114                .reduce(and);
115        }
116
117        None
118    }
119}
120
121/// Creates an expression that checks if a value is contained in a list.
122///
123/// Returns a boolean array indicating whether the value appears in each list.
124///
125/// ```rust
126/// # use vortex_array::expr::{list_contains, lit, root};
127/// let expr = list_contains(root(), lit(42));
128/// ```
129pub fn list_contains(list: Expression, value: Expression) -> Expression {
130    ListContains.new_expr((), [list, value])
131}
132
133impl ExpressionView<'_, ListContains> {
134    pub fn list(&self) -> &Expression {
135        &self.children()[0]
136    }
137
138    pub fn needle(&self) -> &Expression {
139        &self.children()[1]
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::sync::Arc;
146
147    use vortex_buffer::BitBuffer;
148    use vortex_dtype::PType::I32;
149    use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability, StructFields};
150    use vortex_scalar::Scalar;
151    use vortex_utils::aliases::hash_map::HashMap;
152    use vortex_utils::aliases::hash_set::HashSet;
153
154    use super::list_contains;
155    use crate::arrays::{BoolArray, ListArray, PrimitiveArray};
156    use crate::expr::exprs::binary::{and, gt, lt, or};
157    use crate::expr::exprs::get_item::{col, get_item};
158    use crate::expr::exprs::literal::lit;
159    use crate::expr::exprs::root::root;
160    use crate::expr::pruning::checked_pruning_expr;
161    use crate::stats::Stat;
162    use crate::validity::Validity;
163    use crate::{Array, ArrayRef, IntoArray};
164
165    fn test_array() -> ArrayRef {
166        ListArray::try_new(
167            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
168            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
169            Validity::AllValid,
170        )
171        .unwrap()
172        .into_array()
173    }
174
175    #[test]
176    pub fn test_one() {
177        let arr = test_array();
178
179        let expr = list_contains(root(), lit(1));
180        let item = expr.evaluate(&arr).unwrap();
181
182        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
183        assert_eq!(
184            item.scalar_at(1),
185            Scalar::bool(false, Nullability::Nullable)
186        );
187    }
188
189    #[test]
190    pub fn test_all() {
191        let arr = test_array();
192
193        let expr = list_contains(root(), lit(2));
194        let item = expr.evaluate(&arr).unwrap();
195
196        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
197        assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
198    }
199
200    #[test]
201    pub fn test_none() {
202        let arr = test_array();
203
204        let expr = list_contains(root(), lit(4));
205        let item = expr.evaluate(&arr).unwrap();
206
207        assert_eq!(
208            item.scalar_at(0),
209            Scalar::bool(false, Nullability::Nullable)
210        );
211        assert_eq!(
212            item.scalar_at(1),
213            Scalar::bool(false, Nullability::Nullable)
214        );
215    }
216
217    #[test]
218    pub fn test_empty() {
219        let arr = ListArray::try_new(
220            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
221            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
222            Validity::AllValid,
223        )
224        .unwrap()
225        .into_array();
226
227        let expr = list_contains(root(), lit(2));
228        let item = expr.evaluate(&arr).unwrap();
229
230        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
231        assert_eq!(
232            item.scalar_at(1),
233            Scalar::bool(false, Nullability::Nullable)
234        );
235    }
236
237    #[test]
238    pub fn test_nullable() {
239        let arr = ListArray::try_new(
240            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
241            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
242            Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
243        )
244        .unwrap()
245        .into_array();
246
247        let expr = list_contains(root(), lit(2));
248        let item = expr.evaluate(&arr).unwrap();
249
250        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
251        assert!(!item.is_valid(1));
252    }
253
254    #[test]
255    pub fn test_return_type() {
256        let scope = DType::Struct(
257            StructFields::new(
258                ["array"].into(),
259                vec![DType::List(
260                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
261                    Nullability::Nullable,
262                )],
263            ),
264            Nullability::NonNullable,
265        );
266
267        let expr = list_contains(get_item("array", root()), lit(2));
268
269        // Expect nullable, although scope is non-nullable
270        assert_eq!(
271            expr.return_dtype(&scope).unwrap(),
272            DType::Bool(Nullability::Nullable)
273        );
274    }
275
276    #[test]
277    pub fn list_falsification() {
278        let expr = list_contains(
279            lit(Scalar::list(
280                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
281                vec![1.into(), 2.into(), 3.into()],
282                Nullability::NonNullable,
283            )),
284            col("a"),
285        );
286
287        let (expr, st) = checked_pruning_expr(
288            &expr,
289            &FieldPathSet::from_iter([
290                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
291                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
292            ]),
293        )
294        .unwrap();
295
296        assert_eq!(
297            &expr,
298            &and(
299                and(
300                    or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
301                    or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
302                ),
303                or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
304            )
305        );
306
307        assert_eq!(
308            st.map(),
309            &HashMap::from_iter([(
310                FieldPath::from_name("a"),
311                HashSet::from([Stat::Min, Stat::Max])
312            )])
313        );
314    }
315
316    #[test]
317    pub fn test_display() {
318        let expr = list_contains(get_item("tags", root()), lit("urgent"));
319        assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
320
321        let expr2 = list_contains(root(), lit(42));
322        assert_eq!(expr2.to_string(), "contains($, 42i32)");
323    }
324}