1use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11
12use crate::display::{DisplayAs, DisplayFormat};
13use crate::{
14 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, LiteralVTable, Scope, StatsCatalog,
15 VTable, and, gt, lit, lt, or, vtable,
16};
17
18vtable!(ListContains);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct ListContainsExpr {
23 list: ExprRef,
24 value: ExprRef,
25}
26
27impl PartialEq for ListContainsExpr {
28 fn eq(&self, other: &Self) -> bool {
29 self.list.eq(&other.list) && self.value.eq(&other.value)
30 }
31}
32
33pub struct ListContainsExprEncoding;
34
35impl VTable for ListContainsVTable {
36 type Expr = ListContainsExpr;
37 type Encoding = ListContainsExprEncoding;
38 type Metadata = EmptyMetadata;
39
40 fn id(_encoding: &Self::Encoding) -> ExprId {
41 ExprId::new_ref("list_contains")
42 }
43
44 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
45 ExprEncodingRef::new_ref(ListContainsExprEncoding.as_ref())
46 }
47
48 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
49 Some(EmptyMetadata)
50 }
51
52 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
53 vec![&expr.list, &expr.value]
54 }
55
56 fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
57 Ok(ListContainsExpr::new(
58 children[0].clone(),
59 children[1].clone(),
60 ))
61 }
62
63 fn build(
64 _encoding: &Self::Encoding,
65 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
66 children: Vec<ExprRef>,
67 ) -> VortexResult<Self::Expr> {
68 if children.len() != 2 {
69 vortex_bail!(
70 "ListContains expression must have exactly 2 children, got {}",
71 children.len()
72 );
73 }
74 Ok(ListContainsExpr::new(
75 children[0].clone(),
76 children[1].clone(),
77 ))
78 }
79
80 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
81 compute_list_contains(
82 expr.list.evaluate(scope)?.as_ref(),
83 expr.value.evaluate(scope)?.as_ref(),
84 )
85 }
86
87 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
88 Ok(DType::Bool(
89 expr.list.return_dtype(scope)?.nullability()
90 | expr.value.return_dtype(scope)?.nullability(),
91 ))
92 }
93}
94
95impl ListContainsExpr {
96 pub fn new(list: ExprRef, value: ExprRef) -> Self {
97 Self { list, value }
98 }
99
100 pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
101 Self::new(list, value).into_expr()
102 }
103
104 pub fn value(&self) -> &ExprRef {
105 &self.value
106 }
107}
108
109pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
118 ListContainsExpr::new(list, value).into_expr()
119}
120
121impl DisplayAs for ListContainsExpr {
122 fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
123 match df {
124 DisplayFormat::Compact => {
125 write!(f, "contains({}, {})", &self.list, &self.value)
126 }
127 DisplayFormat::Tree => {
128 write!(f, "ListContains")
129 }
130 }
131 }
132
133 fn child_names(&self) -> Option<Vec<String>> {
134 Some(vec!["list".to_string(), "value".to_string()])
135 }
136}
137
138impl AnalysisExpr for ListContainsExpr {
139 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
143 let min = self.list.min(catalog)?;
144 let max = self.list.max(catalog)?;
145 if min == max {
147 let list_ = min
148 .as_opt::<LiteralVTable>()
149 .and_then(|l| l.value().as_list_opt())
150 .and_then(|l| l.elements())?;
151 if list_.is_empty() {
152 return Some(lit(true));
154 }
155 let value_max = self.value.max(catalog)?;
156 let value_min = self.value.min(catalog)?;
157
158 return list_
159 .iter()
160 .map(move |v| {
161 or(
162 lt(value_max.clone(), lit(v.clone())),
163 gt(value_min.clone(), lit(v.clone())),
164 )
165 })
166 .reduce(and);
167 }
168
169 None
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
176 use vortex_array::stats::Stat;
177 use vortex_array::validity::Validity;
178 use vortex_array::{Array, ArrayRef, IntoArray};
179 use vortex_dtype::PType::I32;
180 use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability, StructFields};
181 use vortex_scalar::Scalar;
182 use vortex_utils::aliases::hash_map::HashMap;
183
184 use crate::list_contains::list_contains;
185 use crate::pruning::checked_pruning_expr;
186 use crate::{Arc, HashSet, Scope, and, col, get_item, gt, lit, lt, or, root};
187
188 fn test_array() -> ArrayRef {
189 ListArray::try_new(
190 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
191 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
192 Validity::AllValid,
193 )
194 .unwrap()
195 .into_array()
196 }
197
198 #[test]
199 pub fn test_one() {
200 let arr = test_array();
201
202 let expr = list_contains(root(), lit(1));
203 let item = expr.evaluate(&Scope::new(arr)).unwrap();
204
205 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
206 assert_eq!(
207 item.scalar_at(1),
208 Scalar::bool(false, Nullability::Nullable)
209 );
210 }
211
212 #[test]
213 pub fn test_all() {
214 let arr = test_array();
215
216 let expr = list_contains(root(), lit(2));
217 let item = expr.evaluate(&Scope::new(arr)).unwrap();
218
219 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
220 assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
221 }
222
223 #[test]
224 pub fn test_none() {
225 let arr = test_array();
226
227 let expr = list_contains(root(), lit(4));
228 let item = expr.evaluate(&Scope::new(arr)).unwrap();
229
230 assert_eq!(
231 item.scalar_at(0),
232 Scalar::bool(false, Nullability::Nullable)
233 );
234 assert_eq!(
235 item.scalar_at(1),
236 Scalar::bool(false, Nullability::Nullable)
237 );
238 }
239
240 #[test]
241 pub fn test_empty() {
242 let arr = ListArray::try_new(
243 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
244 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
245 Validity::AllValid,
246 )
247 .unwrap()
248 .into_array();
249
250 let expr = list_contains(root(), lit(2));
251 let item = expr.evaluate(&Scope::new(arr)).unwrap();
252
253 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
254 assert_eq!(
255 item.scalar_at(1),
256 Scalar::bool(false, Nullability::Nullable)
257 );
258 }
259
260 #[test]
261 pub fn test_nullable() {
262 let arr = ListArray::try_new(
263 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
264 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
265 Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
266 )
267 .unwrap()
268 .into_array();
269
270 let expr = list_contains(root(), lit(2));
271 let item = expr.evaluate(&Scope::new(arr)).unwrap();
272
273 assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
274 assert!(!item.is_valid(1));
275 }
276
277 #[test]
278 pub fn test_return_type() {
279 let scope = DType::Struct(
280 StructFields::new(
281 ["array"].into(),
282 vec![DType::List(
283 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
284 Nullability::Nullable,
285 )],
286 ),
287 Nullability::NonNullable,
288 );
289
290 let expr = list_contains(get_item("array", root()), lit(2));
291
292 assert_eq!(
294 expr.return_dtype(&scope).unwrap(),
295 DType::Bool(Nullability::Nullable)
296 );
297 }
298
299 #[test]
300 pub fn list_falsification() {
301 let expr = list_contains(
302 lit(Scalar::list(
303 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
304 vec![1.into(), 2.into(), 3.into()],
305 Nullability::NonNullable,
306 )),
307 col("a"),
308 );
309
310 let (expr, st) = checked_pruning_expr(
311 &expr,
312 &FieldPathSet::from_iter([
313 FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
314 FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
315 ]),
316 )
317 .unwrap();
318
319 assert_eq!(
320 &expr,
321 &and(
322 and(
323 or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
324 or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
325 ),
326 or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
327 )
328 );
329
330 assert_eq!(
331 st.map(),
332 &HashMap::from_iter([(
333 FieldPath::from_name("a"),
334 HashSet::from([Stat::Min, Stat::Max])
335 )])
336 );
337 }
338
339 #[test]
340 pub fn test_display() {
341 let expr = list_contains(get_item("tags", root()), lit("urgent"));
342 assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
343
344 let expr2 = list_contains(root(), lit(42));
345 assert_eq!(expr2.to_string(), "contains($, 42i32)");
346 }
347}