vortex_array/arrays/primitive/array/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::DType;
6use vortex_dtype::NativePType;
7use vortex_dtype::PType;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10
11use crate::ToCanonical;
12use crate::arrays::PrimitiveArray;
13use crate::compute::cast;
14use crate::compute::min_max;
15use crate::vtable::ValidityHelper;
16
17impl PrimitiveArray {
18    /// Return a slice of the array's buffer.
19    ///
20    /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null.
21    pub fn as_slice<T: NativePType>(&self) -> &[T] {
22        if T::PTYPE != self.ptype() {
23            vortex_panic!(
24                "Attempted to get slice of type {} from array of type {}",
25                T::PTYPE,
26                self.ptype()
27            )
28        }
29        let raw_slice = self.byte_buffer().as_ptr();
30        // SAFETY: alignment of Buffer is checked on construction
31        unsafe {
32            std::slice::from_raw_parts(raw_slice.cast(), self.byte_buffer().len() / size_of::<T>())
33        }
34    }
35
36    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
37        if self.ptype() == ptype {
38            return self.clone();
39        }
40
41        assert_eq!(
42            self.ptype().byte_width(),
43            ptype.byte_width(),
44            "can't reinterpret cast between integers of two different widths"
45        );
46
47        PrimitiveArray::from_byte_buffer(self.byte_buffer().clone(), ptype, self.validity().clone())
48    }
49
50    /// Narrow the array to the smallest possible integer type that can represent all values.
51    pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
52        if !self.ptype().is_int() {
53            return Ok(self.clone());
54        }
55
56        let Some(min_max) = min_max(self.as_ref())? else {
57            return Ok(PrimitiveArray::new(
58                Buffer::<u8>::zeroed(self.len()),
59                self.validity.clone(),
60            ));
61        };
62
63        // If we can't cast to i64, then leave the array as its original type.
64        // It's too big to downcast anyway.
65        let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
66            return Ok(self.clone());
67        };
68        let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
69            return Ok(self.clone());
70        };
71
72        if min < 0 || max < 0 {
73            // Signed
74            if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
75                return Ok(cast(
76                    self.as_ref(),
77                    &DType::Primitive(PType::I8, self.dtype().nullability()),
78                )?
79                .to_primitive());
80            }
81
82            if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
83                return Ok(cast(
84                    self.as_ref(),
85                    &DType::Primitive(PType::I16, self.dtype().nullability()),
86                )?
87                .to_primitive());
88            }
89
90            if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
91                return Ok(cast(
92                    self.as_ref(),
93                    &DType::Primitive(PType::I32, self.dtype().nullability()),
94                )?
95                .to_primitive());
96            }
97        } else {
98            // Unsigned
99            if max <= u8::MAX as i64 {
100                return Ok(cast(
101                    self.as_ref(),
102                    &DType::Primitive(PType::U8, self.dtype().nullability()),
103                )?
104                .to_primitive());
105            }
106
107            if max <= u16::MAX as i64 {
108                return Ok(cast(
109                    self.as_ref(),
110                    &DType::Primitive(PType::U16, self.dtype().nullability()),
111                )?
112                .to_primitive());
113            }
114
115            if max <= u32::MAX as i64 {
116                return Ok(cast(
117                    self.as_ref(),
118                    &DType::Primitive(PType::U32, self.dtype().nullability()),
119                )?
120                .to_primitive());
121            }
122        }
123
124        Ok(self.clone())
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use rstest::rstest;
131    use vortex_buffer::Buffer;
132    use vortex_buffer::buffer;
133    use vortex_dtype::DType;
134    use vortex_dtype::Nullability;
135    use vortex_dtype::PType;
136
137    use crate::arrays::PrimitiveArray;
138    use crate::validity::Validity;
139
140    #[test]
141    fn test_downcast_all_invalid() {
142        let array = PrimitiveArray::new(
143            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
144            Validity::AllInvalid,
145        );
146
147        let result = array.narrow().unwrap();
148        assert_eq!(
149            result.dtype(),
150            &DType::Primitive(PType::U8, Nullability::Nullable)
151        );
152        assert_eq!(result.validity, Validity::AllInvalid);
153    }
154
155    #[rstest]
156    #[case(vec![0_i64, 127], PType::U8)]
157    #[case(vec![-128_i64, 127], PType::I8)]
158    #[case(vec![-129_i64, 127], PType::I16)]
159    #[case(vec![-128_i64, 128], PType::I16)]
160    #[case(vec![-32768_i64, 32767], PType::I16)]
161    #[case(vec![-32769_i64, 32767], PType::I32)]
162    #[case(vec![-32768_i64, 32768], PType::I32)]
163    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
164    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
165        let array = PrimitiveArray::from_iter(values);
166        let result = array.narrow().unwrap();
167        assert_eq!(result.ptype(), expected_ptype);
168    }
169
170    #[rstest]
171    #[case(vec![0_u64, 255], PType::U8)]
172    #[case(vec![0_u64, 256], PType::U16)]
173    #[case(vec![0_u64, 65535], PType::U16)]
174    #[case(vec![0_u64, 65536], PType::U32)]
175    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
176    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
177        let array = PrimitiveArray::from_iter(values);
178        let result = array.narrow().unwrap();
179        assert_eq!(result.ptype(), expected_ptype);
180    }
181
182    #[test]
183    fn test_downcast_keeps_original_if_too_large() {
184        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
185        let result = array.narrow().unwrap();
186        assert_eq!(result.ptype(), PType::U64);
187    }
188
189    #[test]
190    fn test_downcast_preserves_nullability() {
191        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
192        let result = array.narrow().unwrap();
193        assert_eq!(
194            result.dtype(),
195            &DType::Primitive(PType::U8, Nullability::Nullable)
196        );
197        // Check that validity is preserved (the array should still have nullable values)
198        assert!(matches!(&result.validity, Validity::Array(_)));
199    }
200
201    #[test]
202    fn test_downcast_preserves_values() {
203        let values = vec![-100_i16, 0, 100];
204        let array = PrimitiveArray::from_iter(values);
205        let result = array.narrow().unwrap();
206
207        assert_eq!(result.ptype(), PType::I8);
208        // Check that the values were properly downscaled
209        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
210        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
211    }
212
213    #[test]
214    fn test_downcast_with_mixed_signs_chooses_signed() {
215        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
216        let result = array.narrow().unwrap();
217        assert_eq!(result.ptype(), PType::I16);
218    }
219
220    #[test]
221    fn test_downcast_floats() {
222        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
223        let result = array.narrow().unwrap();
224        // Floats should remain unchanged since they can't be downscaled to integers
225        assert_eq!(result.ptype(), PType::F32);
226    }
227
228    #[test]
229    fn test_downcast_empty_array() {
230        let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
231        let result = array.narrow().unwrap();
232        let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
233        let result2 = array2.narrow().unwrap();
234        // Empty arrays should not have their validity changed
235        assert_eq!(result.validity, Validity::AllInvalid);
236        assert_eq!(result2.validity, Validity::NonNullable);
237    }
238}