Skip to main content

vortex_array/arrays/primitive/array/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_panic;
8
9use crate::IntoArray;
10use crate::ToCanonical;
11use crate::arrays::PrimitiveArray;
12use crate::builtins::ArrayBuiltins;
13use crate::compute::min_max;
14use crate::dtype::DType;
15use crate::dtype::NativePType;
16use crate::dtype::PType;
17use crate::vtable::ValidityHelper;
18
19impl PrimitiveArray {
20    /// Return a slice of the array's buffer.
21    ///
22    /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null.
23    ///
24    /// # Panic
25    ///
26    /// This operation will panic if the array is not backed by host memory.
27    pub fn as_slice<T: NativePType>(&self) -> &[T] {
28        if T::PTYPE != self.ptype() {
29            vortex_panic!(
30                "Attempted to get slice of type {} from array of type {}",
31                T::PTYPE,
32                self.ptype()
33            )
34        }
35
36        let byte_buffer = self
37            .buffer
38            .as_host_opt()
39            .vortex_expect("as_slice must be called on host buffer");
40        let raw_slice = byte_buffer.as_ptr();
41
42        // SAFETY: alignment of Buffer is checked on construction
43        unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::<T>()) }
44    }
45
46    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
47        if self.ptype() == ptype {
48            return self.clone();
49        }
50
51        assert_eq!(
52            self.ptype().byte_width(),
53            ptype.byte_width(),
54            "can't reinterpret cast between integers of two different widths"
55        );
56
57        PrimitiveArray::from_buffer_handle(
58            self.buffer_handle().clone(),
59            ptype,
60            self.validity().clone(),
61        )
62    }
63
64    /// Narrow the array to the smallest possible integer type that can represent all values.
65    pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
66        if !self.ptype().is_int() {
67            return Ok(self.clone());
68        }
69
70        let Some(min_max) = min_max(&self.clone().into_array())? else {
71            return Ok(PrimitiveArray::new(
72                Buffer::<u8>::zeroed(self.len()),
73                self.validity.clone(),
74            ));
75        };
76
77        // If we can't cast to i64, then leave the array as its original type.
78        // It's too big to downcast anyway.
79        let Ok(min) = min_max
80            .min
81            .cast(&PType::I64.into())
82            .and_then(|s| i64::try_from(&s))
83        else {
84            return Ok(self.clone());
85        };
86        let Ok(max) = min_max
87            .max
88            .cast(&PType::I64.into())
89            .and_then(|s| i64::try_from(&s))
90        else {
91            return Ok(self.clone());
92        };
93
94        if min < 0 || max < 0 {
95            // Signed
96            if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
97                return Ok(self
98                    .clone()
99                    .into_array()
100                    .cast(DType::Primitive(PType::I8, self.dtype().nullability()))?
101                    .to_primitive());
102            }
103
104            if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
105                return Ok(self
106                    .clone()
107                    .into_array()
108                    .cast(DType::Primitive(PType::I16, self.dtype().nullability()))?
109                    .to_primitive());
110            }
111
112            if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
113                return Ok(self
114                    .clone()
115                    .into_array()
116                    .cast(DType::Primitive(PType::I32, self.dtype().nullability()))?
117                    .to_primitive());
118            }
119        } else {
120            // Unsigned
121            if max <= u8::MAX as i64 {
122                return Ok(self
123                    .clone()
124                    .into_array()
125                    .cast(DType::Primitive(PType::U8, self.dtype().nullability()))?
126                    .to_primitive());
127            }
128
129            if max <= u16::MAX as i64 {
130                return Ok(self
131                    .clone()
132                    .into_array()
133                    .cast(DType::Primitive(PType::U16, self.dtype().nullability()))?
134                    .to_primitive());
135            }
136
137            if max <= u32::MAX as i64 {
138                return Ok(self
139                    .clone()
140                    .into_array()
141                    .cast(DType::Primitive(PType::U32, self.dtype().nullability()))?
142                    .to_primitive());
143            }
144        }
145
146        Ok(self.clone())
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use rstest::rstest;
153    use vortex_buffer::Buffer;
154    use vortex_buffer::buffer;
155
156    use crate::arrays::PrimitiveArray;
157    use crate::dtype::DType;
158    use crate::dtype::Nullability;
159    use crate::dtype::PType;
160    use crate::validity::Validity;
161
162    #[test]
163    fn test_downcast_all_invalid() {
164        let array = PrimitiveArray::new(
165            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
166            Validity::AllInvalid,
167        );
168
169        let result = array.narrow().unwrap();
170        assert_eq!(
171            result.dtype(),
172            &DType::Primitive(PType::U8, Nullability::Nullable)
173        );
174        assert_eq!(result.validity, Validity::AllInvalid);
175    }
176
177    #[rstest]
178    #[case(vec![0_i64, 127], PType::U8)]
179    #[case(vec![-128_i64, 127], PType::I8)]
180    #[case(vec![-129_i64, 127], PType::I16)]
181    #[case(vec![-128_i64, 128], PType::I16)]
182    #[case(vec![-32768_i64, 32767], PType::I16)]
183    #[case(vec![-32769_i64, 32767], PType::I32)]
184    #[case(vec![-32768_i64, 32768], PType::I32)]
185    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
186    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
187        let array = PrimitiveArray::from_iter(values);
188        let result = array.narrow().unwrap();
189        assert_eq!(result.ptype(), expected_ptype);
190    }
191
192    #[rstest]
193    #[case(vec![0_u64, 255], PType::U8)]
194    #[case(vec![0_u64, 256], PType::U16)]
195    #[case(vec![0_u64, 65535], PType::U16)]
196    #[case(vec![0_u64, 65536], PType::U32)]
197    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
198    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
199        let array = PrimitiveArray::from_iter(values);
200        let result = array.narrow().unwrap();
201        assert_eq!(result.ptype(), expected_ptype);
202    }
203
204    #[test]
205    fn test_downcast_keeps_original_if_too_large() {
206        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
207        let result = array.narrow().unwrap();
208        assert_eq!(result.ptype(), PType::U64);
209    }
210
211    #[test]
212    fn test_downcast_preserves_nullability() {
213        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
214        let result = array.narrow().unwrap();
215        assert_eq!(
216            result.dtype(),
217            &DType::Primitive(PType::U8, Nullability::Nullable)
218        );
219        // Check that validity is preserved (the array should still have nullable values)
220        assert!(matches!(&result.validity, Validity::Array(_)));
221    }
222
223    #[test]
224    fn test_downcast_preserves_values() {
225        let values = vec![-100_i16, 0, 100];
226        let array = PrimitiveArray::from_iter(values);
227        let result = array.narrow().unwrap();
228
229        assert_eq!(result.ptype(), PType::I8);
230        // Check that the values were properly downscaled
231        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
232        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
233    }
234
235    #[test]
236    fn test_downcast_with_mixed_signs_chooses_signed() {
237        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
238        let result = array.narrow().unwrap();
239        assert_eq!(result.ptype(), PType::I16);
240    }
241
242    #[test]
243    fn test_downcast_floats() {
244        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
245        let result = array.narrow().unwrap();
246        // Floats should remain unchanged since they can't be downscaled to integers
247        assert_eq!(result.ptype(), PType::F32);
248    }
249
250    #[test]
251    fn test_downcast_empty_array() {
252        let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
253        let result = array.narrow().unwrap();
254        let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
255        let result2 = array2.narrow().unwrap();
256        // Empty arrays should not have their validity changed
257        assert_eq!(result.validity, Validity::AllInvalid);
258        assert_eq!(result2.validity, Validity::NonNullable);
259    }
260}