1use std::fmt::Formatter;
5use std::ops::BitOr;
6use std::ops::Deref;
7
8use arrow_buffer::bit_iterator::BitIndexIterator;
9use vortex_buffer::BitBuffer;
10use vortex_compute::logical::LogicalOr;
11use vortex_dtype::DType;
12use vortex_dtype::IntegerPType;
13use vortex_dtype::Nullability;
14use vortex_dtype::PTypeDowncastExt;
15use vortex_dtype::match_each_integer_ptype;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_error::vortex_err;
19use vortex_mask::Mask;
20use vortex_vector::BoolDatum;
21use vortex_vector::Datum;
22use vortex_vector::Vector;
23use vortex_vector::VectorOps;
24use vortex_vector::bool::BoolScalar;
25use vortex_vector::bool::BoolVector;
26use vortex_vector::listview::ListViewScalar;
27use vortex_vector::listview::ListViewVector;
28use vortex_vector::primitive::PVector;
29
30use crate::ArrayRef;
31use crate::compute::list_contains as compute_list_contains;
32use crate::expr::Arity;
33use crate::expr::Binary;
34use crate::expr::ChildName;
35use crate::expr::EmptyOptions;
36use crate::expr::ExecutionArgs;
37use crate::expr::ExprId;
38use crate::expr::Expression;
39use crate::expr::StatsCatalog;
40use crate::expr::VTable;
41use crate::expr::VTableExt;
42use crate::expr::exprs::binary::and;
43use crate::expr::exprs::binary::gt;
44use crate::expr::exprs::binary::lt;
45use crate::expr::exprs::binary::or;
46use crate::expr::exprs::literal::Literal;
47use crate::expr::exprs::literal::lit;
48use crate::expr::operators;
49
50pub struct ListContains;
51
52impl VTable for ListContains {
53 type Options = EmptyOptions;
54
55 fn id(&self) -> ExprId {
56 ExprId::from("vortex.list.contains")
57 }
58
59 fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60 Ok(Some(vec![]))
61 }
62
63 fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Self::Options> {
64 Ok(EmptyOptions)
65 }
66
67 fn arity(&self, _options: &Self::Options) -> Arity {
68 Arity::Exact(2)
69 }
70
71 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
72 match child_idx {
73 0 => ChildName::from("list"),
74 1 => ChildName::from("needle"),
75 _ => unreachable!(
76 "Invalid child index {} for ListContains expression",
77 child_idx
78 ),
79 }
80 }
81 fn fmt_sql(
82 &self,
83 _options: &Self::Options,
84 expr: &Expression,
85 f: &mut Formatter<'_>,
86 ) -> std::fmt::Result {
87 write!(f, "contains(")?;
88 expr.child(0).fmt_sql(f)?;
89 write!(f, ", ")?;
90 expr.child(1).fmt_sql(f)?;
91 write!(f, ")")
92 }
93
94 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
95 let list_dtype = &arg_dtypes[0];
96 let needle_dtype = &arg_dtypes[0];
97
98 let nullability = match list_dtype {
99 DType::List(_, list_nullability) => list_nullability,
100 _ => {
101 vortex_bail!(
102 "First argument to ListContains must be a List, got {:?}",
103 list_dtype
104 );
105 }
106 }
107 .bitor(needle_dtype.nullability());
108
109 Ok(DType::Bool(nullability))
110 }
111
112 fn evaluate(
113 &self,
114 _options: &Self::Options,
115 expr: &Expression,
116 scope: &ArrayRef,
117 ) -> VortexResult<ArrayRef> {
118 let list_array = expr.child(0).evaluate(scope)?;
119 let value_array = expr.child(1).evaluate(scope)?;
120 compute_list_contains(list_array.as_ref(), value_array.as_ref())
121 }
122
123 fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
124 let [lhs, rhs]: [Datum; _] = args
125 .datums
126 .try_into()
127 .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
128
129 match (lhs, rhs) {
130 (Datum::Scalar(list_scalar), Datum::Scalar(needle_scalar)) => {
131 let list = list_scalar.into_list();
132 let found = list_contains_scalar_scalar(&list, &needle_scalar)?;
133 Ok(Datum::Scalar(BoolScalar::new(Some(found)).into()))
134 }
135 (Datum::Scalar(list_scalar), Datum::Vector(needle_vector)) => {
136 let matches =
137 constant_list_scalar_contains(list_scalar.into_list(), needle_vector)?;
138 Ok(Datum::Vector(matches.into()))
139 }
140 (Datum::Vector(list_vector), Datum::Scalar(needle_scalar)) => {
141 let matches =
142 list_contains_scalar(list_vector.into_list(), needle_scalar.into_list())?;
143 Ok(Datum::Vector(matches.into()))
144 }
145 (Datum::Vector(_), Datum::Vector(_)) => {
146 vortex_bail!(
147 "ListContains currently only supports constant needle (RHS) or constant list (LHS)"
148 )
149 }
150 }
151 }
152
153 fn stat_falsification(
154 &self,
155 _options: &Self::Options,
156 expr: &Expression,
157 catalog: &dyn StatsCatalog,
158 ) -> Option<Expression> {
159 let list = expr.child(0);
160 let needle = expr.child(1);
161
162 let min = list.stat_min(catalog)?;
165 let max = list.stat_max(catalog)?;
166 if min == max {
168 let list_ = min
169 .as_opt::<Literal>()
170 .and_then(|l| l.as_list_opt())
171 .and_then(|l| l.elements())?;
172 if list_.is_empty() {
173 return Some(lit(true));
175 }
176 let value_max = needle.stat_max(catalog)?;
177 let value_min = needle.stat_min(catalog)?;
178
179 return list_
180 .iter()
181 .map(move |v| {
182 or(
183 lt(value_max.clone(), lit(v.clone())),
184 gt(value_min.clone(), lit(v.clone())),
185 )
186 })
187 .reduce(and);
188 }
189
190 None
191 }
192
193 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
195 true
196 }
197
198 fn is_fallible(&self, _options: &Self::Options) -> bool {
199 false
200 }
201}
202
203pub fn list_contains(list: Expression, value: Expression) -> Expression {
212 ListContains.new_expr(EmptyOptions, [list, value])
213}
214
215fn list_contains_scalar(list: ListViewVector, value: ListViewScalar) -> VortexResult<BoolVector> {
218 let elems = list.elements();
225 if elems.is_empty() {
226 todo!()
229 }
230
231 let matches = Binary
232 .bind(operators::Operator::Eq)
233 .execute(ExecutionArgs {
234 datums: vec![
235 Datum::Vector(elems.deref().clone()),
236 Datum::Scalar(value.into()),
237 ],
238 dtypes: vec![],
240 row_count: elems.len(),
241 return_dtype: DType::Bool(Nullability::Nullable),
242 })?
243 .unwrap_into_vector(elems.len())
244 .into_bool()
245 .into_bits();
246
247 let offsets = list.offsets();
278 let sizes = list.sizes();
279
280 let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| {
282 match_each_integer_ptype!(sizes.ptype(), |S| {
283 process_matches::<O, S>(
284 matches,
285 list.len(),
286 offsets.downcast::<O>(),
287 sizes.downcast::<S>(),
288 )
289 })
290 });
291
292 Ok(BoolVector::new(list_matches, list.validity().clone()))
293}
294
295fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> VortexResult<BoolVector> {
298 let elements = list.value().elements();
299
300 let mut result: BoolVector = BoolVector::new(
303 BitBuffer::new_unset(values.len()),
304 Mask::new(values.len(), true),
305 );
306 for i in 0..elements.len() {
307 let element = Datum::Scalar(elements.scalar_at(i));
308 let compared: BoolDatum = Binary
309 .bind(operators::Operator::Eq)
310 .execute(ExecutionArgs {
311 datums: vec![Datum::Vector(values.clone()), element],
312 dtypes: vec![
313 ],
315 row_count: values.len(),
316 return_dtype: DType::Bool(Nullability::Nullable),
317 })?
318 .into_bool();
319 let compared = Datum::from(compared)
320 .unwrap_into_vector(values.len())
321 .into_bool();
322
323 result = LogicalOr::or(&result, &compared);
324 }
325
326 Ok(result)
327}
328
329fn list_contains_scalar_scalar(
331 list: &ListViewScalar,
332 needle: &vortex_vector::Scalar,
333) -> VortexResult<bool> {
334 let elements = list.value().elements();
335
336 let found = Binary
340 .bind(operators::Operator::Eq)
341 .execute(ExecutionArgs {
342 datums: vec![
343 Datum::Vector(elements.deref().clone()),
344 Datum::Scalar(needle.clone()),
345 ],
346 dtypes: vec![],
347 row_count: elements.len(),
348 return_dtype: DType::Bool(Nullability::Nullable),
349 })?
350 .unwrap_into_vector(elements.len())
351 .into_bool()
352 .into_bits();
353
354 let mut true_bits = BitIndexIterator::new(found.inner().as_ref(), 0, found.len());
355 Ok(true_bits.next().is_some())
356}
357
358fn process_matches<O, S>(
363 matches: BitBuffer,
364 list_array_len: usize,
365 offsets: &PVector<O>,
366 sizes: &PVector<S>,
367) -> BitBuffer
368where
369 O: IntegerPType,
370 S: IntegerPType,
371{
372 let offsets_slice = offsets.elements().as_slice();
373 let sizes_slice = sizes.elements().as_slice();
374
375 (0..list_array_len)
376 .map(|i| {
377 let offset = offsets_slice[i].as_();
379 let size = sizes_slice[i].as_();
380
381 let mut set_bits =
384 BitIndexIterator::new(matches.inner().as_slice(), matches.offset() + offset, size);
385 set_bits.next().is_some()
386 })
387 .collect::<BitBuffer>()
388}
389
390#[cfg(test)]
391mod tests {
392 use std::sync::Arc;
393
394 use vortex_buffer::BitBuffer;
395 use vortex_dtype::DType;
396 use vortex_dtype::Field;
397 use vortex_dtype::FieldPath;
398 use vortex_dtype::FieldPathSet;
399 use vortex_dtype::Nullability;
400 use vortex_dtype::PType::I32;
401 use vortex_dtype::StructFields;
402 use vortex_scalar::Scalar;
403 use vortex_utils::aliases::hash_map::HashMap;
404 use vortex_utils::aliases::hash_set::HashSet;
405
406 use super::list_contains;
407 use crate::Array;
408 use crate::ArrayRef;
409 use crate::IntoArray;
410 use crate::arrays::BoolArray;
411 use crate::arrays::ListArray;
412 use crate::arrays::PrimitiveArray;
413 use crate::expr::exprs::binary::and;
414 use crate::expr::exprs::binary::gt;
415 use crate::expr::exprs::binary::lt;
416 use crate::expr::exprs::binary::or;
417 use crate::expr::exprs::get_item::col;
418 use crate::expr::exprs::get_item::get_item;
419 use crate::expr::exprs::literal::lit;
420 use crate::expr::exprs::root::root;
421 use crate::expr::pruning::checked_pruning_expr;
422 use crate::expr::stats::Stat;
423 use crate::validity::Validity;
424
425 fn test_array() -> ArrayRef {
426 ListArray::try_new(
427 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
428 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
429 Validity::AllValid,
430 )
431 .unwrap()
432 .into_array()
433 }
434
435 #[test]
436 pub fn test_one() {
437 let arr = test_array();
438
439 let expr = list_contains(root(), lit(1));
440 let item = expr.evaluate(&arr).unwrap();
441
442 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
443 assert_eq!(
444 item.scalar_at(1),
445 Scalar::bool(false, Nullability::Nullable)
446 );
447 }
448
449 #[test]
450 pub fn test_all() {
451 let arr = test_array();
452
453 let expr = list_contains(root(), lit(2));
454 let item = expr.evaluate(&arr).unwrap();
455
456 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
457 assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
458 }
459
460 #[test]
461 pub fn test_none() {
462 let arr = test_array();
463
464 let expr = list_contains(root(), lit(4));
465 let item = expr.evaluate(&arr).unwrap();
466
467 assert_eq!(
468 item.scalar_at(0),
469 Scalar::bool(false, Nullability::Nullable)
470 );
471 assert_eq!(
472 item.scalar_at(1),
473 Scalar::bool(false, Nullability::Nullable)
474 );
475 }
476
477 #[test]
478 pub fn test_empty() {
479 let arr = ListArray::try_new(
480 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
481 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
482 Validity::AllValid,
483 )
484 .unwrap()
485 .into_array();
486
487 let expr = list_contains(root(), lit(2));
488 let item = expr.evaluate(&arr).unwrap();
489
490 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
491 assert_eq!(
492 item.scalar_at(1),
493 Scalar::bool(false, Nullability::Nullable)
494 );
495 }
496
497 #[test]
498 pub fn test_nullable() {
499 let arr = ListArray::try_new(
500 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
501 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
502 Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
503 )
504 .unwrap()
505 .into_array();
506
507 let expr = list_contains(root(), lit(2));
508 let item = expr.evaluate(&arr).unwrap();
509
510 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
511 assert!(!item.is_valid(1));
512 }
513
514 #[test]
515 pub fn test_return_type() {
516 let scope = DType::Struct(
517 StructFields::new(
518 ["array"].into(),
519 vec![DType::List(
520 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
521 Nullability::Nullable,
522 )],
523 ),
524 Nullability::NonNullable,
525 );
526
527 let expr = list_contains(get_item("array", root()), lit(2));
528
529 assert_eq!(
531 expr.return_dtype(&scope).unwrap(),
532 DType::Bool(Nullability::Nullable)
533 );
534 }
535
536 #[test]
537 pub fn list_falsification() {
538 let expr = list_contains(
539 lit(Scalar::list(
540 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
541 vec![1.into(), 2.into(), 3.into()],
542 Nullability::NonNullable,
543 )),
544 col("a"),
545 );
546
547 let (expr, st) = checked_pruning_expr(
548 &expr,
549 &FieldPathSet::from_iter([
550 FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
551 FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
552 ]),
553 )
554 .unwrap();
555
556 assert_eq!(
557 &expr,
558 &and(
559 and(
560 or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
561 or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
562 ),
563 or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
564 )
565 );
566
567 assert_eq!(
568 st.map(),
569 &HashMap::from_iter([(
570 FieldPath::from_name("a"),
571 HashSet::from([Stat::Min, Stat::Max])
572 )])
573 );
574 }
575
576 #[test]
577 pub fn test_display() {
578 let expr = list_contains(get_item("tags", root()), lit("urgent"));
579 assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
580
581 let expr2 = list_contains(root(), lit(42));
582 assert_eq!(expr2.to_string(), "contains($, 42i32)");
583 }
584}