1use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11use vortex_mask::Mask;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::aggregate_fn;
17use crate::array::ArrayView;
18use crate::arrays::Primitive;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::primitive::PrimitiveArrayExt;
21use crate::dtype::DType;
22use crate::dtype::NativePType;
23use crate::dtype::Nullability;
24use crate::dtype::PType;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::match_each_native_ptype;
28use crate::scalar_fn::fns::cast::CastKernel;
29use crate::scalar_fn::fns::cast::CastReduce;
30use crate::validity::Validity;
31
32impl CastReduce for Primitive {
33 fn cast(array: ArrayView<'_, Primitive>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
34 let DType::Primitive(new_ptype, new_nullability) = dtype else {
37 return Ok(None);
38 };
39 if *new_ptype != array.ptype() {
40 return Ok(None);
41 }
42
43 let Some(new_validity) = array
44 .validity()?
45 .trivially_cast_nullability(*new_nullability, array.len())?
46 else {
47 return Ok(None);
48 };
49
50 Ok(Some(unsafe {
52 PrimitiveArray::new_unchecked_from_handle(
53 array.buffer_handle().clone(),
54 array.ptype(),
55 new_validity,
56 )
57 .into_array()
58 }))
59 }
60}
61
62impl CastKernel for Primitive {
63 fn cast(
64 array: ArrayView<'_, Primitive>,
65 dtype: &DType,
66 ctx: &mut ExecutionCtx,
67 ) -> VortexResult<Option<ArrayRef>> {
68 let DType::Primitive(new_ptype, new_nullability) = dtype else {
69 return Ok(None);
70 };
71 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
72 let src_ptype = array.ptype();
73
74 let new_validity = array
75 .validity()?
76 .cast_nullability(new_nullability, array.len(), ctx)?;
77
78 let same_rep = src_ptype == new_ptype
82 || (src_ptype.is_int()
83 && new_ptype.is_int()
84 && src_ptype.byte_width() == new_ptype.byte_width());
85 if same_rep {
86 if !values_fit_in(array, new_ptype, ctx, true) {
87 vortex_bail!(
88 Compute: "Cannot cast {} to {} — values exceed target range",
89 src_ptype, new_ptype,
90 );
91 }
92 return Ok(Some(reinterpret(array, new_ptype, new_validity)));
93 }
94
95 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
98 match_each_native_ptype!(src_ptype, |F| {
99 cast_values::<F, T>(array, new_validity, ctx)?
100 })
101 })))
102 }
103}
104
105fn cast_values<F, T>(
109 array: ArrayView<'_, Primitive>,
110 new_validity: Validity,
111 ctx: &mut ExecutionCtx,
112) -> VortexResult<ArrayRef>
113where
114 F: NativePType + AsPrimitive<T>,
115 T: NativePType,
116{
117 let values = array.as_slice::<F>();
118
119 if values_always_fit(F::PTYPE, T::PTYPE) || values_fit_in(array, T::PTYPE, ctx, false) {
123 return Ok(PrimitiveArray::new(cast::<F, T>(values), new_validity).into_array());
124 }
125
126 let mask = array.validity()?.execute_mask(array.len(), ctx)?;
132 let overflow = || {
133 vortex_err!(
134 Compute: "Cannot cast {} to {} — value exceeds target range",
135 F::PTYPE, T::PTYPE,
136 )
137 };
138 let buffer: Buffer<T> = match &mask {
139 Mask::AllTrue(_) => BufferMut::try_from_trusted_len_iter(
140 values
141 .iter()
142 .map(|&v| <T as NumCast>::from(v).ok_or_else(overflow)),
143 )?
144 .freeze(),
145 Mask::AllFalse(_) => BufferMut::<T>::zeroed(values.len()).freeze(),
146 Mask::Values(m) => BufferMut::try_from_trusted_len_iter(
147 values.iter().zip(m.bit_buffer().iter()).map(|(&v, valid)| {
148 let factor = if valid { F::one() } else { F::zero() };
149 <T as NumCast>::from(v * factor).ok_or_else(overflow)
150 }),
151 )?
152 .freeze(),
153 };
154
155 Ok(PrimitiveArray::new(buffer, new_validity).into_array())
156}
157
158fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
161 BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
162}
163
164fn reinterpret(
165 array: ArrayView<'_, Primitive>,
166 new_ptype: PType,
167 new_validity: Validity,
168) -> ArrayRef {
169 unsafe {
172 PrimitiveArray::new_unchecked_from_handle(
173 array.buffer_handle().clone(),
174 new_ptype,
175 new_validity,
176 )
177 }
178 .into_array()
179}
180
181fn values_always_fit(src: PType, target: PType) -> bool {
185 if src == target {
186 return true;
187 }
188 if src.is_int() && target.is_int() {
189 return target.byte_width() > src.byte_width()
190 && (src.is_unsigned_int() || target.is_signed_int());
191 }
192 if src.is_float() && target.is_float() {
193 return target.byte_width() > src.byte_width();
194 }
195 src.is_int() && matches!(target, PType::F32 | PType::F64)
196}
197
198fn values_fit_in(
204 array: ArrayView<'_, Primitive>,
205 target_ptype: PType,
206 ctx: &mut ExecutionCtx,
207 compute: bool,
208) -> bool {
209 let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
210 if let Some(fits) = cached_values_fit_in(array, &target_dtype) {
211 return fits;
212 }
213 if !compute {
214 return false;
215 }
216 aggregate_fn::fns::min_max::min_max(array.array(), ctx)
217 .ok()
218 .flatten()
219 .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
220}
221
222fn cached_values_fit_in(array: ArrayView<'_, Primitive>, target_dtype: &DType) -> Option<bool> {
225 let stats = array.array().statistics();
226 let min = stats.get(Stat::Min).as_exact()?;
227 let max = stats.get(Stat::Max).as_exact()?;
228 Some(min.cast(target_dtype).is_ok() && max.cast(target_dtype).is_ok())
229}
230
231#[cfg(test)]
232mod test {
233 use rstest::rstest;
234 use vortex_buffer::BitBuffer;
235 use vortex_buffer::buffer;
236 use vortex_error::VortexError;
237 use vortex_mask::Mask;
238
239 use crate::IntoArray;
240 use crate::LEGACY_SESSION;
241 use crate::VortexSessionExecute;
242 use crate::arrays::PrimitiveArray;
243 use crate::assert_arrays_eq;
244 use crate::builtins::ArrayBuiltins;
245 #[expect(deprecated)]
246 use crate::canonical::ToCanonical as _;
247 use crate::compute::conformance::cast::test_cast_conformance;
248 use crate::dtype::DType;
249 use crate::dtype::Nullability;
250 use crate::dtype::PType;
251 use crate::validity::Validity;
252
253 #[test]
254 fn cast_u32_u8() {
255 let arr = buffer![0u32, 10, 200].into_array();
256
257 #[expect(deprecated)]
259 let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
260 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
261 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
262
263 #[expect(deprecated)]
265 let p = p
266 .into_array()
267 .cast(DType::Primitive(PType::U8, Nullability::Nullable))
268 .unwrap()
269 .to_primitive();
270 assert_arrays_eq!(
271 p,
272 PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
273 );
274 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
275
276 #[expect(deprecated)]
278 let p = p
279 .into_array()
280 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
281 .unwrap()
282 .to_primitive();
283 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
284 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
285
286 #[expect(deprecated)]
288 let p = p
289 .into_array()
290 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
291 .unwrap()
292 .to_primitive();
293 assert_arrays_eq!(
294 p,
295 PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
296 );
297 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
298
299 #[expect(deprecated)]
301 let p = p
302 .into_array()
303 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
304 .unwrap()
305 .to_primitive();
306 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
307 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
308 }
309
310 #[test]
311 fn cast_u32_f32() {
312 let arr = buffer![0u32, 10, 200].into_array();
313 #[expect(deprecated)]
314 let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
315 assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
316 }
317
318 #[test]
319 fn cast_i32_u32() {
320 let arr = buffer![-1i32].into_array();
321 #[expect(deprecated)]
322 let error = arr
323 .cast(PType::U32.into())
324 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
325 .unwrap_err();
326 assert!(matches!(error, VortexError::Compute(..)));
327 assert!(error.to_string().contains("values exceed target range"));
328 }
329
330 #[test]
331 fn cast_array_with_nulls_to_nonnullable() {
332 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
333 #[expect(deprecated)]
334 let err = arr
335 .into_array()
336 .cast(PType::I32.into())
337 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
338 .unwrap_err();
339
340 assert!(matches!(err, VortexError::InvalidArgument(..)));
341 assert!(
342 err.to_string()
343 .contains("Cannot cast array with invalid values to non-nullable type.")
344 );
345 }
346
347 #[test]
348 fn cast_with_invalid_nulls() {
349 let arr = PrimitiveArray::new(
350 buffer![-1i32, 0, 10],
351 Validity::from_iter([false, true, true]),
352 );
353 #[expect(deprecated)]
354 let p = arr
355 .into_array()
356 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
357 .unwrap()
358 .to_primitive();
359 assert_arrays_eq!(
360 p,
361 PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
362 );
363 assert_eq!(
364 p.as_ref()
365 .validity()
366 .unwrap()
367 .execute_mask(p.as_ref().len(), &mut LEGACY_SESSION.create_execution_ctx())
368 .unwrap(),
369 Mask::from(BitBuffer::from(vec![false, true, true]))
370 );
371 }
372
373 #[test]
376 fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
377 let src = PrimitiveArray::from_iter([0u32, 10, 100]);
378 let src_ptr = src.as_slice::<u32>().as_ptr();
379
380 #[expect(deprecated)]
381 let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
382 let dst_ptr = dst.as_slice::<i32>().as_ptr();
383
384 assert_eq!(src_ptr as usize, dst_ptr as usize);
386 assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
387 Ok(())
388 }
389
390 #[test]
393 fn cast_same_width_int_out_of_range_errors() {
394 let arr = buffer![u32::MAX].into_array();
395 #[expect(deprecated)]
396 let err = arr
397 .cast(PType::I32.into())
398 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
399 .unwrap_err();
400 assert!(matches!(err, VortexError::Compute(..)));
401 }
402
403 #[test]
406 fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
407 let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
408 #[expect(deprecated)]
409 let casted = arr
410 .into_array()
411 .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
412 .to_primitive();
413 assert_eq!(casted.len(), 2);
414 assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
415 Ok(())
416 }
417
418 #[test]
421 fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
422 let arr = PrimitiveArray::new(
425 buffer![u32::MAX, 0u32, 42u32],
426 Validity::from_iter([false, true, true]),
427 );
428 #[expect(deprecated)]
429 let casted = arr
430 .into_array()
431 .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
432 .to_primitive();
433 assert_arrays_eq!(
434 casted,
435 PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
436 );
437 Ok(())
438 }
439
440 #[test]
441 fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
442 let arr = PrimitiveArray::new(
443 buffer![1000u32, 10u32, 42u32],
444 Validity::from_iter([false, true, true]),
445 );
446 #[expect(deprecated)]
447 let casted = arr
448 .into_array()
449 .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
450 .to_primitive();
451 assert_arrays_eq!(
452 casted,
453 PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
454 );
455 Ok(())
456 }
457
458 #[rstest]
459 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
460 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
461 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
462 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
463 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
464 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
465 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
466 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
467 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
468 #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
469 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
470 #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
471 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
472 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
473 #[case(buffer![42u32].into_array())]
474 fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
475 test_cast_conformance(&array);
476 }
477}