1use std::any::Any;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
12
13#[derive(Debug, Clone, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct ListContains {
16 list: ExprRef,
17 value: ExprRef,
18}
19
20impl ListContains {
21 pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
22 Arc::new(Self { list, value })
23 }
24
25 pub fn value(&self) -> &ExprRef {
26 &self.value
27 }
28}
29
30pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
31 ListContains::new_expr(list, value)
32}
33
34impl Display for ListContains {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 write!(f, "contains({}, {})", &self.list, &self.value)
37 }
38}
39
40#[cfg(feature = "proto")]
41pub(crate) mod proto {
42 use vortex_error::{VortexResult, vortex_bail};
43 use vortex_proto::expr::kind;
44 use vortex_proto::expr::kind::Kind;
45
46 use crate::list_contains::ListContains;
47 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
48
49 pub(crate) struct ListContainsSerde;
50
51 impl Id for ListContainsSerde {
52 fn id(&self) -> &'static str {
53 "list_contains"
54 }
55 }
56
57 impl ExprDeserialize for ListContainsSerde {
58 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
59 let Kind::ListContains(kind::ListContains {}) = kind else {
60 vortex_bail!("wrong kind {:?}, want list_contains", kind)
61 };
62
63 Ok(ListContains::new_expr(
64 children[0].clone(),
65 children[1].clone(),
66 ))
67 }
68 }
69
70 impl ExprSerializable for ListContains {
71 fn id(&self) -> &'static str {
72 ListContainsSerde.id()
73 }
74
75 fn serialize_kind(&self) -> VortexResult<Kind> {
76 Ok(Kind::ListContains(kind::ListContains {}))
77 }
78 }
79}
80
81impl AnalysisExpr for ListContains {}
82
83impl VortexExpr for ListContains {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
89 compute_list_contains(
90 self.list.evaluate(scope)?.as_ref(),
91 self.value.evaluate(scope)?.as_ref(),
92 )
93 }
94
95 fn children(&self) -> Vec<&ExprRef> {
96 vec![&self.list, &self.value]
97 }
98
99 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
100 assert_eq!(children.len(), 2);
101 Self::new_expr(children[0].clone(), children[1].clone())
102 }
103
104 fn return_dtype(&self, scope_dtype: &ScopeDType) -> VortexResult<DType> {
105 Ok(DType::Bool(
106 self.list.return_dtype(scope_dtype)?.nullability()
107 | self.value.return_dtype(scope_dtype)?.nullability(),
108 ))
109 }
110}
111
112impl PartialEq for ListContains {
113 fn eq(&self, other: &ListContains) -> bool {
114 self.value.eq(&other.value) && self.list.eq(&other.list)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
121 use vortex_array::validity::Validity;
122 use vortex_array::{Array, ArrayRef, IntoArray};
123 use vortex_dtype::{FieldNames, Nullability, StructFields};
124 use vortex_scalar::Scalar;
125
126 use crate::list_contains::list_contains;
127 use crate::{Arc, DType, Scope, ScopeDType, get_item, lit, root};
128
129 fn test_array() -> ArrayRef {
130 ListArray::try_new(
131 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
132 PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
133 Validity::AllValid,
134 )
135 .unwrap()
136 .into_array()
137 }
138
139 #[test]
140 pub fn test_one() {
141 let arr = test_array();
142
143 let expr = list_contains(root(), lit(1));
144 let item = expr.evaluate(&Scope::new(arr)).unwrap();
145
146 assert_eq!(
147 item.scalar_at(0).unwrap(),
148 Scalar::bool(true, Nullability::Nullable)
149 );
150 assert_eq!(
151 item.scalar_at(1).unwrap(),
152 Scalar::bool(false, Nullability::Nullable)
153 );
154 }
155
156 #[test]
157 pub fn test_all() {
158 let arr = test_array();
159
160 let expr = list_contains(root(), lit(2));
161 let item = expr.evaluate(&Scope::new(arr)).unwrap();
162
163 assert_eq!(
164 item.scalar_at(0).unwrap(),
165 Scalar::bool(true, Nullability::Nullable)
166 );
167 assert_eq!(
168 item.scalar_at(1).unwrap(),
169 Scalar::bool(true, Nullability::Nullable)
170 );
171 }
172
173 #[test]
174 pub fn test_none() {
175 let arr = test_array();
176
177 let expr = list_contains(root(), lit(4));
178 let item = expr.evaluate(&Scope::new(arr)).unwrap();
179
180 assert_eq!(
181 item.scalar_at(0).unwrap(),
182 Scalar::bool(false, Nullability::Nullable)
183 );
184 assert_eq!(
185 item.scalar_at(1).unwrap(),
186 Scalar::bool(false, Nullability::Nullable)
187 );
188 }
189
190 #[test]
191 pub fn test_empty() {
192 let arr = ListArray::try_new(
193 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
194 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
195 Validity::AllValid,
196 )
197 .unwrap()
198 .into_array();
199
200 let expr = list_contains(root(), lit(2));
201 let item = expr.evaluate(&Scope::new(arr)).unwrap();
202
203 assert_eq!(
204 item.scalar_at(0).unwrap(),
205 Scalar::bool(true, Nullability::Nullable)
206 );
207 assert_eq!(
208 item.scalar_at(1).unwrap(),
209 Scalar::bool(false, Nullability::Nullable)
210 );
211 }
212
213 #[test]
214 pub fn test_nullable() {
215 let arr = ListArray::try_new(
216 PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
217 PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
218 Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
219 )
220 .unwrap()
221 .into_array();
222
223 let expr = list_contains(root(), lit(2));
224 let item = expr.evaluate(&Scope::new(arr)).unwrap();
225
226 assert_eq!(
227 item.scalar_at(0).unwrap(),
228 Scalar::bool(true, Nullability::Nullable)
229 );
230 assert!(!item.is_valid(1).unwrap());
231 }
232
233 #[test]
234 pub fn test_return_type() {
235 let scope = ScopeDType::new(DType::Struct(
236 Arc::new(StructFields::new(
237 FieldNames::from(["array".into()]),
238 vec![DType::List(
239 Arc::new(DType::Primitive(
240 vortex_dtype::PType::I32,
241 Nullability::NonNullable,
242 )),
243 Nullability::Nullable,
244 )],
245 )),
246 Nullability::NonNullable,
247 ));
248
249 let expr = list_contains(get_item("array", root()), lit(2));
250
251 assert_eq!(
253 expr.return_dtype(&scope).unwrap(),
254 DType::Bool(Nullability::Nullable)
255 );
256 }
257}