1use num_traits::AsPrimitive;
5use vortex_error::VortexResult;
6use vortex_error::vortex_bail;
7use vortex_session::VortexSession;
8use vortex_session::registry::CachedId;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::ConstantArray;
15use crate::arrays::FixedSizeList;
16use crate::arrays::List;
17use crate::arrays::ListView;
18use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
19use crate::arrays::list::ListArrayExt;
20use crate::arrays::listview::ListViewArrayExt;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::dtype::PType;
25use crate::expr::Expression;
26use crate::matcher::Matcher;
27use crate::scalar::Scalar;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::EmptyOptions;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::fns::operators::Operator;
35
36#[derive(Clone)]
42pub struct ListLength;
43
44impl ScalarFnVTable for ListLength {
45 type Options = EmptyOptions;
46
47 fn id(&self) -> ScalarFnId {
48 static ID: CachedId = CachedId::new("vortex.list.length");
49 *ID
50 }
51
52 fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53 Ok(Some(vec![]))
54 }
55
56 fn deserialize(
57 &self,
58 _metadata: &[u8],
59 _session: &VortexSession,
60 ) -> VortexResult<Self::Options> {
61 Ok(EmptyOptions)
62 }
63
64 fn arity(&self, _options: &Self::Options) -> Arity {
65 Arity::Exact(1)
66 }
67
68 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
69 match child_idx {
70 0 => ChildName::from("input"),
71 _ => unreachable!("Invalid child index {child_idx} for list_length()"),
72 }
73 }
74
75 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
76 match &arg_dtypes[0] {
77 DType::List(_, nullable) | DType::FixedSizeList(_, _, nullable) => {
78 Ok(DType::Primitive(PType::U64, *nullable))
79 }
80 other => vortex_bail!("list_length() requires List or FixedSizeList, got {other}"),
81 }
82 }
83
84 fn execute(
85 &self,
86 _options: &Self::Options,
87 args: &dyn ExecutionArgs,
88 ctx: &mut ExecutionCtx,
89 ) -> VortexResult<ArrayRef> {
90 let input = args.get(0)?;
91 let nullability = input.dtype().nullability();
92
93 if let Some(scalar) = input.as_constant() {
94 let len_scalar = scalar_list_length(&scalar, nullability)?;
95 return Ok(ConstantArray::new(len_scalar, args.row_count()).into_array());
96 }
97
98 list_length(&input, nullability, ctx)
99 }
100
101 fn validity(
102 &self,
103 _: &Self::Options,
104 expression: &Expression,
105 ) -> VortexResult<Option<Expression>> {
106 Ok(Some(expression.child(0).validity()?))
107 }
108
109 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
110 false
111 }
112
113 fn is_fallible(&self, _options: &Self::Options) -> bool {
114 false
115 }
116}
117
118fn scalar_list_length(scalar: &Scalar, nullability: Nullability) -> VortexResult<Scalar> {
119 if scalar.is_null() {
120 let dtype = DType::Primitive(PType::U64, Nullability::Nullable);
121 return Ok(Scalar::null(dtype));
122 }
123 let len: u64 = scalar.as_list().len().as_();
124 Ok(Scalar::primitive(len, nullability))
125}
126
127pub(crate) fn list_length(
128 array: &ArrayRef,
129 nullability: Nullability,
130 ctx: &mut ExecutionCtx,
131) -> VortexResult<ArrayRef> {
132 let any_list = array.clone().execute_until::<AnyList>(ctx)?;
133
134 let (lengths, validity) = if let Some(fsl) = any_list.as_opt::<FixedSizeList>() {
135 let size = fsl.list_size() as u64;
137 let lengths =
138 ConstantArray::new(Scalar::primitive(size, Nullability::NonNullable), fsl.len())
139 .into_array();
140 (lengths, fsl.validity()?)
141 } else if let Some(lv) = any_list.as_opt::<ListView>() {
142 (lv.sizes().clone(), lv.listview_validity())
144 } else if let Some(l) = any_list.as_opt::<List>() {
145 let lengths = list_length_from_offsets(l)?;
146 (lengths, l.list_validity())
147 } else {
148 let dtype = any_list.dtype();
149 vortex_bail!("list_length() requires List, ListView, or FixedSizeList but got {dtype}")
150 };
151
152 let len = lengths.len();
154 let lengths = lengths.cast(DType::Primitive(PType::U64, nullability))?;
155
156 if matches!(nullability, Nullability::Nullable) {
158 lengths.mask(validity.to_array(len))
159 } else {
160 Ok(lengths)
161 }
162}
163
164fn list_length_from_offsets(list: ArrayView<'_, List>) -> VortexResult<ArrayRef> {
167 let offsets = list.offsets();
168 let n = offsets.len().saturating_sub(1);
169
170 offsets
171 .slice(1..offsets.len())?
172 .binary(offsets.slice(0..n)?, Operator::Sub)
173}
174
175struct AnyList;
177
178impl Matcher for AnyList {
179 type Match<'a> = ();
180
181 fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
182 (array.as_opt::<List>().is_some()
183 || array.as_opt::<ListView>().is_some()
184 || array.as_opt::<FixedSizeList>().is_some())
185 .then_some(())
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use std::sync::Arc;
192
193 use rstest::rstest;
194 use vortex_buffer::buffer;
195 use vortex_error::VortexResult;
196
197 use crate::ArrayRef;
198 use crate::IntoArray;
199 use crate::VortexSessionExecute;
200 use crate::array_session;
201 use crate::arrays::BoolArray;
202 use crate::arrays::ConstantArray;
203 use crate::arrays::FixedSizeListArray;
204 use crate::arrays::ListArray;
205 use crate::arrays::ListViewArray;
206 use crate::arrays::PrimitiveArray;
207 use crate::assert_arrays_eq;
208 use crate::dtype::DType;
209 use crate::dtype::Nullability;
210 use crate::dtype::PType;
211 use crate::expr::cast;
212 use crate::expr::list_length;
213 use crate::expr::root;
214 use crate::scalar::Scalar;
215 use crate::validity::Validity;
216
217 fn create_list_elements() -> ArrayRef {
218 PrimitiveArray::from_option_iter::<i32, _>([
219 Some(1),
220 Some(2),
221 Some(3),
222 Some(4),
223 Some(5),
224 Some(6),
225 None,
226 ])
227 .into_array()
228 }
229
230 #[rstest]
231 #[case(buffer![0u32, 2, 5, 5, 7].into_array())]
232 #[case(buffer![0u64, 2, 5, 5, 7].into_array())]
233 fn test_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> {
234 let elements = create_list_elements();
235 let list = ListArray::try_new(elements, offsets, Validity::NonNullable)?.into_array();
236 let result = list.apply(&list_length(root()))?;
237 let mut ctx = array_session().create_execution_ctx();
238 assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2]), &mut ctx);
239 Ok(())
240 }
241
242 #[rstest]
243 #[case(buffer![0u32, 2, 5, 5, 7].into_array())]
244 #[case(buffer![0u64, 2, 5, 5, 7].into_array())]
245 fn test_nullable_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> {
246 let elements = create_list_elements();
247 let list = ListArray::try_new(
248 elements,
249 offsets,
250 Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()),
251 )?
252 .into_array();
253 let result = list.apply(&list_length(root()))?;
254
255 let mut ctx = array_session().create_execution_ctx();
256 let result = result.execute::<PrimitiveArray>(&mut ctx)?;
257
258 let expected = PrimitiveArray::from_option_iter::<u64, _>([Some(2), None, Some(0), None]);
259
260 assert_arrays_eq!(result, expected, &mut ctx);
261
262 Ok(())
263 }
264
265 #[test]
266 fn test_null_scalar_list_length() -> VortexResult<()> {
267 let null_scalar = Scalar::null(DType::List(
268 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
269 Nullability::Nullable,
270 ));
271 let array = ConstantArray::new(null_scalar, 2).into_array();
272 let result = array.apply(&list_length(root()))?;
273
274 let mut ctx = array_session().create_execution_ctx();
275 assert!(!result.is_valid(0, &mut ctx)?);
276 assert!(!result.is_valid(1, &mut ctx)?);
277 Ok(())
278 }
279
280 #[test]
281 fn test_listview_length() -> VortexResult<()> {
282 let elements = create_list_elements();
283 let lv = ListViewArray::new(
284 elements,
285 buffer![5u32, 0, 4, 1].into_array(),
286 buffer![2u32, 3, 0, 2].into_array(),
287 Validity::NonNullable,
288 )
289 .into_array();
290 let result = lv.apply(&list_length(root()))?;
291 let mut ctx = array_session().create_execution_ctx();
292 assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2]), &mut ctx);
293 Ok(())
294 }
295
296 #[test]
297 fn test_listview_length_nullable() -> VortexResult<()> {
298 let elements = create_list_elements();
299 let lv = ListViewArray::new(
300 elements,
301 buffer![5u32, 0, 4, 1].into_array(),
302 buffer![2u32, 3, 0, 2].into_array(),
303 Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()),
304 )
305 .into_array();
306 let result = lv.apply(&list_length(root()))?;
307
308 let mut ctx = array_session().create_execution_ctx();
309 let result = result.execute::<PrimitiveArray>(&mut ctx)?;
310
311 let expected = PrimitiveArray::from_option_iter::<u64, _>([Some(2), None, Some(0), None]);
312 assert_arrays_eq!(result, expected, &mut ctx);
313 Ok(())
314 }
315
316 #[test]
317 fn test_list_length_take() -> VortexResult<()> {
318 let elements = create_list_elements();
319 let list = ListArray::try_new(
320 elements,
321 buffer![0u32, 2, 5, 5, 7].into_array(),
322 Validity::NonNullable,
323 )?
324 .into_array();
325 let taken = list.take(buffer![3u64, 0, 2].into_array())?;
326
327 let result = taken.apply(&list_length(root()))?;
328 let mut ctx = array_session().create_execution_ctx();
329 assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 0]), &mut ctx);
330 Ok(())
331 }
332
333 fn create_fixed_size_list(validity: Validity) -> ArrayRef {
334 let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6, 7, 8]).into_array();
336 FixedSizeListArray::new(elements, 2, validity, 4).into_array()
337 }
338
339 #[test]
340 fn test_fixed_size_list_length() -> VortexResult<()> {
341 let fsl = create_fixed_size_list(Validity::NonNullable);
342 let result = fsl.apply(&list_length(root()))?;
343
344 let mut ctx = array_session().create_execution_ctx();
345 assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 2, 2]), &mut ctx);
346 Ok(())
347 }
348
349 #[test]
350 fn test_fixed_size_list_length_nullable() -> VortexResult<()> {
351 let fsl = create_fixed_size_list(Validity::Array(
352 BoolArray::from_iter([true, false, true, false]).into_array(),
353 ));
354 let result = fsl.apply(&list_length(root()))?;
355
356 let mut ctx = array_session().create_execution_ctx();
357 let result = result.execute::<PrimitiveArray>(&mut ctx)?;
358
359 let expected = PrimitiveArray::from_option_iter::<u64, _>([Some(2), None, Some(2), None]);
360 assert_arrays_eq!(result, expected, &mut ctx);
361 Ok(())
362 }
363
364 #[test]
365 fn test_fallible_child_expression_fails() -> VortexResult<()> {
366 let fsl = create_fixed_size_list(Validity::Array(
367 BoolArray::from_iter([true, false, true, false]).into_array(),
368 ));
369 let failing_cast_dtype = DType::FixedSizeList(
370 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
371 2,
372 Nullability::NonNullable,
373 );
374
375 let lengths = fsl.apply(&list_length(cast(root(), failing_cast_dtype)))?;
376
377 let mut ctx = array_session().create_execution_ctx();
378 let result = lengths.execute::<ArrayRef>(&mut ctx);
379
380 assert!(result.is_err());
381
382 let err_message = result.unwrap_err().to_string();
383
384 assert!(
385 err_message.contains("Cannot cast array with invalid values to non-nullable type.")
386 );
387
388 Ok(())
389 }
390
391 #[test]
392 fn test_display() {
393 let expr = list_length(root());
394 assert_eq!(expr.to_string(), "vortex.list.length($)");
395 }
396}