Skip to main content

vortex_array/arrays/list/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::arrays::ListArray;
9use crate::arrays::ListVTable;
10use crate::builtins::ArrayBuiltins;
11use crate::dtype::DType;
12use crate::scalar_fn::fns::cast::CastReduce;
13use crate::vtable::ValidityHelper;
14
15impl CastReduce for ListVTable {
16    fn cast(array: &ListArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        let Some(target_element_type) = dtype.as_list_element_opt() else {
18            return Ok(None);
19        };
20
21        let validity = array
22            .validity()
23            .clone()
24            .cast_nullability(dtype.nullability(), array.len())?;
25
26        let new_elements = array.elements().cast((**target_element_type).clone())?;
27
28        ListArray::try_new(new_elements, array.offsets().clone(), validity)
29            .map(|a| Some(a.into_array()))
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use std::sync::Arc;
36
37    use rstest::rstest;
38    use vortex_buffer::buffer;
39
40    use crate::IntoArray;
41    use crate::LEGACY_SESSION;
42    use crate::RecursiveCanonical;
43    use crate::VortexSessionExecute;
44    use crate::arrays::BoolArray;
45    use crate::arrays::ListArray;
46    use crate::arrays::PrimitiveArray;
47    use crate::arrays::VarBinArray;
48    use crate::builtins::ArrayBuiltins;
49    use crate::compute::conformance::cast::test_cast_conformance;
50    use crate::dtype::DType;
51    use crate::dtype::Nullability;
52    use crate::dtype::PType;
53    use crate::validity::Validity;
54
55    #[test]
56    fn test_cast_list_success() {
57        let list = ListArray::try_new(
58            buffer![1i32, 2, 3, 4].into_array().to_array(),
59            buffer![0, 2, 3].into_array().to_array(),
60            Validity::NonNullable,
61        )
62        .unwrap();
63
64        let target_dtype = DType::List(
65            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
66            Nullability::Nullable,
67        );
68
69        let result = list
70            .clone()
71            .into_array()
72            .cast(target_dtype.clone())
73            .unwrap();
74        assert_eq!(result.dtype(), &target_dtype);
75        assert_eq!(result.len(), list.len());
76    }
77
78    #[test]
79    fn test_cast_to_wrong_type() {
80        let list = ListArray::try_new(
81            buffer![0i32, 2, 3, 4].into_array().to_array(),
82            buffer![0, 2, 3].into_array().to_array(),
83            Validity::NonNullable,
84        )
85        .unwrap();
86
87        let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
88        // can't cast list to u64
89
90        let result = list
91            .into_array()
92            .cast(target_dtype)
93            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
94        assert!(result.is_err());
95    }
96
97    #[test]
98    fn test_cant_cast_nulls_to_non_null() {
99        // Test that if list has nulls, the conversion will fail
100
101        // Nulls in the list itself
102        let list = ListArray::try_new(
103            buffer![0i32, 2, 3, 4].into_array().to_array(),
104            buffer![0, 2, 3].into_array().to_array(),
105            Validity::Array(BoolArray::from_iter(vec![false, true]).into_array()),
106        )
107        .unwrap();
108
109        let target_dtype = DType::List(
110            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
111            Nullability::NonNullable,
112        );
113
114        let result = list
115            .into_array()
116            .cast(target_dtype)
117            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
118        assert!(result.is_err());
119
120        // Nulls in list element array — the inner cast error is deferred until
121        // the elements are executed.
122        let list = ListArray::try_new(
123            PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).into_array(),
124            buffer![0, 2, 3].into_array().to_array(),
125            Validity::NonNullable,
126        )
127        .unwrap();
128
129        let target_dtype = DType::List(
130            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
131            Nullability::NonNullable,
132        );
133
134        let result = list.into_array().cast(target_dtype).and_then(|a| {
135            a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
136                .map(|c| c.0.into_array())
137        });
138        assert!(result.is_err());
139    }
140
141    #[rstest]
142    #[case(create_simple_list())]
143    #[case(create_nullable_list())]
144    #[case(create_string_list())]
145    #[case(create_nested_list())]
146    #[case(create_empty_lists())]
147    fn test_cast_list_conformance(#[case] array: ListArray) {
148        test_cast_conformance(&array.into_array());
149    }
150
151    fn create_simple_list() -> ListArray {
152        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
153        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
154
155        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
156    }
157
158    fn create_nullable_list() -> ListArray {
159        let data = PrimitiveArray::from_option_iter([
160            Some(10i64),
161            None,
162            Some(20),
163            Some(30),
164            None,
165            Some(40),
166        ])
167        .into_array();
168        let offsets = buffer![0i64, 3, 6].into_array();
169        let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
170
171        ListArray::try_new(data, offsets, validity).unwrap()
172    }
173
174    fn create_string_list() -> ListArray {
175        let data = VarBinArray::from_iter(
176            vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
177            DType::Utf8(Nullability::NonNullable),
178        )
179        .into_array();
180        let offsets = buffer![0i64, 2, 4].into_array();
181
182        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
183    }
184
185    fn create_nested_list() -> ListArray {
186        // Create inner lists: [[1, 2], [3], [4, 5, 6]]
187        let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
188        let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
189        let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
190            .unwrap()
191            .into_array();
192
193        // Create outer list: [[[1, 2], [3]], [[4, 5, 6]]]
194        let outer_offsets = buffer![0i64, 2, 3].into_array();
195
196        ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
197    }
198
199    fn create_empty_lists() -> ListArray {
200        let data = buffer![42u8].into_array();
201        let offsets = buffer![0i64, 0, 0, 1].into_array();
202
203        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
204    }
205}