vortex_array/arrays/list/compute/
cast.rs1use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::arrays::ListArray;
9use crate::arrays::ListVTable;
10use crate::builtins::ArrayBuiltins;
11use crate::dtype::DType;
12use crate::scalar_fn::fns::cast::CastReduce;
13use crate::vtable::ValidityHelper;
14
15impl CastReduce for ListVTable {
16 fn cast(array: &ListArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 let Some(target_element_type) = dtype.as_list_element_opt() else {
18 return Ok(None);
19 };
20
21 let validity = array
22 .validity()
23 .clone()
24 .cast_nullability(dtype.nullability(), array.len())?;
25
26 let new_elements = array.elements().cast((**target_element_type).clone())?;
27
28 ListArray::try_new(new_elements, array.offsets().clone(), validity)
29 .map(|a| Some(a.into_array()))
30 }
31}
32
33#[cfg(test)]
34mod tests {
35 use std::sync::Arc;
36
37 use rstest::rstest;
38 use vortex_buffer::buffer;
39
40 use crate::IntoArray;
41 use crate::LEGACY_SESSION;
42 use crate::RecursiveCanonical;
43 use crate::VortexSessionExecute;
44 use crate::arrays::BoolArray;
45 use crate::arrays::ListArray;
46 use crate::arrays::PrimitiveArray;
47 use crate::arrays::VarBinArray;
48 use crate::builtins::ArrayBuiltins;
49 use crate::compute::conformance::cast::test_cast_conformance;
50 use crate::dtype::DType;
51 use crate::dtype::Nullability;
52 use crate::dtype::PType;
53 use crate::validity::Validity;
54
55 #[test]
56 fn test_cast_list_success() {
57 let list = ListArray::try_new(
58 buffer![1i32, 2, 3, 4].into_array().to_array(),
59 buffer![0, 2, 3].into_array().to_array(),
60 Validity::NonNullable,
61 )
62 .unwrap();
63
64 let target_dtype = DType::List(
65 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
66 Nullability::Nullable,
67 );
68
69 let result = list
70 .clone()
71 .into_array()
72 .cast(target_dtype.clone())
73 .unwrap();
74 assert_eq!(result.dtype(), &target_dtype);
75 assert_eq!(result.len(), list.len());
76 }
77
78 #[test]
79 fn test_cast_to_wrong_type() {
80 let list = ListArray::try_new(
81 buffer![0i32, 2, 3, 4].into_array().to_array(),
82 buffer![0, 2, 3].into_array().to_array(),
83 Validity::NonNullable,
84 )
85 .unwrap();
86
87 let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
88 let result = list
91 .into_array()
92 .cast(target_dtype)
93 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
94 assert!(result.is_err());
95 }
96
97 #[test]
98 fn test_cant_cast_nulls_to_non_null() {
99 let list = ListArray::try_new(
103 buffer![0i32, 2, 3, 4].into_array().to_array(),
104 buffer![0, 2, 3].into_array().to_array(),
105 Validity::Array(BoolArray::from_iter(vec![false, true]).into_array()),
106 )
107 .unwrap();
108
109 let target_dtype = DType::List(
110 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
111 Nullability::NonNullable,
112 );
113
114 let result = list
115 .into_array()
116 .cast(target_dtype)
117 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
118 assert!(result.is_err());
119
120 let list = ListArray::try_new(
123 PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).into_array(),
124 buffer![0, 2, 3].into_array().to_array(),
125 Validity::NonNullable,
126 )
127 .unwrap();
128
129 let target_dtype = DType::List(
130 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
131 Nullability::NonNullable,
132 );
133
134 let result = list.into_array().cast(target_dtype).and_then(|a| {
135 a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
136 .map(|c| c.0.into_array())
137 });
138 assert!(result.is_err());
139 }
140
141 #[rstest]
142 #[case(create_simple_list())]
143 #[case(create_nullable_list())]
144 #[case(create_string_list())]
145 #[case(create_nested_list())]
146 #[case(create_empty_lists())]
147 fn test_cast_list_conformance(#[case] array: ListArray) {
148 test_cast_conformance(&array.into_array());
149 }
150
151 fn create_simple_list() -> ListArray {
152 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
153 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
154
155 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
156 }
157
158 fn create_nullable_list() -> ListArray {
159 let data = PrimitiveArray::from_option_iter([
160 Some(10i64),
161 None,
162 Some(20),
163 Some(30),
164 None,
165 Some(40),
166 ])
167 .into_array();
168 let offsets = buffer![0i64, 3, 6].into_array();
169 let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
170
171 ListArray::try_new(data, offsets, validity).unwrap()
172 }
173
174 fn create_string_list() -> ListArray {
175 let data = VarBinArray::from_iter(
176 vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
177 DType::Utf8(Nullability::NonNullable),
178 )
179 .into_array();
180 let offsets = buffer![0i64, 2, 4].into_array();
181
182 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
183 }
184
185 fn create_nested_list() -> ListArray {
186 let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
188 let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
189 let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
190 .unwrap()
191 .into_array();
192
193 let outer_offsets = buffer![0i64, 2, 3].into_array();
195
196 ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
197 }
198
199 fn create_empty_lists() -> ListArray {
200 let data = buffer![42u8].into_array();
201 let offsets = buffer![0i64, 0, 0, 1].into_array();
202
203 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
204 }
205}