1use std::fmt::Formatter;
5use std::ops::BitOr;
6
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11use vortex_session::VortexSession;
12
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::arrays::ConstantArray;
16use crate::compute::list_contains as compute_list_contains;
17use crate::expr::Arity;
18use crate::expr::ChildName;
19use crate::expr::EmptyOptions;
20use crate::expr::ExecutionArgs;
21use crate::expr::ExprId;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::VTable;
25use crate::expr::VTableExt;
26use crate::expr::and_collect;
27use crate::expr::exprs::binary::gt;
28use crate::expr::exprs::binary::lt;
29use crate::expr::exprs::binary::or;
30use crate::expr::exprs::literal::Literal;
31use crate::expr::exprs::literal::lit;
32use crate::scalar::Scalar;
33
34pub struct ListContains;
35
36impl VTable for ListContains {
37 type Options = EmptyOptions;
38
39 fn id(&self) -> ExprId {
40 ExprId::from("vortex.list.contains")
41 }
42
43 fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
44 Ok(Some(vec![]))
45 }
46
47 fn deserialize(
48 &self,
49 _metadata: &[u8],
50 _session: &VortexSession,
51 ) -> VortexResult<Self::Options> {
52 Ok(EmptyOptions)
53 }
54
55 fn arity(&self, _options: &Self::Options) -> Arity {
56 Arity::Exact(2)
57 }
58
59 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
60 match child_idx {
61 0 => ChildName::from("list"),
62 1 => ChildName::from("needle"),
63 _ => unreachable!(
64 "Invalid child index {} for ListContains expression",
65 child_idx
66 ),
67 }
68 }
69 fn fmt_sql(
70 &self,
71 _options: &Self::Options,
72 expr: &Expression,
73 f: &mut Formatter<'_>,
74 ) -> std::fmt::Result {
75 write!(f, "contains(")?;
76 expr.child(0).fmt_sql(f)?;
77 write!(f, ", ")?;
78 expr.child(1).fmt_sql(f)?;
79 write!(f, ")")
80 }
81
82 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
83 let list_dtype = &arg_dtypes[0];
84 let needle_dtype = &arg_dtypes[0];
85
86 let nullability = match list_dtype {
87 DType::List(_, list_nullability) => list_nullability,
88 _ => {
89 vortex_bail!(
90 "First argument to ListContains must be a List, got {:?}",
91 list_dtype
92 );
93 }
94 }
95 .bitor(needle_dtype.nullability());
96
97 Ok(DType::Bool(nullability))
98 }
99
100 fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
101 let [list_array, value_array]: [ArrayRef; _] = args
102 .inputs
103 .try_into()
104 .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
105
106 if let Some(list_scalar) = list_array.as_constant()
107 && let Some(value_scalar) = value_array.as_constant()
108 {
109 let result = compute_contains_scalar(&list_scalar, &value_scalar)?;
110 return Ok(ConstantArray::new(result, args.row_count).into_array());
111 }
112
113 compute_list_contains(list_array.as_ref(), value_array.as_ref())?.execute(args.ctx)
114 }
115
116 fn stat_falsification(
117 &self,
118 _options: &Self::Options,
119 expr: &Expression,
120 catalog: &dyn StatsCatalog,
121 ) -> Option<Expression> {
122 let list = expr.child(0);
123 let needle = expr.child(1);
124
125 let min = list.stat_min(catalog)?;
128 let max = list.stat_max(catalog)?;
129 if min == max {
131 let list_ = min
132 .as_opt::<Literal>()
133 .and_then(|l| l.as_list_opt())
134 .and_then(|l| l.elements())?;
135 if list_.is_empty() {
136 return Some(lit(true));
138 }
139 let value_max = needle.stat_max(catalog)?;
140 let value_min = needle.stat_min(catalog)?;
141
142 return and_collect(list_.iter().map(move |v| {
143 or(
144 lt(value_max.clone(), lit(v.clone())),
145 gt(value_min.clone(), lit(v.clone())),
146 )
147 }));
148 }
149
150 None
151 }
152
153 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
155 true
156 }
157
158 fn is_fallible(&self, _options: &Self::Options) -> bool {
159 false
160 }
161}
162
163fn compute_contains_scalar(list: &Scalar, needle: &Scalar) -> VortexResult<Scalar> {
164 let nullability = list.dtype().nullability() | needle.dtype().nullability();
165
166 if list.is_null() || needle.is_null() {
168 return Ok(Scalar::null(DType::Bool(nullability)));
169 }
170
171 let list_scalar = list.as_list();
172 let elements = list_scalar
173 .elements()
174 .ok_or_else(|| vortex_err!("Expected non-null list"))?;
175
176 let contains = elements.iter().any(|elem| elem == needle);
177 Ok(Scalar::bool(contains, nullability))
178}
179
180pub fn list_contains(list: Expression, value: Expression) -> Expression {
189 ListContains.new_expr(EmptyOptions, [list, value])
190}
191
192#[cfg(test)]
193mod tests {
194 use std::sync::Arc;
195
196 use vortex_buffer::BitBuffer;
197 use vortex_dtype::DType;
198 use vortex_dtype::Field;
199 use vortex_dtype::FieldPath;
200 use vortex_dtype::FieldPathSet;
201 use vortex_dtype::Nullability;
202 use vortex_dtype::PType::I32;
203 use vortex_dtype::StructFields;
204 use vortex_utils::aliases::hash_map::HashMap;
205 use vortex_utils::aliases::hash_set::HashSet;
206
207 use super::list_contains;
208 use crate::Array;
209 use crate::ArrayRef;
210 use crate::IntoArray;
211 use crate::arrays::BoolArray;
212 use crate::arrays::ListArray;
213 use crate::arrays::PrimitiveArray;
214 use crate::expr::exprs::binary::and;
215 use crate::expr::exprs::binary::gt;
216 use crate::expr::exprs::binary::lt;
217 use crate::expr::exprs::binary::or;
218 use crate::expr::exprs::get_item::col;
219 use crate::expr::exprs::get_item::get_item;
220 use crate::expr::exprs::literal::lit;
221 use crate::expr::exprs::root::root;
222 use crate::expr::pruning::checked_pruning_expr;
223 use crate::expr::stats::Stat;
224 use crate::scalar::Scalar;
225 use crate::validity::Validity;
226
227 fn test_array() -> ArrayRef {
228 ListArray::try_new(
229 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
230 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
231 Validity::AllValid,
232 )
233 .unwrap()
234 .into_array()
235 }
236
237 #[test]
238 pub fn test_one() {
239 let arr = test_array();
240
241 let expr = list_contains(root(), lit(1));
242 let item = arr.apply(&expr).unwrap();
243
244 assert_eq!(
245 item.scalar_at(0).unwrap(),
246 Scalar::bool(true, Nullability::Nullable)
247 );
248 assert_eq!(
249 item.scalar_at(1).unwrap(),
250 Scalar::bool(false, Nullability::Nullable)
251 );
252 }
253
254 #[test]
255 pub fn test_all() {
256 let arr = test_array();
257
258 let expr = list_contains(root(), lit(2));
259 let item = arr.apply(&expr).unwrap();
260
261 assert_eq!(
262 item.scalar_at(0).unwrap(),
263 Scalar::bool(true, Nullability::Nullable)
264 );
265 assert_eq!(
266 item.scalar_at(1).unwrap(),
267 Scalar::bool(true, Nullability::Nullable)
268 );
269 }
270
271 #[test]
272 pub fn test_none() {
273 let arr = test_array();
274
275 let expr = list_contains(root(), lit(4));
276 let item = arr.apply(&expr).unwrap();
277
278 assert_eq!(
279 item.scalar_at(0).unwrap(),
280 Scalar::bool(false, Nullability::Nullable)
281 );
282 assert_eq!(
283 item.scalar_at(1).unwrap(),
284 Scalar::bool(false, Nullability::Nullable)
285 );
286 }
287
288 #[test]
289 pub fn test_empty() {
290 let arr = ListArray::try_new(
291 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
292 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
293 Validity::AllValid,
294 )
295 .unwrap()
296 .into_array();
297
298 let expr = list_contains(root(), lit(2));
299 let item = arr.apply(&expr).unwrap();
300
301 assert_eq!(
302 item.scalar_at(0).unwrap(),
303 Scalar::bool(true, Nullability::Nullable)
304 );
305 assert_eq!(
306 item.scalar_at(1).unwrap(),
307 Scalar::bool(false, Nullability::Nullable)
308 );
309 }
310
311 #[test]
312 pub fn test_nullable() {
313 let arr = ListArray::try_new(
314 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
315 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
316 Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
317 )
318 .unwrap()
319 .into_array();
320
321 let expr = list_contains(root(), lit(2));
322 let item = arr.apply(&expr).unwrap();
323
324 assert_eq!(
325 item.scalar_at(0).unwrap(),
326 Scalar::bool(true, Nullability::Nullable)
327 );
328 assert!(!item.is_valid(1).unwrap());
329 }
330
331 #[test]
332 pub fn test_return_type() {
333 let scope = DType::Struct(
334 StructFields::new(
335 ["array"].into(),
336 vec![DType::List(
337 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
338 Nullability::Nullable,
339 )],
340 ),
341 Nullability::NonNullable,
342 );
343
344 let expr = list_contains(get_item("array", root()), lit(2));
345
346 assert_eq!(
348 expr.return_dtype(&scope).unwrap(),
349 DType::Bool(Nullability::Nullable)
350 );
351 }
352
353 #[test]
354 pub fn list_falsification() {
355 let expr = list_contains(
356 lit(Scalar::list(
357 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
358 vec![1.into(), 2.into(), 3.into()],
359 Nullability::NonNullable,
360 )),
361 col("a"),
362 );
363
364 let (expr, st) = checked_pruning_expr(
365 &expr,
366 &FieldPathSet::from_iter([
367 FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
368 FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
369 ]),
370 )
371 .unwrap();
372
373 assert_eq!(
374 &expr,
375 &and(
376 and(
377 or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
378 or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
379 ),
380 or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
381 )
382 );
383
384 assert_eq!(
385 st.map(),
386 &HashMap::from_iter([(
387 FieldPath::from_name("a"),
388 HashSet::from([Stat::Min, Stat::Max])
389 )])
390 );
391 }
392
393 #[test]
394 pub fn test_display() {
395 let expr = list_contains(get_item("tags", root()), lit("urgent"));
396 assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
397
398 let expr2 = list_contains(root(), lit(42));
399 assert_eq!(expr2.to_string(), "contains($, 42i32)");
400 }
401
402 #[test]
403 pub fn test_constant_scalars() {
404 let arr = test_array();
405
406 let list_scalar = Scalar::list(
408 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
409 vec![1.into(), 2.into(), 3.into()],
410 Nullability::NonNullable,
411 );
412
413 let expr = list_contains(lit(list_scalar.clone()), lit(2i32));
415 let result = arr.apply(&expr).unwrap();
416 assert_eq!(
417 result.scalar_at(0).unwrap(),
418 Scalar::bool(true, Nullability::NonNullable)
419 );
420
421 let expr = list_contains(lit(list_scalar), lit(42i32));
423 let result = arr.apply(&expr).unwrap();
424 assert_eq!(
425 result.scalar_at(0).unwrap(),
426 Scalar::bool(false, Nullability::NonNullable)
427 );
428 }
429}