Skip to main content

vortex_array/arrays/list/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::list::ListArrayExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastKernel;
16use crate::scalar_fn::fns::cast::CastReduce;
17
18impl CastReduce for List {
19    fn cast(array: ArrayView<'_, List>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20        let Some(target_element_type) = dtype.as_list_element_opt() else {
21            return Ok(None);
22        };
23
24        let Some(validity) = array
25            .validity()?
26            .trivial_cast_nullability(dtype.nullability(), array.len())?
27        else {
28            return Ok(None);
29        };
30
31        let new_elements = array.elements().cast((**target_element_type).clone())?;
32
33        Ok(Some(
34            unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
35                .into_array(),
36        ))
37    }
38}
39
40impl CastKernel for List {
41    fn cast(
42        array: ArrayView<'_, List>,
43        dtype: &DType,
44        ctx: &mut ExecutionCtx,
45    ) -> VortexResult<Option<ArrayRef>> {
46        let Some(target_element_type) = dtype.as_list_element_opt() else {
47            return Ok(None);
48        };
49
50        let validity = array
51            .validity()?
52            .cast_nullability(dtype.nullability(), array.len(), ctx)?;
53
54        let new_elements = array.elements().cast((**target_element_type).clone())?;
55
56        Ok(Some(
57            unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
58                .into_array(),
59        ))
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use std::sync::Arc;
66    use std::sync::LazyLock;
67
68    use rstest::rstest;
69    use vortex_array::session::ArraySession;
70    use vortex_buffer::buffer;
71    use vortex_session::VortexSession;
72
73    use crate::Canonical;
74    use crate::IntoArray;
75    use crate::LEGACY_SESSION;
76    use crate::RecursiveCanonical;
77    use crate::VortexSessionExecute;
78    use crate::arrays::BoolArray;
79    use crate::arrays::ListArray;
80    use crate::arrays::PrimitiveArray;
81    use crate::arrays::VarBinArray;
82    use crate::builtins::ArrayBuiltins;
83    use crate::compute::conformance::cast::test_cast_conformance;
84    use crate::dtype::DType;
85    use crate::dtype::Nullability;
86    use crate::dtype::PType;
87    use crate::validity::Validity;
88
89    static SESSION: LazyLock<VortexSession> =
90        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
91
92    #[test]
93    fn test_cast_list_success() {
94        let list = ListArray::try_new(
95            buffer![1i32, 2, 3, 4].into_array(),
96            buffer![0, 2, 3].into_array(),
97            Validity::NonNullable,
98        )
99        .unwrap();
100
101        let target_dtype = DType::List(
102            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
103            Nullability::Nullable,
104        );
105
106        let result = list
107            .clone()
108            .into_array()
109            .cast(target_dtype.clone())
110            .unwrap();
111        assert_eq!(result.dtype(), &target_dtype);
112        assert_eq!(result.len(), list.len());
113    }
114
115    #[test]
116    fn test_cast_to_wrong_type() {
117        let list = ListArray::try_new(
118            buffer![0i32, 2, 3, 4].into_array(),
119            buffer![0, 2, 3].into_array(),
120            Validity::NonNullable,
121        )
122        .unwrap();
123
124        let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
125        // can't cast list to u64
126
127        let result = list
128            .into_array()
129            .cast(target_dtype)
130            .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
131            .map(|c| c.into_array());
132        assert!(result.is_err());
133    }
134
135    #[test]
136    fn test_cant_cast_nulls_to_non_null() {
137        // Test that if list has nulls, the conversion will fail
138
139        // Nulls in the list itself
140        let list = ListArray::try_new(
141            buffer![0i32, 2, 3, 4].into_array(),
142            buffer![0, 2, 3].into_array(),
143            Validity::Array(BoolArray::from_iter(vec![false, true]).into_array()),
144        )
145        .unwrap();
146
147        let target_dtype = DType::List(
148            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
149            Nullability::NonNullable,
150        );
151
152        let result = list
153            .into_array()
154            .cast(target_dtype)
155            .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
156            .map(|c| c.into_array());
157        assert!(result.is_err());
158
159        // Nulls in list element array — the inner cast error is deferred until
160        // the elements are executed.
161        let list = ListArray::try_new(
162            PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).into_array(),
163            buffer![0, 2, 3].into_array(),
164            Validity::NonNullable,
165        )
166        .unwrap();
167
168        let target_dtype = DType::List(
169            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
170            Nullability::NonNullable,
171        );
172
173        let result = list.into_array().cast(target_dtype).and_then(|a| {
174            a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
175                .map(|c| c.0.into_array())
176        });
177        assert!(result.is_err());
178    }
179
180    #[rstest]
181    #[case(create_simple_list())]
182    #[case(create_nullable_list())]
183    #[case(create_string_list())]
184    #[case(create_nested_list())]
185    #[case(create_empty_lists())]
186    fn test_cast_list_conformance(#[case] array: ListArray) {
187        test_cast_conformance(&array.into_array());
188    }
189
190    fn create_simple_list() -> ListArray {
191        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
192        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
193
194        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
195    }
196
197    fn create_nullable_list() -> ListArray {
198        let data = PrimitiveArray::from_option_iter([
199            Some(10i64),
200            None,
201            Some(20),
202            Some(30),
203            None,
204            Some(40),
205        ])
206        .into_array();
207        let offsets = buffer![0i64, 3, 6].into_array();
208        let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
209
210        ListArray::try_new(data, offsets, validity).unwrap()
211    }
212
213    fn create_string_list() -> ListArray {
214        let data = VarBinArray::from_iter(
215            vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
216            DType::Utf8(Nullability::NonNullable),
217        )
218        .into_array();
219        let offsets = buffer![0i64, 2, 4].into_array();
220
221        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
222    }
223
224    fn create_nested_list() -> ListArray {
225        // Create inner lists: [[1, 2], [3], [4, 5, 6]]
226        let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
227        let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
228        let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
229            .unwrap()
230            .into_array();
231
232        // Create outer list: [[[1, 2], [3]], [[4, 5, 6]]]
233        let outer_offsets = buffer![0i64, 2, 3].into_array();
234
235        ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
236    }
237
238    fn create_empty_lists() -> ListArray {
239        let data = buffer![42u8].into_array();
240        let offsets = buffer![0i64, 0, 0, 1].into_array();
241
242        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
243    }
244}