Skip to main content

vortex_array/arrays/list/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::arrays::ListArray;
10use crate::arrays::ListVTable;
11use crate::builtins::ArrayBuiltins;
12use crate::compute::CastReduce;
13use crate::vtable::ValidityHelper;
14
15impl CastReduce for ListVTable {
16    fn cast(array: &ListArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        let Some(target_element_type) = dtype.as_list_element_opt() else {
18            return Ok(None);
19        };
20
21        let validity = array
22            .validity()
23            .clone()
24            .cast_nullability(dtype.nullability(), array.len())?;
25
26        let new_elements = array
27            .elements()
28            .cast((**target_element_type).clone())?
29            .to_canonical()?
30            .into_array();
31
32        ListArray::try_new(new_elements, array.offsets().clone(), validity)
33            .map(|a| Some(a.to_array()))
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use std::sync::Arc;
40
41    use rstest::rstest;
42    use vortex_buffer::buffer;
43    use vortex_dtype::DType;
44    use vortex_dtype::Nullability;
45    use vortex_dtype::PType;
46
47    use crate::IntoArray;
48    use crate::arrays::BoolArray;
49    use crate::arrays::ListArray;
50    use crate::arrays::PrimitiveArray;
51    use crate::arrays::VarBinArray;
52    use crate::builtins::ArrayBuiltins;
53    use crate::compute::conformance::cast::test_cast_conformance;
54    use crate::validity::Validity;
55
56    #[test]
57    fn test_cast_list_success() {
58        let list = ListArray::try_new(
59            buffer![1i32, 2, 3, 4].into_array().to_array(),
60            buffer![0, 2, 3].into_array().to_array(),
61            Validity::NonNullable,
62        )
63        .unwrap();
64
65        let target_dtype = DType::List(
66            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
67            Nullability::Nullable,
68        );
69
70        let result = list.to_array().cast(target_dtype.clone()).unwrap();
71        assert_eq!(result.dtype(), &target_dtype);
72        assert_eq!(result.len(), list.len());
73    }
74
75    #[test]
76    fn test_cast_to_wrong_type() {
77        let list = ListArray::try_new(
78            buffer![0i32, 2, 3, 4].into_array().to_array(),
79            buffer![0, 2, 3].into_array().to_array(),
80            Validity::NonNullable,
81        )
82        .unwrap();
83
84        let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
85        // can't cast list to u64
86
87        let result = list
88            .to_array()
89            .cast(target_dtype)
90            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
91        assert!(result.is_err());
92    }
93
94    #[test]
95    fn test_cant_cast_nulls_to_non_null() {
96        // Test that if list has nulls, the conversion will fail
97
98        // Nulls in the list itself
99        let list = ListArray::try_new(
100            buffer![0i32, 2, 3, 4].into_array().to_array(),
101            buffer![0, 2, 3].into_array().to_array(),
102            Validity::Array(BoolArray::from_iter(vec![false, true]).to_array()),
103        )
104        .unwrap();
105
106        let target_dtype = DType::List(
107            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
108            Nullability::NonNullable,
109        );
110
111        let result = list
112            .to_array()
113            .cast(target_dtype)
114            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
115        assert!(result.is_err());
116
117        // Nulls in list element array
118        let list = ListArray::try_new(
119            PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
120            buffer![0, 2, 3].into_array().to_array(),
121            Validity::NonNullable,
122        )
123        .unwrap();
124
125        let target_dtype = DType::List(
126            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
127            Nullability::NonNullable,
128        );
129
130        let result = list
131            .to_array()
132            .cast(target_dtype)
133            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
134        assert!(result.is_err());
135    }
136
137    #[rstest]
138    #[case(create_simple_list())]
139    #[case(create_nullable_list())]
140    #[case(create_string_list())]
141    #[case(create_nested_list())]
142    #[case(create_empty_lists())]
143    fn test_cast_list_conformance(#[case] array: ListArray) {
144        test_cast_conformance(array.as_ref());
145    }
146
147    fn create_simple_list() -> ListArray {
148        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
149        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
150
151        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
152    }
153
154    fn create_nullable_list() -> ListArray {
155        let data = PrimitiveArray::from_option_iter([
156            Some(10i64),
157            None,
158            Some(20),
159            Some(30),
160            None,
161            Some(40),
162        ])
163        .into_array();
164        let offsets = buffer![0i64, 3, 6].into_array();
165        let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
166
167        ListArray::try_new(data, offsets, validity).unwrap()
168    }
169
170    fn create_string_list() -> ListArray {
171        let data = VarBinArray::from_iter(
172            vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
173            DType::Utf8(Nullability::NonNullable),
174        )
175        .into_array();
176        let offsets = buffer![0i64, 2, 4].into_array();
177
178        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
179    }
180
181    fn create_nested_list() -> ListArray {
182        // Create inner lists: [[1, 2], [3], [4, 5, 6]]
183        let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
184        let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
185        let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
186            .unwrap()
187            .into_array();
188
189        // Create outer list: [[[1, 2], [3]], [[4, 5, 6]]]
190        let outer_offsets = buffer![0i64, 2, 3].into_array();
191
192        ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
193    }
194
195    fn create_empty_lists() -> ListArray {
196        let data = buffer![42u8].into_array();
197        let offsets = buffer![0i64, 0, 0, 1].into_array();
198
199        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
200    }
201}