Skip to main content

vortex_array/scalar_fn/fns/
list_length.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_error::VortexResult;
6use vortex_error::vortex_bail;
7use vortex_session::VortexSession;
8use vortex_session::registry::CachedId;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::ConstantArray;
15use crate::arrays::FixedSizeList;
16use crate::arrays::List;
17use crate::arrays::ListView;
18use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
19use crate::arrays::list::ListArrayExt;
20use crate::arrays::listview::ListViewArrayExt;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::dtype::PType;
25use crate::expr::Expression;
26use crate::matcher::Matcher;
27use crate::scalar::Scalar;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::EmptyOptions;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::fns::operators::Operator;
35
36/// Number of elements in each list of a `List` or `FixedSizeList` typed array.
37///
38/// This is computed purely from the list's offsets (`ListArray`), sizes (`ListViewArray`), or
39/// dtype (`FixedSizeListArray`) without reading the element *values*. Validity is carried over
40/// from the original array.
41#[derive(Clone)]
42pub struct ListLength;
43
44impl ScalarFnVTable for ListLength {
45    type Options = EmptyOptions;
46
47    fn id(&self) -> ScalarFnId {
48        static ID: CachedId = CachedId::new("vortex.list.length");
49        *ID
50    }
51
52    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        Ok(Some(vec![]))
54    }
55
56    fn deserialize(
57        &self,
58        _metadata: &[u8],
59        _session: &VortexSession,
60    ) -> VortexResult<Self::Options> {
61        Ok(EmptyOptions)
62    }
63
64    fn arity(&self, _options: &Self::Options) -> Arity {
65        Arity::Exact(1)
66    }
67
68    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
69        match child_idx {
70            0 => ChildName::from("input"),
71            _ => unreachable!("Invalid child index {child_idx} for list_length()"),
72        }
73    }
74
75    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
76        match &arg_dtypes[0] {
77            DType::List(_, nullable) | DType::FixedSizeList(_, _, nullable) => {
78                Ok(DType::Primitive(PType::U64, *nullable))
79            }
80            other => vortex_bail!("list_length() requires List or FixedSizeList, got {other}"),
81        }
82    }
83
84    fn execute(
85        &self,
86        _options: &Self::Options,
87        args: &dyn ExecutionArgs,
88        ctx: &mut ExecutionCtx,
89    ) -> VortexResult<ArrayRef> {
90        let input = args.get(0)?;
91        let nullability = input.dtype().nullability();
92
93        if let Some(scalar) = input.as_constant() {
94            let len_scalar = scalar_list_length(&scalar, nullability)?;
95            return Ok(ConstantArray::new(len_scalar, args.row_count()).into_array());
96        }
97
98        list_length(&input, nullability, ctx)
99    }
100
101    fn validity(
102        &self,
103        _: &Self::Options,
104        expression: &Expression,
105    ) -> VortexResult<Option<Expression>> {
106        Ok(Some(expression.child(0).validity()?))
107    }
108
109    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
110        false
111    }
112
113    fn is_fallible(&self, _options: &Self::Options) -> bool {
114        false
115    }
116}
117
118fn scalar_list_length(scalar: &Scalar, nullability: Nullability) -> VortexResult<Scalar> {
119    if scalar.is_null() {
120        let dtype = DType::Primitive(PType::U64, Nullability::Nullable);
121        return Ok(Scalar::null(dtype));
122    }
123    let len: u64 = scalar.as_list().len().as_();
124    Ok(Scalar::primitive(len, nullability))
125}
126
127pub(crate) fn list_length(
128    array: &ArrayRef,
129    nullability: Nullability,
130    ctx: &mut ExecutionCtx,
131) -> VortexResult<ArrayRef> {
132    let any_list = array.clone().execute_until::<AnyList>(ctx)?;
133
134    let (lengths, validity) = if let Some(fsl) = any_list.as_opt::<FixedSizeList>() {
135        // The length of fixed-size list is constant, so just need to carry over validity
136        let size = fsl.list_size() as u64;
137        let lengths =
138            ConstantArray::new(Scalar::primitive(size, Nullability::NonNullable), fsl.len())
139                .into_array();
140        (lengths, fsl.validity()?)
141    } else if let Some(lv) = any_list.as_opt::<ListView>() {
142        // Length array is exactly the sizes child
143        (lv.sizes().clone(), lv.listview_validity())
144    } else if let Some(l) = any_list.as_opt::<List>() {
145        let lengths = list_length_from_offsets(l)?;
146        (lengths, l.list_validity())
147    } else {
148        let dtype = any_list.dtype();
149        vortex_bail!("list_length() requires List, ListView, or FixedSizeList but got {dtype}")
150    };
151
152    // Cast to `U64`
153    let len = lengths.len();
154    let lengths = lengths.cast(DType::Primitive(PType::U64, nullability))?;
155
156    // Carry over validity mask for nullable arrays
157    if matches!(nullability, Nullability::Nullable) {
158        lengths.mask(validity.to_array(len))
159    } else {
160        Ok(lengths)
161    }
162}
163
164/// Calculate the lengths of `ListArray` elements via the `offsets` child:
165/// `length[i] = offsets[i + 1] - offsets[i]`.
166fn list_length_from_offsets(list: ArrayView<'_, List>) -> VortexResult<ArrayRef> {
167    let offsets = list.offsets();
168    let n = offsets.len().saturating_sub(1);
169
170    offsets
171        .slice(1..offsets.len())?
172        .binary(offsets.slice(0..n)?, Operator::Sub)
173}
174
175/// Matches an `Array<List>`, `Array<ListView>`, or `Array<FixedSizeList>`
176struct AnyList;
177
178impl Matcher for AnyList {
179    type Match<'a> = ();
180
181    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
182        (array.as_opt::<List>().is_some()
183            || array.as_opt::<ListView>().is_some()
184            || array.as_opt::<FixedSizeList>().is_some())
185        .then_some(())
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use std::sync::Arc;
192
193    use rstest::rstest;
194    use vortex_buffer::buffer;
195    use vortex_error::VortexResult;
196
197    use crate::ArrayRef;
198    use crate::IntoArray;
199    use crate::VortexSessionExecute;
200    use crate::array_session;
201    use crate::arrays::BoolArray;
202    use crate::arrays::ConstantArray;
203    use crate::arrays::FixedSizeListArray;
204    use crate::arrays::ListArray;
205    use crate::arrays::ListViewArray;
206    use crate::arrays::PrimitiveArray;
207    use crate::assert_arrays_eq;
208    use crate::dtype::DType;
209    use crate::dtype::Nullability;
210    use crate::dtype::PType;
211    use crate::expr::cast;
212    use crate::expr::list_length;
213    use crate::expr::root;
214    use crate::scalar::Scalar;
215    use crate::validity::Validity;
216
217    fn create_list_elements() -> ArrayRef {
218        PrimitiveArray::from_option_iter::<i32, _>([
219            Some(1),
220            Some(2),
221            Some(3),
222            Some(4),
223            Some(5),
224            Some(6),
225            None,
226        ])
227        .into_array()
228    }
229
230    #[rstest]
231    #[case(buffer![0u32, 2, 5, 5, 7].into_array())]
232    #[case(buffer![0u64, 2, 5, 5, 7].into_array())]
233    fn test_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> {
234        let elements = create_list_elements();
235        let list = ListArray::try_new(elements, offsets, Validity::NonNullable)?.into_array();
236        let result = list.apply(&list_length(root()))?;
237        let mut ctx = array_session().create_execution_ctx();
238        assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2]), &mut ctx);
239        Ok(())
240    }
241
242    #[rstest]
243    #[case(buffer![0u32, 2, 5, 5, 7].into_array())]
244    #[case(buffer![0u64, 2, 5, 5, 7].into_array())]
245    fn test_nullable_list_length(#[case] offsets: ArrayRef) -> VortexResult<()> {
246        let elements = create_list_elements();
247        let list = ListArray::try_new(
248            elements,
249            offsets,
250            Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()),
251        )?
252        .into_array();
253        let result = list.apply(&list_length(root()))?;
254
255        let mut ctx = array_session().create_execution_ctx();
256        let result = result.execute::<PrimitiveArray>(&mut ctx)?;
257
258        let expected = PrimitiveArray::from_option_iter::<u64, _>([Some(2), None, Some(0), None]);
259
260        assert_arrays_eq!(result, expected, &mut ctx);
261
262        Ok(())
263    }
264
265    #[test]
266    fn test_null_scalar_list_length() -> VortexResult<()> {
267        let null_scalar = Scalar::null(DType::List(
268            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
269            Nullability::Nullable,
270        ));
271        let array = ConstantArray::new(null_scalar, 2).into_array();
272        let result = array.apply(&list_length(root()))?;
273
274        let mut ctx = array_session().create_execution_ctx();
275        assert!(!result.is_valid(0, &mut ctx)?);
276        assert!(!result.is_valid(1, &mut ctx)?);
277        Ok(())
278    }
279
280    #[test]
281    fn test_listview_length() -> VortexResult<()> {
282        let elements = create_list_elements();
283        let lv = ListViewArray::new(
284            elements,
285            buffer![5u32, 0, 4, 1].into_array(),
286            buffer![2u32, 3, 0, 2].into_array(),
287            Validity::NonNullable,
288        )
289        .into_array();
290        let result = lv.apply(&list_length(root()))?;
291        let mut ctx = array_session().create_execution_ctx();
292        assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 3, 0, 2]), &mut ctx);
293        Ok(())
294    }
295
296    #[test]
297    fn test_listview_length_nullable() -> VortexResult<()> {
298        let elements = create_list_elements();
299        let lv = ListViewArray::new(
300            elements,
301            buffer![5u32, 0, 4, 1].into_array(),
302            buffer![2u32, 3, 0, 2].into_array(),
303            Validity::Array(BoolArray::from_iter([true, false, true, false]).into_array()),
304        )
305        .into_array();
306        let result = lv.apply(&list_length(root()))?;
307
308        let mut ctx = array_session().create_execution_ctx();
309        let result = result.execute::<PrimitiveArray>(&mut ctx)?;
310
311        let expected = PrimitiveArray::from_option_iter::<u64, _>([Some(2), None, Some(0), None]);
312        assert_arrays_eq!(result, expected, &mut ctx);
313        Ok(())
314    }
315
316    #[test]
317    fn test_list_length_take() -> VortexResult<()> {
318        let elements = create_list_elements();
319        let list = ListArray::try_new(
320            elements,
321            buffer![0u32, 2, 5, 5, 7].into_array(),
322            Validity::NonNullable,
323        )?
324        .into_array();
325        let taken = list.take(buffer![3u64, 0, 2].into_array())?;
326
327        let result = taken.apply(&list_length(root()))?;
328        let mut ctx = array_session().create_execution_ctx();
329        assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 0]), &mut ctx);
330        Ok(())
331    }
332
333    fn create_fixed_size_list(validity: Validity) -> ArrayRef {
334        // 4 lists of size 2 over 8 primitive elements.
335        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6, 7, 8]).into_array();
336        FixedSizeListArray::new(elements, 2, validity, 4).into_array()
337    }
338
339    #[test]
340    fn test_fixed_size_list_length() -> VortexResult<()> {
341        let fsl = create_fixed_size_list(Validity::NonNullable);
342        let result = fsl.apply(&list_length(root()))?;
343
344        let mut ctx = array_session().create_execution_ctx();
345        assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 2, 2]), &mut ctx);
346        Ok(())
347    }
348
349    #[test]
350    fn test_fixed_size_list_length_nullable() -> VortexResult<()> {
351        let fsl = create_fixed_size_list(Validity::Array(
352            BoolArray::from_iter([true, false, true, false]).into_array(),
353        ));
354        let result = fsl.apply(&list_length(root()))?;
355
356        let mut ctx = array_session().create_execution_ctx();
357        let result = result.execute::<PrimitiveArray>(&mut ctx)?;
358
359        let expected = PrimitiveArray::from_option_iter::<u64, _>([Some(2), None, Some(2), None]);
360        assert_arrays_eq!(result, expected, &mut ctx);
361        Ok(())
362    }
363
364    #[test]
365    fn test_fallible_child_expression_fails() -> VortexResult<()> {
366        let fsl = create_fixed_size_list(Validity::Array(
367            BoolArray::from_iter([true, false, true, false]).into_array(),
368        ));
369        let failing_cast_dtype = DType::FixedSizeList(
370            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
371            2,
372            Nullability::NonNullable,
373        );
374
375        let lengths = fsl.apply(&list_length(cast(root(), failing_cast_dtype)))?;
376
377        let mut ctx = array_session().create_execution_ctx();
378        let result = lengths.execute::<ArrayRef>(&mut ctx);
379
380        assert!(result.is_err());
381
382        let err_message = result.unwrap_err().to_string();
383
384        assert!(
385            err_message.contains("Cannot cast array with invalid values to non-nullable type.")
386        );
387
388        Ok(())
389    }
390
391    #[test]
392    fn test_display() {
393        let expr = list_length(root());
394        assert_eq!(expr.to_string(), "vortex.list.length($)");
395    }
396}