Skip to main content

vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_compute::lane_kernels::IndexedSinkExt;
9use vortex_compute::lane_kernels::IndexedSourceExt;
10use vortex_compute::lane_kernels::ReinterpretSink;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_mask::Mask;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::aggregate_fn;
20use crate::array::ArrayView;
21use crate::arrays::Primitive;
22use crate::arrays::PrimitiveArray;
23use crate::arrays::primitive::PrimitiveArrayExt;
24use crate::dtype::DType;
25use crate::dtype::NativePType;
26use crate::dtype::Nullability;
27use crate::dtype::PType;
28use crate::expr::stats::Stat;
29use crate::expr::stats::StatsProvider;
30use crate::match_each_native_ptype;
31use crate::scalar_fn::fns::cast::CastKernel;
32use crate::scalar_fn::fns::cast::CastReduce;
33use crate::validity::Validity;
34
35impl CastReduce for Primitive {
36    fn cast(array: ArrayView<'_, Primitive>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
37        // Only the same ptype is reducible without execution; type changes need the kernel
38        // to verify values fit in the target range.
39        let DType::Primitive(new_ptype, new_nullability) = dtype else {
40            return Ok(None);
41        };
42        if *new_ptype != array.ptype() {
43            return Ok(None);
44        }
45
46        let Some(new_validity) = array
47            .validity()?
48            .trivially_cast_nullability(*new_nullability, array.len())?
49        else {
50            return Ok(None);
51        };
52
53        // SAFETY: validity and data buffer still have same length.
54        Ok(Some(unsafe {
55            PrimitiveArray::new_unchecked_from_handle(
56                array.buffer_handle().clone(),
57                array.ptype(),
58                new_validity,
59            )
60            .into_array()
61        }))
62    }
63}
64
65impl CastKernel for Primitive {
66    fn cast(
67        array: ArrayView<'_, Primitive>,
68        dtype: &DType,
69        ctx: &mut ExecutionCtx,
70    ) -> VortexResult<Option<ArrayRef>> {
71        let DType::Primitive(new_ptype, new_nullability) = dtype else {
72            return Ok(None);
73        };
74        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
75        let src_ptype = array.ptype();
76
77        let new_validity = array
78            .validity()?
79            .cast_nullability(new_nullability, array.len(), ctx)?;
80
81        // Same bit representation: either the same ptype (only the nullability changed) or two
82        // same-width integers (identical layout under 2's complement). The only non-trivial case
83        // is the sign change between same-width ints, which still needs a value-range check.
84        let same_rep = src_ptype == new_ptype
85            || (src_ptype.is_int()
86                && new_ptype.is_int()
87                && src_ptype.byte_width() == new_ptype.byte_width());
88        if same_rep {
89            if !values_fit_in(array, new_ptype, ctx, true) {
90                vortex_bail!(
91                    Compute: "Cannot cast {} to {} — values exceed target range",
92                    src_ptype, new_ptype,
93                );
94            }
95            return Ok(Some(reinterpret(array, new_ptype, new_validity)));
96        }
97
98        // Different bit rep: cast each element. `cast_values` picks a pure or checked loop based
99        // on whether the conversion is statically infallible.
100        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
101            match_each_native_ptype!(src_ptype, |F| {
102                cast_values::<F, T>(array, new_validity, ctx)?
103            })
104        })))
105    }
106}
107
108/// Cast Primitive values from `F` to `T`.
109fn cast_values<F, T>(
110    array: ArrayView<'_, Primitive>,
111    new_validity: Validity,
112    ctx: &mut ExecutionCtx,
113) -> VortexResult<ArrayRef>
114where
115    F: NativePType + AsPrimitive<T>,
116    T: NativePType,
117{
118    let overflow = || {
119        vortex_err!(
120            Compute: "Cannot cast {} to {} — value exceeds target range",
121            F::PTYPE, T::PTYPE,
122        )
123    };
124
125    // Returns `true` if every value of `from` is representable in `to` without loss.
126    fn casts_losslessly_to(from: PType, to: PType) -> bool {
127        from.least_supertype(to) == Some(to)
128    }
129
130    // Skip the fallible kernel when type widening or (cached) min/max prove every value fits.
131    let target_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable);
132    let infallible = casts_losslessly_to(F::PTYPE, T::PTYPE)
133        || cached_values_fit_in(array, &target_dtype).unwrap_or(false);
134
135    let len = array.len();
136
137    // If F and T have the same byte width, try to take unique ownership of the buffer.
138    let same_bit_width = F::PTYPE.byte_width() == T::PTYPE.byte_width();
139    let owned: Option<BufferMut<F>> = same_bit_width
140        .then(|| array.into_owned().try_into_buffer_mut::<F>().ok())
141        .flatten();
142    let values: &[F] = array.as_slice::<F>();
143
144    if infallible {
145        return match owned {
146            Some(mut buf) => {
147                ReinterpretSink::<F, T>::new(buf.as_mut_slice()).map_into_in_place(|v: F| v.as_());
148                // SAFETY: same size + alignment for NativePType
149                let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
150                Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array())
151            }
152            None => {
153                let mut buffer = BufferMut::<T>::with_capacity(len);
154                values.map_into(&mut buffer.spare_capacity_mut()[..len], |v| v.as_());
155                // SAFETY: map_into initializes every lane.
156                unsafe { buffer.set_len(len) };
157                Ok(PrimitiveArray::new(buffer.freeze(), new_validity).into_array())
158            }
159        };
160    }
161
162    let mask = array.validity()?.execute_mask(len, ctx)?;
163
164    let buffer: Buffer<T> = match (&mask, owned) {
165        (Mask::AllTrue(_), Some(mut buf)) => {
166            ReinterpretSink::<F, T>::new(buf.as_mut_slice())
167                .try_map_in_place(|v: F| <T as NumCast>::from(v))
168                .map_err(|_| overflow())?;
169            // SAFETY: same size + alignment for NativePType
170            let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
171            result.freeze()
172        }
173        (Mask::AllTrue(_), None) => {
174            let mut buffer = BufferMut::<T>::with_capacity(len);
175            values
176                .try_map_into(&mut buffer.spare_capacity_mut()[..len], |v| {
177                    <T as NumCast>::from(v)
178                })
179                .map_err(|_| overflow())?;
180            // SAFETY: initialized every lane.
181            unsafe { buffer.set_len(len) };
182            buffer.freeze()
183        }
184        (Mask::AllFalse(_), _) => BufferMut::<T>::zeroed(len).freeze(),
185        (Mask::Values(m), Some(mut buf)) => {
186            ReinterpretSink::<F, T>::new(buf.as_mut_slice())
187                .try_map_masked_in_place(m.bit_buffer(), |v: F| <T as NumCast>::from(v))
188                .map_err(|_| overflow())?;
189            // SAFETY: same size + alignment for NativePType
190            let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
191            result.freeze()
192        }
193        (Mask::Values(m), None) => {
194            let mut buffer = BufferMut::<T>::with_capacity(len);
195            values
196                .try_map_masked_into(
197                    m.bit_buffer(),
198                    &mut buffer.spare_capacity_mut()[..len],
199                    |v| <T as NumCast>::from(v),
200                )
201                .map_err(|_| overflow())?;
202            // SAFETY: initialized every lane.
203            unsafe { buffer.set_len(len) };
204            buffer.freeze()
205        }
206    };
207
208    Ok(PrimitiveArray::new(buffer, new_validity).into_array())
209}
210
211fn reinterpret(
212    array: ArrayView<'_, Primitive>,
213    new_ptype: PType,
214    new_validity: Validity,
215) -> ArrayRef {
216    // SAFETY: caller has verified the bit representation is compatible and that validity length
217    // still matches the buffer length.
218    unsafe {
219        PrimitiveArray::new_unchecked_from_handle(
220            array.buffer_handle().clone(),
221            new_ptype,
222            new_validity,
223        )
224    }
225    .into_array()
226}
227
228/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
229///
230/// Cached min/max statistics are consulted first. If either bound is missing, the function either
231/// computes them with a single pass (when `compute` is `true`) or returns `false` so the caller
232/// can fall back to a slower path (when `compute` is `false`).
233fn values_fit_in(
234    array: ArrayView<'_, Primitive>,
235    target_ptype: PType,
236    ctx: &mut ExecutionCtx,
237    compute: bool,
238) -> bool {
239    let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
240    if let Some(fits) = cached_values_fit_in(array, &target_dtype) {
241        return fits;
242    }
243    if !compute {
244        return false;
245    }
246    aggregate_fn::fns::min_max::min_max(
247        array.array(),
248        ctx,
249        aggregate_fn::NumericalAggregateOpts::default(),
250    )
251    .ok()
252    .flatten()
253    .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
254}
255
256/// Cached-only check: returns `Some(fits)` if both `Min` and `Max` are present as `Exact` in the
257/// stats cache, otherwise `None`.
258fn cached_values_fit_in(array: ArrayView<'_, Primitive>, target_dtype: &DType) -> Option<bool> {
259    let stats = array.array().statistics();
260    let min = stats.get(Stat::Min).as_exact()?;
261    let max = stats.get(Stat::Max).as_exact()?;
262    Some(min.cast(target_dtype).is_ok() && max.cast(target_dtype).is_ok())
263}
264
265#[cfg(test)]
266mod test {
267    use rstest::rstest;
268    use vortex_buffer::BitBuffer;
269    use vortex_buffer::buffer;
270    use vortex_error::VortexError;
271    use vortex_mask::Mask;
272
273    use crate::ArrayRef;
274    use crate::IntoArray;
275    use crate::VortexSessionExecute;
276    use crate::array_session;
277    use crate::arrays::PrimitiveArray;
278    use crate::assert_arrays_eq;
279    use crate::builtins::ArrayBuiltins;
280    #[expect(deprecated)]
281    use crate::canonical::ToCanonical as _;
282    use crate::compute::conformance::cast::test_cast_conformance;
283    use crate::dtype::DType;
284    use crate::dtype::Nullability;
285    use crate::dtype::PType;
286    use crate::validity::Validity;
287
288    #[test]
289    fn cast_u32_u8() {
290        let mut ctx = array_session().create_execution_ctx();
291        let arr = buffer![0u32, 10, 200].into_array();
292
293        // cast from u32 to u8
294        #[expect(deprecated)]
295        let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
296        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]), &mut ctx);
297        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
298
299        // to nullable
300        #[expect(deprecated)]
301        let p = p
302            .into_array()
303            .cast(DType::Primitive(PType::U8, Nullability::Nullable))
304            .unwrap()
305            .to_primitive();
306        assert_arrays_eq!(
307            p,
308            PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid),
309            &mut ctx
310        );
311        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
312
313        // back to non-nullable
314        #[expect(deprecated)]
315        let p = p
316            .into_array()
317            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
318            .unwrap()
319            .to_primitive();
320        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]), &mut ctx);
321        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
322
323        // to nullable u32
324        #[expect(deprecated)]
325        let p = p
326            .into_array()
327            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
328            .unwrap()
329            .to_primitive();
330        assert_arrays_eq!(
331            p,
332            PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid),
333            &mut ctx
334        );
335        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
336
337        // to non-nullable u8
338        #[expect(deprecated)]
339        let p = p
340            .into_array()
341            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
342            .unwrap()
343            .to_primitive();
344        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]), &mut ctx);
345        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
346    }
347
348    #[test]
349    fn cast_u32_f32() {
350        let mut ctx = array_session().create_execution_ctx();
351        let arr = buffer![0u32, 10, 200].into_array();
352        #[expect(deprecated)]
353        let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
354        assert_arrays_eq!(
355            u8arr,
356            PrimitiveArray::from_iter([0.0f32, 10., 200.]),
357            &mut ctx
358        );
359    }
360
361    #[test]
362    fn cast_i32_u32() {
363        let arr = buffer![-1i32].into_array();
364        #[expect(deprecated)]
365        let error = arr
366            .cast(PType::U32.into())
367            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
368            .unwrap_err();
369        assert!(matches!(error, VortexError::Compute(..)));
370        assert!(error.to_string().contains("values exceed target range"));
371    }
372
373    #[test]
374    fn cast_array_with_nulls_to_nonnullable() {
375        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
376        #[expect(deprecated)]
377        let err = arr
378            .into_array()
379            .cast(PType::I32.into())
380            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
381            .unwrap_err();
382
383        assert!(matches!(err, VortexError::InvalidArgument(..)));
384        assert!(
385            err.to_string()
386                .contains("Cannot cast array with invalid values to non-nullable type.")
387        );
388    }
389
390    #[test]
391    fn cast_with_invalid_nulls() {
392        let mut ctx = array_session().create_execution_ctx();
393        let arr = PrimitiveArray::new(
394            buffer![-1i32, 0, 10],
395            Validity::from_iter([false, true, true]),
396        );
397        #[expect(deprecated)]
398        let p = arr
399            .into_array()
400            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
401            .unwrap()
402            .to_primitive();
403        assert_arrays_eq!(
404            p,
405            PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)]),
406            &mut ctx
407        );
408        assert_eq!(
409            p.as_ref()
410                .validity()
411                .unwrap()
412                .execute_mask(
413                    p.as_ref().len(),
414                    &mut array_session().create_execution_ctx()
415                )
416                .unwrap(),
417            Mask::from(BitBuffer::from(vec![false, true, true]))
418        );
419    }
420
421    /// Same-width integer cast where all values fit: should reinterpret the
422    /// buffer without allocation (pointer identity).
423    #[test]
424    fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
425        let mut ctx = array_session().create_execution_ctx();
426        let src = PrimitiveArray::from_iter([0u32, 10, 100]);
427        let src_ptr = src.as_slice::<u32>().as_ptr();
428
429        #[expect(deprecated)]
430        let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
431        let dst_ptr = dst.as_slice::<i32>().as_ptr();
432
433        // Zero-copy: the data pointer should be identical.
434        assert_eq!(src_ptr as usize, dst_ptr as usize);
435        assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]), &mut ctx);
436        Ok(())
437    }
438
439    /// Same-width integer cast where values don't fit: should fall through
440    /// to the allocating path and produce an error.
441    #[test]
442    fn cast_same_width_int_out_of_range_errors() {
443        let arr = buffer![u32::MAX].into_array();
444        #[expect(deprecated)]
445        let err = arr
446            .cast(PType::I32.into())
447            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
448            .unwrap_err();
449        assert!(matches!(err, VortexError::Compute(..)));
450    }
451
452    /// All-null array cast between same-width types should succeed without
453    /// touching the buffer contents.
454    #[test]
455    fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
456        let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
457        #[expect(deprecated)]
458        let casted = arr
459            .into_array()
460            .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
461            .to_primitive();
462        assert_eq!(casted.len(), 2);
463        assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
464        Ok(())
465    }
466
467    /// Same-width integer cast with nullable values: out-of-range nulls should
468    /// not prevent the cast from succeeding.
469    #[test]
470    fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
471        let mut ctx = array_session().create_execution_ctx();
472        // The null position holds u32::MAX which doesn't fit in i32, but it's
473        // masked as invalid so the cast should still succeed via reinterpret.
474        let arr = PrimitiveArray::new(
475            buffer![u32::MAX, 0u32, 42u32],
476            Validity::from_iter([false, true, true]),
477        );
478        #[expect(deprecated)]
479        let casted = arr
480            .into_array()
481            .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
482            .to_primitive();
483        assert_arrays_eq!(
484            casted,
485            PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)]),
486            &mut ctx
487        );
488        Ok(())
489    }
490
491    #[test]
492    fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
493        let mut ctx = array_session().create_execution_ctx();
494        let arr = PrimitiveArray::new(
495            buffer![1000u32, 10u32, 42u32],
496            Validity::from_iter([false, true, true]),
497        );
498        #[expect(deprecated)]
499        let casted = arr
500            .into_array()
501            .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
502            .to_primitive();
503        assert_arrays_eq!(
504            casted,
505            PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)]),
506            &mut ctx
507        );
508        Ok(())
509    }
510
511    #[rstest]
512    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
513    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
514    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
515    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
516    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
517    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
518    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
519    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
520    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
521    #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
522    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
523    #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
524    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
525    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
526    #[case(buffer![42u32].into_array())]
527    fn test_cast_primitive_conformance(#[case] array: ArrayRef) {
528        test_cast_conformance(&array);
529    }
530}