vortex_array/arrays/primitive/compute/
cast.rs1use vortex_buffer::{Buffer, BufferMut};
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexResult, vortex_err};
7use vortex_mask::{AllOr, Mask};
8
9use crate::arrays::PrimitiveVTable;
10use crate::arrays::primitive::PrimitiveArray;
11use crate::compute::{CastKernel, CastKernelAdapter};
12use crate::vtable::ValidityHelper;
13use crate::{ArrayRef, IntoArray, register_kernel};
14
15impl CastKernel for PrimitiveVTable {
16 fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 let DType::Primitive(new_ptype, new_nullability) = dtype else {
18 return Ok(None);
19 };
20 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
21
22 let new_validity = array
24 .validity()
25 .clone()
26 .cast_nullability(new_nullability, array.len())?;
27
28 if array.ptype() == new_ptype {
30 return Ok(Some(
31 PrimitiveArray::from_byte_buffer(
32 array.byte_buffer().clone(),
33 array.ptype(),
34 new_validity,
35 )
36 .into_array(),
37 ));
38 }
39
40 let mask = array.validity_mask();
41
42 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
44 match_each_native_ptype!(array.ptype(), |F| {
45 PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
46 .into_array()
47 })
48 })))
49 }
50}
51
52register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
53
54fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
55 match mask.boolean_buffer() {
56 AllOr::All => {
57 let mut buffer = BufferMut::with_capacity(array.len());
58 for item in array {
59 let item = T::from(*item).ok_or_else(
60 || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
61 )?;
62 unsafe { buffer.push_unchecked(item) }
64 }
65 Ok(buffer.freeze())
66 }
67 AllOr::None => Ok(Buffer::zeroed(array.len())),
68 AllOr::Some(b) => {
69 let mut buffer = BufferMut::with_capacity(array.len());
71 for (item, valid) in array.iter().zip(b.iter()) {
72 if valid {
73 let item = T::from(*item).ok_or_else(
74 || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
75 )?;
76 unsafe { buffer.push_unchecked(item) }
78 } else {
79 unsafe { buffer.push_unchecked(T::default()) }
81 }
82 }
83 Ok(buffer.freeze())
84 }
85 }
86}
87
88#[cfg(test)]
89mod test {
90 use arrow_buffer::BooleanBuffer;
91 use rstest::rstest;
92 use vortex_buffer::buffer;
93 use vortex_dtype::{DType, Nullability, PType};
94 use vortex_error::VortexError;
95 use vortex_mask::Mask;
96
97 use crate::IntoArray;
98 use crate::arrays::PrimitiveArray;
99 use crate::canonical::ToCanonical;
100 use crate::compute::cast;
101 use crate::compute::conformance::cast::test_cast_conformance;
102 use crate::validity::Validity;
103 use crate::vtable::ValidityHelper;
104
105 #[test]
106 fn cast_u32_u8() {
107 let arr = buffer![0u32, 10, 200].into_array();
108
109 let p = cast(&arr, PType::U8.into()).unwrap().to_primitive();
111 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
112 assert_eq!(p.validity(), &Validity::NonNullable);
113
114 let p = cast(
116 p.as_ref(),
117 &DType::Primitive(PType::U8, Nullability::Nullable),
118 )
119 .unwrap()
120 .to_primitive();
121 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
122 assert_eq!(p.validity(), &Validity::AllValid);
123
124 let p = cast(
126 p.as_ref(),
127 &DType::Primitive(PType::U8, Nullability::NonNullable),
128 )
129 .unwrap()
130 .to_primitive();
131 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
132 assert_eq!(p.validity(), &Validity::NonNullable);
133
134 let p = cast(
136 p.as_ref(),
137 &DType::Primitive(PType::U32, Nullability::Nullable),
138 )
139 .unwrap()
140 .to_primitive();
141 assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
142 assert_eq!(p.validity(), &Validity::AllValid);
143
144 let p = cast(
146 p.as_ref(),
147 &DType::Primitive(PType::U8, Nullability::NonNullable),
148 )
149 .unwrap()
150 .to_primitive();
151 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
152 assert_eq!(p.validity(), &Validity::NonNullable);
153 }
154
155 #[test]
156 fn cast_u32_f32() {
157 let arr = buffer![0u32, 10, 200].into_array();
158 let u8arr = cast(&arr, PType::F32.into()).unwrap().to_primitive();
159 assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
160 }
161
162 #[test]
163 fn cast_i32_u32() {
164 let arr = buffer![-1i32].into_array();
165 let error = cast(&arr, PType::U32.into()).err().unwrap();
166 let VortexError::ComputeError(s, _) = error else {
167 unreachable!()
168 };
169 assert_eq!(s.to_string(), "Failed to cast -1 to U32");
170 }
171
172 #[test]
173 fn cast_array_with_nulls_to_nonnullable() {
174 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
175 let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
176 let VortexError::InvalidArgument(s, _) = err else {
177 unreachable!()
178 };
179 assert_eq!(
180 s.to_string(),
181 "Cannot cast array with invalid values to non-nullable type."
182 );
183 }
184
185 #[test]
186 fn cast_with_invalid_nulls() {
187 let arr = PrimitiveArray::new(
188 buffer![-1i32, 0, 10],
189 Validity::from_iter([false, true, true]),
190 );
191 let p = cast(
192 arr.as_ref(),
193 &DType::Primitive(PType::U32, Nullability::Nullable),
194 )
195 .unwrap()
196 .to_primitive();
197 assert_eq!(p.as_slice::<u32>(), vec![0, 0, 10]);
198 assert_eq!(
199 p.validity_mask(),
200 Mask::from(BooleanBuffer::from(vec![false, true, true]))
201 );
202 }
203
204 #[rstest]
205 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
206 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
207 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
208 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
209 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
210 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
211 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
212 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
213 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
214 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
215 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
216 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
217 #[case(buffer![42u32].into_array())]
218 fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
219 test_cast_conformance(array.as_ref());
220 }
221}