1use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::aggregate_fn;
14use crate::array::ArrayView;
15use crate::arrays::Primitive;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::primitive::PrimitiveArrayExt;
18use crate::dtype::DType;
19use crate::dtype::NativePType;
20use crate::dtype::Nullability;
21use crate::dtype::PType;
22use crate::match_each_native_ptype;
23use crate::scalar_fn::fns::cast::CastKernel;
24use crate::scalar_fn::fns::cast::CastReduce;
25
26impl CastReduce for Primitive {
27 fn cast(array: ArrayView<'_, Primitive>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
28 let DType::Primitive(new_ptype, new_nullability) = dtype else {
31 return Ok(None);
32 };
33 if *new_ptype != array.ptype() {
34 return Ok(None);
35 }
36
37 let Some(new_validity) = array
38 .validity()?
39 .trivial_cast_nullability(*new_nullability, array.len())?
40 else {
41 return Ok(None);
42 };
43
44 Ok(Some(unsafe {
46 PrimitiveArray::new_unchecked_from_handle(
47 array.buffer_handle().clone(),
48 array.ptype(),
49 new_validity,
50 )
51 .into_array()
52 }))
53 }
54}
55
56impl CastKernel for Primitive {
57 fn cast(
58 array: ArrayView<'_, Primitive>,
59 dtype: &DType,
60 ctx: &mut ExecutionCtx,
61 ) -> VortexResult<Option<ArrayRef>> {
62 let DType::Primitive(new_ptype, new_nullability) = dtype else {
63 return Ok(None);
64 };
65 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
66
67 let new_validity = array
69 .validity()?
70 .cast_nullability(new_nullability, array.len(), ctx)?;
71
72 if array.ptype() == new_ptype {
74 return Ok(Some(unsafe {
76 PrimitiveArray::new_unchecked_from_handle(
77 array.buffer_handle().clone(),
78 array.ptype(),
79 new_validity,
80 )
81 .into_array()
82 }));
83 }
84
85 if !values_fit_in(array, new_ptype, ctx) {
86 vortex_bail!(
87 Compute: "Cannot cast {} to {} — values exceed target range",
88 array.ptype(),
89 new_ptype,
90 );
91 }
92
93 if array.ptype().is_int()
97 && new_ptype.is_int()
98 && array.ptype().byte_width() == new_ptype.byte_width()
99 {
100 return Ok(Some(unsafe {
103 PrimitiveArray::new_unchecked_from_handle(
104 array.buffer_handle().clone(),
105 new_ptype,
106 new_validity,
107 )
108 .into_array()
109 }));
110 }
111
112 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
114 match_each_native_ptype!(array.ptype(), |F| {
115 PrimitiveArray::new(cast::<F, T>(array.as_slice()), new_validity).into_array()
116 })
117 })))
118 }
119}
120
121fn values_fit_in(
123 array: ArrayView<'_, Primitive>,
124 target_ptype: PType,
125 ctx: &mut ExecutionCtx,
126) -> bool {
127 let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
128 aggregate_fn::fns::min_max::min_max(array.array(), ctx)
129 .ok()
130 .flatten()
131 .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
132}
133
134fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
138 BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
139}
140
141#[cfg(test)]
142mod test {
143 use rstest::rstest;
144 use vortex_buffer::BitBuffer;
145 use vortex_buffer::buffer;
146 use vortex_error::VortexError;
147 use vortex_mask::Mask;
148
149 use crate::IntoArray;
150 use crate::LEGACY_SESSION;
151 use crate::VortexSessionExecute;
152 use crate::arrays::PrimitiveArray;
153 use crate::assert_arrays_eq;
154 use crate::builtins::ArrayBuiltins;
155 #[expect(deprecated)]
156 use crate::canonical::ToCanonical as _;
157 use crate::compute::conformance::cast::test_cast_conformance;
158 use crate::dtype::DType;
159 use crate::dtype::Nullability;
160 use crate::dtype::PType;
161 use crate::validity::Validity;
162
163 #[test]
164 fn cast_u32_u8() {
165 let arr = buffer![0u32, 10, 200].into_array();
166
167 #[expect(deprecated)]
169 let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
170 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
171 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
172
173 #[expect(deprecated)]
175 let p = p
176 .into_array()
177 .cast(DType::Primitive(PType::U8, Nullability::Nullable))
178 .unwrap()
179 .to_primitive();
180 assert_arrays_eq!(
181 p,
182 PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
183 );
184 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
185
186 #[expect(deprecated)]
188 let p = p
189 .into_array()
190 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
191 .unwrap()
192 .to_primitive();
193 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
194 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
195
196 #[expect(deprecated)]
198 let p = p
199 .into_array()
200 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
201 .unwrap()
202 .to_primitive();
203 assert_arrays_eq!(
204 p,
205 PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
206 );
207 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
208
209 #[expect(deprecated)]
211 let p = p
212 .into_array()
213 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
214 .unwrap()
215 .to_primitive();
216 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
217 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
218 }
219
220 #[test]
221 fn cast_u32_f32() {
222 let arr = buffer![0u32, 10, 200].into_array();
223 #[expect(deprecated)]
224 let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
225 assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
226 }
227
228 #[test]
229 fn cast_i32_u32() {
230 let arr = buffer![-1i32].into_array();
231 #[expect(deprecated)]
232 let error = arr
233 .cast(PType::U32.into())
234 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
235 .unwrap_err();
236 assert!(matches!(error, VortexError::Compute(..)));
237 assert!(error.to_string().contains("values exceed target range"));
238 }
239
240 #[test]
241 fn cast_array_with_nulls_to_nonnullable() {
242 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
243 #[expect(deprecated)]
244 let err = arr
245 .into_array()
246 .cast(PType::I32.into())
247 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
248 .unwrap_err();
249
250 assert!(matches!(err, VortexError::InvalidArgument(..)));
251 assert!(
252 err.to_string()
253 .contains("Cannot cast array with invalid values to non-nullable type.")
254 );
255 }
256
257 #[test]
258 fn cast_with_invalid_nulls() {
259 let arr = PrimitiveArray::new(
260 buffer![-1i32, 0, 10],
261 Validity::from_iter([false, true, true]),
262 );
263 #[expect(deprecated)]
264 let p = arr
265 .into_array()
266 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
267 .unwrap()
268 .to_primitive();
269 assert_arrays_eq!(
270 p,
271 PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
272 );
273 assert_eq!(
274 p.as_ref()
275 .validity()
276 .unwrap()
277 .execute_mask(p.as_ref().len(), &mut LEGACY_SESSION.create_execution_ctx())
278 .unwrap(),
279 Mask::from(BitBuffer::from(vec![false, true, true]))
280 );
281 }
282
283 #[test]
286 fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
287 let src = PrimitiveArray::from_iter([0u32, 10, 100]);
288 let src_ptr = src.as_slice::<u32>().as_ptr();
289
290 #[expect(deprecated)]
291 let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
292 let dst_ptr = dst.as_slice::<i32>().as_ptr();
293
294 assert_eq!(src_ptr as usize, dst_ptr as usize);
296 assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
297 Ok(())
298 }
299
300 #[test]
303 fn cast_same_width_int_out_of_range_errors() {
304 let arr = buffer![u32::MAX].into_array();
305 #[expect(deprecated)]
306 let err = arr
307 .cast(PType::I32.into())
308 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
309 .unwrap_err();
310 assert!(matches!(err, VortexError::Compute(..)));
311 }
312
313 #[test]
316 fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
317 let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
318 #[expect(deprecated)]
319 let casted = arr
320 .into_array()
321 .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
322 .to_primitive();
323 assert_eq!(casted.len(), 2);
324 assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
325 Ok(())
326 }
327
328 #[test]
331 fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
332 let arr = PrimitiveArray::new(
335 buffer![u32::MAX, 0u32, 42u32],
336 Validity::from_iter([false, true, true]),
337 );
338 #[expect(deprecated)]
339 let casted = arr
340 .into_array()
341 .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
342 .to_primitive();
343 assert_arrays_eq!(
344 casted,
345 PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
346 );
347 Ok(())
348 }
349
350 #[test]
351 fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
352 let arr = PrimitiveArray::new(
353 buffer![1000u32, 10u32, 42u32],
354 Validity::from_iter([false, true, true]),
355 );
356 #[expect(deprecated)]
357 let casted = arr
358 .into_array()
359 .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
360 .to_primitive();
361 assert_arrays_eq!(
362 casted,
363 PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
364 );
365 Ok(())
366 }
367
368 #[rstest]
369 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
370 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
371 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
372 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
373 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
374 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
375 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
376 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
377 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
378 #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
379 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
380 #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
381 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
382 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
383 #[case(buffer![42u32].into_array())]
384 fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
385 test_cast_conformance(&array);
386 }
387}