vortex_array/arrays/list/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::arrays::ListArray;
9use crate::arrays::ListVTable;
10use crate::compute::CastKernel;
11use crate::compute::CastKernelAdapter;
12use crate::compute::cast;
13use crate::register_kernel;
14use crate::vtable::ValidityHelper;
15
16impl CastKernel for ListVTable {
17 fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18 let Some(target_element_type) = dtype.as_list_element_opt() else {
19 return Ok(None);
20 };
21
22 let validity = array
23 .validity()
24 .clone()
25 .cast_nullability(dtype.nullability(), array.len())?;
26
27 ListArray::try_new(
28 cast(array.elements(), target_element_type)?,
29 array.offsets().clone(),
30 validity,
31 )
32 .map(|a| Some(a.to_array()))
33 }
34}
35
36register_kernel!(CastKernelAdapter(ListVTable).lift());
37
38#[cfg(test)]
39mod tests {
40 use std::sync::Arc;
41
42 use rstest::rstest;
43 use vortex_buffer::buffer;
44 use vortex_dtype::DType;
45 use vortex_dtype::Nullability;
46 use vortex_dtype::PType;
47
48 use crate::IntoArray;
49 use crate::arrays::BoolArray;
50 use crate::arrays::ListArray;
51 use crate::arrays::PrimitiveArray;
52 use crate::arrays::VarBinArray;
53 use crate::compute::cast;
54 use crate::compute::conformance::cast::test_cast_conformance;
55 use crate::validity::Validity;
56
57 #[test]
58 fn test_cast_list_success() {
59 let list = ListArray::try_new(
60 buffer![1i32, 2, 3, 4].into_array().to_array(),
61 buffer![0, 2, 3].into_array().to_array(),
62 Validity::NonNullable,
63 )
64 .unwrap();
65
66 let target_dtype = DType::List(
67 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
68 Nullability::Nullable,
69 );
70
71 let result = cast(list.to_array().as_ref(), &target_dtype).unwrap();
72 assert_eq!(result.dtype(), &target_dtype);
73 assert_eq!(result.len(), list.len());
74 }
75
76 #[test]
77 fn test_cast_to_wrong_type() {
78 let list = ListArray::try_new(
79 buffer![0i32, 2, 3, 4].into_array().to_array(),
80 buffer![0, 2, 3].into_array().to_array(),
81 Validity::NonNullable,
82 )
83 .unwrap();
84
85 let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
86 let result = cast(list.to_array().as_ref(), &target_dtype);
89 assert!(result.is_err());
90 }
91
92 #[test]
93 fn test_cant_cast_nulls_to_non_null() {
94 let list = ListArray::try_new(
98 buffer![0i32, 2, 3, 4].into_array().to_array(),
99 buffer![0, 2, 3].into_array().to_array(),
100 Validity::Array(BoolArray::from_iter(vec![false, true]).to_array()),
101 )
102 .unwrap();
103
104 let target_dtype = DType::List(
105 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
106 Nullability::NonNullable,
107 );
108
109 let result = cast(list.to_array().as_ref(), &target_dtype);
110 assert!(result.is_err());
111
112 let list = ListArray::try_new(
114 PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
115 buffer![0, 2, 3].into_array().to_array(),
116 Validity::NonNullable,
117 )
118 .unwrap();
119
120 let target_dtype = DType::List(
121 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
122 Nullability::NonNullable,
123 );
124
125 let result = cast(list.to_array().as_ref(), &target_dtype);
126 assert!(result.is_err());
127 }
128
129 #[rstest]
130 #[case(create_simple_list())]
131 #[case(create_nullable_list())]
132 #[case(create_string_list())]
133 #[case(create_nested_list())]
134 #[case(create_empty_lists())]
135 fn test_cast_list_conformance(#[case] array: ListArray) {
136 test_cast_conformance(array.as_ref());
137 }
138
139 fn create_simple_list() -> ListArray {
140 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
141 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
142
143 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
144 }
145
146 fn create_nullable_list() -> ListArray {
147 let data = PrimitiveArray::from_option_iter([
148 Some(10i64),
149 None,
150 Some(20),
151 Some(30),
152 None,
153 Some(40),
154 ])
155 .into_array();
156 let offsets = buffer![0i64, 3, 6].into_array();
157 let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
158
159 ListArray::try_new(data, offsets, validity).unwrap()
160 }
161
162 fn create_string_list() -> ListArray {
163 let data = VarBinArray::from_iter(
164 vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
165 DType::Utf8(Nullability::NonNullable),
166 )
167 .into_array();
168 let offsets = buffer![0i64, 2, 4].into_array();
169
170 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
171 }
172
173 fn create_nested_list() -> ListArray {
174 let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
176 let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
177 let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
178 .unwrap()
179 .into_array();
180
181 let outer_offsets = buffer![0i64, 2, 3].into_array();
183
184 ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
185 }
186
187 fn create_empty_lists() -> ListArray {
188 let data = buffer![42u8].into_array();
189 let offsets = buffer![0i64, 0, 0, 1].into_array();
190
191 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
192 }
193}