vortex_array/arrays/primitive/array/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::{DType, NativePType, PType};
6use vortex_error::{VortexResult, vortex_panic};
7
8use crate::ToCanonical;
9use crate::arrays::PrimitiveArray;
10use crate::compute::{cast, min_max};
11use crate::vtable::ValidityHelper;
12
13impl PrimitiveArray {
14    /// Return a slice of the array's buffer.
15    ///
16    /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null.
17    pub fn as_slice<T: NativePType>(&self) -> &[T] {
18        if T::PTYPE != self.ptype() {
19            vortex_panic!(
20                "Attempted to get slice of type {} from array of type {}",
21                T::PTYPE,
22                self.ptype()
23            )
24        }
25        let raw_slice = self.byte_buffer().as_ptr();
26        // SAFETY: alignment of Buffer is checked on construction
27        unsafe {
28            std::slice::from_raw_parts(raw_slice.cast(), self.byte_buffer().len() / size_of::<T>())
29        }
30    }
31
32    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
33        if self.ptype() == ptype {
34            return self.clone();
35        }
36
37        assert_eq!(
38            self.ptype().byte_width(),
39            ptype.byte_width(),
40            "can't reinterpret cast between integers of two different widths"
41        );
42
43        PrimitiveArray::from_byte_buffer(self.byte_buffer().clone(), ptype, self.validity().clone())
44    }
45
46    /// Narrow the array to the smallest possible integer type that can represent all values.
47    pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
48        if !self.ptype().is_int() {
49            return Ok(self.clone());
50        }
51
52        let Some(min_max) = min_max(self.as_ref())? else {
53            return Ok(PrimitiveArray::new(
54                Buffer::<u8>::zeroed(self.len()),
55                self.validity.clone(),
56            ));
57        };
58
59        // If we can't cast to i64, then leave the array as its original type.
60        // It's too big to downcast anyway.
61        let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
62            return Ok(self.clone());
63        };
64        let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
65            return Ok(self.clone());
66        };
67
68        if min < 0 || max < 0 {
69            // Signed
70            if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
71                return Ok(cast(
72                    self.as_ref(),
73                    &DType::Primitive(PType::I8, self.dtype().nullability()),
74                )?
75                .to_primitive());
76            }
77
78            if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
79                return Ok(cast(
80                    self.as_ref(),
81                    &DType::Primitive(PType::I16, self.dtype().nullability()),
82                )?
83                .to_primitive());
84            }
85
86            if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
87                return Ok(cast(
88                    self.as_ref(),
89                    &DType::Primitive(PType::I32, self.dtype().nullability()),
90                )?
91                .to_primitive());
92            }
93        } else {
94            // Unsigned
95            if max <= u8::MAX as i64 {
96                return Ok(cast(
97                    self.as_ref(),
98                    &DType::Primitive(PType::U8, self.dtype().nullability()),
99                )?
100                .to_primitive());
101            }
102
103            if max <= u16::MAX as i64 {
104                return Ok(cast(
105                    self.as_ref(),
106                    &DType::Primitive(PType::U16, self.dtype().nullability()),
107                )?
108                .to_primitive());
109            }
110
111            if max <= u32::MAX as i64 {
112                return Ok(cast(
113                    self.as_ref(),
114                    &DType::Primitive(PType::U32, self.dtype().nullability()),
115                )?
116                .to_primitive());
117            }
118        }
119
120        Ok(self.clone())
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use rstest::rstest;
127    use vortex_buffer::{Buffer, buffer};
128    use vortex_dtype::{DType, Nullability, PType};
129
130    use crate::arrays::PrimitiveArray;
131    use crate::validity::Validity;
132
133    #[test]
134    fn test_downcast_all_invalid() {
135        let array = PrimitiveArray::new(
136            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
137            Validity::AllInvalid,
138        );
139
140        let result = array.narrow().unwrap();
141        assert_eq!(
142            result.dtype(),
143            &DType::Primitive(PType::U8, Nullability::Nullable)
144        );
145        assert_eq!(result.validity, Validity::AllInvalid);
146    }
147
148    #[rstest]
149    #[case(vec![0_i64, 127], PType::U8)]
150    #[case(vec![-128_i64, 127], PType::I8)]
151    #[case(vec![-129_i64, 127], PType::I16)]
152    #[case(vec![-128_i64, 128], PType::I16)]
153    #[case(vec![-32768_i64, 32767], PType::I16)]
154    #[case(vec![-32769_i64, 32767], PType::I32)]
155    #[case(vec![-32768_i64, 32768], PType::I32)]
156    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
157    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
158        let array = PrimitiveArray::from_iter(values);
159        let result = array.narrow().unwrap();
160        assert_eq!(result.ptype(), expected_ptype);
161    }
162
163    #[rstest]
164    #[case(vec![0_u64, 255], PType::U8)]
165    #[case(vec![0_u64, 256], PType::U16)]
166    #[case(vec![0_u64, 65535], PType::U16)]
167    #[case(vec![0_u64, 65536], PType::U32)]
168    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
169    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
170        let array = PrimitiveArray::from_iter(values);
171        let result = array.narrow().unwrap();
172        assert_eq!(result.ptype(), expected_ptype);
173    }
174
175    #[test]
176    fn test_downcast_keeps_original_if_too_large() {
177        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
178        let result = array.narrow().unwrap();
179        assert_eq!(result.ptype(), PType::U64);
180    }
181
182    #[test]
183    fn test_downcast_preserves_nullability() {
184        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
185        let result = array.narrow().unwrap();
186        assert_eq!(
187            result.dtype(),
188            &DType::Primitive(PType::U8, Nullability::Nullable)
189        );
190        // Check that validity is preserved (the array should still have nullable values)
191        assert!(matches!(&result.validity, Validity::Array(_)));
192    }
193
194    #[test]
195    fn test_downcast_preserves_values() {
196        let values = vec![-100_i16, 0, 100];
197        let array = PrimitiveArray::from_iter(values);
198        let result = array.narrow().unwrap();
199
200        assert_eq!(result.ptype(), PType::I8);
201        // Check that the values were properly downscaled
202        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
203        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
204    }
205
206    #[test]
207    fn test_downcast_with_mixed_signs_chooses_signed() {
208        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
209        let result = array.narrow().unwrap();
210        assert_eq!(result.ptype(), PType::I16);
211    }
212
213    #[test]
214    fn test_downcast_floats() {
215        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
216        let result = array.narrow().unwrap();
217        // Floats should remain unchanged since they can't be downscaled to integers
218        assert_eq!(result.ptype(), PType::F32);
219    }
220
221    #[test]
222    fn test_downcast_empty_array() {
223        let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
224        let result = array.narrow().unwrap();
225        let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
226        let result2 = array2.narrow().unwrap();
227        // Empty arrays should not have their validity changed
228        assert_eq!(result.validity, Validity::AllInvalid);
229        assert_eq!(result2.validity, Validity::NonNullable);
230    }
231}