Skip to main content

vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_buffer::Buffer;
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::aggregate_fn;
14use crate::array::ArrayView;
15use crate::arrays::Primitive;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::primitive::PrimitiveArrayExt;
18use crate::dtype::DType;
19use crate::dtype::NativePType;
20use crate::dtype::Nullability;
21use crate::dtype::PType;
22use crate::match_each_native_ptype;
23use crate::scalar_fn::fns::cast::CastKernel;
24
25impl CastKernel for Primitive {
26    fn cast(
27        array: ArrayView<'_, Primitive>,
28        dtype: &DType,
29        ctx: &mut ExecutionCtx,
30    ) -> VortexResult<Option<ArrayRef>> {
31        let DType::Primitive(new_ptype, new_nullability) = dtype else {
32            return Ok(None);
33        };
34        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
35
36        // First, check that the cast is compatible with the source array's validity
37        let new_validity = array
38            .validity()?
39            .cast_nullability(new_nullability, array.len())?;
40
41        // Same ptype: zero-copy, just update validity.
42        if array.ptype() == new_ptype {
43            // SAFETY: validity and data buffer still have same length
44            return Ok(Some(unsafe {
45                PrimitiveArray::new_unchecked_from_handle(
46                    array.buffer_handle().clone(),
47                    array.ptype(),
48                    new_validity,
49                )
50                .into_array()
51            }));
52        }
53
54        if !values_fit_in(array, new_ptype, ctx) {
55            vortex_bail!(
56                Compute: "Cannot cast {} to {} — values exceed target range",
57                array.ptype(),
58                new_ptype,
59            );
60        }
61
62        // Same-width integers have identical bit representations due to 2's
63        // complement. If all values fit in the target range, reinterpret with
64        // no allocation.
65        if array.ptype().is_int()
66            && new_ptype.is_int()
67            && array.ptype().byte_width() == new_ptype.byte_width()
68        {
69            // SAFETY: both types are integers with the same size and alignment, and
70            // min/max confirm all valid values are representable in the target type.
71            return Ok(Some(unsafe {
72                PrimitiveArray::new_unchecked_from_handle(
73                    array.buffer_handle().clone(),
74                    new_ptype,
75                    new_validity,
76                )
77                .into_array()
78            }));
79        }
80
81        // Otherwise, we need to cast the values one-by-one.
82        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
83            match_each_native_ptype!(array.ptype(), |F| {
84                PrimitiveArray::new(cast::<F, T>(array.as_slice()), new_validity).into_array()
85            })
86        })))
87    }
88}
89
90/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
91fn values_fit_in(
92    array: ArrayView<'_, Primitive>,
93    target_ptype: PType,
94    ctx: &mut ExecutionCtx,
95) -> bool {
96    let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
97    aggregate_fn::fns::min_max::min_max(array.array(), ctx)
98        .ok()
99        .flatten()
100        .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
101}
102
103/// Caller must ensure all valid values are representable via `values_fit_in`.
104/// Out-of-range values at invalid positions are truncated/wrapped by `as`,
105/// which is fine because they are masked out by validity.
106fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
107    BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
108}
109
110#[cfg(test)]
111mod test {
112    use rstest::rstest;
113    use vortex_buffer::BitBuffer;
114    use vortex_buffer::buffer;
115    use vortex_error::VortexError;
116    use vortex_mask::Mask;
117
118    use crate::IntoArray;
119    use crate::LEGACY_SESSION;
120    use crate::VortexSessionExecute;
121    use crate::arrays::PrimitiveArray;
122    use crate::assert_arrays_eq;
123    use crate::builtins::ArrayBuiltins;
124    use crate::canonical::ToCanonical;
125    use crate::compute::conformance::cast::test_cast_conformance;
126    use crate::dtype::DType;
127    use crate::dtype::Nullability;
128    use crate::dtype::PType;
129    use crate::validity::Validity;
130
131    #[test]
132    fn cast_u32_u8() {
133        let arr = buffer![0u32, 10, 200].into_array();
134
135        // cast from u32 to u8
136        let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
137        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
138        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
139
140        // to nullable
141        let p = p
142            .into_array()
143            .cast(DType::Primitive(PType::U8, Nullability::Nullable))
144            .unwrap()
145            .to_primitive();
146        assert_arrays_eq!(
147            p,
148            PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
149        );
150        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
151
152        // back to non-nullable
153        let p = p
154            .into_array()
155            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
156            .unwrap()
157            .to_primitive();
158        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
159        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
160
161        // to nullable u32
162        let p = p
163            .into_array()
164            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
165            .unwrap()
166            .to_primitive();
167        assert_arrays_eq!(
168            p,
169            PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
170        );
171        assert!(matches!(p.validity(), Ok(Validity::AllValid)));
172
173        // to non-nullable u8
174        let p = p
175            .into_array()
176            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
177            .unwrap()
178            .to_primitive();
179        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
180        assert!(matches!(p.validity(), Ok(Validity::NonNullable)));
181    }
182
183    #[test]
184    fn cast_u32_f32() {
185        let arr = buffer![0u32, 10, 200].into_array();
186        let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
187        assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
188    }
189
190    #[test]
191    fn cast_i32_u32() {
192        let arr = buffer![-1i32].into_array();
193        let error = arr
194            .cast(PType::U32.into())
195            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
196            .unwrap_err();
197        assert!(matches!(error, VortexError::Compute(..)));
198        assert!(error.to_string().contains("values exceed target range"));
199    }
200
201    #[test]
202    fn cast_array_with_nulls_to_nonnullable() {
203        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
204        let err = arr
205            .into_array()
206            .cast(PType::I32.into())
207            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
208            .unwrap_err();
209
210        assert!(matches!(err, VortexError::InvalidArgument(..)));
211        assert!(
212            err.to_string()
213                .contains("Cannot cast array with invalid values to non-nullable type.")
214        );
215    }
216
217    #[test]
218    fn cast_with_invalid_nulls() {
219        let arr = PrimitiveArray::new(
220            buffer![-1i32, 0, 10],
221            Validity::from_iter([false, true, true]),
222        );
223        let p = arr
224            .into_array()
225            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
226            .unwrap()
227            .to_primitive();
228        assert_arrays_eq!(
229            p,
230            PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
231        );
232        assert_eq!(
233            p.as_ref()
234                .validity()
235                .unwrap()
236                .to_mask(p.as_ref().len(), &mut LEGACY_SESSION.create_execution_ctx())
237                .unwrap(),
238            Mask::from(BitBuffer::from(vec![false, true, true]))
239        );
240    }
241
242    /// Same-width integer cast where all values fit: should reinterpret the
243    /// buffer without allocation (pointer identity).
244    #[test]
245    fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
246        let src = PrimitiveArray::from_iter([0u32, 10, 100]);
247        let src_ptr = src.as_slice::<u32>().as_ptr();
248
249        let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
250        let dst_ptr = dst.as_slice::<i32>().as_ptr();
251
252        // Zero-copy: the data pointer should be identical.
253        assert_eq!(src_ptr as usize, dst_ptr as usize);
254        assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
255        Ok(())
256    }
257
258    /// Same-width integer cast where values don't fit: should fall through
259    /// to the allocating path and produce an error.
260    #[test]
261    fn cast_same_width_int_out_of_range_errors() {
262        let arr = buffer![u32::MAX].into_array();
263        let err = arr
264            .cast(PType::I32.into())
265            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
266            .unwrap_err();
267        assert!(matches!(err, VortexError::Compute(..)));
268    }
269
270    /// All-null array cast between same-width types should succeed without
271    /// touching the buffer contents.
272    #[test]
273    fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
274        let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
275        let casted = arr
276            .into_array()
277            .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
278            .to_primitive();
279        assert_eq!(casted.len(), 2);
280        assert!(matches!(casted.validity(), Ok(Validity::AllInvalid)));
281        Ok(())
282    }
283
284    /// Same-width integer cast with nullable values: out-of-range nulls should
285    /// not prevent the cast from succeeding.
286    #[test]
287    fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
288        // The null position holds u32::MAX which doesn't fit in i32, but it's
289        // masked as invalid so the cast should still succeed via reinterpret.
290        let arr = PrimitiveArray::new(
291            buffer![u32::MAX, 0u32, 42u32],
292            Validity::from_iter([false, true, true]),
293        );
294        let casted = arr
295            .into_array()
296            .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
297            .to_primitive();
298        assert_arrays_eq!(
299            casted,
300            PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
301        );
302        Ok(())
303    }
304
305    #[test]
306    fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
307        let arr = PrimitiveArray::new(
308            buffer![1000u32, 10u32, 42u32],
309            Validity::from_iter([false, true, true]),
310        );
311        let casted = arr
312            .into_array()
313            .cast(DType::Primitive(PType::U8, Nullability::Nullable))?
314            .to_primitive();
315        assert_arrays_eq!(
316            casted,
317            PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
318        );
319        Ok(())
320    }
321
322    #[rstest]
323    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
324    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
325    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
326    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
327    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
328    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
329    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
330    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
331    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
332    #[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
333    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
334    #[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
335    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
336    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
337    #[case(buffer![42u32].into_array())]
338    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
339        test_cast_conformance(&array);
340    }
341}