vortex_array/arrays/primitive/compute/
cast.rs

1use vortex_buffer::{Buffer, BufferMut};
2use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
3use vortex_error::{VortexResult, vortex_bail, vortex_err};
4
5use crate::arrays::PrimitiveVTable;
6use crate::arrays::primitive::PrimitiveArray;
7use crate::compute::{CastKernel, CastKernelAdapter};
8use crate::validity::Validity;
9use crate::vtable::ValidityHelper;
10use crate::{ArrayRef, IntoArray, register_kernel};
11
12impl CastKernel for PrimitiveVTable {
13    fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<ArrayRef> {
14        let DType::Primitive(new_ptype, new_nullability) = dtype else {
15            vortex_bail!(MismatchedTypes: "primitive type", dtype);
16        };
17        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
18
19        // First, check that the cast is compatible with the source array's validity
20        let new_validity = if array.dtype().nullability() == new_nullability {
21            array.validity().clone()
22        } else if new_nullability == Nullability::Nullable {
23            // from non-nullable to nullable
24            array.validity().clone().into_nullable()
25        } else if new_nullability == Nullability::NonNullable && array.validity().all_valid()? {
26            // from nullable but all valid, to non-nullable
27            Validity::NonNullable
28        } else {
29            vortex_bail!(
30                "invalid cast from nullable to non-nullable, since source array actually contains nulls"
31            );
32        };
33
34        // If the bit width is the same, we can short-circuit and simply update the validity
35        if array.ptype() == new_ptype {
36            return Ok(PrimitiveArray::from_byte_buffer(
37                array.byte_buffer().clone(),
38                array.ptype(),
39                new_validity,
40            )
41            .into_array());
42        }
43
44        // Otherwise, we need to cast the values one-by-one
45        match_each_native_ptype!(new_ptype, |T| {
46            Ok(PrimitiveArray::new(cast::<T>(array)?, new_validity).into_array())
47        })
48    }
49}
50
51register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
52
53fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> {
54    let mut buffer = BufferMut::with_capacity(array.len());
55    match_each_native_ptype!(array.ptype(), |P| {
56        for item in array.as_slice::<P>() {
57            let item = T::from(*item).ok_or_else(
58                || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
59            )?;
60            // SAFETY: we've pre-allocated the required capacity
61            unsafe { buffer.push_unchecked(item) }
62        }
63    });
64    Ok(buffer.freeze())
65}
66
67#[cfg(test)]
68mod test {
69    use vortex_buffer::buffer;
70    use vortex_dtype::{DType, Nullability, PType};
71    use vortex_error::VortexError;
72
73    use crate::IntoArray;
74    use crate::arrays::PrimitiveArray;
75    use crate::canonical::ToCanonical;
76    use crate::compute::cast;
77    use crate::validity::Validity;
78    use crate::vtable::ValidityHelper;
79
80    #[test]
81    fn cast_u32_u8() {
82        let arr = buffer![0u32, 10, 200].into_array();
83
84        // cast from u32 to u8
85        let p = cast(&arr, PType::U8.into())
86            .unwrap()
87            .to_primitive()
88            .unwrap();
89        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
90        assert_eq!(p.validity(), &Validity::NonNullable);
91
92        // to nullable
93        let p = cast(
94            p.as_ref(),
95            &DType::Primitive(PType::U8, Nullability::Nullable),
96        )
97        .unwrap()
98        .to_primitive()
99        .unwrap();
100        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
101        assert_eq!(p.validity(), &Validity::AllValid);
102
103        // back to non-nullable
104        let p = cast(
105            p.as_ref(),
106            &DType::Primitive(PType::U8, Nullability::NonNullable),
107        )
108        .unwrap()
109        .to_primitive()
110        .unwrap();
111        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
112        assert_eq!(p.validity(), &Validity::NonNullable);
113
114        // to nullable u32
115        let p = cast(
116            p.as_ref(),
117            &DType::Primitive(PType::U32, Nullability::Nullable),
118        )
119        .unwrap()
120        .to_primitive()
121        .unwrap();
122        assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
123        assert_eq!(p.validity(), &Validity::AllValid);
124
125        // to non-nullable u8
126        let p = cast(
127            p.as_ref(),
128            &DType::Primitive(PType::U8, Nullability::NonNullable),
129        )
130        .unwrap()
131        .to_primitive()
132        .unwrap();
133        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
134        assert_eq!(p.validity(), &Validity::NonNullable);
135    }
136
137    #[test]
138    fn cast_u32_f32() {
139        let arr = buffer![0u32, 10, 200].into_array();
140        let u8arr = cast(&arr, PType::F32.into())
141            .unwrap()
142            .to_primitive()
143            .unwrap();
144        assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
145    }
146
147    #[test]
148    fn cast_i32_u32() {
149        let arr = buffer![-1i32].into_array();
150        let error = cast(&arr, PType::U32.into()).err().unwrap();
151        let VortexError::ComputeError(s, _) = error else {
152            unreachable!()
153        };
154        assert_eq!(s.to_string(), "Failed to cast -1 to U32");
155    }
156
157    #[test]
158    fn cast_array_with_nulls_to_nonnullable() {
159        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
160        let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
161        let VortexError::InvalidArgument(s, _) = err else {
162            unreachable!()
163        };
164        assert_eq!(
165            s.to_string(),
166            "invalid cast from nullable to non-nullable, since source array actually contains nulls"
167        );
168    }
169}