vortex_array/arrays/list/compute/
cast.rs1use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::list::ListArrayExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastKernel;
16use crate::scalar_fn::fns::cast::CastReduce;
17
18impl CastReduce for List {
19 fn cast(array: ArrayView<'_, List>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20 let Some(target_element_type) = dtype.as_list_element_opt() else {
21 return Ok(None);
22 };
23
24 let Some(validity) = array
25 .validity()?
26 .trivial_cast_nullability(dtype.nullability(), array.len())?
27 else {
28 return Ok(None);
29 };
30
31 let new_elements = array.elements().cast((**target_element_type).clone())?;
32
33 Ok(Some(
34 unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
35 .into_array(),
36 ))
37 }
38}
39
40impl CastKernel for List {
41 fn cast(
42 array: ArrayView<'_, List>,
43 dtype: &DType,
44 ctx: &mut ExecutionCtx,
45 ) -> VortexResult<Option<ArrayRef>> {
46 let Some(target_element_type) = dtype.as_list_element_opt() else {
47 return Ok(None);
48 };
49
50 let validity = array
51 .validity()?
52 .cast_nullability(dtype.nullability(), array.len(), ctx)?;
53
54 let new_elements = array.elements().cast((**target_element_type).clone())?;
55
56 Ok(Some(
57 unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
58 .into_array(),
59 ))
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use std::sync::Arc;
66 use std::sync::LazyLock;
67
68 use rstest::rstest;
69 use vortex_array::session::ArraySession;
70 use vortex_buffer::buffer;
71 use vortex_session::VortexSession;
72
73 use crate::Canonical;
74 use crate::IntoArray;
75 use crate::LEGACY_SESSION;
76 use crate::RecursiveCanonical;
77 use crate::VortexSessionExecute;
78 use crate::arrays::BoolArray;
79 use crate::arrays::ListArray;
80 use crate::arrays::PrimitiveArray;
81 use crate::arrays::VarBinArray;
82 use crate::builtins::ArrayBuiltins;
83 use crate::compute::conformance::cast::test_cast_conformance;
84 use crate::dtype::DType;
85 use crate::dtype::Nullability;
86 use crate::dtype::PType;
87 use crate::validity::Validity;
88
89 static SESSION: LazyLock<VortexSession> =
90 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
91
92 #[test]
93 fn test_cast_list_success() {
94 let list = ListArray::try_new(
95 buffer![1i32, 2, 3, 4].into_array(),
96 buffer![0, 2, 3].into_array(),
97 Validity::NonNullable,
98 )
99 .unwrap();
100
101 let target_dtype = DType::List(
102 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
103 Nullability::Nullable,
104 );
105
106 let result = list
107 .clone()
108 .into_array()
109 .cast(target_dtype.clone())
110 .unwrap();
111 assert_eq!(result.dtype(), &target_dtype);
112 assert_eq!(result.len(), list.len());
113 }
114
115 #[test]
116 fn test_cast_to_wrong_type() {
117 let list = ListArray::try_new(
118 buffer![0i32, 2, 3, 4].into_array(),
119 buffer![0, 2, 3].into_array(),
120 Validity::NonNullable,
121 )
122 .unwrap();
123
124 let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
125 let result = list
128 .into_array()
129 .cast(target_dtype)
130 .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
131 .map(|c| c.into_array());
132 assert!(result.is_err());
133 }
134
135 #[test]
136 fn test_cant_cast_nulls_to_non_null() {
137 let list = ListArray::try_new(
141 buffer![0i32, 2, 3, 4].into_array(),
142 buffer![0, 2, 3].into_array(),
143 Validity::Array(BoolArray::from_iter(vec![false, true]).into_array()),
144 )
145 .unwrap();
146
147 let target_dtype = DType::List(
148 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
149 Nullability::NonNullable,
150 );
151
152 let result = list
153 .into_array()
154 .cast(target_dtype)
155 .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
156 .map(|c| c.into_array());
157 assert!(result.is_err());
158
159 let list = ListArray::try_new(
162 PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).into_array(),
163 buffer![0, 2, 3].into_array(),
164 Validity::NonNullable,
165 )
166 .unwrap();
167
168 let target_dtype = DType::List(
169 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
170 Nullability::NonNullable,
171 );
172
173 let result = list.into_array().cast(target_dtype).and_then(|a| {
174 a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
175 .map(|c| c.0.into_array())
176 });
177 assert!(result.is_err());
178 }
179
180 #[rstest]
181 #[case(create_simple_list())]
182 #[case(create_nullable_list())]
183 #[case(create_string_list())]
184 #[case(create_nested_list())]
185 #[case(create_empty_lists())]
186 fn test_cast_list_conformance(#[case] array: ListArray) {
187 test_cast_conformance(&array.into_array());
188 }
189
190 fn create_simple_list() -> ListArray {
191 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
192 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
193
194 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
195 }
196
197 fn create_nullable_list() -> ListArray {
198 let data = PrimitiveArray::from_option_iter([
199 Some(10i64),
200 None,
201 Some(20),
202 Some(30),
203 None,
204 Some(40),
205 ])
206 .into_array();
207 let offsets = buffer![0i64, 3, 6].into_array();
208 let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
209
210 ListArray::try_new(data, offsets, validity).unwrap()
211 }
212
213 fn create_string_list() -> ListArray {
214 let data = VarBinArray::from_iter(
215 vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
216 DType::Utf8(Nullability::NonNullable),
217 )
218 .into_array();
219 let offsets = buffer![0i64, 2, 4].into_array();
220
221 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
222 }
223
224 fn create_nested_list() -> ListArray {
225 let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
227 let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
228 let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
229 .unwrap()
230 .into_array();
231
232 let outer_offsets = buffer![0i64, 2, 3].into_array();
234
235 ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
236 }
237
238 fn create_empty_lists() -> ListArray {
239 let data = buffer![42u8].into_array();
240 let offsets = buffer![0i64, 0, 0, 1].into_array();
241
242 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
243 }
244}