1use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::aggregate_fn;
14use crate::array::ArrayView;
15use crate::arrays::Primitive;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::primitive::PrimitiveArrayExt;
18use crate::dtype::DType;
19use crate::dtype::NativePType;
20use crate::dtype::Nullability;
21use crate::dtype::PType;
22use crate::match_each_native_ptype;
23use crate::scalar_fn::fns::cast::CastKernel;
24
25impl CastKernel for Primitive {
26 fn cast(
27 array: ArrayView<'_, Primitive>,
28 dtype: &DType,
29 ctx: &mut ExecutionCtx,
30 ) -> VortexResult<Option<ArrayRef>> {
31 let DType::Primitive(new_ptype, new_nullability) = dtype else {
32 return Ok(None);
33 };
34 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
35
36 let new_validity = array
38 .validity()?
39 .cast_nullability(new_nullability, array.len())?;
40
41 if array.ptype() == new_ptype {
43 return Ok(Some(unsafe {
45 PrimitiveArray::new_unchecked_from_handle(
46 array.buffer_handle().clone(),
47 array.ptype(),
48 new_validity,
49 )
50 .into_array()
51 }));
52 }
53
54 if !values_fit_in(array, new_ptype, ctx) {
55 vortex_bail!(
56 Compute: "Cannot cast {} to {} — values exceed target range",
57 array.ptype(),
58 new_ptype,
59 );
60 }
61
62 if array.ptype().is_int()
66 && new_ptype.is_int()
67 && array.ptype().byte_width() == new_ptype.byte_width()
68 {
69 return Ok(Some(unsafe {
72 PrimitiveArray::new_unchecked_from_handle(
73 array.buffer_handle().clone(),
74 new_ptype,
75 new_validity,
76 )
77 .into_array()
78 }));
79 }
80
81 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
83 match_each_native_ptype!(array.ptype(), |F| {
84 PrimitiveArray::new(cast::<F, T>(array.as_slice()), new_validity).into_array()
85 })
86 })))
87 }
88}
89
90fn values_fit_in(
92 array: ArrayView<'_, Primitive>,
93 target_ptype: PType,
94 ctx: &mut ExecutionCtx,
95) -> bool {
96 let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
97 aggregate_fn::fns::min_max::min_max(array.array(), ctx)
98 .ok()
99 .flatten()
100 .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
101}
102
103fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
107 BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
108}
109
110#[cfg(test)]
111mod test {
112 use rstest::rstest;
113 use vortex_buffer::BitBuffer;
114 use vortex_buffer::buffer;
115 use vortex_error::VortexError;
116 use vortex_mask::Mask;
117
118 use crate::IntoArray;
119 use crate::LEGACY_SESSION;
120 use crate::VortexSessionExecute;
121 use crate::arrays::PrimitiveArray;
122 use crate::assert_arrays_eq;
123 use crate::builtins::ArrayBuiltins;
124 use crate::canonical::ToCanonical;
125 use crate::compute::conformance::cast::test_cast_conformance;
126 use crate::dtype::DType;
127 use crate::dtype::Nullability;
128 use crate::dtype::PType;
129 use crate::validity::Validity;
130
131 #[test]
132 fn cast_u32_u8() {
133 let arr = buffer![0u32, 10, 200].into_array();
134
135 let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
137 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
138 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
139
140 let p = p
142 .into_array()
143 .cast(DType::Primitive(PType::U8, Nullability::Nullable))
144 .unwrap()
145 .to_primitive();
146 assert_arrays_eq!(
147 p,
148 PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
149 );
150 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
151
152 let p = p
154 .into_array()
155 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
156 .unwrap()
157 .to_primitive();
158 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
159 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
160
161 let p = p
163 .into_array()
164 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
165 .unwrap()
166 .to_primitive();
167 assert_arrays_eq!(
168 p,
169 PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
170 );
171 assert!(matches!(p.validity(), Ok(Validity::AllValid)));
172
173 let p = p
175 .into_array()
176 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
177 .unwrap()
178 .to_primitive();
179 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
180 assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
181 }
182
183 #[test]
184 fn cast_u32_f32() {
185 let arr = buffer![0u32, 10, 200].into_array();
186 let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
187 assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
188 }
189
190 #[test]
191 fn cast_i32_u32() {
192 let arr = buffer![-1i32].into_array();
193 let error = arr
194 .cast(PType::U32.into())
195 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
196 .unwrap_err();
197 assert!(matches!(error, VortexError::Compute(..)));
198 assert!(error.to_string().contains("values exceed target range"));
199 }
200
201 #[test]
202 fn cast_array_with_nulls_to_nonnullable() {
203 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
204 let err = arr
205 .into_array()
206 .cast(PType::I32.into())
207 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
208 .unwrap_err();
209
210 assert!(matches!(err, VortexError::InvalidArgument(..)));
211 assert!(
212 err.to_string()
213 .contains("Cannot cast array with invalid values to non-nullable type.")
214 );
215 }
216
217 #[test]
218 fn cast_with_invalid_nulls() {
219 let arr = PrimitiveArray::new(
220 buffer![-1i32, 0, 10],
221 Validity::from_iter([false, true, true]),
222 );
223 let p = arr
224 .into_array()
225 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
226 .unwrap()
227 .to_primitive();
228 assert_arrays_eq!(
229 p,
230 PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
231 );
232 assert_eq!(
233 p.as_ref()
234 .validity()
235 .unwrap()
236 .to_mask(p.as_ref().len(), &mut LEGACY_SESSION.create_execution_ctx())
237 .unwrap(),
238 Mask::from(BitBuffer::from(vec![false, true, true]))
239 );
240 }
241
242 #[test]
245 fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
246 let src = PrimitiveArray::from_iter([0u32, 10, 100]);
247 let src_ptr = src.as_slice::<u32>().as_ptr();
248
249 let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
250 let dst_ptr = dst.as_slice::<i32>().as_ptr();
251
252 assert_eq!(src_ptr as usize, dst_ptr as usize);
254 assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
255 Ok(())
256 }
257
258 #[test]
261 fn cast_same_width_int_out_of_range_errors() {
262 let arr = buffer![u32::MAX].into_array();
263 let err = arr
264 .cast(PType::I32.into())
265 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
266 .unwrap_err();
267 assert!(matches!(err, VortexError::Compute(..)));
268 }
269
270 #[test]
273 fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
274 let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
275 let casted = arr
276 .into_array()
277 .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
278 .to_primitive();
279 assert_eq!(casted.len(), 2);
280 assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
281 Ok(())
282 }
283
284 #[test]
287 fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
288 let arr = PrimitiveArray::new(
291 buffer![u32::MAX, 0u32, 42u32],
292 Validity::from_iter([false, true, true]),
293 );
294 let casted = arr
295 .into_array()
296 .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
297 .to_primitive();
298 assert_arrays_eq!(
299 casted,
300 PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
301 );
302 Ok(())
303 }
304
305 #[test]
306 fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
307 let arr = PrimitiveArray::new(
308 buffer![1000u32, 10u32, 42u32],
309 Validity::from_iter([false, true, true]),
310 );
311 let casted = arr
312 .into_array()
313 .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
314 .to_primitive();
315 assert_arrays_eq!(
316 casted,
317 PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
318 );
319 Ok(())
320 }
321
322 #[rstest]
323 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
324 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
325 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
326 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
327 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
328 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
329 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
330 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
331 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
332 #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
333 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
334 #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
335 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
336 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
337 #[case(buffer![42u32].into_array())]
338 fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
339 test_cast_conformance(&array);
340 }
341}