Skip to main content

vortex_array/arrays/primitive/array/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::DType;
6use vortex_dtype::NativePType;
7use vortex_dtype::PType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_panic;
11
12use crate::ToCanonical;
13use crate::arrays::PrimitiveArray;
14use crate::builtins::ArrayBuiltins;
15use crate::compute::min_max;
16use crate::vtable::ValidityHelper;
17
18impl PrimitiveArray {
19    /// Return a slice of the array's buffer.
20    ///
21    /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null.
22    ///
23    /// # Panic
24    ///
25    /// This operation will panic if the array is not backed by host memory.
26    pub fn as_slice<T: NativePType>(&self) -> &[T] {
27        if T::PTYPE != self.ptype() {
28            vortex_panic!(
29                "Attempted to get slice of type {} from array of type {}",
30                T::PTYPE,
31                self.ptype()
32            )
33        }
34
35        let byte_buffer = self
36            .buffer
37            .as_host_opt()
38            .vortex_expect("as_slice must be called on host buffer");
39        let raw_slice = byte_buffer.as_ptr();
40
41        // SAFETY: alignment of Buffer is checked on construction
42        unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::<T>()) }
43    }
44
45    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
46        if self.ptype() == ptype {
47            return self.clone();
48        }
49
50        assert_eq!(
51            self.ptype().byte_width(),
52            ptype.byte_width(),
53            "can't reinterpret cast between integers of two different widths"
54        );
55
56        PrimitiveArray::from_buffer_handle(
57            self.buffer_handle().clone(),
58            ptype,
59            self.validity().clone(),
60        )
61    }
62
63    /// Narrow the array to the smallest possible integer type that can represent all values.
64    pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
65        if !self.ptype().is_int() {
66            return Ok(self.clone());
67        }
68
69        let Some(min_max) = min_max(self.as_ref())? else {
70            return Ok(PrimitiveArray::new(
71                Buffer::<u8>::zeroed(self.len()),
72                self.validity.clone(),
73            ));
74        };
75
76        // If we can't cast to i64, then leave the array as its original type.
77        // It's too big to downcast anyway.
78        let Ok(min) = min_max
79            .min
80            .cast(&PType::I64.into())
81            .and_then(|s| i64::try_from(&s))
82        else {
83            return Ok(self.clone());
84        };
85        let Ok(max) = min_max
86            .max
87            .cast(&PType::I64.into())
88            .and_then(|s| i64::try_from(&s))
89        else {
90            return Ok(self.clone());
91        };
92
93        if min < 0 || max < 0 {
94            // Signed
95            if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
96                return Ok(self
97                    .to_array()
98                    .cast(DType::Primitive(PType::I8, self.dtype().nullability()))?
99                    .to_primitive());
100            }
101
102            if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
103                return Ok(self
104                    .to_array()
105                    .cast(DType::Primitive(PType::I16, self.dtype().nullability()))?
106                    .to_primitive());
107            }
108
109            if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
110                return Ok(self
111                    .to_array()
112                    .cast(DType::Primitive(PType::I32, self.dtype().nullability()))?
113                    .to_primitive());
114            }
115        } else {
116            // Unsigned
117            if max <= u8::MAX as i64 {
118                return Ok(self
119                    .to_array()
120                    .cast(DType::Primitive(PType::U8, self.dtype().nullability()))?
121                    .to_primitive());
122            }
123
124            if max <= u16::MAX as i64 {
125                return Ok(self
126                    .to_array()
127                    .cast(DType::Primitive(PType::U16, self.dtype().nullability()))?
128                    .to_primitive());
129            }
130
131            if max <= u32::MAX as i64 {
132                return Ok(self
133                    .to_array()
134                    .cast(DType::Primitive(PType::U32, self.dtype().nullability()))?
135                    .to_primitive());
136            }
137        }
138
139        Ok(self.clone())
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use rstest::rstest;
146    use vortex_buffer::Buffer;
147    use vortex_buffer::buffer;
148    use vortex_dtype::DType;
149    use vortex_dtype::Nullability;
150    use vortex_dtype::PType;
151
152    use crate::arrays::PrimitiveArray;
153    use crate::validity::Validity;
154
155    #[test]
156    fn test_downcast_all_invalid() {
157        let array = PrimitiveArray::new(
158            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
159            Validity::AllInvalid,
160        );
161
162        let result = array.narrow().unwrap();
163        assert_eq!(
164            result.dtype(),
165            &DType::Primitive(PType::U8, Nullability::Nullable)
166        );
167        assert_eq!(result.validity, Validity::AllInvalid);
168    }
169
170    #[rstest]
171    #[case(vec![0_i64, 127], PType::U8)]
172    #[case(vec![-128_i64, 127], PType::I8)]
173    #[case(vec![-129_i64, 127], PType::I16)]
174    #[case(vec![-128_i64, 128], PType::I16)]
175    #[case(vec![-32768_i64, 32767], PType::I16)]
176    #[case(vec![-32769_i64, 32767], PType::I32)]
177    #[case(vec![-32768_i64, 32768], PType::I32)]
178    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
179    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
180        let array = PrimitiveArray::from_iter(values);
181        let result = array.narrow().unwrap();
182        assert_eq!(result.ptype(), expected_ptype);
183    }
184
185    #[rstest]
186    #[case(vec![0_u64, 255], PType::U8)]
187    #[case(vec![0_u64, 256], PType::U16)]
188    #[case(vec![0_u64, 65535], PType::U16)]
189    #[case(vec![0_u64, 65536], PType::U32)]
190    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
191    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
192        let array = PrimitiveArray::from_iter(values);
193        let result = array.narrow().unwrap();
194        assert_eq!(result.ptype(), expected_ptype);
195    }
196
197    #[test]
198    fn test_downcast_keeps_original_if_too_large() {
199        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
200        let result = array.narrow().unwrap();
201        assert_eq!(result.ptype(), PType::U64);
202    }
203
204    #[test]
205    fn test_downcast_preserves_nullability() {
206        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
207        let result = array.narrow().unwrap();
208        assert_eq!(
209            result.dtype(),
210            &DType::Primitive(PType::U8, Nullability::Nullable)
211        );
212        // Check that validity is preserved (the array should still have nullable values)
213        assert!(matches!(&result.validity, Validity::Array(_)));
214    }
215
216    #[test]
217    fn test_downcast_preserves_values() {
218        let values = vec![-100_i16, 0, 100];
219        let array = PrimitiveArray::from_iter(values);
220        let result = array.narrow().unwrap();
221
222        assert_eq!(result.ptype(), PType::I8);
223        // Check that the values were properly downscaled
224        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
225        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
226    }
227
228    #[test]
229    fn test_downcast_with_mixed_signs_chooses_signed() {
230        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
231        let result = array.narrow().unwrap();
232        assert_eq!(result.ptype(), PType::I16);
233    }
234
235    #[test]
236    fn test_downcast_floats() {
237        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
238        let result = array.narrow().unwrap();
239        // Floats should remain unchanged since they can't be downscaled to integers
240        assert_eq!(result.ptype(), PType::F32);
241    }
242
243    #[test]
244    fn test_downcast_empty_array() {
245        let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
246        let result = array.narrow().unwrap();
247        let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
248        let result2 = array2.narrow().unwrap();
249        // Empty arrays should not have their validity changed
250        assert_eq!(result.validity, Validity::AllInvalid);
251        assert_eq!(result2.validity, Validity::NonNullable);
252    }
253}