vortex_array/arrays/primitive/compute/
cast.rs1use vortex_buffer::{Buffer, BufferMut};
5use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
6use vortex_error::{VortexResult, vortex_bail, vortex_err};
7
8use crate::arrays::PrimitiveVTable;
9use crate::arrays::primitive::PrimitiveArray;
10use crate::compute::{CastKernel, CastKernelAdapter};
11use crate::validity::Validity;
12use crate::vtable::ValidityHelper;
13use crate::{ArrayRef, IntoArray, register_kernel};
14
15impl CastKernel for PrimitiveVTable {
16 fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 let DType::Primitive(new_ptype, new_nullability) = dtype else {
18 return Ok(None);
19 };
20 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
21
22 let new_validity = if array.dtype().nullability() == new_nullability {
24 array.validity().clone()
25 } else if new_nullability == Nullability::Nullable {
26 array.validity().clone().into_nullable()
28 } else if new_nullability == Nullability::NonNullable && array.validity().all_valid()? {
29 Validity::NonNullable
31 } else {
32 vortex_bail!(
33 "invalid cast from nullable to non-nullable, since source array actually contains nulls"
34 );
35 };
36
37 if array.ptype() == new_ptype {
39 return Ok(Some(
40 PrimitiveArray::from_byte_buffer(
41 array.byte_buffer().clone(),
42 array.ptype(),
43 new_validity,
44 )
45 .into_array(),
46 ));
47 }
48
49 match_each_native_ptype!(new_ptype, |T| {
51 Ok(Some(
52 PrimitiveArray::new(cast::<T>(array)?, new_validity).into_array(),
53 ))
54 })
55 }
56}
57
58register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
59
60fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> {
61 let mut buffer = BufferMut::with_capacity(array.len());
62 match_each_native_ptype!(array.ptype(), |P| {
63 for item in array.as_slice::<P>() {
64 let item = T::from(*item).ok_or_else(
65 || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
66 )?;
67 unsafe { buffer.push_unchecked(item) }
69 }
70 });
71 Ok(buffer.freeze())
72}
73
74#[cfg(test)]
75mod test {
76 use vortex_buffer::buffer;
77 use vortex_dtype::{DType, Nullability, PType};
78 use vortex_error::VortexError;
79
80 use crate::IntoArray;
81 use crate::arrays::PrimitiveArray;
82 use crate::canonical::ToCanonical;
83 use crate::compute::cast;
84 use crate::validity::Validity;
85 use crate::vtable::ValidityHelper;
86
87 #[test]
88 fn cast_u32_u8() {
89 let arr = buffer![0u32, 10, 200].into_array();
90
91 let p = cast(&arr, PType::U8.into())
93 .unwrap()
94 .to_primitive()
95 .unwrap();
96 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
97 assert_eq!(p.validity(), &Validity::NonNullable);
98
99 let p = cast(
101 p.as_ref(),
102 &DType::Primitive(PType::U8, Nullability::Nullable),
103 )
104 .unwrap()
105 .to_primitive()
106 .unwrap();
107 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
108 assert_eq!(p.validity(), &Validity::AllValid);
109
110 let p = cast(
112 p.as_ref(),
113 &DType::Primitive(PType::U8, Nullability::NonNullable),
114 )
115 .unwrap()
116 .to_primitive()
117 .unwrap();
118 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
119 assert_eq!(p.validity(), &Validity::NonNullable);
120
121 let p = cast(
123 p.as_ref(),
124 &DType::Primitive(PType::U32, Nullability::Nullable),
125 )
126 .unwrap()
127 .to_primitive()
128 .unwrap();
129 assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
130 assert_eq!(p.validity(), &Validity::AllValid);
131
132 let p = cast(
134 p.as_ref(),
135 &DType::Primitive(PType::U8, Nullability::NonNullable),
136 )
137 .unwrap()
138 .to_primitive()
139 .unwrap();
140 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
141 assert_eq!(p.validity(), &Validity::NonNullable);
142 }
143
144 #[test]
145 fn cast_u32_f32() {
146 let arr = buffer![0u32, 10, 200].into_array();
147 let u8arr = cast(&arr, PType::F32.into())
148 .unwrap()
149 .to_primitive()
150 .unwrap();
151 assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
152 }
153
154 #[test]
155 fn cast_i32_u32() {
156 let arr = buffer![-1i32].into_array();
157 let error = cast(&arr, PType::U32.into()).err().unwrap();
158 let VortexError::ComputeError(s, _) = error else {
159 unreachable!()
160 };
161 assert_eq!(s.to_string(), "Failed to cast -1 to U32");
162 }
163
164 #[test]
165 fn cast_array_with_nulls_to_nonnullable() {
166 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
167 let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
168 let VortexError::InvalidArgument(s, _) = err else {
169 unreachable!()
170 };
171 assert_eq!(
172 s.to_string(),
173 "invalid cast from nullable to non-nullable, since source array actually contains nulls"
174 );
175 }
176}