vortex_array/expr/exprs/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::compute::list_contains as compute_list_contains;
12use crate::expr::ChildName;
13use crate::expr::ExprId;
14use crate::expr::Expression;
15use crate::expr::ExpressionView;
16use crate::expr::StatsCatalog;
17use crate::expr::VTable;
18use crate::expr::VTableExt;
19use crate::expr::exprs::binary::and;
20use crate::expr::exprs::binary::gt;
21use crate::expr::exprs::binary::lt;
22use crate::expr::exprs::binary::or;
23use crate::expr::exprs::literal::Literal;
24use crate::expr::exprs::literal::lit;
25
26pub struct ListContains;
27
28impl VTable for ListContains {
29    type Instance = ();
30
31    fn id(&self) -> ExprId {
32        ExprId::from("vortex.list.contains")
33    }
34
35    fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
36        Ok(Some(vec![]))
37    }
38
39    fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
40        Ok(Some(()))
41    }
42
43    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
44        if expr.children().len() != 2 {
45            vortex_bail!(
46                "ListContains expression requires exactly 2 children, got {}",
47                expr.children().len()
48            );
49        }
50        Ok(())
51    }
52
53    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
54        match child_idx {
55            0 => ChildName::from("list"),
56            1 => ChildName::from("needle"),
57            _ => unreachable!(
58                "Invalid child index {} for ListContains expression",
59                child_idx
60            ),
61        }
62    }
63
64    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
65        write!(f, "contains(")?;
66        expr.child(0).fmt_sql(f)?;
67        write!(f, ", ")?;
68        expr.child(1).fmt_sql(f)?;
69        write!(f, ")")
70    }
71
72    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
73        let list_dtype = expr.child(0).return_dtype(scope)?;
74        let value_dtype = expr.child(1).return_dtype(scope)?;
75
76        let nullability = match list_dtype {
77            DType::List(_, list_nullability) => list_nullability,
78            _ => {
79                vortex_bail!(
80                    "First argument to ListContains must be a List, got {:?}",
81                    list_dtype
82                );
83            }
84        } | value_dtype.nullability();
85
86        Ok(DType::Bool(nullability))
87    }
88
89    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
90        let list_array = expr.child(0).evaluate(scope)?;
91        let value_array = expr.child(1).evaluate(scope)?;
92        compute_list_contains(list_array.as_ref(), value_array.as_ref())
93    }
94
95    fn stat_falsification(
96        &self,
97        expr: &ExpressionView<Self>,
98        catalog: &dyn StatsCatalog,
99    ) -> Option<Expression> {
100        // falsification(contains([1,2,5], x)) =>
101        //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
102        let min = expr.list().stat_min(catalog)?;
103        let max = expr.list().stat_max(catalog)?;
104        // If the list is constant when we can compare each element to the value
105        if min == max {
106            let list_ = min
107                .as_opt::<Literal>()
108                .and_then(|l| l.data().as_list_opt())
109                .and_then(|l| l.elements())?;
110            if list_.is_empty() {
111                // contains([], x) is always false.
112                return Some(lit(true));
113            }
114            let value_max = expr.needle().stat_max(catalog)?;
115            let value_min = expr.needle().stat_min(catalog)?;
116
117            return list_
118                .iter()
119                .map(move |v| {
120                    or(
121                        lt(value_max.clone(), lit(v.clone())),
122                        gt(value_min.clone(), lit(v.clone())),
123                    )
124                })
125                .reduce(and);
126        }
127
128        None
129    }
130
131    // Nullability matters for contains([], x) where x is false.
132    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
133        true
134    }
135}
136
137/// Creates an expression that checks if a value is contained in a list.
138///
139/// Returns a boolean array indicating whether the value appears in each list.
140///
141/// ```rust
142/// # use vortex_array::expr::{list_contains, lit, root};
143/// let expr = list_contains(root(), lit(42));
144/// ```
145pub fn list_contains(list: Expression, value: Expression) -> Expression {
146    ListContains.new_expr((), [list, value])
147}
148
149impl ExpressionView<'_, ListContains> {
150    pub fn list(&self) -> &Expression {
151        &self.children()[0]
152    }
153
154    pub fn needle(&self) -> &Expression {
155        &self.children()[1]
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::sync::Arc;
162
163    use vortex_buffer::BitBuffer;
164    use vortex_dtype::DType;
165    use vortex_dtype::Field;
166    use vortex_dtype::FieldPath;
167    use vortex_dtype::FieldPathSet;
168    use vortex_dtype::Nullability;
169    use vortex_dtype::PType::I32;
170    use vortex_dtype::StructFields;
171    use vortex_scalar::Scalar;
172    use vortex_utils::aliases::hash_map::HashMap;
173    use vortex_utils::aliases::hash_set::HashSet;
174
175    use super::list_contains;
176    use crate::Array;
177    use crate::ArrayRef;
178    use crate::IntoArray;
179    use crate::arrays::BoolArray;
180    use crate::arrays::ListArray;
181    use crate::arrays::PrimitiveArray;
182    use crate::expr::exprs::binary::and;
183    use crate::expr::exprs::binary::gt;
184    use crate::expr::exprs::binary::lt;
185    use crate::expr::exprs::binary::or;
186    use crate::expr::exprs::get_item::col;
187    use crate::expr::exprs::get_item::get_item;
188    use crate::expr::exprs::literal::lit;
189    use crate::expr::exprs::root::root;
190    use crate::expr::pruning::checked_pruning_expr;
191    use crate::expr::stats::Stat;
192    use crate::validity::Validity;
193
194    fn test_array() -> ArrayRef {
195        ListArray::try_new(
196            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
197            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
198            Validity::AllValid,
199        )
200        .unwrap()
201        .into_array()
202    }
203
204    #[test]
205    pub fn test_one() {
206        let arr = test_array();
207
208        let expr = list_contains(root(), lit(1));
209        let item = expr.evaluate(&arr).unwrap();
210
211        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
212        assert_eq!(
213            item.scalar_at(1),
214            Scalar::bool(false, Nullability::Nullable)
215        );
216    }
217
218    #[test]
219    pub fn test_all() {
220        let arr = test_array();
221
222        let expr = list_contains(root(), lit(2));
223        let item = expr.evaluate(&arr).unwrap();
224
225        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
226        assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
227    }
228
229    #[test]
230    pub fn test_none() {
231        let arr = test_array();
232
233        let expr = list_contains(root(), lit(4));
234        let item = expr.evaluate(&arr).unwrap();
235
236        assert_eq!(
237            item.scalar_at(0),
238            Scalar::bool(false, Nullability::Nullable)
239        );
240        assert_eq!(
241            item.scalar_at(1),
242            Scalar::bool(false, Nullability::Nullable)
243        );
244    }
245
246    #[test]
247    pub fn test_empty() {
248        let arr = ListArray::try_new(
249            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
250            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
251            Validity::AllValid,
252        )
253        .unwrap()
254        .into_array();
255
256        let expr = list_contains(root(), lit(2));
257        let item = expr.evaluate(&arr).unwrap();
258
259        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
260        assert_eq!(
261            item.scalar_at(1),
262            Scalar::bool(false, Nullability::Nullable)
263        );
264    }
265
266    #[test]
267    pub fn test_nullable() {
268        let arr = ListArray::try_new(
269            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
270            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
271            Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
272        )
273        .unwrap()
274        .into_array();
275
276        let expr = list_contains(root(), lit(2));
277        let item = expr.evaluate(&arr).unwrap();
278
279        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
280        assert!(!item.is_valid(1));
281    }
282
283    #[test]
284    pub fn test_return_type() {
285        let scope = DType::Struct(
286            StructFields::new(
287                ["array"].into(),
288                vec![DType::List(
289                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
290                    Nullability::Nullable,
291                )],
292            ),
293            Nullability::NonNullable,
294        );
295
296        let expr = list_contains(get_item("array", root()), lit(2));
297
298        // Expect nullable, although scope is non-nullable
299        assert_eq!(
300            expr.return_dtype(&scope).unwrap(),
301            DType::Bool(Nullability::Nullable)
302        );
303    }
304
305    #[test]
306    pub fn list_falsification() {
307        let expr = list_contains(
308            lit(Scalar::list(
309                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
310                vec![1.into(), 2.into(), 3.into()],
311                Nullability::NonNullable,
312            )),
313            col("a"),
314        );
315
316        let (expr, st) = checked_pruning_expr(
317            &expr,
318            &FieldPathSet::from_iter([
319                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
320                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
321            ]),
322        )
323        .unwrap();
324
325        assert_eq!(
326            &expr,
327            &and(
328                and(
329                    or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
330                    or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
331                ),
332                or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
333            )
334        );
335
336        assert_eq!(
337            st.map(),
338            &HashMap::from_iter([(
339                FieldPath::from_name("a"),
340                HashSet::from([Stat::Min, Stat::Max])
341            )])
342        );
343    }
344
345    #[test]
346    pub fn test_display() {
347        let expr = list_contains(get_item("tags", root()), lit("urgent"));
348        assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
349
350        let expr2 = list_contains(root(), lit(42));
351        assert_eq!(expr2.to_string(), "contains($, 42i32)");
352    }
353}