Skip to main content

vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::aggregate_fn;
14use crate::array::ArrayView;
15use crate::arrays::Primitive;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::primitive::PrimitiveArrayExt;
18use crate::dtype::DType;
19use crate::dtype::NativePType;
20use crate::dtype::Nullability;
21use crate::dtype::PType;
22use crate::match_each_native_ptype;
23use crate::scalar_fn::fns::cast::CastKernel;
24use crate::scalar_fn::fns::cast::CastReduce;
25
26impl CastReduce for Primitive {
27    fn cast(array: ArrayView<'_, Primitive>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
28        // Only the same ptype is reducible without execution; type changes need the kernel
29        // to verify values fit in the target range.
30        let DType::Primitive(new_ptype, new_nullability) = dtype else {
31            return Ok(None);
32        };
33        if *new_ptype != array.ptype() {
34            return Ok(None);
35        }
36
37        let Some(new_validity) = array
38            .validity()?
39            .trivial_cast_nullability(*new_nullability, array.len())?
40        else {
41            return Ok(None);
42        };
43
44        // SAFETY: validity and data buffer still have same length.
45        Ok(Some(unsafe {
46            PrimitiveArray::new_unchecked_from_handle(
47                array.buffer_handle().clone(),
48                array.ptype(),
49                new_validity,
50            )
51            .into_array()
52        }))
53    }
54}
55
56impl CastKernel for Primitive {
57    fn cast(
58        array: ArrayView<'_, Primitive>,
59        dtype: &DType,
60        ctx: &mut ExecutionCtx,
61    ) -> VortexResult<Option<ArrayRef>> {
62        let DType::Primitive(new_ptype, new_nullability) = dtype else {
63            return Ok(None);
64        };
65        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
66
67        // First, check that the cast is compatible with the source array's validity
68        let new_validity = array
69            .validity()?
70            .cast_nullability(new_nullability, array.len(), ctx)?;
71
72        // Same ptype: zero-copy, just update validity.
73        if array.ptype() == new_ptype {
74            // SAFETY: validity and data buffer still have same length
75            return Ok(Some(unsafe {
76                PrimitiveArray::new_unchecked_from_handle(
77                    array.buffer_handle().clone(),
78                    array.ptype(),
79                    new_validity,
80                )
81                .into_array()
82            }));
83        }
84
85        if !values_fit_in(array, new_ptype, ctx) {
86            vortex_bail!(
87                Compute: "Cannot cast {} to {} — values exceed target range",
88                array.ptype(),
89                new_ptype,
90            );
91        }
92
93        // Same-width integers have identical bit representations due to 2's
94        // complement. If all values fit in the target range, reinterpret with
95        // no allocation.
96        if array.ptype().is_int()
97            && new_ptype.is_int()
98            && array.ptype().byte_width() == new_ptype.byte_width()
99        {
100            // SAFETY: both types are integers with the same size and alignment, and
101            // min/max confirm all valid values are representable in the target type.
102            return Ok(Some(unsafe {
103                PrimitiveArray::new_unchecked_from_handle(
104                    array.buffer_handle().clone(),
105                    new_ptype,
106                    new_validity,
107                )
108                .into_array()
109            }));
110        }
111
112        // Otherwise, we need to cast the values one-by-one.
113        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
114            match_each_native_ptype!(array.ptype(), |F| {
115                PrimitiveArray::new(cast::<F, T>(array.as_slice()), new_validity).into_array()
116            })
117        })))
118    }
119}
120
121/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
122fn values_fit_in(
123    array: ArrayView<'_, Primitive>,
124    target_ptype: PType,
125    ctx: &mut ExecutionCtx,
126) -> bool {
127    let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
128    aggregate_fn::fns::min_max::min_max(array.array(), ctx)
129        .ok()
130        .flatten()
131        .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
132}
133
134/// Caller must ensure all valid values are representable via `values_fit_in`.
135/// Out-of-range values at invalid positions are truncated/wrapped by `as`,
136/// which is fine because they are masked out by validity.
137fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
138    BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
139}
140
141#[cfg(test)]
142mod test {
143    use rstest::rstest;
144    use vortex_buffer::BitBuffer;
145    use vortex_buffer::buffer;
146    use vortex_error::VortexError;
147    use vortex_mask::Mask;
148
149    use crate::IntoArray;
150    use crate::LEGACY_SESSION;
151    use crate::VortexSessionExecute;
152    use crate::arrays::PrimitiveArray;
153    use crate::assert_arrays_eq;
154    use crate::builtins::ArrayBuiltins;
155    #[expect(deprecated)]
156    use crate::canonical::ToCanonical as _;
157    use crate::compute::conformance::cast::test_cast_conformance;
158    use crate::dtype::DType;
159    use crate::dtype::Nullability;
160    use crate::dtype::PType;
161    use crate::validity::Validity;
162
163    #[test]
164    fn cast_u32_u8() {
165        let arr = buffer![0u32, 10, 200].into_array();
166
167        // cast from u32 to u8
168        #[expect(deprecated)]
169        let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
170        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
171        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
172
173        // to nullable
174        #[expect(deprecated)]
175        let p = p
176            .into_array()
177            .cast(DType::Primitive(PType::U8, Nullability::Nullable))
178            .unwrap()
179            .to_primitive();
180        assert_arrays_eq!(
181            p,
182            PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
183        );
184        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
185
186        // back to non-nullable
187        #[expect(deprecated)]
188        let p = p
189            .into_array()
190            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
191            .unwrap()
192            .to_primitive();
193        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
194        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
195
196        // to nullable u32
197        #[expect(deprecated)]
198        let p = p
199            .into_array()
200            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
201            .unwrap()
202            .to_primitive();
203        assert_arrays_eq!(
204            p,
205            PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
206        );
207        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
208
209        // to non-nullable u8
210        #[expect(deprecated)]
211        let p = p
212            .into_array()
213            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
214            .unwrap()
215            .to_primitive();
216        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
217        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
218    }
219
220    #[test]
221    fn cast_u32_f32() {
222        let arr = buffer![0u32, 10, 200].into_array();
223        #[expect(deprecated)]
224        let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
225        assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
226    }
227
228    #[test]
229    fn cast_i32_u32() {
230        let arr = buffer![-1i32].into_array();
231        #[expect(deprecated)]
232        let error = arr
233            .cast(PType::U32.into())
234            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
235            .unwrap_err();
236        assert!(matches!(error, VortexError::Compute(..)));
237        assert!(error.to_string().contains("values exceed target range"));
238    }
239
240    #[test]
241    fn cast_array_with_nulls_to_nonnullable() {
242        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
243        #[expect(deprecated)]
244        let err = arr
245            .into_array()
246            .cast(PType::I32.into())
247            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
248            .unwrap_err();
249
250        assert!(matches!(err, VortexError::InvalidArgument(..)));
251        assert!(
252            err.to_string()
253                .contains("Cannot cast array with invalid values to non-nullable type.")
254        );
255    }
256
257    #[test]
258    fn cast_with_invalid_nulls() {
259        let arr = PrimitiveArray::new(
260            buffer![-1i32, 0, 10],
261            Validity::from_iter([false, true, true]),
262        );
263        #[expect(deprecated)]
264        let p = arr
265            .into_array()
266            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
267            .unwrap()
268            .to_primitive();
269        assert_arrays_eq!(
270            p,
271            PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
272        );
273        assert_eq!(
274            p.as_ref()
275                .validity()
276                .unwrap()
277                .execute_mask(p.as_ref().len(), &mut LEGACY_SESSION.create_execution_ctx())
278                .unwrap(),
279            Mask::from(BitBuffer::from(vec![false, true, true]))
280        );
281    }
282
283    /// Same-width integer cast where all values fit: should reinterpret the
284    /// buffer without allocation (pointer identity).
285    #[test]
286    fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
287        let src = PrimitiveArray::from_iter([0u32, 10, 100]);
288        let src_ptr = src.as_slice::<u32>().as_ptr();
289
290        #[expect(deprecated)]
291        let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
292        let dst_ptr = dst.as_slice::<i32>().as_ptr();
293
294        // Zero-copy: the data pointer should be identical.
295        assert_eq!(src_ptr as usize, dst_ptr as usize);
296        assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
297        Ok(())
298    }
299
300    /// Same-width integer cast where values don't fit: should fall through
301    /// to the allocating path and produce an error.
302    #[test]
303    fn cast_same_width_int_out_of_range_errors() {
304        let arr = buffer![u32::MAX].into_array();
305        #[expect(deprecated)]
306        let err = arr
307            .cast(PType::I32.into())
308            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
309            .unwrap_err();
310        assert!(matches!(err, VortexError::Compute(..)));
311    }
312
313    /// All-null array cast between same-width types should succeed without
314    /// touching the buffer contents.
315    #[test]
316    fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
317        let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
318        #[expect(deprecated)]
319        let casted = arr
320            .into_array()
321            .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
322            .to_primitive();
323        assert_eq!(casted.len(), 2);
324        assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
325        Ok(())
326    }
327
328    /// Same-width integer cast with nullable values: out-of-range nulls should
329    /// not prevent the cast from succeeding.
330    #[test]
331    fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
332        // The null position holds u32::MAX which doesn't fit in i32, but it's
333        // masked as invalid so the cast should still succeed via reinterpret.
334        let arr = PrimitiveArray::new(
335            buffer![u32::MAX, 0u32, 42u32],
336            Validity::from_iter([false, true, true]),
337        );
338        #[expect(deprecated)]
339        let casted = arr
340            .into_array()
341            .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
342            .to_primitive();
343        assert_arrays_eq!(
344            casted,
345            PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
346        );
347        Ok(())
348    }
349
350    #[test]
351    fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
352        let arr = PrimitiveArray::new(
353            buffer![1000u32, 10u32, 42u32],
354            Validity::from_iter([false, true, true]),
355        );
356        #[expect(deprecated)]
357        let casted = arr
358            .into_array()
359            .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
360            .to_primitive();
361        assert_arrays_eq!(
362            casted,
363            PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
364        );
365        Ok(())
366    }
367
368    #[rstest]
369    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
370    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
371    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
372    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
373    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
374    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
375    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
376    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
377    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
378    #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
379    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
380    #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
381    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
382    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
383    #[case(buffer![42u32].into_array())]
384    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
385        test_cast_conformance(&array);
386    }
387}