vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::{Buffer, BufferMut};
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexResult, vortex_err};
7use vortex_mask::{AllOr, Mask};
8
9use crate::arrays::PrimitiveVTable;
10use crate::arrays::primitive::PrimitiveArray;
11use crate::compute::{CastKernel, CastKernelAdapter};
12use crate::vtable::ValidityHelper;
13use crate::{ArrayRef, IntoArray, register_kernel};
14
15impl CastKernel for PrimitiveVTable {
16    fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        let DType::Primitive(new_ptype, new_nullability) = dtype else {
18            return Ok(None);
19        };
20        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
21
22        // First, check that the cast is compatible with the source array's validity
23        let new_validity = array
24            .validity()
25            .clone()
26            .cast_nullability(new_nullability, array.len())?;
27
28        // If the bit width is the same, we can short-circuit and simply update the validity
29        if array.ptype() == new_ptype {
30            return Ok(Some(
31                PrimitiveArray::from_byte_buffer(
32                    array.byte_buffer().clone(),
33                    array.ptype(),
34                    new_validity,
35                )
36                .into_array(),
37            ));
38        }
39
40        let mask = array.validity_mask();
41
42        // Otherwise, we need to cast the values one-by-one
43        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
44            match_each_native_ptype!(array.ptype(), |F| {
45                PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
46                    .into_array()
47            })
48        })))
49    }
50}
51
52register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
53
54fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
55    match mask.bit_buffer() {
56        AllOr::All => {
57            let mut buffer = BufferMut::with_capacity(array.len());
58            for item in array {
59                let item = T::from(*item).ok_or_else(
60                    || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
61                )?;
62                // SAFETY: we've pre-allocated the required capacity
63                unsafe { buffer.push_unchecked(item) }
64            }
65            Ok(buffer.freeze())
66        }
67        AllOr::None => Ok(Buffer::zeroed(array.len())),
68        AllOr::Some(b) => {
69            // TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
70            let mut buffer = BufferMut::with_capacity(array.len());
71            for (item, valid) in array.iter().zip(b.iter()) {
72                if valid {
73                    let item = T::from(*item).ok_or_else(
74                        || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
75                    )?;
76                    // SAFETY: we've pre-allocated the required capacity
77                    unsafe { buffer.push_unchecked(item) }
78                } else {
79                    // SAFETY: we've pre-allocated the required capacity
80                    unsafe { buffer.push_unchecked(T::default()) }
81                }
82            }
83            Ok(buffer.freeze())
84        }
85    }
86}
87
88#[cfg(test)]
89mod test {
90    use rstest::rstest;
91    use vortex_buffer::{BitBuffer, buffer};
92    use vortex_dtype::{DType, Nullability, PType};
93    use vortex_error::VortexError;
94    use vortex_mask::Mask;
95
96    use crate::IntoArray;
97    use crate::arrays::PrimitiveArray;
98    use crate::canonical::ToCanonical;
99    use crate::compute::cast;
100    use crate::compute::conformance::cast::test_cast_conformance;
101    use crate::validity::Validity;
102    use crate::vtable::ValidityHelper;
103
104    #[test]
105    fn cast_u32_u8() {
106        let arr = buffer![0u32, 10, 200].into_array();
107
108        // cast from u32 to u8
109        let p = cast(&arr, PType::U8.into()).unwrap().to_primitive();
110        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
111        assert_eq!(p.validity(), &Validity::NonNullable);
112
113        // to nullable
114        let p = cast(
115            p.as_ref(),
116            &DType::Primitive(PType::U8, Nullability::Nullable),
117        )
118        .unwrap()
119        .to_primitive();
120        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
121        assert_eq!(p.validity(), &Validity::AllValid);
122
123        // back to non-nullable
124        let p = cast(
125            p.as_ref(),
126            &DType::Primitive(PType::U8, Nullability::NonNullable),
127        )
128        .unwrap()
129        .to_primitive();
130        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
131        assert_eq!(p.validity(), &Validity::NonNullable);
132
133        // to nullable u32
134        let p = cast(
135            p.as_ref(),
136            &DType::Primitive(PType::U32, Nullability::Nullable),
137        )
138        .unwrap()
139        .to_primitive();
140        assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
141        assert_eq!(p.validity(), &Validity::AllValid);
142
143        // to non-nullable u8
144        let p = cast(
145            p.as_ref(),
146            &DType::Primitive(PType::U8, Nullability::NonNullable),
147        )
148        .unwrap()
149        .to_primitive();
150        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
151        assert_eq!(p.validity(), &Validity::NonNullable);
152    }
153
154    #[test]
155    fn cast_u32_f32() {
156        let arr = buffer![0u32, 10, 200].into_array();
157        let u8arr = cast(&arr, PType::F32.into()).unwrap().to_primitive();
158        assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
159    }
160
161    #[test]
162    fn cast_i32_u32() {
163        let arr = buffer![-1i32].into_array();
164        let error = cast(&arr, PType::U32.into()).err().unwrap();
165        let VortexError::ComputeError(s, _) = error else {
166            unreachable!()
167        };
168        assert_eq!(s.to_string(), "Failed to cast -1 to U32");
169    }
170
171    #[test]
172    fn cast_array_with_nulls_to_nonnullable() {
173        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
174        let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
175        let VortexError::InvalidArgument(s, _) = err else {
176            unreachable!()
177        };
178        assert_eq!(
179            s.to_string(),
180            "Cannot cast array with invalid values to non-nullable type."
181        );
182    }
183
184    #[test]
185    fn cast_with_invalid_nulls() {
186        let arr = PrimitiveArray::new(
187            buffer![-1i32, 0, 10],
188            Validity::from_iter([false, true, true]),
189        );
190        let p = cast(
191            arr.as_ref(),
192            &DType::Primitive(PType::U32, Nullability::Nullable),
193        )
194        .unwrap()
195        .to_primitive();
196        assert_eq!(p.as_slice::<u32>(), vec![0, 0, 10]);
197        assert_eq!(
198            p.validity_mask(),
199            Mask::from(BitBuffer::from(vec![false, true, true]))
200        );
201    }
202
203    #[rstest]
204    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
205    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
206    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
207    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
208    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
209    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
210    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
211    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
212    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
213    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
214    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
215    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
216    #[case(buffer![42u32].into_array())]
217    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
218        test_cast_conformance(array.as_ref());
219    }
220}