1use std::fmt::Formatter;
5
6use vortex_dtype::DType;
7use vortex_error::{VortexResult, vortex_bail};
8
9use crate::ArrayRef;
10use crate::compute::list_contains as compute_list_contains;
11use crate::expr::exprs::binary::{and, gt, lt, or};
12use crate::expr::exprs::literal::{Literal, lit};
13use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
14
15pub struct ListContains;
16
17impl VTable for ListContains {
18 type Instance = ();
19
20 fn id(&self) -> ExprId {
21 ExprId::from("vortex.list.contains")
22 }
23
24 fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
25 Ok(Some(vec![]))
26 }
27
28 fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
29 Ok(Some(()))
30 }
31
32 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
33 if expr.children().len() != 2 {
34 vortex_bail!(
35 "ListContains expression requires exactly 2 children, got {}",
36 expr.children().len()
37 );
38 }
39 Ok(())
40 }
41
42 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
43 match child_idx {
44 0 => ChildName::from("list"),
45 1 => ChildName::from("needle"),
46 _ => unreachable!(
47 "Invalid child index {} for ListContains expression",
48 child_idx
49 ),
50 }
51 }
52
53 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
54 write!(f, "contains(")?;
55 expr.child(0).fmt_sql(f)?;
56 write!(f, ", ")?;
57 expr.child(1).fmt_sql(f)?;
58 write!(f, ")")
59 }
60
61 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
62 let list_dtype = expr.child(0).return_dtype(scope)?;
63 let value_dtype = expr.child(1).return_dtype(scope)?;
64
65 let nullability = match list_dtype {
66 DType::List(_, list_nullability) => list_nullability,
67 _ => {
68 vortex_bail!(
69 "First argument to ListContains must be a List, got {:?}",
70 list_dtype
71 );
72 }
73 } | value_dtype.nullability();
74
75 Ok(DType::Bool(nullability))
76 }
77
78 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
79 let list_array = expr.child(0).evaluate(scope)?;
80 let value_array = expr.child(1).evaluate(scope)?;
81 compute_list_contains(list_array.as_ref(), value_array.as_ref())
82 }
83
84 fn stat_falsification(
85 &self,
86 expr: &ExpressionView<Self>,
87 catalog: &mut dyn StatsCatalog,
88 ) -> Option<Expression> {
89 let min = expr.list().stat_min(catalog)?;
92 let max = expr.list().stat_max(catalog)?;
93 if min == max {
95 let list_ = min
96 .as_opt::<Literal>()
97 .and_then(|l| l.data().as_list_opt())
98 .and_then(|l| l.elements())?;
99 if list_.is_empty() {
100 return Some(lit(true));
102 }
103 let value_max = expr.needle().stat_max(catalog)?;
104 let value_min = expr.needle().stat_min(catalog)?;
105
106 return list_
107 .iter()
108 .map(move |v| {
109 or(
110 lt(value_max.clone(), lit(v.clone())),
111 gt(value_min.clone(), lit(v.clone())),
112 )
113 })
114 .reduce(and);
115 }
116
117 None
118 }
119}
120
121pub fn list_contains(list: Expression, value: Expression) -> Expression {
130 ListContains.new_expr((), [list, value])
131}
132
133impl ExpressionView<'_, ListContains> {
134 pub fn list(&self) -> &Expression {
135 &self.children()[0]
136 }
137
138 pub fn needle(&self) -> &Expression {
139 &self.children()[1]
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::sync::Arc;
146
147 use vortex_buffer::BitBuffer;
148 use vortex_dtype::PType::I32;
149 use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability, StructFields};
150 use vortex_scalar::Scalar;
151 use vortex_utils::aliases::hash_map::HashMap;
152 use vortex_utils::aliases::hash_set::HashSet;
153
154 use super::list_contains;
155 use crate::arrays::{BoolArray, ListArray, PrimitiveArray};
156 use crate::expr::exprs::binary::{and, gt, lt, or};
157 use crate::expr::exprs::get_item::{col, get_item};
158 use crate::expr::exprs::literal::lit;
159 use crate::expr::exprs::root::root;
160 use crate::expr::pruning::checked_pruning_expr;
161 use crate::stats::Stat;
162 use crate::validity::Validity;
163 use crate::{Array, ArrayRef, IntoArray};
164
165 fn test_array() -> ArrayRef {
166 ListArray::try_new(
167 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
168 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
169 Validity::AllValid,
170 )
171 .unwrap()
172 .into_array()
173 }
174
175 #[test]
176 pub fn test_one() {
177 let arr = test_array();
178
179 let expr = list_contains(root(), lit(1));
180 let item = expr.evaluate(&arr).unwrap();
181
182 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
183 assert_eq!(
184 item.scalar_at(1),
185 Scalar::bool(false, Nullability::Nullable)
186 );
187 }
188
189 #[test]
190 pub fn test_all() {
191 let arr = test_array();
192
193 let expr = list_contains(root(), lit(2));
194 let item = expr.evaluate(&arr).unwrap();
195
196 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
197 assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
198 }
199
200 #[test]
201 pub fn test_none() {
202 let arr = test_array();
203
204 let expr = list_contains(root(), lit(4));
205 let item = expr.evaluate(&arr).unwrap();
206
207 assert_eq!(
208 item.scalar_at(0),
209 Scalar::bool(false, Nullability::Nullable)
210 );
211 assert_eq!(
212 item.scalar_at(1),
213 Scalar::bool(false, Nullability::Nullable)
214 );
215 }
216
217 #[test]
218 pub fn test_empty() {
219 let arr = ListArray::try_new(
220 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
221 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
222 Validity::AllValid,
223 )
224 .unwrap()
225 .into_array();
226
227 let expr = list_contains(root(), lit(2));
228 let item = expr.evaluate(&arr).unwrap();
229
230 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
231 assert_eq!(
232 item.scalar_at(1),
233 Scalar::bool(false, Nullability::Nullable)
234 );
235 }
236
237 #[test]
238 pub fn test_nullable() {
239 let arr = ListArray::try_new(
240 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
241 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
242 Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
243 )
244 .unwrap()
245 .into_array();
246
247 let expr = list_contains(root(), lit(2));
248 let item = expr.evaluate(&arr).unwrap();
249
250 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
251 assert!(!item.is_valid(1));
252 }
253
254 #[test]
255 pub fn test_return_type() {
256 let scope = DType::Struct(
257 StructFields::new(
258 ["array"].into(),
259 vec![DType::List(
260 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
261 Nullability::Nullable,
262 )],
263 ),
264 Nullability::NonNullable,
265 );
266
267 let expr = list_contains(get_item("array", root()), lit(2));
268
269 assert_eq!(
271 expr.return_dtype(&scope).unwrap(),
272 DType::Bool(Nullability::Nullable)
273 );
274 }
275
276 #[test]
277 pub fn list_falsification() {
278 let expr = list_contains(
279 lit(Scalar::list(
280 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
281 vec![1.into(), 2.into(), 3.into()],
282 Nullability::NonNullable,
283 )),
284 col("a"),
285 );
286
287 let (expr, st) = checked_pruning_expr(
288 &expr,
289 &FieldPathSet::from_iter([
290 FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
291 FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
292 ]),
293 )
294 .unwrap();
295
296 assert_eq!(
297 &expr,
298 &and(
299 and(
300 or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
301 or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
302 ),
303 or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
304 )
305 );
306
307 assert_eq!(
308 st.map(),
309 &HashMap::from_iter([(
310 FieldPath::from_name("a"),
311 HashSet::from([Stat::Min, Stat::Max])
312 )])
313 );
314 }
315
316 #[test]
317 pub fn test_display() {
318 let expr = list_contains(get_item("tags", root()), lit("urgent"));
319 assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
320
321 let expr2 = list_contains(root(), lit(42));
322 assert_eq!(expr2.to_string(), "contains($, 42i32)");
323 }
324}