vortex_array/arrays/primitive/compute/
cast.rs1use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6use vortex_error::VortexResult;
7use vortex_error::vortex_bail;
8use vortex_error::vortex_err;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::aggregate_fn;
16use crate::array::ArrayView;
17use crate::arrays::Primitive;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::primitive::PrimitiveArrayExt;
20use crate::dtype::DType;
21use crate::dtype::NativePType;
22use crate::dtype::Nullability;
23use crate::dtype::PType;
24use crate::match_each_native_ptype;
25use crate::scalar_fn::fns::cast::CastKernel;
26
27impl CastKernel for Primitive {
28 fn cast(
29 array: ArrayView<'_, Primitive>,
30 dtype: &DType,
31 ctx: &mut ExecutionCtx,
32 ) -> VortexResult<Option<ArrayRef>> {
33 let DType::Primitive(new_ptype, new_nullability) = dtype else {
34 return Ok(None);
35 };
36 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
37
38 let new_validity = array
40 .validity()?
41 .cast_nullability(new_nullability, array.len())?;
42
43 if array.ptype() == new_ptype {
45 return Ok(Some(unsafe {
47 PrimitiveArray::new_unchecked_from_handle(
48 array.buffer_handle().clone(),
49 array.ptype(),
50 new_validity,
51 )
52 .into_array()
53 }));
54 }
55
56 if array.ptype().is_int()
60 && new_ptype.is_int()
61 && array.ptype().byte_width() == new_ptype.byte_width()
62 {
63 if !values_fit_in(array, new_ptype, ctx) {
64 vortex_bail!(
65 Compute: "Cannot cast {} to {} — values exceed target range",
66 array.ptype(),
67 new_ptype,
68 );
69 }
70 return Ok(Some(unsafe {
73 PrimitiveArray::new_unchecked_from_handle(
74 array.buffer_handle().clone(),
75 new_ptype,
76 new_validity,
77 )
78 .into_array()
79 }));
80 }
81
82 let mask = array.validity_mask();
83
84 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
86 match_each_native_ptype!(array.ptype(), |F| {
87 PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
88 .into_array()
89 })
90 })))
91 }
92}
93
94fn values_fit_in(
96 array: ArrayView<'_, Primitive>,
97 target_ptype: PType,
98 ctx: &mut ExecutionCtx,
99) -> bool {
100 let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
101 aggregate_fn::fns::min_max::min_max(array.array(), ctx)
102 .ok()
103 .flatten()
104 .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
105}
106
107fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
108 let try_cast = |src: F| -> VortexResult<T> {
109 T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE))
110 };
111 match mask.bit_buffer() {
112 AllOr::None => Ok(Buffer::zeroed(array.len())),
113 AllOr::All => {
114 let mut buffer = BufferMut::with_capacity(array.len());
115 for &src in array {
116 unsafe { buffer.push_unchecked(try_cast(src)?) }
118 }
119 Ok(buffer.freeze())
120 }
121 AllOr::Some(b) => {
122 let mut buffer = BufferMut::with_capacity(array.len());
123 for (&src, valid) in array.iter().zip(b.iter()) {
124 let dst = if valid { try_cast(src)? } else { T::default() };
125 unsafe { buffer.push_unchecked(dst) }
127 }
128 Ok(buffer.freeze())
129 }
130 }
131}
132
133#[cfg(test)]
134mod test {
135 use rstest::rstest;
136 use vortex_buffer::BitBuffer;
137 use vortex_buffer::buffer;
138 use vortex_error::VortexError;
139 use vortex_mask::Mask;
140
141 use crate::IntoArray;
142 use crate::arrays::PrimitiveArray;
143 use crate::assert_arrays_eq;
144 use crate::builtins::ArrayBuiltins;
145 use crate::canonical::ToCanonical;
146 use crate::compute::conformance::cast::test_cast_conformance;
147 use crate::dtype::DType;
148 use crate::dtype::Nullability;
149 use crate::dtype::PType;
150 use crate::validity::Validity;
151
152 #[allow(clippy::cognitive_complexity)]
153 #[test]
154 fn cast_u32_u8() {
155 let arr = buffer![0u32, 10, 200].into_array();
156
157 let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
159 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
160 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
161
162 let p = p
164 .into_array()
165 .cast(DType::Primitive(PType::U8, Nullability::Nullable))
166 .unwrap()
167 .to_primitive();
168 assert_arrays_eq!(
169 p,
170 PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
171 );
172 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
173
174 let p = p
176 .into_array()
177 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
178 .unwrap()
179 .to_primitive();
180 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
181 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
182
183 let p = p
185 .into_array()
186 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
187 .unwrap()
188 .to_primitive();
189 assert_arrays_eq!(
190 p,
191 PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
192 );
193 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
194
195 let p = p
197 .into_array()
198 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
199 .unwrap()
200 .to_primitive();
201 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
202 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
203 }
204
205 #[test]
206 fn cast_u32_f32() {
207 let arr = buffer![0u32, 10, 200].into_array();
208 let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
209 assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
210 }
211
212 #[test]
213 fn cast_i32_u32() {
214 let arr = buffer![-1i32].into_array();
215 let error = arr
216 .cast(PType::U32.into())
217 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
218 .unwrap_err();
219 assert!(matches!(error, VortexError::Compute(..)));
220 assert!(error.to_string().contains("values exceed target range"));
221 }
222
223 #[test]
224 fn cast_array_with_nulls_to_nonnullable() {
225 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
226 let err = arr
227 .into_array()
228 .cast(PType::I32.into())
229 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
230 .unwrap_err();
231
232 assert!(matches!(err, VortexError::InvalidArgument(..)));
233 assert!(
234 err.to_string()
235 .contains("Cannot cast array with invalid values to non-nullable type.")
236 );
237 }
238
239 #[test]
240 fn cast_with_invalid_nulls() {
241 let arr = PrimitiveArray::new(
242 buffer![-1i32, 0, 10],
243 Validity::from_iter([false, true, true]),
244 );
245 let p = arr
246 .into_array()
247 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
248 .unwrap()
249 .to_primitive();
250 assert_arrays_eq!(
251 p,
252 PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
253 );
254 assert_eq!(
255 p.validity_mask().unwrap(),
256 Mask::from(BitBuffer::from(vec![false, true, true]))
257 );
258 }
259
260 #[test]
263 fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
264 let src = PrimitiveArray::from_iter([0u32, 10, 100]);
265 let src_ptr = src.as_slice::<u32>().as_ptr();
266
267 let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
268 let dst_ptr = dst.as_slice::<i32>().as_ptr();
269
270 assert_eq!(src_ptr as usize, dst_ptr as usize);
272 assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
273 Ok(())
274 }
275
276 #[test]
279 fn cast_same_width_int_out_of_range_errors() {
280 let arr = buffer![u32::MAX].into_array();
281 let err = arr
282 .cast(PType::I32.into())
283 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
284 .unwrap_err();
285 assert!(matches!(err, VortexError::Compute(..)));
286 }
287
288 #[test]
291 fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
292 let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
293 let casted = arr
294 .into_array()
295 .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
296 .to_primitive();
297 assert_eq!(casted.len(), 2);
298 assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
299 Ok(())
300 }
301
302 #[test]
305 fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
306 let arr = PrimitiveArray::new(
309 buffer![u32::MAX, 0u32, 42u32],
310 Validity::from_iter([false, true, true]),
311 );
312 let casted = arr
313 .into_array()
314 .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
315 .to_primitive();
316 assert_arrays_eq!(
317 casted,
318 PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
319 );
320 Ok(())
321 }
322
323 #[rstest]
324 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
325 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
326 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
327 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
328 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
329 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
330 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
331 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
332 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
333 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
334 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
335 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
336 #[case(buffer![42u32].into_array())]
337 fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
338 test_cast_conformance(&array);
339 }
340}