1use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11
12use crate::{
13 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, LiteralVTable, Scope, StatsCatalog,
14 VTable, and, gt, lit, lt, or, vtable,
15};
16
17vtable!(ListContains);
18
19#[allow(clippy::derived_hash_with_manual_eq)]
20#[derive(Debug, Clone, Hash, Eq)]
21pub struct ListContainsExpr {
22 list: ExprRef,
23 value: ExprRef,
24}
25
26impl PartialEq for ListContainsExpr {
27 fn eq(&self, other: &Self) -> bool {
28 self.list.eq(&other.list) && self.value.eq(&other.value)
29 }
30}
31
32pub struct ListContainsExprEncoding;
33
34impl VTable for ListContainsVTable {
35 type Expr = ListContainsExpr;
36 type Encoding = ListContainsExprEncoding;
37 type Metadata = EmptyMetadata;
38
39 fn id(_encoding: &Self::Encoding) -> ExprId {
40 ExprId::new_ref("list_contains")
41 }
42
43 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
44 ExprEncodingRef::new_ref(ListContainsExprEncoding.as_ref())
45 }
46
47 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
48 Some(EmptyMetadata)
49 }
50
51 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
52 vec![&expr.list, &expr.value]
53 }
54
55 fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
56 Ok(ListContainsExpr::new(
57 children[0].clone(),
58 children[1].clone(),
59 ))
60 }
61
62 fn build(
63 _encoding: &Self::Encoding,
64 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
65 children: Vec<ExprRef>,
66 ) -> VortexResult<Self::Expr> {
67 if children.len() != 2 {
68 vortex_bail!(
69 "ListContains expression must have exactly 2 children, got {}",
70 children.len()
71 );
72 }
73 Ok(ListContainsExpr::new(
74 children[0].clone(),
75 children[1].clone(),
76 ))
77 }
78
79 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80 compute_list_contains(
81 expr.list.evaluate(scope)?.as_ref(),
82 expr.value.evaluate(scope)?.as_ref(),
83 )
84 }
85
86 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
87 Ok(DType::Bool(
88 expr.list.return_dtype(scope)?.nullability()
89 | expr.value.return_dtype(scope)?.nullability(),
90 ))
91 }
92}
93
94impl ListContainsExpr {
95 pub fn new(list: ExprRef, value: ExprRef) -> Self {
96 Self { list, value }
97 }
98
99 pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
100 Self::new(list, value).into_expr()
101 }
102
103 pub fn value(&self) -> &ExprRef {
104 &self.value
105 }
106}
107
108pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
117 ListContainsExpr::new(list, value).into_expr()
118}
119
120impl Display for ListContainsExpr {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 write!(f, "contains({}, {})", &self.list, &self.value)
123 }
124}
125
126impl AnalysisExpr for ListContainsExpr {
127 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
131 let min = self.list.min(catalog)?;
132 let max = self.list.max(catalog)?;
133 if min == max {
135 let list_ = min
136 .as_opt::<LiteralVTable>()
137 .and_then(|l| l.value().as_list_opt())
138 .and_then(|l| l.elements())?;
139 if list_.is_empty() {
140 return Some(lit(true));
142 }
143 let value_max = self.value.max(catalog)?;
144 let value_min = self.value.min(catalog)?;
145
146 return list_
147 .iter()
148 .map(move |v| {
149 or(
150 lt(value_max.clone(), lit(v.clone())),
151 gt(value_min.clone(), lit(v.clone())),
152 )
153 })
154 .reduce(and);
155 }
156
157 None
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
164 use vortex_array::stats::Stat;
165 use vortex_array::validity::Validity;
166 use vortex_array::{Array, ArrayRef, IntoArray};
167 use vortex_dtype::PType::I32;
168 use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability, StructFields};
169 use vortex_scalar::Scalar;
170 use vortex_utils::aliases::hash_map::HashMap;
171
172 use crate::list_contains::list_contains;
173 use crate::pruning::checked_pruning_expr;
174 use crate::{Arc, HashSet, Scope, and, col, get_item, gt, lit, lt, or, root};
175
176 fn test_array() -> ArrayRef {
177 ListArray::try_new(
178 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
179 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
180 Validity::AllValid,
181 )
182 .unwrap()
183 .into_array()
184 }
185
186 #[test]
187 pub fn test_one() {
188 let arr = test_array();
189
190 let expr = list_contains(root(), lit(1));
191 let item = expr.evaluate(&Scope::new(arr)).unwrap();
192
193 assert_eq!(
194 item.scalar_at(0).unwrap(),
195 Scalar::bool(true, Nullability::Nullable)
196 );
197 assert_eq!(
198 item.scalar_at(1).unwrap(),
199 Scalar::bool(false, Nullability::Nullable)
200 );
201 }
202
203 #[test]
204 pub fn test_all() {
205 let arr = test_array();
206
207 let expr = list_contains(root(), lit(2));
208 let item = expr.evaluate(&Scope::new(arr)).unwrap();
209
210 assert_eq!(
211 item.scalar_at(0).unwrap(),
212 Scalar::bool(true, Nullability::Nullable)
213 );
214 assert_eq!(
215 item.scalar_at(1).unwrap(),
216 Scalar::bool(true, Nullability::Nullable)
217 );
218 }
219
220 #[test]
221 pub fn test_none() {
222 let arr = test_array();
223
224 let expr = list_contains(root(), lit(4));
225 let item = expr.evaluate(&Scope::new(arr)).unwrap();
226
227 assert_eq!(
228 item.scalar_at(0).unwrap(),
229 Scalar::bool(false, Nullability::Nullable)
230 );
231 assert_eq!(
232 item.scalar_at(1).unwrap(),
233 Scalar::bool(false, Nullability::Nullable)
234 );
235 }
236
237 #[test]
238 pub fn test_empty() {
239 let arr = ListArray::try_new(
240 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
241 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
242 Validity::AllValid,
243 )
244 .unwrap()
245 .into_array();
246
247 let expr = list_contains(root(), lit(2));
248 let item = expr.evaluate(&Scope::new(arr)).unwrap();
249
250 assert_eq!(
251 item.scalar_at(0).unwrap(),
252 Scalar::bool(true, Nullability::Nullable)
253 );
254 assert_eq!(
255 item.scalar_at(1).unwrap(),
256 Scalar::bool(false, Nullability::Nullable)
257 );
258 }
259
260 #[test]
261 pub fn test_nullable() {
262 let arr = ListArray::try_new(
263 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
264 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
265 Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
266 )
267 .unwrap()
268 .into_array();
269
270 let expr = list_contains(root(), lit(2));
271 let item = expr.evaluate(&Scope::new(arr)).unwrap();
272
273 assert_eq!(
274 item.scalar_at(0).unwrap(),
275 Scalar::bool(true, Nullability::Nullable)
276 );
277 assert!(!item.is_valid(1).unwrap());
278 }
279
280 #[test]
281 pub fn test_return_type() {
282 let scope = DType::Struct(
283 StructFields::new(
284 ["array"].into(),
285 vec![DType::List(
286 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
287 Nullability::Nullable,
288 )],
289 ),
290 Nullability::NonNullable,
291 );
292
293 let expr = list_contains(get_item("array", root()), lit(2));
294
295 assert_eq!(
297 expr.return_dtype(&scope).unwrap(),
298 DType::Bool(Nullability::Nullable)
299 );
300 }
301
302 #[test]
303 pub fn list_falsification() {
304 let expr = list_contains(
305 lit(Scalar::list(
306 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
307 vec![1.into(), 2.into(), 3.into()],
308 Nullability::NonNullable,
309 )),
310 col("a"),
311 );
312
313 let (expr, st) = checked_pruning_expr(
314 &expr,
315 &FieldPathSet::from_iter([
316 FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
317 FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
318 ]),
319 )
320 .unwrap();
321
322 assert_eq!(
323 &expr,
324 &and(
325 and(
326 or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
327 or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
328 ),
329 or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
330 )
331 );
332
333 assert_eq!(
334 st.map(),
335 &HashMap::from_iter([(
336 FieldPath::from_name("a"),
337 HashSet::from([Stat::Min, Stat::Max])
338 )])
339 );
340 }
341}