vortex_array/arrays/primitive/compute/
cast.rs1use vortex_buffer::{Buffer, BufferMut};
2use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
3use vortex_error::{VortexResult, vortex_bail, vortex_err};
4
5use crate::arrays::PrimitiveVTable;
6use crate::arrays::primitive::PrimitiveArray;
7use crate::compute::{CastKernel, CastKernelAdapter};
8use crate::validity::Validity;
9use crate::vtable::ValidityHelper;
10use crate::{ArrayRef, IntoArray, register_kernel};
11
12impl CastKernel for PrimitiveVTable {
13 fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<ArrayRef> {
14 let DType::Primitive(new_ptype, new_nullability) = dtype else {
15 vortex_bail!(MismatchedTypes: "primitive type", dtype);
16 };
17 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
18
19 let new_validity = if array.dtype().nullability() == new_nullability {
21 array.validity().clone()
22 } else if new_nullability == Nullability::Nullable {
23 array.validity().clone().into_nullable()
25 } else if new_nullability == Nullability::NonNullable && array.validity().all_valid()? {
26 Validity::NonNullable
28 } else {
29 vortex_bail!(
30 "invalid cast from nullable to non-nullable, since source array actually contains nulls"
31 );
32 };
33
34 if array.ptype() == new_ptype {
36 return Ok(PrimitiveArray::from_byte_buffer(
37 array.byte_buffer().clone(),
38 array.ptype(),
39 new_validity,
40 )
41 .into_array());
42 }
43
44 match_each_native_ptype!(new_ptype, |$T| {
46 Ok(PrimitiveArray::new(
47 cast::<$T>(array)?,
48 new_validity,
49 ).into_array())
50 })
51 }
52}
53
54register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
55
56fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> {
57 let mut buffer = BufferMut::with_capacity(array.len());
58 match_each_native_ptype!(array.ptype(), |$P| {
59 for item in array.as_slice::<$P>() {
60 let item = T::from(*item).ok_or_else(
61 || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
62 )?;
63 unsafe { buffer.push_unchecked(item) }
65 }
66 });
67 Ok(buffer.freeze())
68}
69
70#[cfg(test)]
71mod test {
72 use vortex_buffer::buffer;
73 use vortex_dtype::{DType, Nullability, PType};
74 use vortex_error::VortexError;
75
76 use crate::IntoArray;
77 use crate::arrays::PrimitiveArray;
78 use crate::canonical::ToCanonical;
79 use crate::compute::cast;
80 use crate::validity::Validity;
81 use crate::vtable::ValidityHelper;
82
83 #[test]
84 fn cast_u32_u8() {
85 let arr = buffer![0u32, 10, 200].into_array();
86
87 let p = cast(&arr, PType::U8.into())
89 .unwrap()
90 .to_primitive()
91 .unwrap();
92 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
93 assert_eq!(p.validity(), &Validity::NonNullable);
94
95 let p = cast(
97 p.as_ref(),
98 &DType::Primitive(PType::U8, Nullability::Nullable),
99 )
100 .unwrap()
101 .to_primitive()
102 .unwrap();
103 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
104 assert_eq!(p.validity(), &Validity::AllValid);
105
106 let p = cast(
108 p.as_ref(),
109 &DType::Primitive(PType::U8, Nullability::NonNullable),
110 )
111 .unwrap()
112 .to_primitive()
113 .unwrap();
114 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
115 assert_eq!(p.validity(), &Validity::NonNullable);
116
117 let p = cast(
119 p.as_ref(),
120 &DType::Primitive(PType::U32, Nullability::Nullable),
121 )
122 .unwrap()
123 .to_primitive()
124 .unwrap();
125 assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
126 assert_eq!(p.validity(), &Validity::AllValid);
127
128 let p = cast(
130 p.as_ref(),
131 &DType::Primitive(PType::U8, Nullability::NonNullable),
132 )
133 .unwrap()
134 .to_primitive()
135 .unwrap();
136 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
137 assert_eq!(p.validity(), &Validity::NonNullable);
138 }
139
140 #[test]
141 fn cast_u32_f32() {
142 let arr = buffer![0u32, 10, 200].into_array();
143 let u8arr = cast(&arr, PType::F32.into())
144 .unwrap()
145 .to_primitive()
146 .unwrap();
147 assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
148 }
149
150 #[test]
151 fn cast_i32_u32() {
152 let arr = buffer![-1i32].into_array();
153 let error = cast(&arr, PType::U32.into()).err().unwrap();
154 let VortexError::ComputeError(s, _) = error else {
155 unreachable!()
156 };
157 assert_eq!(s.to_string(), "Failed to cast -1 to U32");
158 }
159
160 #[test]
161 fn cast_array_with_nulls_to_nonnullable() {
162 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
163 let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
164 let VortexError::InvalidArgument(s, _) = err else {
165 unreachable!()
166 };
167 assert_eq!(
168 s.to_string(),
169 "invalid cast from nullable to non-nullable, since source array actually contains nulls"
170 );
171 }
172}