vortex_array/arrays/list/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use crate::arrays::{ListArray, ListVTable};
8use crate::compute::{CastKernel, CastKernelAdapter, cast};
9use crate::vtable::ValidityHelper;
10use crate::{ArrayRef, register_kernel};
11
12impl CastKernel for ListVTable {
13 fn cast(&self, array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14 let Some(target_element_type) = dtype.as_list_element() else {
15 return Ok(None);
16 };
17
18 let validity = array
19 .validity()
20 .clone()
21 .cast_nullability(dtype.nullability())?;
22
23 ListArray::try_new(
24 cast(array.elements(), target_element_type)?,
25 array.offsets().clone(),
26 validity,
27 )
28 .map(|a| Some(a.to_array()))
29 }
30}
31
32register_kernel!(CastKernelAdapter(ListVTable).lift());
33
34#[cfg(test)]
35mod tests {
36 use std::sync::Arc;
37
38 use vortex_dtype::{DType, Nullability};
39
40 use crate::arrays::{BoolArray, ListArray, PrimitiveArray};
41 use crate::compute::cast;
42 use crate::validity::Validity;
43
44 #[test]
45 fn test_cast_list_success() {
46 let list = ListArray::try_new(
47 PrimitiveArray::from_iter([1i32, 2, 3, 4]).to_array(),
48 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
49 Validity::NonNullable,
50 )
51 .unwrap();
52
53 let target_dtype = DType::List(
54 Arc::new(DType::Primitive(
55 vortex_dtype::PType::U64,
56 Nullability::Nullable,
57 )),
58 Nullability::Nullable,
59 );
60
61 let result = cast(list.to_array().as_ref(), &target_dtype).unwrap();
62 assert_eq!(result.dtype(), &target_dtype);
63 assert_eq!(result.len(), list.len());
64 }
65
66 #[test]
67 fn test_cast_to_wrong_type() {
68 let list = ListArray::try_new(
69 PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
70 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
71 Validity::NonNullable,
72 )
73 .unwrap();
74
75 let target_dtype = DType::Primitive(vortex_dtype::PType::U64, Nullability::NonNullable);
76 let result = cast(list.to_array().as_ref(), &target_dtype);
79 assert!(result.is_err());
80 }
81
82 #[test]
83 fn test_cant_cast_nulls_to_non_null() {
84 let list = ListArray::try_new(
88 PrimitiveArray::from_iter([0i32, 2, 3, 4]).to_array(),
89 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
90 Validity::Array(BoolArray::from_iter(vec![false, true, true]).to_array()),
91 )
92 .unwrap();
93
94 let target_dtype = DType::List(
95 Arc::new(DType::Primitive(
96 vortex_dtype::PType::U64,
97 Nullability::Nullable,
98 )),
99 Nullability::NonNullable,
100 );
101
102 let result = cast(list.to_array().as_ref(), &target_dtype);
103 assert!(result.is_err());
104
105 let list = ListArray::try_new(
107 PrimitiveArray::from_option_iter([Some(0i32), Some(2), None, None]).to_array(),
108 PrimitiveArray::from_iter([0, 2, 3]).to_array(),
109 Validity::NonNullable,
110 )
111 .unwrap();
112
113 let target_dtype = DType::List(
114 Arc::new(DType::Primitive(
115 vortex_dtype::PType::U64,
116 Nullability::NonNullable,
117 )),
118 Nullability::NonNullable,
119 );
120
121 let result = cast(list.to_array().as_ref(), &target_dtype);
122 assert!(result.is_err());
123 }
124}