vortex_array/arrays/list/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::arrays::ListArray;
10use crate::arrays::ListVTable;
11use crate::builtins::ArrayBuiltins;
12use crate::compute::CastReduce;
13use crate::vtable::ValidityHelper;
14
15impl CastReduce for ListVTable {
16 fn cast(array: &ListArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 let Some(target_element_type) = dtype.as_list_element_opt() else {
18 return Ok(None);
19 };
20
21 let validity = array
22 .validity()
23 .clone()
24 .cast_nullability(dtype.nullability(), array.len())?;
25
26 let new_elements = array
27 .elements()
28 .cast((**target_element_type).clone())?
29 .to_canonical()?
30 .into_array();
31
32 ListArray::try_new(new_elements, array.offsets().clone(), validity)
33 .map(|a| Some(a.to_array()))
34 }
35}
36
37#[cfg(test)]
38mod tests {
39 use std::sync::Arc;
40
41 use rstest::rstest;
42 use vortex_buffer::buffer;
43 use vortex_dtype::DType;
44 use vortex_dtype::Nullability;
45 use vortex_dtype::PType;
46
47 use crate::IntoArray;
48 use crate::arrays::BoolArray;
49 use crate::arrays::ListArray;
50 use crate::arrays::PrimitiveArray;
51 use crate::arrays::VarBinArray;
52 use crate::builtins::ArrayBuiltins;
53 use crate::compute::conformance::cast::test_cast_conformance;
54 use crate::validity::Validity;
55
56 #[test]
57 fn test_cast_list_success() {
58 let list = ListArray::try_new(
59 buffer![1i32, 2, 3, 4].into_array().to_array(),
60 buffer![0, 2, 3].into_array().to_array(),
61 Validity::NonNullable,
62 )
63 .unwrap();
64
65 let target_dtype = DType::List(
66 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
67 Nullability::Nullable,
68 );
69
70 let result = list.to_array().cast(target_dtype.clone()).unwrap();
71 assert_eq!(result.dtype(), &target_dtype);
72 assert_eq!(result.len(), list.len());
73 }
74
75 #[test]
76 fn test_cast_to_wrong_type() {
77 let list = ListArray::try_new(
78 buffer![0i32, 2, 3, 4].into_array().to_array(),
79 buffer![0, 2, 3].into_array().to_array(),
80 Validity::NonNullable,
81 )
82 .unwrap();
83
84 let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
85 let result = list
88 .to_array()
89 .cast(target_dtype)
90 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
91 assert!(result.is_err());
92 }
93
94 #[test]
95 fn test_cant_cast_nulls_to_non_null() {
96 let list = ListArray::try_new(
100 buffer![0i32, 2, 3, 4].into_array().to_array(),
101 buffer![0, 2, 3].into_array().to_array(),
102 Validity::Array(BoolArray::from_iter(vec![false, true]).to_array()),
103 )
104 .unwrap();
105
106 let target_dtype = DType::List(
107 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
108 Nullability::NonNullable,
109 );
110
111 let result = list
112 .to_array()
113 .cast(target_dtype)
114 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
115 assert!(result.is_err());
116
117 let list = ListArray::try_new(
119 PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
120 buffer![0, 2, 3].into_array().to_array(),
121 Validity::NonNullable,
122 )
123 .unwrap();
124
125 let target_dtype = DType::List(
126 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
127 Nullability::NonNullable,
128 );
129
130 let result = list
131 .to_array()
132 .cast(target_dtype)
133 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
134 assert!(result.is_err());
135 }
136
137 #[rstest]
138 #[case(create_simple_list())]
139 #[case(create_nullable_list())]
140 #[case(create_string_list())]
141 #[case(create_nested_list())]
142 #[case(create_empty_lists())]
143 fn test_cast_list_conformance(#[case] array: ListArray) {
144 test_cast_conformance(array.as_ref());
145 }
146
147 fn create_simple_list() -> ListArray {
148 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
149 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
150
151 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
152 }
153
154 fn create_nullable_list() -> ListArray {
155 let data = PrimitiveArray::from_option_iter([
156 Some(10i64),
157 None,
158 Some(20),
159 Some(30),
160 None,
161 Some(40),
162 ])
163 .into_array();
164 let offsets = buffer![0i64, 3, 6].into_array();
165 let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
166
167 ListArray::try_new(data, offsets, validity).unwrap()
168 }
169
170 fn create_string_list() -> ListArray {
171 let data = VarBinArray::from_iter(
172 vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
173 DType::Utf8(Nullability::NonNullable),
174 )
175 .into_array();
176 let offsets = buffer![0i64, 2, 4].into_array();
177
178 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
179 }
180
181 fn create_nested_list() -> ListArray {
182 let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
184 let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
185 let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
186 .unwrap()
187 .into_array();
188
189 let outer_offsets = buffer![0i64, 2, 3].into_array();
191
192 ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
193 }
194
195 fn create_empty_lists() -> ListArray {
196 let data = buffer![42u8].into_array();
197 let offsets = buffer![0i64, 0, 0, 1].into_array();
198
199 ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
200 }
201}