1use std::any::Any;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{
12 AnalysisExpr, ExprRef, Literal, Scope, ScopeDType, StatsCatalog, VortexExpr, and, gt, lit, lt,
13 or,
14};
15
16#[derive(Debug, Clone, Eq, Hash)]
17#[allow(clippy::derived_hash_with_manual_eq)]
18pub struct ListContains {
19 list: ExprRef,
20 value: ExprRef,
21}
22
23impl ListContains {
24 pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
25 Arc::new(Self { list, value })
26 }
27
28 pub fn value(&self) -> &ExprRef {
29 &self.value
30 }
31}
32
33pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
34 ListContains::new_expr(list, value)
35}
36
37impl Display for ListContains {
38 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39 write!(f, "contains({}, {})", &self.list, &self.value)
40 }
41}
42
43#[cfg(feature = "proto")]
44pub(crate) mod proto {
45 use vortex_error::{VortexResult, vortex_bail};
46 use vortex_proto::expr::kind;
47 use vortex_proto::expr::kind::Kind;
48
49 use crate::list_contains::ListContains;
50 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
51
52 pub(crate) struct ListContainsSerde;
53
54 impl Id for ListContainsSerde {
55 fn id(&self) -> &'static str {
56 "list_contains"
57 }
58 }
59
60 impl ExprDeserialize for ListContainsSerde {
61 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
62 let Kind::ListContains(kind::ListContains {}) = kind else {
63 vortex_bail!("wrong kind {:?}, want list_contains", kind)
64 };
65
66 Ok(ListContains::new_expr(
67 children[0].clone(),
68 children[1].clone(),
69 ))
70 }
71 }
72
73 impl ExprSerializable for ListContains {
74 fn id(&self) -> &'static str {
75 ListContainsSerde.id()
76 }
77
78 fn serialize_kind(&self) -> VortexResult<Kind> {
79 Ok(Kind::ListContains(kind::ListContains {}))
80 }
81 }
82}
83
84impl AnalysisExpr for ListContains {
85 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
89 let min = self.list.min(catalog)?;
90 let max = self.list.max(catalog)?;
91 if min == max {
93 let list_ = min
94 .as_any()
95 .downcast_ref::<Literal>()
96 .and_then(|l| l.value().as_list_opt())
97 .and_then(|l| l.elements())?;
98 if list_.is_empty() {
99 return Some(lit(true));
101 }
102 let value_max = self.value.max(catalog)?;
103 let value_min = self.value.min(catalog)?;
104
105 return list_
106 .iter()
107 .map(move |v| {
108 or(
109 lt(value_max.clone(), lit(v.clone())),
110 gt(value_min.clone(), lit(v.clone())),
111 )
112 })
113 .reduce(and);
114 }
115
116 None
117 }
118}
119
120impl VortexExpr for ListContains {
121 fn as_any(&self) -> &dyn Any {
122 self
123 }
124
125 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
126 compute_list_contains(
127 self.list.evaluate(scope)?.as_ref(),
128 self.value.evaluate(scope)?.as_ref(),
129 )
130 }
131
132 fn children(&self) -> Vec<&ExprRef> {
133 vec![&self.list, &self.value]
134 }
135
136 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
137 assert_eq!(children.len(), 2);
138 Self::new_expr(children[0].clone(), children[1].clone())
139 }
140
141 fn return_dtype(&self, scope_dtype: &ScopeDType) -> VortexResult<DType> {
142 Ok(DType::Bool(
143 self.list.return_dtype(scope_dtype)?.nullability()
144 | self.value.return_dtype(scope_dtype)?.nullability(),
145 ))
146 }
147}
148
149impl PartialEq for ListContains {
150 fn eq(&self, other: &ListContains) -> bool {
151 self.value.eq(&other.value) && self.list.eq(&other.list)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
158 use vortex_array::stats::Stat;
159 use vortex_array::validity::Validity;
160 use vortex_array::{Array, ArrayRef, IntoArray};
161 use vortex_dtype::PType::I32;
162 use vortex_dtype::{Field, FieldPath, FieldPathSet, Nullability, StructFields};
163 use vortex_scalar::Scalar;
164 use vortex_utils::aliases::hash_map::HashMap;
165
166 use crate::list_contains::list_contains;
167 use crate::pruning::checked_pruning_expr;
168 use crate::{
169 AccessPath, Arc, DType, HashSet, Scope, ScopeDType, ScopeFieldPathSet, and, get_item,
170 get_item_scope, gt, lit, lt, or, root,
171 };
172
173 fn test_array() -> ArrayRef {
174 ListArray::try_new(
175 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
176 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
177 Validity::AllValid,
178 )
179 .unwrap()
180 .into_array()
181 }
182
183 #[test]
184 pub fn test_one() {
185 let arr = test_array();
186
187 let expr = list_contains(root(), lit(1));
188 let item = expr.evaluate(&Scope::new(arr)).unwrap();
189
190 assert_eq!(
191 item.scalar_at(0).unwrap(),
192 Scalar::bool(true, Nullability::Nullable)
193 );
194 assert_eq!(
195 item.scalar_at(1).unwrap(),
196 Scalar::bool(false, Nullability::Nullable)
197 );
198 }
199
200 #[test]
201 pub fn test_all() {
202 let arr = test_array();
203
204 let expr = list_contains(root(), lit(2));
205 let item = expr.evaluate(&Scope::new(arr)).unwrap();
206
207 assert_eq!(
208 item.scalar_at(0).unwrap(),
209 Scalar::bool(true, Nullability::Nullable)
210 );
211 assert_eq!(
212 item.scalar_at(1).unwrap(),
213 Scalar::bool(true, Nullability::Nullable)
214 );
215 }
216
217 #[test]
218 pub fn test_none() {
219 let arr = test_array();
220
221 let expr = list_contains(root(), lit(4));
222 let item = expr.evaluate(&Scope::new(arr)).unwrap();
223
224 assert_eq!(
225 item.scalar_at(0).unwrap(),
226 Scalar::bool(false, Nullability::Nullable)
227 );
228 assert_eq!(
229 item.scalar_at(1).unwrap(),
230 Scalar::bool(false, Nullability::Nullable)
231 );
232 }
233
234 #[test]
235 pub fn test_empty() {
236 let arr = ListArray::try_new(
237 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
238 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
239 Validity::AllValid,
240 )
241 .unwrap()
242 .into_array();
243
244 let expr = list_contains(root(), lit(2));
245 let item = expr.evaluate(&Scope::new(arr)).unwrap();
246
247 assert_eq!(
248 item.scalar_at(0).unwrap(),
249 Scalar::bool(true, Nullability::Nullable)
250 );
251 assert_eq!(
252 item.scalar_at(1).unwrap(),
253 Scalar::bool(false, Nullability::Nullable)
254 );
255 }
256
257 #[test]
258 pub fn test_nullable() {
259 let arr = ListArray::try_new(
260 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
261 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
262 Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
263 )
264 .unwrap()
265 .into_array();
266
267 let expr = list_contains(root(), lit(2));
268 let item = expr.evaluate(&Scope::new(arr)).unwrap();
269
270 assert_eq!(
271 item.scalar_at(0).unwrap(),
272 Scalar::bool(true, Nullability::Nullable)
273 );
274 assert!(!item.is_valid(1).unwrap());
275 }
276
277 #[test]
278 pub fn test_return_type() {
279 let scope = ScopeDType::new(DType::Struct(
280 StructFields::new(
281 ["array"].into(),
282 vec![DType::List(
283 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
284 Nullability::Nullable,
285 )],
286 ),
287 Nullability::NonNullable,
288 ));
289
290 let expr = list_contains(get_item("array", root()), lit(2));
291
292 assert_eq!(
294 expr.return_dtype(&scope).unwrap(),
295 DType::Bool(Nullability::Nullable)
296 );
297 }
298
299 #[test]
300 pub fn list_falsification() {
301 let expr = list_contains(
302 lit(Scalar::list(
303 Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
304 vec![1.into(), 2.into(), 3.into()],
305 Nullability::NonNullable,
306 )),
307 get_item_scope("a"),
308 );
309
310 let (expr, st) = checked_pruning_expr(
311 &expr,
312 &ScopeFieldPathSet::new(FieldPathSet::from_iter([
313 FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
314 FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
315 ])),
316 )
317 .unwrap();
318
319 assert_eq!(
320 &expr,
321 &and(
322 and(
323 or(
324 lt(get_item_scope("a_max"), lit(1i32)),
325 gt(get_item_scope("a_min"), lit(1i32)),
326 ),
327 or(
328 lt(get_item_scope("a_max"), lit(2i32)),
329 gt(get_item_scope("a_min"), lit(2i32)),
330 )
331 ),
332 or(
333 lt(get_item_scope("a_max"), lit(3i32)),
334 gt(get_item_scope("a_min"), lit(3i32)),
335 )
336 )
337 );
338
339 assert_eq!(
340 st.map(),
341 &HashMap::from_iter([(
342 AccessPath::root_field("a".into()),
343 HashSet::from([Stat::Min, Stat::Max])
344 )])
345 );
346 }
347}