vortex_array/arrays/list/compute/
cast.rs1use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::list::ListArrayExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastKernel;
16use crate::scalar_fn::fns::cast::CastReduce;
17
18impl CastReduce for List {
19 fn cast(array: ArrayView<'_, List>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20 let Some(target_element_type) = dtype.as_list_element_opt() else {
21 return Ok(None);
22 };
23
24 let Some(validity) = array
25 .validity()?
26 .trivially_cast_nullability(dtype.nullability(), array.len())?
27 else {
28 return Ok(None);
29 };
30
31 let new_elements = array.elements().cast((**target_element_type).clone())?;
32
33 Ok(Some(
34 unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
35 .into_array(),
36 ))
37 }
38}
39
40impl CastKernel for List {
41 fn cast(
42 array: ArrayView<'_, List>,
43 dtype: &DType,
44 ctx: &mut ExecutionCtx,
45 ) -> VortexResult<Option<ArrayRef>> {
46 let Some(target_element_type) = dtype.as_list_element_opt() else {
47 return Ok(None);
48 };
49
50 let validity = array
51 .validity()?
52 .cast_nullability(dtype.nullability(), array.len(), ctx)?;
53
54 let new_elements = array.elements().cast((**target_element_type).clone())?;
55
56 Ok(Some(
57 unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
58 .into_array(),
59 ))
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use std::sync::Arc;
66 use std::sync::LazyLock;
67
68 use rstest::rstest;
69 use vortex_buffer::buffer;
70 use vortex_session::VortexSession;
71
72 use crate::Canonical;
73 use crate::IntoArray;
74 use crate::RecursiveCanonical;
75 use crate::VortexSessionExecute;
76 use crate::arrays::BoolArray;
77 use crate::arrays::ListArray;
78 use crate::arrays::PrimitiveArray;
79 use crate::arrays::VarBinArray;
80 use crate::builtins::ArrayBuiltins;
81 use crate::compute::conformance::cast::test_cast_conformance;
82 use crate::dtype::DType;
83 use crate::dtype::Nullability;
84 use crate::dtype::PType;
85 use crate::validity::Validity;
86
87 static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
88
89 #[test]
90 fn test_cast_list_success() {
91 let list = ListArray::try_new(
92 buffer![1i32, 2, 3, 4].into_array(),
93 buffer![0, 2, 3].into_array(),
94 Validity::NonNullable,
95 )
96 .unwrap();
97
98 let target_dtype = DType::List(
99 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
100 Nullability::Nullable,
101 );
102
103 let result = list
104 .clone()
105 .into_array()
106 .cast(target_dtype.clone())
107 .unwrap();
108 assert_eq!(result.dtype(), &target_dtype);
109 assert_eq!(result.len(), list.len());
110 }
111
112 #[test]
113 fn test_cast_to_wrong_type() {
114 let list = ListArray::try_new(
115 buffer![0i32, 2, 3, 4].into_array(),
116 buffer![0, 2, 3].into_array(),
117 Validity::NonNullable,
118 )
119 .unwrap();
120
121 let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
122 let result = list
125 .into_array()
126 .cast(target_dtype)
127 .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
128 .map(|c| c.into_array());
129 assert!(result.is_err());
130 }
131
132 #[test]
133 fn test_cant_cast_nulls_to_non_null() {
134 let list = ListArray::try_new(
138 buffer![0i32, 2, 3, 4].into_array(),
139 buffer![0, 2, 3].into_array(),
140 Validity::Array(BoolArray::from_iter(vec![false, true]).into_array()),
141 )
142 .unwrap();
143
144 let target_dtype = DType::List(
145 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
146 Nullability::NonNullable,
147 );
148
149 let result = list
150 .into_array()
151 .cast(target_dtype)
152 .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
153 .map(|c| c.into_array());
154 assert!(result.is_err());
155
156 let list = ListArray::try_new(
159 PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).into_array(),
160 buffer![0, 2, 3].into_array(),
161 Validity::NonNullable,
162 )
163 .unwrap();
164
165 let target_dtype = DType::List(
166 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
167 Nullability::NonNullable,
168 );
169
170 let result = list.into_array().cast(target_dtype).and_then(|a| {
171 a.execute::<RecursiveCanonical>(&mut SESSION.create_execution_ctx())
172 .map(|c| c.0.into_array())
173 });
174 assert!(result.is_err());
175 }
176
177 #[rstest]
178 #[case(create_simple_list())]
179 #[case(create_nullable_list())]
180 #[case(create_string_list())]
181 #[case(create_nested_list())]
182 #[case(create_empty_lists())]
183 fn test_cast_list_conformance(#[case] array: ListArray) {
184 test_cast_conformance(&array.into_array());
185 }
186
187 fn create_simple_list() -> ListArray {
188 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
189 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
190
191 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
192 }
193
194 fn create_nullable_list() -> ListArray {
195 let data = PrimitiveArray::from_option_iter([
196 Some(10i64),
197 None,
198 Some(20),
199 Some(30),
200 None,
201 Some(40),
202 ])
203 .into_array();
204 let offsets = buffer![0i64, 3, 6].into_array();
205 let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
206
207 ListArray::try_new(data, offsets, validity).unwrap()
208 }
209
210 fn create_string_list() -> ListArray {
211 let data = VarBinArray::from_iter(
212 vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
213 DType::Utf8(Nullability::NonNullable),
214 )
215 .into_array();
216 let offsets = buffer![0i64, 2, 4].into_array();
217
218 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
219 }
220
221 fn create_nested_list() -> ListArray {
222 let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
224 let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
225 let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
226 .unwrap()
227 .into_array();
228
229 let outer_offsets = buffer![0i64, 2, 3].into_array();
231
232 ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
233 }
234
235 fn create_empty_lists() -> ListArray {
236 let data = buffer![42u8].into_array();
237 let offsets = buffer![0i64, 0, 0, 1].into_array();
238
239 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
240 }
241}