1use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_compute::lane_kernels::IndexedSinkExt;
9use vortex_compute::lane_kernels::IndexedSourceExt;
10use vortex_compute::lane_kernels::ReinterpretSink;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_mask::Mask;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::aggregate_fn;
20use crate::array::ArrayView;
21use crate::arrays::Primitive;
22use crate::arrays::PrimitiveArray;
23use crate::arrays::primitive::PrimitiveArrayExt;
24use crate::dtype::DType;
25use crate::dtype::NativePType;
26use crate::dtype::Nullability;
27use crate::dtype::PType;
28use crate::expr::stats::Stat;
29use crate::expr::stats::StatsProvider;
30use crate::match_each_native_ptype;
31use crate::scalar_fn::fns::cast::CastKernel;
32use crate::scalar_fn::fns::cast::CastReduce;
33use crate::validity::Validity;
34
35impl CastReduce for Primitive {
36 fn cast(array: ArrayView<'_, Primitive>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
37 let DType::Primitive(new_ptype, new_nullability) = dtype else {
40 return Ok(None);
41 };
42 if *new_ptype != array.ptype() {
43 return Ok(None);
44 }
45
46 let Some(new_validity) = array
47 .validity()?
48 .trivially_cast_nullability(*new_nullability, array.len())?
49 else {
50 return Ok(None);
51 };
52
53 Ok(Some(unsafe {
55 PrimitiveArray::new_unchecked_from_handle(
56 array.buffer_handle().clone(),
57 array.ptype(),
58 new_validity,
59 )
60 .into_array()
61 }))
62 }
63}
64
65impl CastKernel for Primitive {
66 fn cast(
67 array: ArrayView<'_, Primitive>,
68 dtype: &DType,
69 ctx: &mut ExecutionCtx,
70 ) -> VortexResult<Option<ArrayRef>> {
71 let DType::Primitive(new_ptype, new_nullability) = dtype else {
72 return Ok(None);
73 };
74 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
75 let src_ptype = array.ptype();
76
77 let new_validity = array
78 .validity()?
79 .cast_nullability(new_nullability, array.len(), ctx)?;
80
81 let same_rep = src_ptype == new_ptype
85 || (src_ptype.is_int()
86 && new_ptype.is_int()
87 && src_ptype.byte_width() == new_ptype.byte_width());
88 if same_rep {
89 if !values_fit_in(array, new_ptype, ctx, true) {
90 vortex_bail!(
91 Compute: "Cannot cast {} to {} — values exceed target range",
92 src_ptype, new_ptype,
93 );
94 }
95 return Ok(Some(reinterpret(array, new_ptype, new_validity)));
96 }
97
98 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
101 match_each_native_ptype!(src_ptype, |F| {
102 cast_values::<F, T>(array, new_validity, ctx)?
103 })
104 })))
105 }
106}
107
108fn cast_values<F, T>(
110 array: ArrayView<'_, Primitive>,
111 new_validity: Validity,
112 ctx: &mut ExecutionCtx,
113) -> VortexResult<ArrayRef>
114where
115 F: NativePType + AsPrimitive<T>,
116 T: NativePType,
117{
118 let overflow = || {
119 vortex_err!(
120 Compute: "Cannot cast {} to {} — value exceeds target range",
121 F::PTYPE, T::PTYPE,
122 )
123 };
124
125 fn casts_losslessly_to(from: PType, to: PType) -> bool {
127 from.least_supertype(to) == Some(to)
128 }
129
130 let target_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable);
132 let infallible = casts_losslessly_to(F::PTYPE, T::PTYPE)
133 || cached_values_fit_in(array, &target_dtype).unwrap_or(false);
134
135 let len = array.len();
136
137 let same_bit_width = F::PTYPE.byte_width() == T::PTYPE.byte_width();
139 let owned: Option<BufferMut<F>> = same_bit_width
140 .then(|| array.into_owned().try_into_buffer_mut::<F>().ok())
141 .flatten();
142 let values: &[F] = array.as_slice::<F>();
143
144 if infallible {
145 return match owned {
146 Some(mut buf) => {
147 ReinterpretSink::<F, T>::new(buf.as_mut_slice()).map_into_in_place(|v: F| v.as_());
148 let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
150 Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array())
151 }
152 None => {
153 let mut buffer = BufferMut::<T>::with_capacity(len);
154 values.map_into(&mut buffer.spare_capacity_mut()[..len], |v| v.as_());
155 unsafe { buffer.set_len(len) };
157 Ok(PrimitiveArray::new(buffer.freeze(), new_validity).into_array())
158 }
159 };
160 }
161
162 let mask = array.validity()?.execute_mask(len, ctx)?;
163
164 let buffer: Buffer<T> = match (&mask, owned) {
165 (Mask::AllTrue(_), Some(mut buf)) => {
166 ReinterpretSink::<F, T>::new(buf.as_mut_slice())
167 .try_map_in_place(|v: F| <T as NumCast>::from(v))
168 .map_err(|_| overflow())?;
169 let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
171 result.freeze()
172 }
173 (Mask::AllTrue(_), None) => {
174 let mut buffer = BufferMut::<T>::with_capacity(len);
175 values
176 .try_map_into(&mut buffer.spare_capacity_mut()[..len], |v| {
177 <T as NumCast>::from(v)
178 })
179 .map_err(|_| overflow())?;
180 unsafe { buffer.set_len(len) };
182 buffer.freeze()
183 }
184 (Mask::AllFalse(_), _) => BufferMut::<T>::zeroed(len).freeze(),
185 (Mask::Values(m), Some(mut buf)) => {
186 ReinterpretSink::<F, T>::new(buf.as_mut_slice())
187 .try_map_masked_in_place(m.bit_buffer(), |v: F| <T as NumCast>::from(v))
188 .map_err(|_| overflow())?;
189 let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
191 result.freeze()
192 }
193 (Mask::Values(m), None) => {
194 let mut buffer = BufferMut::<T>::with_capacity(len);
195 values
196 .try_map_masked_into(
197 m.bit_buffer(),
198 &mut buffer.spare_capacity_mut()[..len],
199 |v| <T as NumCast>::from(v),
200 )
201 .map_err(|_| overflow())?;
202 unsafe { buffer.set_len(len) };
204 buffer.freeze()
205 }
206 };
207
208 Ok(PrimitiveArray::new(buffer, new_validity).into_array())
209}
210
211fn reinterpret(
212 array: ArrayView<'_, Primitive>,
213 new_ptype: PType,
214 new_validity: Validity,
215) -> ArrayRef {
216 unsafe {
219 PrimitiveArray::new_unchecked_from_handle(
220 array.buffer_handle().clone(),
221 new_ptype,
222 new_validity,
223 )
224 }
225 .into_array()
226}
227
228fn values_fit_in(
234 array: ArrayView<'_, Primitive>,
235 target_ptype: PType,
236 ctx: &mut ExecutionCtx,
237 compute: bool,
238) -> bool {
239 let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
240 if let Some(fits) = cached_values_fit_in(array, &target_dtype) {
241 return fits;
242 }
243 if !compute {
244 return false;
245 }
246 aggregate_fn::fns::min_max::min_max(
247 array.array(),
248 ctx,
249 aggregate_fn::NumericalAggregateOpts::default(),
250 )
251 .ok()
252 .flatten()
253 .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
254}
255
256fn cached_values_fit_in(array: ArrayView<'_, Primitive>, target_dtype: &DType) -> Option<bool> {
259 let stats = array.array().statistics();
260 let min = stats.get(Stat::Min).as_exact()?;
261 let max = stats.get(Stat::Max).as_exact()?;
262 Some(min.cast(target_dtype).is_ok() && max.cast(target_dtype).is_ok())
263}
264
265#[cfg(test)]
266mod test {
267 use rstest::rstest;
268 use vortex_buffer::BitBuffer;
269 use vortex_buffer::buffer;
270 use vortex_error::VortexError;
271 use vortex_mask::Mask;
272
273 use crate::ArrayRef;
274 use crate::IntoArray;
275 use crate::VortexSessionExecute;
276 use crate::array_session;
277 use crate::arrays::PrimitiveArray;
278 use crate::assert_arrays_eq;
279 use crate::builtins::ArrayBuiltins;
280 #[expect(deprecated)]
281 use crate::canonical::ToCanonical as _;
282 use crate::compute::conformance::cast::test_cast_conformance;
283 use crate::dtype::DType;
284 use crate::dtype::Nullability;
285 use crate::dtype::PType;
286 use crate::validity::Validity;
287
288 #[test]
289 fn cast_u32_u8() {
290 let mut ctx = array_session().create_execution_ctx();
291 let arr = buffer![0u32, 10, 200].into_array();
292
293 #[expect(deprecated)]
295 let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
296 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]), &mut ctx);
297 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
298
299 #[expect(deprecated)]
301 let p = p
302 .into_array()
303 .cast(DType::Primitive(PType::U8, Nullability::Nullable))
304 .unwrap()
305 .to_primitive();
306 assert_arrays_eq!(
307 p,
308 PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid),
309 &mut ctx
310 );
311 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
312
313 #[expect(deprecated)]
315 let p = p
316 .into_array()
317 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
318 .unwrap()
319 .to_primitive();
320 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]), &mut ctx);
321 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
322
323 #[expect(deprecated)]
325 let p = p
326 .into_array()
327 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
328 .unwrap()
329 .to_primitive();
330 assert_arrays_eq!(
331 p,
332 PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid),
333 &mut ctx
334 );
335 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
336
337 #[expect(deprecated)]
339 let p = p
340 .into_array()
341 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
342 .unwrap()
343 .to_primitive();
344 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]), &mut ctx);
345 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
346 }
347
348 #[test]
349 fn cast_u32_f32() {
350 let mut ctx = array_session().create_execution_ctx();
351 let arr = buffer![0u32, 10, 200].into_array();
352 #[expect(deprecated)]
353 let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
354 assert_arrays_eq!(
355 u8arr,
356 PrimitiveArray::from_iter([0.0f32, 10., 200.]),
357 &mut ctx
358 );
359 }
360
361 #[test]
362 fn cast_i32_u32() {
363 let arr = buffer![-1i32].into_array();
364 #[expect(deprecated)]
365 let error = arr
366 .cast(PType::U32.into())
367 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
368 .unwrap_err();
369 assert!(matches!(error, VortexError::Compute(..)));
370 assert!(error.to_string().contains("values exceed target range"));
371 }
372
373 #[test]
374 fn cast_array_with_nulls_to_nonnullable() {
375 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
376 #[expect(deprecated)]
377 let err = arr
378 .into_array()
379 .cast(PType::I32.into())
380 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
381 .unwrap_err();
382
383 assert!(matches!(err, VortexError::InvalidArgument(..)));
384 assert!(
385 err.to_string()
386 .contains("Cannot cast array with invalid values to non-nullable type.")
387 );
388 }
389
390 #[test]
391 fn cast_with_invalid_nulls() {
392 let mut ctx = array_session().create_execution_ctx();
393 let arr = PrimitiveArray::new(
394 buffer![-1i32, 0, 10],
395 Validity::from_iter([false, true, true]),
396 );
397 #[expect(deprecated)]
398 let p = arr
399 .into_array()
400 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
401 .unwrap()
402 .to_primitive();
403 assert_arrays_eq!(
404 p,
405 PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)]),
406 &mut ctx
407 );
408 assert_eq!(
409 p.as_ref()
410 .validity()
411 .unwrap()
412 .execute_mask(
413 p.as_ref().len(),
414 &mut array_session().create_execution_ctx()
415 )
416 .unwrap(),
417 Mask::from(BitBuffer::from(vec![false, true, true]))
418 );
419 }
420
421 #[test]
424 fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
425 let mut ctx = array_session().create_execution_ctx();
426 let src = PrimitiveArray::from_iter([0u32, 10, 100]);
427 let src_ptr = src.as_slice::<u32>().as_ptr();
428
429 #[expect(deprecated)]
430 let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
431 let dst_ptr = dst.as_slice::<i32>().as_ptr();
432
433 assert_eq!(src_ptr as usize, dst_ptr as usize);
435 assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]), &mut ctx);
436 Ok(())
437 }
438
439 #[test]
442 fn cast_same_width_int_out_of_range_errors() {
443 let arr = buffer![u32::MAX].into_array();
444 #[expect(deprecated)]
445 let err = arr
446 .cast(PType::I32.into())
447 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
448 .unwrap_err();
449 assert!(matches!(err, VortexError::Compute(..)));
450 }
451
452 #[test]
455 fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
456 let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
457 #[expect(deprecated)]
458 let casted = arr
459 .into_array()
460 .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
461 .to_primitive();
462 assert_eq!(casted.len(), 2);
463 assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
464 Ok(())
465 }
466
467 #[test]
470 fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
471 let mut ctx = array_session().create_execution_ctx();
472 let arr = PrimitiveArray::new(
475 buffer![u32::MAX, 0u32, 42u32],
476 Validity::from_iter([false, true, true]),
477 );
478 #[expect(deprecated)]
479 let casted = arr
480 .into_array()
481 .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
482 .to_primitive();
483 assert_arrays_eq!(
484 casted,
485 PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)]),
486 &mut ctx
487 );
488 Ok(())
489 }
490
491 #[test]
492 fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
493 let mut ctx = array_session().create_execution_ctx();
494 let arr = PrimitiveArray::new(
495 buffer![1000u32, 10u32, 42u32],
496 Validity::from_iter([false, true, true]),
497 );
498 #[expect(deprecated)]
499 let casted = arr
500 .into_array()
501 .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
502 .to_primitive();
503 assert_arrays_eq!(
504 casted,
505 PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)]),
506 &mut ctx
507 );
508 Ok(())
509 }
510
511 #[rstest]
512 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
513 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
514 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
515 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
516 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
517 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
518 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
519 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
520 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
521 #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
522 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
523 #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
524 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
525 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
526 #[case(buffer![42u32].into_array())]
527 fn test_cast_primitive_conformance(#[case] array: ArrayRef) {
528 test_cast_conformance(&array);
529 }
530}