vortex_array/arrays/list/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::arrays::{ListArray, ListVTable};
8use crate::compute::{CastKernel, CastKernelAdapter, cast};
9use crate::vtable::ValidityHelper;
10use crate::{ArrayRef, register_kernel};
11
12impl CastKernel for ListVTable {
13    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14        let Some(target_element_type) = dtype.as_list_element() else {
15            return Ok(None);
16        };
17
18        let validity = array
19            .validity()
20            .clone()
21            .cast_nullability(dtype.nullability())?;
22
23        ListArray::try_new(
24            cast(array.elements(), target_element_type)?,
25            array.offsets().clone(),
26            validity,
27        )
28        .map(|a| Some(a.to_array()))
29    }
30}
31
32register_kernel!(CastKernelAdapter(ListVTable).lift());
33
34#[cfg(test)]
35mod tests {
36    use std::sync::Arc;
37
38    use vortex_dtype::{DType, Nullability};
39
40    use crate::arrays::{BoolArray, ListArray, PrimitiveArray};
41    use crate::compute::cast;
42    use crate::validity::Validity;
43
44    #[test]
45    fn test_cast_list_success() {
46        let list = ListArray::try_new(
47            PrimitiveArray::from_iter([1i32, 2, 3, 4]).to_array(),
48            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
49            Validity::NonNullable,
50        )
51        .unwrap();
52
53        let target_dtype = DType::List(
54            Arc::new(DType::Primitive(
55                vortex_dtype::PType::U64,
56                Nullability::Nullable,
57            )),
58            Nullability::Nullable,
59        );
60
61        let result = cast(list.to_array().as_ref(), &target_dtype).unwrap();
62        assert_eq!(result.dtype(), &target_dtype);
63        assert_eq!(result.len(), list.len());
64    }
65
66    #[test]
67    fn test_cast_to_wrong_type() {
68        let list = ListArray::try_new(
69            PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
70            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
71            Validity::NonNullable,
72        )
73        .unwrap();
74
75        let target_dtype = DType::Primitive(vortex_dtype::PType::U64, Nullability::NonNullable);
76        // can't cast list to u64
77
78        let result = cast(list.to_array().as_ref(), &target_dtype);
79        assert!(result.is_err());
80    }
81
82    #[test]
83    fn test_cant_cast_nulls_to_non_null() {
84        // Test that if list has nulls, the conversion will fail
85
86        // Nulls in the list itself
87        let list = ListArray::try_new(
88            PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
89            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
90            Validity::Array(BoolArray::from_iter(vec![false, true, true]).to_array()),
91        )
92        .unwrap();
93
94        let target_dtype = DType::List(
95            Arc::new(DType::Primitive(
96                vortex_dtype::PType::U64,
97                Nullability::Nullable,
98            )),
99            Nullability::NonNullable,
100        );
101
102        let result = cast(list.to_array().as_ref(), &target_dtype);
103        assert!(result.is_err());
104
105        // Nulls in list element array
106        let list = ListArray::try_new(
107            PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
108            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
109            Validity::NonNullable,
110        )
111        .unwrap();
112
113        let target_dtype = DType::List(
114            Arc::new(DType::Primitive(
115                vortex_dtype::PType::U64,
116                Nullability::NonNullable,
117            )),
118            Nullability::NonNullable,
119        );
120
121        let result = cast(list.to_array().as_ref(), &target_dtype);
122        assert!(result.is_err());
123    }
124}