Skip to main content

vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use num_traits::NumCast;
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11use vortex_mask::Mask;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::aggregate_fn;
17use crate::array::ArrayView;
18use crate::arrays::Primitive;
19use crate::arrays::PrimitiveArray;
20use crate::arrays::primitive::PrimitiveArrayExt;
21use crate::dtype::DType;
22use crate::dtype::NativePType;
23use crate::dtype::Nullability;
24use crate::dtype::PType;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::match_each_native_ptype;
28use crate::scalar_fn::fns::cast::CastKernel;
29use crate::scalar_fn::fns::cast::CastReduce;
30use crate::validity::Validity;
31
32impl CastReduce for Primitive {
33    fn cast(array: ArrayView<'_, Primitive>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
34        // Only the same ptype is reducible without execution; type changes need the kernel
35        // to verify values fit in the target range.
36        let DType::Primitive(new_ptype, new_nullability) = dtype else {
37            return Ok(None);
38        };
39        if *new_ptype != array.ptype() {
40            return Ok(None);
41        }
42
43        let Some(new_validity) = array
44            .validity()?
45            .trivially_cast_nullability(*new_nullability, array.len())?
46        else {
47            return Ok(None);
48        };
49
50        // SAFETY: validity and data buffer still have same length.
51        Ok(Some(unsafe {
52            PrimitiveArray::new_unchecked_from_handle(
53                array.buffer_handle().clone(),
54                array.ptype(),
55                new_validity,
56            )
57            .into_array()
58        }))
59    }
60}
61
62impl CastKernel for Primitive {
63    fn cast(
64        array: ArrayView<'_, Primitive>,
65        dtype: &DType,
66        ctx: &mut ExecutionCtx,
67    ) -> VortexResult<Option<ArrayRef>> {
68        let DType::Primitive(new_ptype, new_nullability) = dtype else {
69            return Ok(None);
70        };
71        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
72        let src_ptype = array.ptype();
73
74        let new_validity = array
75            .validity()?
76            .cast_nullability(new_nullability, array.len(), ctx)?;
77
78        // Same bit representation: either the same ptype (only the nullability changed) or two
79        // same-width integers (identical layout under 2's complement). The only non-trivial case
80        // is the sign change between same-width ints, which still needs a value-range check.
81        let same_rep = src_ptype == new_ptype
82            || (src_ptype.is_int()
83                && new_ptype.is_int()
84                && src_ptype.byte_width() == new_ptype.byte_width());
85        if same_rep {
86            if !values_fit_in(array, new_ptype, ctx, true) {
87                vortex_bail!(
88                    Compute: "Cannot cast {} to {} — values exceed target range",
89                    src_ptype, new_ptype,
90                );
91            }
92            return Ok(Some(reinterpret(array, new_ptype, new_validity)));
93        }
94
95        // Different bit rep: cast each element. `cast_values` picks a pure or checked loop based
96        // on whether the conversion is statically infallible.
97        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
98            match_each_native_ptype!(src_ptype, |F| {
99                cast_values::<F, T>(array, new_validity, ctx)?
100            })
101        })))
102    }
103}
104
105/// Cast values from `F` to `T`. For infallible casts this is a pure pass; for fallible casts
106/// each valid value goes through a checked `NumCast::from` and the kernel bails if any of them
107/// overflow `T`. Invalid positions use the wrapping `as` cast since their values are masked out.
108fn cast_values<F, T>(
109    array: ArrayView<'_, Primitive>,
110    new_validity: Validity,
111    ctx: &mut ExecutionCtx,
112) -> VortexResult<ArrayRef>
113where
114    F: NativePType + AsPrimitive<T>,
115    T: NativePType,
116{
117    let values = array.as_slice::<F>();
118
119    // Fast path: statically infallible, or cached min/max prove every valid value fits in `T`.
120    // The cached check never triggers a stats computation — if the bounds aren't already known
121    // we fall through to the per-lane loop below.
122    if values_always_fit(F::PTYPE, T::PTYPE) || values_fit_in(array, T::PTYPE, ctx, false) {
123        return Ok(PrimitiveArray::new(cast::<F, T>(values), new_validity).into_array());
124    }
125
126    // TODO(joe): if the values source and target have the same bit-width we can
127    // mutate in place.
128
129    // Fallible: invalid lanes are pre-multiplied to zero so the checked cast always succeeds for
130    // them; valid lanes go through `NumCast::from` and the whole cast bails on the first overflow.
131    let mask = array.validity()?.execute_mask(array.len(), ctx)?;
132    let overflow = || {
133        vortex_err!(
134            Compute: "Cannot cast {} to {} — value exceeds target range",
135            F::PTYPE, T::PTYPE,
136        )
137    };
138    let buffer: Buffer<T> = match &mask {
139        Mask::AllTrue(_) => BufferMut::try_from_trusted_len_iter(
140            values
141                .iter()
142                .map(|&v| <T as NumCast>::from(v).ok_or_else(overflow)),
143        )?
144        .freeze(),
145        Mask::AllFalse(_) => BufferMut::<T>::zeroed(values.len()).freeze(),
146        Mask::Values(m) => BufferMut::try_from_trusted_len_iter(
147            values.iter().zip(m.bit_buffer().iter()).map(|(&v, valid)| {
148                let factor = if valid { F::one() } else { F::zero() };
149                <T as NumCast>::from(v * factor).ok_or_else(overflow)
150            }),
151        )?
152        .freeze(),
153    };
154
155    Ok(PrimitiveArray::new(buffer, new_validity).into_array())
156}
157
158/// Out-of-range values at invalid positions are truncated/wrapped by `as`, which is fine because
159/// they are masked out by validity.
160fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
161    BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
162}
163
164fn reinterpret(
165    array: ArrayView<'_, Primitive>,
166    new_ptype: PType,
167    new_validity: Validity,
168) -> ArrayRef {
169    // SAFETY: caller has verified the bit representation is compatible and that validity length
170    // still matches the buffer length.
171    unsafe {
172        PrimitiveArray::new_unchecked_from_handle(
173            array.buffer_handle().clone(),
174            new_ptype,
175            new_validity,
176        )
177    }
178    .into_array()
179}
180
181/// Returns `true` if every value of `src` is guaranteed representable in `target` without
182/// overflow. Precision may be lost (e.g. large integers cast to `f32`), but the cast can never
183/// produce an out-of-range result.
184fn values_always_fit(src: PType, target: PType) -> bool {
185    if src == target {
186        return true;
187    }
188    if src.is_int() && target.is_int() {
189        return target.byte_width() > src.byte_width()
190            && (src.is_unsigned_int() || target.is_signed_int());
191    }
192    if src.is_float() && target.is_float() {
193        return target.byte_width() > src.byte_width();
194    }
195    src.is_int() && matches!(target, PType::F32 | PType::F64)
196}
197
198/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
199///
200/// Cached min/max statistics are consulted first. If either bound is missing, the function either
201/// computes them with a single pass (when `compute` is `true`) or returns `false` so the caller
202/// can fall back to a slower path (when `compute` is `false`).
203fn values_fit_in(
204    array: ArrayView<'_, Primitive>,
205    target_ptype: PType,
206    ctx: &mut ExecutionCtx,
207    compute: bool,
208) -> bool {
209    let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
210    if let Some(fits) = cached_values_fit_in(array, &target_dtype) {
211        return fits;
212    }
213    if !compute {
214        return false;
215    }
216    aggregate_fn::fns::min_max::min_max(array.array(), ctx)
217        .ok()
218        .flatten()
219        .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
220}
221
222/// Cached-only check: returns `Some(fits)` if both `Min` and `Max` are present as `Exact` in the
223/// stats cache, otherwise `None`.
224fn cached_values_fit_in(array: ArrayView<'_, Primitive>, target_dtype: &DType) -> Option<bool> {
225    let stats = array.array().statistics();
226    let min = stats.get(Stat::Min).as_exact()?;
227    let max = stats.get(Stat::Max).as_exact()?;
228    Some(min.cast(target_dtype).is_ok() && max.cast(target_dtype).is_ok())
229}
230
231#[cfg(test)]
232mod test {
233    use rstest::rstest;
234    use vortex_buffer::BitBuffer;
235    use vortex_buffer::buffer;
236    use vortex_error::VortexError;
237    use vortex_mask::Mask;
238
239    use crate::IntoArray;
240    use crate::LEGACY_SESSION;
241    use crate::VortexSessionExecute;
242    use crate::arrays::PrimitiveArray;
243    use crate::assert_arrays_eq;
244    use crate::builtins::ArrayBuiltins;
245    #[expect(deprecated)]
246    use crate::canonical::ToCanonical as _;
247    use crate::compute::conformance::cast::test_cast_conformance;
248    use crate::dtype::DType;
249    use crate::dtype::Nullability;
250    use crate::dtype::PType;
251    use crate::validity::Validity;
252
253    #[test]
254    fn cast_u32_u8() {
255        let arr = buffer![0u32, 10, 200].into_array();
256
257        // cast from u32 to u8
258        #[expect(deprecated)]
259        let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
260        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
261        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
262
263        // to nullable
264        #[expect(deprecated)]
265        let p = p
266            .into_array()
267            .cast(DType::Primitive(PType::U8, Nullability::Nullable))
268            .unwrap()
269            .to_primitive();
270        assert_arrays_eq!(
271            p,
272            PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
273        );
274        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
275
276        // back to non-nullable
277        #[expect(deprecated)]
278        let p = p
279            .into_array()
280            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
281            .unwrap()
282            .to_primitive();
283        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
284        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
285
286        // to nullable u32
287        #[expect(deprecated)]
288        let p = p
289            .into_array()
290            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
291            .unwrap()
292            .to_primitive();
293        assert_arrays_eq!(
294            p,
295            PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
296        );
297        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
298
299        // to non-nullable u8
300        #[expect(deprecated)]
301        let p = p
302            .into_array()
303            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
304            .unwrap()
305            .to_primitive();
306        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
307        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
308    }
309
310    #[test]
311    fn cast_u32_f32() {
312        let arr = buffer![0u32, 10, 200].into_array();
313        #[expect(deprecated)]
314        let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
315        assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
316    }
317
318    #[test]
319    fn cast_i32_u32() {
320        let arr = buffer![-1i32].into_array();
321        #[expect(deprecated)]
322        let error = arr
323            .cast(PType::U32.into())
324            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
325            .unwrap_err();
326        assert!(matches!(error, VortexError::Compute(..)));
327        assert!(error.to_string().contains("values exceed target range"));
328    }
329
330    #[test]
331    fn cast_array_with_nulls_to_nonnullable() {
332        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
333        #[expect(deprecated)]
334        let err = arr
335            .into_array()
336            .cast(PType::I32.into())
337            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
338            .unwrap_err();
339
340        assert!(matches!(err, VortexError::InvalidArgument(..)));
341        assert!(
342            err.to_string()
343                .contains("Cannot cast array with invalid values to non-nullable type.")
344        );
345    }
346
347    #[test]
348    fn cast_with_invalid_nulls() {
349        let arr = PrimitiveArray::new(
350            buffer![-1i32, 0, 10],
351            Validity::from_iter([false, true, true]),
352        );
353        #[expect(deprecated)]
354        let p = arr
355            .into_array()
356            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
357            .unwrap()
358            .to_primitive();
359        assert_arrays_eq!(
360            p,
361            PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
362        );
363        assert_eq!(
364            p.as_ref()
365                .validity()
366                .unwrap()
367                .execute_mask(p.as_ref().len(), &mut LEGACY_SESSION.create_execution_ctx())
368                .unwrap(),
369            Mask::from(BitBuffer::from(vec![false, true, true]))
370        );
371    }
372
373    /// Same-width integer cast where all values fit: should reinterpret the
374    /// buffer without allocation (pointer identity).
375    #[test]
376    fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
377        let src = PrimitiveArray::from_iter([0u32, 10, 100]);
378        let src_ptr = src.as_slice::<u32>().as_ptr();
379
380        #[expect(deprecated)]
381        let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
382        let dst_ptr = dst.as_slice::<i32>().as_ptr();
383
384        // Zero-copy: the data pointer should be identical.
385        assert_eq!(src_ptr as usize, dst_ptr as usize);
386        assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
387        Ok(())
388    }
389
390    /// Same-width integer cast where values don't fit: should fall through
391    /// to the allocating path and produce an error.
392    #[test]
393    fn cast_same_width_int_out_of_range_errors() {
394        let arr = buffer![u32::MAX].into_array();
395        #[expect(deprecated)]
396        let err = arr
397            .cast(PType::I32.into())
398            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
399            .unwrap_err();
400        assert!(matches!(err, VortexError::Compute(..)));
401    }
402
403    /// All-null array cast between same-width types should succeed without
404    /// touching the buffer contents.
405    #[test]
406    fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
407        let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
408        #[expect(deprecated)]
409        let casted = arr
410            .into_array()
411            .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
412            .to_primitive();
413        assert_eq!(casted.len(), 2);
414        assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
415        Ok(())
416    }
417
418    /// Same-width integer cast with nullable values: out-of-range nulls should
419    /// not prevent the cast from succeeding.
420    #[test]
421    fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
422        // The null position holds u32::MAX which doesn't fit in i32, but it's
423        // masked as invalid so the cast should still succeed via reinterpret.
424        let arr = PrimitiveArray::new(
425            buffer![u32::MAX, 0u32, 42u32],
426            Validity::from_iter([false, true, true]),
427        );
428        #[expect(deprecated)]
429        let casted = arr
430            .into_array()
431            .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
432            .to_primitive();
433        assert_arrays_eq!(
434            casted,
435            PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
436        );
437        Ok(())
438    }
439
440    #[test]
441    fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
442        let arr = PrimitiveArray::new(
443            buffer![1000u32, 10u32, 42u32],
444            Validity::from_iter([false, true, true]),
445        );
446        #[expect(deprecated)]
447        let casted = arr
448            .into_array()
449            .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
450            .to_primitive();
451        assert_arrays_eq!(
452            casted,
453            PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
454        );
455        Ok(())
456    }
457
458    #[rstest]
459    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
460    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
461    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
462    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
463    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
464    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
465    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
466    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
467    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
468    #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
469    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
470    #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
471    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
472    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
473    #[case(buffer![42u32].into_array())]
474    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
475        test_cast_conformance(&array);
476    }
477}