vortex_array/arrays/primitive/compute/
cast.rs

1use vortex_buffer::{Buffer, BufferMut};
2use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
3use vortex_error::{VortexResult, vortex_bail, vortex_err};
4
5use crate::arrays::PrimitiveEncoding;
6use crate::arrays::primitive::PrimitiveArray;
7use crate::compute::CastFn;
8use crate::validity::Validity;
9use crate::variants::PrimitiveArrayTrait;
10use crate::{Array, ArrayRef};
11
12impl CastFn<&PrimitiveArray> for PrimitiveEncoding {
13    fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<ArrayRef> {
14        let DType::Primitive(new_ptype, new_nullability) = dtype else {
15            vortex_bail!(MismatchedTypes: "primitive type", dtype);
16        };
17        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
18
19        // First, check that the cast is compatible with the source array's validity
20        let new_validity = if array.dtype().nullability() == new_nullability {
21            array.validity().clone()
22        } else if new_nullability == Nullability::Nullable {
23            // from non-nullable to nullable
24            array.validity().clone().into_nullable()
25        } else if new_nullability == Nullability::NonNullable
26            && array.validity().to_logical(array.len())?.all_true()
27        {
28            // from nullable but all valid, to non-nullable
29            Validity::NonNullable
30        } else {
31            vortex_bail!(
32                "invalid cast from nullable to non-nullable, since source array actually contains nulls"
33            );
34        };
35
36        // If the bit width is the same, we can short-circuit and simply update the validity
37        if array.ptype() == new_ptype {
38            return Ok(PrimitiveArray::from_byte_buffer(
39                array.byte_buffer().clone(),
40                array.ptype(),
41                new_validity,
42            )
43            .into_array());
44        }
45
46        // Otherwise, we need to cast the values one-by-one
47        match_each_native_ptype!(new_ptype, |$T| {
48            Ok(PrimitiveArray::new(
49                cast::<$T>(array)?,
50                new_validity,
51            ).into_array())
52        })
53    }
54}
55
56fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> {
57    let mut buffer = BufferMut::with_capacity(array.len());
58    match_each_native_ptype!(array.ptype(), |$P| {
59        for item in array.as_slice::<$P>() {
60            let item = T::from(*item).ok_or_else(
61                || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
62            )?;
63            // SAFETY: we've pre-allocated the required capacity
64            unsafe { buffer.push_unchecked(item) }
65        }
66    });
67    Ok(buffer.freeze())
68}
69
70#[cfg(test)]
71mod test {
72    use vortex_buffer::buffer;
73    use vortex_dtype::{DType, Nullability, PType};
74    use vortex_error::VortexError;
75
76    use crate::IntoArray;
77    use crate::arrays::PrimitiveArray;
78    use crate::canonical::ToCanonical;
79    use crate::compute::try_cast;
80    use crate::validity::Validity;
81
82    #[test]
83    fn cast_u32_u8() {
84        let arr = buffer![0u32, 10, 200].into_array();
85
86        // cast from u32 to u8
87        let p = try_cast(&arr, PType::U8.into())
88            .unwrap()
89            .to_primitive()
90            .unwrap();
91        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
92        assert_eq!(p.validity(), &Validity::NonNullable);
93
94        // to nullable
95        let p = try_cast(&p, &DType::Primitive(PType::U8, Nullability::Nullable))
96            .unwrap()
97            .to_primitive()
98            .unwrap();
99        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
100        assert_eq!(p.validity(), &Validity::AllValid);
101
102        // back to non-nullable
103        let p = try_cast(&p, &DType::Primitive(PType::U8, Nullability::NonNullable))
104            .unwrap()
105            .to_primitive()
106            .unwrap();
107        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
108        assert_eq!(p.validity(), &Validity::NonNullable);
109
110        // to nullable u32
111        let p = try_cast(&p, &DType::Primitive(PType::U32, Nullability::Nullable))
112            .unwrap()
113            .to_primitive()
114            .unwrap();
115        assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
116        assert_eq!(p.validity(), &Validity::AllValid);
117
118        // to non-nullable u8
119        let p = try_cast(&p, &DType::Primitive(PType::U8, Nullability::NonNullable))
120            .unwrap()
121            .to_primitive()
122            .unwrap();
123        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
124        assert_eq!(p.validity(), &Validity::NonNullable);
125    }
126
127    #[test]
128    fn cast_u32_f32() {
129        let arr = buffer![0u32, 10, 200].into_array();
130        let u8arr = try_cast(&arr, PType::F32.into())
131            .unwrap()
132            .to_primitive()
133            .unwrap();
134        assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
135    }
136
137    #[test]
138    fn cast_i32_u32() {
139        let arr = buffer![-1i32].into_array();
140        let error = try_cast(&arr, PType::U32.into()).err().unwrap();
141        let VortexError::ComputeError(s, _) = error else {
142            unreachable!()
143        };
144        assert_eq!(s.to_string(), "Failed to cast -1 to U32");
145    }
146
147    #[test]
148    fn cast_array_with_nulls_to_nonnullable() {
149        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
150        let err = try_cast(&arr, PType::I32.into()).unwrap_err();
151        let VortexError::InvalidArgument(s, _) = err else {
152            unreachable!()
153        };
154        assert_eq!(
155            s.to_string(),
156            "invalid cast from nullable to non-nullable, since source array actually contains nulls"
157        );
158    }
159}