vortex_array/arrays/list/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::arrays::{ListArray, ListVTable};
8use crate::compute::{CastKernel, CastKernelAdapter, cast};
9use crate::vtable::ValidityHelper;
10use crate::{ArrayRef, register_kernel};
11
12impl CastKernel for ListVTable {
13    fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14        let Some(target_element_type) = dtype.as_list_element_opt() else {
15            return Ok(None);
16        };
17
18        let validity = array
19            .validity()
20            .clone()
21            .cast_nullability(dtype.nullability())?;
22
23        ListArray::try_new(
24            cast(array.elements(), target_element_type)?,
25            array.offsets().clone(),
26            validity,
27        )
28        .map(|a| Some(a.to_array()))
29    }
30}
31
32register_kernel!(CastKernelAdapter(ListVTable).lift());
33
34#[cfg(test)]
35mod tests {
36    use std::sync::Arc;
37
38    use rstest::rstest;
39    use vortex_buffer::buffer;
40    use vortex_dtype::{DType, Nullability, PType};
41
42    use crate::IntoArray;
43    use crate::arrays::{BoolArray, ListArray, PrimitiveArray, VarBinArray};
44    use crate::compute::cast;
45    use crate::compute::conformance::cast::test_cast_conformance;
46    use crate::validity::Validity;
47
48    #[test]
49    fn test_cast_list_success() {
50        let list = ListArray::try_new(
51            PrimitiveArray::from_iter([1i32, 2, 3, 4]).to_array(),
52            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
53            Validity::NonNullable,
54        )
55        .unwrap();
56
57        let target_dtype = DType::List(
58            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
59            Nullability::Nullable,
60        );
61
62        let result = cast(list.to_array().as_ref(), &target_dtype).unwrap();
63        assert_eq!(result.dtype(), &target_dtype);
64        assert_eq!(result.len(), list.len());
65    }
66
67    #[test]
68    fn test_cast_to_wrong_type() {
69        let list = ListArray::try_new(
70            PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
71            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
72            Validity::NonNullable,
73        )
74        .unwrap();
75
76        let target_dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
77        // can't cast list to u64
78
79        let result = cast(list.to_array().as_ref(), &target_dtype);
80        assert!(result.is_err());
81    }
82
83    #[test]
84    fn test_cant_cast_nulls_to_non_null() {
85        // Test that if list has nulls, the conversion will fail
86
87        // Nulls in the list itself
88        let list = ListArray::try_new(
89            PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
90            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
91            Validity::Array(BoolArray::from_iter(vec![false, true]).to_array()),
92        )
93        .unwrap();
94
95        let target_dtype = DType::List(
96            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
97            Nullability::NonNullable,
98        );
99
100        let result = cast(list.to_array().as_ref(), &target_dtype);
101        assert!(result.is_err());
102
103        // Nulls in list element array
104        let list = ListArray::try_new(
105            PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
106            PrimitiveArray::from_iter([0, 2, 3]).to_array(),
107            Validity::NonNullable,
108        )
109        .unwrap();
110
111        let target_dtype = DType::List(
112            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
113            Nullability::NonNullable,
114        );
115
116        let result = cast(list.to_array().as_ref(), &target_dtype);
117        assert!(result.is_err());
118    }
119
120    #[rstest]
121    #[case(create_simple_list())]
122    #[case(create_nullable_list())]
123    #[case(create_string_list())]
124    #[case(create_nested_list())]
125    #[case(create_empty_lists())]
126    fn test_cast_list_conformance(#[case] array: ListArray) {
127        test_cast_conformance(array.as_ref());
128    }
129
130    fn create_simple_list() -> ListArray {
131        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
132        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
133
134        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
135    }
136
137    fn create_nullable_list() -> ListArray {
138        let data = PrimitiveArray::from_option_iter([
139            Some(10i64),
140            None,
141            Some(20),
142            Some(30),
143            None,
144            Some(40),
145        ])
146        .into_array();
147        let offsets = buffer![0i64, 3, 6].into_array();
148        let validity = Validity::Array(BoolArray::from_iter(vec![true, false]).into_array());
149
150        ListArray::try_new(data, offsets, validity).unwrap()
151    }
152
153    fn create_string_list() -> ListArray {
154        let data = VarBinArray::from_iter(
155            vec![Some("hello"), Some("world"), Some("foo"), Some("bar")],
156            DType::Utf8(Nullability::NonNullable),
157        )
158        .into_array();
159        let offsets = buffer![0i64, 2, 4].into_array();
160
161        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
162    }
163
164    fn create_nested_list() -> ListArray {
165        // Create inner lists: [[1, 2], [3], [4, 5, 6]]
166        let inner_data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
167        let inner_offsets = buffer![0i64, 2, 3, 6].into_array();
168        let inner_list = ListArray::try_new(inner_data, inner_offsets, Validity::NonNullable)
169            .unwrap()
170            .into_array();
171
172        // Create outer list: [[[1, 2], [3]], [[4, 5, 6]]]
173        let outer_offsets = buffer![0i64, 2, 3].into_array();
174
175        ListArray::try_new(inner_list, outer_offsets, Validity::NonNullable).unwrap()
176    }
177
178    fn create_empty_lists() -> ListArray {
179        let data = buffer![42u8].into_array();
180        let offsets = buffer![0i64, 0, 0, 1].into_array();
181
182        ListArray::try_new(data, offsets, Validity::NonNullable).unwrap()
183    }
184}