vortex_array/arrays/list/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::arrays::{ListArray, ListVTable};
8use crate::compute::{CastKernel, CastKernelAdapter, cast};
9use crate::vtable::ValidityHelper;
10use crate::{ArrayRef, register_kernel};
11
12impl CastKernel for ListVTable {
13 fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14 let Some(target_element_type) = dtype.as_list_element_opt() else {
15 return Ok(None);
16 };
17
18 let validity = array
19 .validity()
20 .clone()
21 .cast_nullability(dtype.nullability())?;
22
23 ListArray::try_new(
24 cast(array.elements(), target_element_type)?,
25 array.offsets().clone(),
26 validity,
27 )
28 .map(|a| Some(a.to_array()))
29 }
30}
31
32register_kernel!(CastKernelAdapter(ListVTable).lift());
33
34#[cfg(test)]
35mod tests {
36 use std::sync::Arc;
37
38 use rstest::rstest;
39 use vortex_buffer::buffer;
40 use vortex_dtype::{DType, Nullability, PType};
41
42 use crate::IntoArray;
43 use crate::arrays::{BoolArray, ListArray, PrimitiveArray, VarBinArray};
44 use crate::compute::cast;
45 use crate::compute::conformance::cast::test_cast_conformance;
46 use crate::validity::Validity;
47
48 #[test]
49 fn test_cast_list_success() {
50 let list = ListArray::try_new(
51 PrimitiveArray::from_iter([1i32, 2, 3, 4]).to_array(),
52 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
53 Validity::NonNullable,
54 )
55 .unwrap();
56
57 let target_dtype = DType::List(
58 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
59 Nullability::Nullable,
60 );
61
62 let result = cast(list.to_array().as_ref(), &target_dtype).unwrap();
63 assert_eq!(result.dtype(), &target_dtype);
64 assert_eq!(result.len(), list.len());
65 }
66
67 #[test]
68 fn test_cast_to_wrong_type() {
69 let list = ListArray::try_new(
70 PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
71 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
72 Validity::NonNullable,
73 )
74 .unwrap();
75
76 let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
77 let result = cast(list.to_array().as_ref(), &target_dtype);
80 assert!(result.is_err());
81 }
82
83 #[test]
84 fn test_cant_cast_nulls_to_non_null() {
85 let list = ListArray::try_new(
89 PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
90 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
91 Validity::Array(BoolArray::from_iter(vec![false, true]).to_array()),
92 )
93 .unwrap();
94
95 let target_dtype = DType::List(
96 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
97 Nullability::NonNullable,
98 );
99
100 let result = cast(list.to_array().as_ref(), &target_dtype);
101 assert!(result.is_err());
102
103 let list = ListArray::try_new(
105 PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
106 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
107 Validity::NonNullable,
108 )
109 .unwrap();
110
111 let target_dtype = DType::List(
112 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
113 Nullability::NonNullable,
114 );
115
116 let result = cast(list.to_array().as_ref(), &target_dtype);
117 assert!(result.is_err());
118 }
119
120 #[rstest]
121 #[case(create_simple_list())]
122 #[case(create_nullable_list())]
123 #[case(create_string_list())]
124 #[case(create_nested_list())]
125 #[case(create_empty_lists())]
126 fn test_cast_list_conformance(#[case] array: ListArray) {
127 test_cast_conformance(array.as_ref());
128 }
129
130 fn create_simple_list() -> ListArray {
131 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
132 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
133
134 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
135 }
136
137 fn create_nullable_list() -> ListArray {
138 let data = PrimitiveArray::from_option_iter([
139 Some(10i64),
140 None,
141 Some(20),
142 Some(30),
143 None,
144 Some(40),
145 ])
146 .into_array();
147 let offsets = buffer![0i64, 3, 6].into_array();
148 let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
149
150 ListArray::try_new(data, offsets, validity).unwrap()
151 }
152
153 fn create_string_list() -> ListArray {
154 let data = VarBinArray::from_iter(
155 vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
156 DType::Utf8(Nullability::NonNullable),
157 )
158 .into_array();
159 let offsets = buffer![0i64, 2, 4].into_array();
160
161 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
162 }
163
164 fn create_nested_list() -> ListArray {
165 let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
167 let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
168 let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
169 .unwrap()
170 .into_array();
171
172 let outer_offsets = buffer![0i64, 2, 3].into_array();
174
175 ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
176 }
177
178 fn create_empty_lists() -> ListArray {
179 let data = buffer![42u8].into_array();
180 let offsets = buffer![0i64, 0, 0, 1].into_array();
181
182 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
183 }
184}