Skip to main content

vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6use vortex_error::VortexResult;
7use vortex_error::vortex_bail;
8use vortex_error::vortex_err;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::aggregate_fn;
16use crate::arrays::Primitive;
17use crate::arrays::PrimitiveArray;
18use crate::dtype::DType;
19use crate::dtype::NativePType;
20use crate::dtype::Nullability;
21use crate::dtype::PType;
22use crate::match_each_native_ptype;
23use crate::scalar_fn::fns::cast::CastKernel;
24use crate::vtable::ValidityHelper;
25
26impl CastKernel for Primitive {
27    fn cast(
28        array: &PrimitiveArray,
29        dtype: &DType,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let DType::Primitive(new_ptype, new_nullability) = dtype else {
33            return Ok(None);
34        };
35        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
36
37        // First, check that the cast is compatible with the source array's validity
38        let new_validity = array
39            .validity()
40            .clone()
41            .cast_nullability(new_nullability, array.len())?;
42
43        // Same ptype: zero-copy, just update validity.
44        if array.ptype() == new_ptype {
45            // SAFETY: validity and data buffer still have same length
46            return Ok(Some(unsafe {
47                PrimitiveArray::new_unchecked_from_handle(
48                    array.buffer_handle().clone(),
49                    array.ptype(),
50                    new_validity,
51                )
52                .into_array()
53            }));
54        }
55
56        // Same-width integers have identical bit representations due to 2's
57        // complement. If all values fit in the target range, reinterpret with
58        // no allocation.
59        if array.ptype().is_int()
60            && new_ptype.is_int()
61            && array.ptype().byte_width() == new_ptype.byte_width()
62        {
63            if !values_fit_in(array, new_ptype, ctx) {
64                vortex_bail!(
65                    Compute: "Cannot cast {} to {} — values exceed target range",
66                    array.ptype(),
67                    new_ptype,
68                );
69            }
70            // SAFETY: both types are integers with the same size and alignment, and
71            // min/max confirm all valid values are representable in the target type.
72            return Ok(Some(unsafe {
73                PrimitiveArray::new_unchecked_from_handle(
74                    array.buffer_handle().clone(),
75                    new_ptype,
76                    new_validity,
77                )
78                .into_array()
79            }));
80        }
81
82        let mask = array.validity_mask()?;
83
84        // Otherwise, we need to cast the values one-by-one.
85        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
86            match_each_native_ptype!(array.ptype(), |F| {
87                PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
88                    .into_array()
89            })
90        })))
91    }
92}
93
94/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
95fn values_fit_in(array: &PrimitiveArray, target_ptype: PType, ctx: &mut ExecutionCtx) -> bool {
96    let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
97    aggregate_fn::fns::min_max::min_max(&array.clone().into_array(), ctx)
98        .ok()
99        .flatten()
100        .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
101}
102
103fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
104    let try_cast = |src: F| -> VortexResult<T> {
105        T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE))
106    };
107    match mask.bit_buffer() {
108        AllOr::None => Ok(Buffer::zeroed(array.len())),
109        AllOr::All => {
110            let mut buffer = BufferMut::with_capacity(array.len());
111            for &src in array {
112                // SAFETY: we've pre-allocated the required capacity
113                unsafe { buffer.push_unchecked(try_cast(src)?) }
114            }
115            Ok(buffer.freeze())
116        }
117        AllOr::Some(b) => {
118            let mut buffer = BufferMut::with_capacity(array.len());
119            for (&src, valid) in array.iter().zip(b.iter()) {
120                let dst = if valid { try_cast(src)? } else { T::default() };
121                // SAFETY: we've pre-allocated the required capacity
122                unsafe { buffer.push_unchecked(dst) }
123            }
124            Ok(buffer.freeze())
125        }
126    }
127}
128
129#[cfg(test)]
130mod test {
131    use rstest::rstest;
132    use vortex_buffer::BitBuffer;
133    use vortex_buffer::buffer;
134    use vortex_error::VortexError;
135    use vortex_mask::Mask;
136
137    use crate::IntoArray;
138    use crate::arrays::PrimitiveArray;
139    use crate::assert_arrays_eq;
140    use crate::builtins::ArrayBuiltins;
141    use crate::canonical::ToCanonical;
142    use crate::compute::conformance::cast::test_cast_conformance;
143    use crate::dtype::DType;
144    use crate::dtype::Nullability;
145    use crate::dtype::PType;
146    use crate::validity::Validity;
147    use crate::vtable::ValidityHelper;
148
149    #[allow(clippy::cognitive_complexity)]
150    #[test]
151    fn cast_u32_u8() {
152        let arr = buffer![0u32, 10, 200].into_array();
153
154        // cast from u32 to u8
155        let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
156        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
157        assert!(matches!(p.validity(), Validity::NonNullable));
158
159        // to nullable
160        let p = p
161            .into_array()
162            .cast(DType::Primitive(PType::U8, Nullability::Nullable))
163            .unwrap()
164            .to_primitive();
165        assert_arrays_eq!(
166            p,
167            PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
168        );
169        assert!(matches!(p.validity(), Validity::AllValid));
170
171        // back to non-nullable
172        let p = p
173            .into_array()
174            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
175            .unwrap()
176            .to_primitive();
177        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
178        assert!(matches!(p.validity(), Validity::NonNullable));
179
180        // to nullable u32
181        let p = p
182            .into_array()
183            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
184            .unwrap()
185            .to_primitive();
186        assert_arrays_eq!(
187            p,
188            PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
189        );
190        assert!(matches!(p.validity(), Validity::AllValid));
191
192        // to non-nullable u8
193        let p = p
194            .into_array()
195            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
196            .unwrap()
197            .to_primitive();
198        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
199        assert!(matches!(p.validity(), Validity::NonNullable));
200    }
201
202    #[test]
203    fn cast_u32_f32() {
204        let arr = buffer![0u32, 10, 200].into_array();
205        let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
206        assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
207    }
208
209    #[test]
210    fn cast_i32_u32() {
211        let arr = buffer![-1i32].into_array();
212        let error = arr
213            .cast(PType::U32.into())
214            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
215            .unwrap_err();
216        assert!(matches!(error, VortexError::Compute(..)));
217        assert!(error.to_string().contains("values exceed target range"));
218    }
219
220    #[test]
221    fn cast_array_with_nulls_to_nonnullable() {
222        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
223        let err = arr
224            .into_array()
225            .cast(PType::I32.into())
226            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
227            .unwrap_err();
228
229        assert!(matches!(err, VortexError::InvalidArgument(..)));
230        assert!(
231            err.to_string()
232                .contains("Cannot cast array with invalid values to non-nullable type.")
233        );
234    }
235
236    #[test]
237    fn cast_with_invalid_nulls() {
238        let arr = PrimitiveArray::new(
239            buffer![-1i32, 0, 10],
240            Validity::from_iter([false, true, true]),
241        );
242        let p = arr
243            .into_array()
244            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
245            .unwrap()
246            .to_primitive();
247        assert_arrays_eq!(
248            p,
249            PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
250        );
251        assert_eq!(
252            p.validity_mask().unwrap(),
253            Mask::from(BitBuffer::from(vec![false, true, true]))
254        );
255    }
256
257    /// Same-width integer cast where all values fit: should reinterpret the
258    /// buffer without allocation (pointer identity).
259    #[test]
260    fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
261        let src = PrimitiveArray::from_iter([0u32, 10, 100]);
262        let src_ptr = src.as_slice::<u32>().as_ptr();
263
264        let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
265        let dst_ptr = dst.as_slice::<i32>().as_ptr();
266
267        // Zero-copy: the data pointer should be identical.
268        assert_eq!(src_ptr as usize, dst_ptr as usize);
269        assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
270        Ok(())
271    }
272
273    /// Same-width integer cast where values don't fit: should fall through
274    /// to the allocating path and produce an error.
275    #[test]
276    fn cast_same_width_int_out_of_range_errors() {
277        let arr = buffer![u32::MAX].into_array();
278        let err = arr
279            .cast(PType::I32.into())
280            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
281            .unwrap_err();
282        assert!(matches!(err, VortexError::Compute(..)));
283    }
284
285    /// All-null array cast between same-width types should succeed without
286    /// touching the buffer contents.
287    #[test]
288    fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
289        let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
290        let casted = arr
291            .into_array()
292            .cast(DType::Primitive(PType::I8, Nullability::Nullable))?
293            .to_primitive();
294        assert_eq!(casted.len(), 2);
295        assert!(matches!(casted.validity(), Validity::AllInvalid));
296        Ok(())
297    }
298
299    /// Same-width integer cast with nullable values: out-of-range nulls should
300    /// not prevent the cast from succeeding.
301    #[test]
302    fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
303        // The null position holds u32::MAX which doesn't fit in i32, but it's
304        // masked as invalid so the cast should still succeed via reinterpret.
305        let arr = PrimitiveArray::new(
306            buffer![u32::MAX, 0u32, 42u32],
307            Validity::from_iter([false, true, true]),
308        );
309        let casted = arr
310            .into_array()
311            .cast(DType::Primitive(PType::I32, Nullability::Nullable))?
312            .to_primitive();
313        assert_arrays_eq!(
314            casted,
315            PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
316        );
317        Ok(())
318    }
319
320    #[rstest]
321    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
322    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
323    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
324    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
325    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
326    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
327    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
328    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
329    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
330    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
331    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
332    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
333    #[case(buffer![42u32].into_array())]
334    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
335        test_cast_conformance(&array);
336    }
337}