vortex_array/arrays/list/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::arrays::ListArray;
9use crate::arrays::ListVTable;
10use crate::compute::CastKernel;
11use crate::compute::CastKernelAdapter;
12use crate::compute::cast;
13use crate::register_kernel;
14use crate::vtable::ValidityHelper;
15
16impl CastKernel for ListVTable {
17    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        let Some(target_element_type) = dtype.as_list_element_opt() else {
19            return Ok(None);
20        };
21
22        let validity = array
23            .validity()
24            .clone()
25            .cast_nullability(dtype.nullability(), array.len())?;
26
27        ListArray::try_new(
28            cast(array.elements(), target_element_type)?,
29            array.offsets().clone(),
30            validity,
31        )
32        .map(|a| Some(a.to_array()))
33    }
34}
35
36register_kernel!(CastKernelAdapter(ListVTable).lift());
37
38#[cfg(test)]
39mod tests {
40    use std::sync::Arc;
41
42    use rstest::rstest;
43    use vortex_buffer::buffer;
44    use vortex_dtype::DType;
45    use vortex_dtype::Nullability;
46    use vortex_dtype::PType;
47
48    use crate::IntoArray;
49    use crate::arrays::BoolArray;
50    use crate::arrays::ListArray;
51    use crate::arrays::PrimitiveArray;
52    use crate::arrays::VarBinArray;
53    use crate::compute::cast;
54    use crate::compute::conformance::cast::test_cast_conformance;
55    use crate::validity::Validity;
56
57    #[test]
58    fn test_cast_list_success() {
59        let list = ListArray::try_new(
60            buffer![1i32, 2, 3, 4].into_array().to_array(),
61            buffer![0, 2, 3].into_array().to_array(),
62            Validity::NonNullable,
63        )
64        .unwrap();
65
66        let target_dtype = DType::List(
67            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
68            Nullability::Nullable,
69        );
70
71        let result = cast(list.to_array().as_ref(), &target_dtype).unwrap();
72        assert_eq!(result.dtype(), &target_dtype);
73        assert_eq!(result.len(), list.len());
74    }
75
76    #[test]
77    fn test_cast_to_wrong_type() {
78        let list = ListArray::try_new(
79            buffer![0i32, 2, 3, 4].into_array().to_array(),
80            buffer![0, 2, 3].into_array().to_array(),
81            Validity::NonNullable,
82        )
83        .unwrap();
84
85        let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
86        // can't cast list to u64
87
88        let result = cast(list.to_array().as_ref(), &target_dtype);
89        assert!(result.is_err());
90    }
91
92    #[test]
93    fn test_cant_cast_nulls_to_non_null() {
94        // Test that if list has nulls, the conversion will fail
95
96        // Nulls in the list itself
97        let list = ListArray::try_new(
98            buffer![0i32, 2, 3, 4].into_array().to_array(),
99            buffer![0, 2, 3].into_array().to_array(),
100            Validity::Array(BoolArray::from_iter(vec![false, true]).to_array()),
101        )
102        .unwrap();
103
104        let target_dtype = DType::List(
105            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
106            Nullability::NonNullable,
107        );
108
109        let result = cast(list.to_array().as_ref(), &target_dtype);
110        assert!(result.is_err());
111
112        // Nulls in list element array
113        let list = ListArray::try_new(
114            PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
115            buffer![0, 2, 3].into_array().to_array(),
116            Validity::NonNullable,
117        )
118        .unwrap();
119
120        let target_dtype = DType::List(
121            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
122            Nullability::NonNullable,
123        );
124
125        let result = cast(list.to_array().as_ref(), &target_dtype);
126        assert!(result.is_err());
127    }
128
129    #[rstest]
130    #[case(create_simple_list())]
131    #[case(create_nullable_list())]
132    #[case(create_string_list())]
133    #[case(create_nested_list())]
134    #[case(create_empty_lists())]
135    fn test_cast_list_conformance(#[case] array: ListArray) {
136        test_cast_conformance(array.as_ref());
137    }
138
139    fn create_simple_list() -> ListArray {
140        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
141        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
142
143        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
144    }
145
146    fn create_nullable_list() -> ListArray {
147        let data = PrimitiveArray::from_option_iter([
148            Some(10i64),
149            None,
150            Some(20),
151            Some(30),
152            None,
153            Some(40),
154        ])
155        .into_array();
156        let offsets = buffer![0i64, 3, 6].into_array();
157        let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
158
159        ListArray::try_new(data, offsets, validity).unwrap()
160    }
161
162    fn create_string_list() -> ListArray {
163        let data = VarBinArray::from_iter(
164            vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
165            DType::Utf8(Nullability::NonNullable),
166        )
167        .into_array();
168        let offsets = buffer![0i64, 2, 4].into_array();
169
170        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
171    }
172
173    fn create_nested_list() -> ListArray {
174        // Create inner lists: [[1, 2], [3], [4, 5, 6]]
175        let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
176        let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
177        let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
178            .unwrap()
179            .into_array();
180
181        // Create outer list: [[[1, 2], [3]], [[4, 5, 6]]]
182        let outer_offsets = buffer![0i64, 2, 3].into_array();
183
184        ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
185    }
186
187    fn create_empty_lists() -> ListArray {
188        let data = buffer![42u8].into_array();
189        let offsets = buffer![0i64, 0, 0, 1].into_array();
190
191        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
192    }
193}