1use std::fmt::Formatter;
5
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::compute::list_contains as compute_list_contains;
12use crate::expr::ChildName;
13use crate::expr::ExprId;
14use crate::expr::Expression;
15use crate::expr::ExpressionView;
16use crate::expr::StatsCatalog;
17use crate::expr::VTable;
18use crate::expr::VTableExt;
19use crate::expr::exprs::binary::and;
20use crate::expr::exprs::binary::gt;
21use crate::expr::exprs::binary::lt;
22use crate::expr::exprs::binary::or;
23use crate::expr::exprs::literal::Literal;
24use crate::expr::exprs::literal::lit;
25
26pub struct ListContains;
27
28impl VTable for ListContains {
29 type Instance = ();
30
31 fn id(&self) -> ExprId {
32 ExprId::from("vortex.list.contains")
33 }
34
35 fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
36 Ok(Some(vec![]))
37 }
38
39 fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
40 Ok(Some(()))
41 }
42
43 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
44 if expr.children().len() != 2 {
45 vortex_bail!(
46 "ListContains expression requires exactly 2 children, got {}",
47 expr.children().len()
48 );
49 }
50 Ok(())
51 }
52
53 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
54 match child_idx {
55 0 => ChildName::from("list"),
56 1 => ChildName::from("needle"),
57 _ => unreachable!(
58 "Invalid child index {} for ListContains expression",
59 child_idx
60 ),
61 }
62 }
63
64 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
65 write!(f, "contains(")?;
66 expr.child(0).fmt_sql(f)?;
67 write!(f, ", ")?;
68 expr.child(1).fmt_sql(f)?;
69 write!(f, ")")
70 }
71
72 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
73 let list_dtype = expr.child(0).return_dtype(scope)?;
74 let value_dtype = expr.child(1).return_dtype(scope)?;
75
76 let nullability = match list_dtype {
77 DType::List(_, list_nullability) => list_nullability,
78 _ => {
79 vortex_bail!(
80 "First argument to ListContains must be a List, got {:?}",
81 list_dtype
82 );
83 }
84 } | value_dtype.nullability();
85
86 Ok(DType::Bool(nullability))
87 }
88
89 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
90 let list_array = expr.child(0).evaluate(scope)?;
91 let value_array = expr.child(1).evaluate(scope)?;
92 compute_list_contains(list_array.as_ref(), value_array.as_ref())
93 }
94
95 fn stat_falsification(
96 &self,
97 expr: &ExpressionView<Self>,
98 catalog: &dyn StatsCatalog,
99 ) -> Option<Expression> {
100 let min = expr.list().stat_min(catalog)?;
103 let max = expr.list().stat_max(catalog)?;
104 if min == max {
106 let list_ = min
107 .as_opt::<Literal>()
108 .and_then(|l| l.data().as_list_opt())
109 .and_then(|l| l.elements())?;
110 if list_.is_empty() {
111 return Some(lit(true));
113 }
114 let value_max = expr.needle().stat_max(catalog)?;
115 let value_min = expr.needle().stat_min(catalog)?;
116
117 return list_
118 .iter()
119 .map(move |v| {
120 or(
121 lt(value_max.clone(), lit(v.clone())),
122 gt(value_min.clone(), lit(v.clone())),
123 )
124 })
125 .reduce(and);
126 }
127
128 None
129 }
130
131 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
133 true
134 }
135}
136
137pub fn list_contains(list: Expression, value: Expression) -> Expression {
146 ListContains.new_expr((), [list, value])
147}
148
149impl ExpressionView<'_, ListContains> {
150 pub fn list(&self) -> &Expression {
151 &self.children()[0]
152 }
153
154 pub fn needle(&self) -> &Expression {
155 &self.children()[1]
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::sync::Arc;
162
163 use vortex_buffer::BitBuffer;
164 use vortex_dtype::DType;
165 use vortex_dtype::Field;
166 use vortex_dtype::FieldPath;
167 use vortex_dtype::FieldPathSet;
168 use vortex_dtype::Nullability;
169 use vortex_dtype::PType::I32;
170 use vortex_dtype::StructFields;
171 use vortex_scalar::Scalar;
172 use vortex_utils::aliases::hash_map::HashMap;
173 use vortex_utils::aliases::hash_set::HashSet;
174
175 use super::list_contains;
176 use crate::Array;
177 use crate::ArrayRef;
178 use crate::IntoArray;
179 use crate::arrays::BoolArray;
180 use crate::arrays::ListArray;
181 use crate::arrays::PrimitiveArray;
182 use crate::expr::exprs::binary::and;
183 use crate::expr::exprs::binary::gt;
184 use crate::expr::exprs::binary::lt;
185 use crate::expr::exprs::binary::or;
186 use crate::expr::exprs::get_item::col;
187 use crate::expr::exprs::get_item::get_item;
188 use crate::expr::exprs::literal::lit;
189 use crate::expr::exprs::root::root;
190 use crate::expr::pruning::checked_pruning_expr;
191 use crate::expr::stats::Stat;
192 use crate::validity::Validity;
193
194 fn test_array() -> ArrayRef {
195 ListArray::try_new(
196 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
197 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
198 Validity::AllValid,
199 )
200 .unwrap()
201 .into_array()
202 }
203
204 #[test]
205 pub fn test_one() {
206 let arr = test_array();
207
208 let expr = list_contains(root(), lit(1));
209 let item = expr.evaluate(&arr).unwrap();
210
211 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
212 assert_eq!(
213 item.scalar_at(1),
214 Scalar::bool(false, Nullability::Nullable)
215 );
216 }
217
218 #[test]
219 pub fn test_all() {
220 let arr = test_array();
221
222 let expr = list_contains(root(), lit(2));
223 let item = expr.evaluate(&arr).unwrap();
224
225 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
226 assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
227 }
228
229 #[test]
230 pub fn test_none() {
231 let arr = test_array();
232
233 let expr = list_contains(root(), lit(4));
234 let item = expr.evaluate(&arr).unwrap();
235
236 assert_eq!(
237 item.scalar_at(0),
238 Scalar::bool(false, Nullability::Nullable)
239 );
240 assert_eq!(
241 item.scalar_at(1),
242 Scalar::bool(false, Nullability::Nullable)
243 );
244 }
245
246 #[test]
247 pub fn test_empty() {
248 let arr = ListArray::try_new(
249 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
250 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
251 Validity::AllValid,
252 )
253 .unwrap()
254 .into_array();
255
256 let expr = list_contains(root(), lit(2));
257 let item = expr.evaluate(&arr).unwrap();
258
259 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
260 assert_eq!(
261 item.scalar_at(1),
262 Scalar::bool(false, Nullability::Nullable)
263 );
264 }
265
266 #[test]
267 pub fn test_nullable() {
268 let arr = ListArray::try_new(
269 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
270 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
271 Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
272 )
273 .unwrap()
274 .into_array();
275
276 let expr = list_contains(root(), lit(2));
277 let item = expr.evaluate(&arr).unwrap();
278
279 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
280 assert!(!item.is_valid(1));
281 }
282
283 #[test]
284 pub fn test_return_type() {
285 let scope = DType::Struct(
286 StructFields::new(
287 ["array"].into(),
288 vec![DType::List(
289 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
290 Nullability::Nullable,
291 )],
292 ),
293 Nullability::NonNullable,
294 );
295
296 let expr = list_contains(get_item("array", root()), lit(2));
297
298 assert_eq!(
300 expr.return_dtype(&scope).unwrap(),
301 DType::Bool(Nullability::Nullable)
302 );
303 }
304
305 #[test]
306 pub fn list_falsification() {
307 let expr = list_contains(
308 lit(Scalar::list(
309 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
310 vec![1.into(), 2.into(), 3.into()],
311 Nullability::NonNullable,
312 )),
313 col("a"),
314 );
315
316 let (expr, st) = checked_pruning_expr(
317 &expr,
318 &FieldPathSet::from_iter([
319 FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
320 FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
321 ]),
322 )
323 .unwrap();
324
325 assert_eq!(
326 &expr,
327 &and(
328 and(
329 or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
330 or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
331 ),
332 or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
333 )
334 );
335
336 assert_eq!(
337 st.map(),
338 &HashMap::from_iter([(
339 FieldPath::from_name("a"),
340 HashSet::from([Stat::Min, Stat::Max])
341 )])
342 );
343 }
344
345 #[test]
346 pub fn test_display() {
347 let expr = list_contains(get_item("tags", root()), lit("urgent"));
348 assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
349
350 let expr2 = list_contains(root(), lit(42));
351 assert_eq!(expr2.to_string(), "contains($, 42i32)");
352 }
353}