Skip to main content

vortex_array/arrays/list/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::List;
11use crate::arrays::ListArray;
12use crate::arrays::list::ListArrayExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastKernel;
16use crate::scalar_fn::fns::cast::CastReduce;
17
18impl CastReduce for List {
19    fn cast(array: ArrayView<'_, List>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20        let Some(target_element_type) = dtype.as_list_element_opt() else {
21            return Ok(None);
22        };
23
24        let Some(validity) = array
25            .validity()?
26            .trivially_cast_nullability(dtype.nullability(), array.len())?
27        else {
28            return Ok(None);
29        };
30
31        let new_elements = array.elements().cast((**target_element_type).clone())?;
32
33        Ok(Some(
34            unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
35                .into_array(),
36        ))
37    }
38}
39
40impl CastKernel for List {
41    fn cast(
42        array: ArrayView<'_, List>,
43        dtype: &DType,
44        ctx: &mut ExecutionCtx,
45    ) -> VortexResult<Option<ArrayRef>> {
46        let Some(target_element_type) = dtype.as_list_element_opt() else {
47            return Ok(None);
48        };
49
50        let validity = array
51            .validity()?
52            .cast_nullability(dtype.nullability(), array.len(), ctx)?;
53
54        let new_elements = array.elements().cast((**target_element_type).clone())?;
55
56        Ok(Some(
57            unsafe { ListArray::new_unchecked(new_elements, array.offsets().clone(), validity) }
58                .into_array(),
59        ))
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use std::sync::Arc;
66    use std::sync::LazyLock;
67
68    use rstest::rstest;
69    use vortex_buffer::buffer;
70    use vortex_session::VortexSession;
71
72    use crate::Canonical;
73    use crate::IntoArray;
74    use crate::RecursiveCanonical;
75    use crate::VortexSessionExecute;
76    use crate::arrays::BoolArray;
77    use crate::arrays::ListArray;
78    use crate::arrays::PrimitiveArray;
79    use crate::arrays::VarBinArray;
80    use crate::builtins::ArrayBuiltins;
81    use crate::compute::conformance::cast::test_cast_conformance;
82    use crate::dtype::DType;
83    use crate::dtype::Nullability;
84    use crate::dtype::PType;
85    use crate::validity::Validity;
86
87    static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
88
89    #[test]
90    fn test_cast_list_success() {
91        let list = ListArray::try_new(
92            buffer![1i32, 2, 3, 4].into_array(),
93            buffer![0, 2, 3].into_array(),
94            Validity::NonNullable,
95        )
96        .unwrap();
97
98        let target_dtype = DType::List(
99            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
100            Nullability::Nullable,
101        );
102
103        let result = list
104            .clone()
105            .into_array()
106            .cast(target_dtype.clone())
107            .unwrap();
108        assert_eq!(result.dtype(), &target_dtype);
109        assert_eq!(result.len(), list.len());
110    }
111
112    #[test]
113    fn test_cast_to_wrong_type() {
114        let list = ListArray::try_new(
115            buffer![0i32, 2, 3, 4].into_array(),
116            buffer![0, 2, 3].into_array(),
117            Validity::NonNullable,
118        )
119        .unwrap();
120
121        let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
122        // can't cast list to u64
123
124        let result = list
125            .into_array()
126            .cast(target_dtype)
127            .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
128            .map(|c| c.into_array());
129        assert!(result.is_err());
130    }
131
132    #[test]
133    fn test_cant_cast_nulls_to_non_null() {
134        // Test that if list has nulls, the conversion will fail
135
136        // Nulls in the list itself
137        let list = ListArray::try_new(
138            buffer![0i32, 2, 3, 4].into_array(),
139            buffer![0, 2, 3].into_array(),
140            Validity::Array(BoolArray::from_iter(vec![false, true]).into_array()),
141        )
142        .unwrap();
143
144        let target_dtype = DType::List(
145            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
146            Nullability::NonNullable,
147        );
148
149        let result = list
150            .into_array()
151            .cast(target_dtype)
152            .and_then(|a| a.execute::<Canonical>(&mut SESSION.create_execution_ctx()))
153            .map(|c| c.into_array());
154        assert!(result.is_err());
155
156        // Nulls in list element array — the inner cast error is deferred until
157        // the elements are executed.
158        let list = ListArray::try_new(
159            PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).into_array(),
160            buffer![0, 2, 3].into_array(),
161            Validity::NonNullable,
162        )
163        .unwrap();
164
165        let target_dtype = DType::List(
166            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
167            Nullability::NonNullable,
168        );
169
170        let result = list.into_array().cast(target_dtype).and_then(|a| {
171            a.execute::<RecursiveCanonical>(&mut SESSION.create_execution_ctx())
172                .map(|c| c.0.into_array())
173        });
174        assert!(result.is_err());
175    }
176
177    #[rstest]
178    #[case(create_simple_list())]
179    #[case(create_nullable_list())]
180    #[case(create_string_list())]
181    #[case(create_nested_list())]
182    #[case(create_empty_lists())]
183    fn test_cast_list_conformance(#[case] array: ListArray) {
184        test_cast_conformance(&array.into_array());
185    }
186
187    fn create_simple_list() -> ListArray {
188        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
189        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
190
191        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
192    }
193
194    fn create_nullable_list() -> ListArray {
195        let data = PrimitiveArray::from_option_iter([
196            Some(10i64),
197            None,
198            Some(20),
199            Some(30),
200            None,
201            Some(40),
202        ])
203        .into_array();
204        let offsets = buffer![0i64, 3, 6].into_array();
205        let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
206
207        ListArray::try_new(data, offsets, validity).unwrap()
208    }
209
210    fn create_string_list() -> ListArray {
211        let data = VarBinArray::from_iter(
212            vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
213            DType::Utf8(Nullability::NonNullable),
214        )
215        .into_array();
216        let offsets = buffer![0i64, 2, 4].into_array();
217
218        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
219    }
220
221    fn create_nested_list() -> ListArray {
222        // Create inner lists: [[1, 2], [3], [4, 5, 6]]
223        let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
224        let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
225        let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
226            .unwrap()
227            .into_array();
228
229        // Create outer list: [[[1, 2], [3]], [[4, 5, 6]]]
230        let outer_offsets = buffer![0i64, 2, 3].into_array();
231
232        ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
233    }
234
235    fn create_empty_lists() -> ListArray {
236        let data = buffer![42u8].into_array();
237        let offsets = buffer![0i64, 0, 0, 1].into_array();
238
239        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
240    }
241}