Skip to main content

vortex_array/arrays/primitive/array/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_panic;
8
9use crate::IntoArray;
10use crate::LEGACY_SESSION;
11use crate::ToCanonical;
12use crate::VortexSessionExecute;
13use crate::aggregate_fn::fns::min_max::min_max;
14use crate::arrays::PrimitiveArray;
15use crate::builtins::ArrayBuiltins;
16use crate::dtype::DType;
17use crate::dtype::NativePType;
18use crate::dtype::PType;
19use crate::vtable::ValidityHelper;
20
21impl PrimitiveArray {
22    /// Return a slice of the array's buffer.
23    ///
24    /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null.
25    ///
26    /// # Panic
27    ///
28    /// This operation will panic if the array is not backed by host memory.
29    pub fn as_slice<T: NativePType>(&self) -> &[T] {
30        if T::PTYPE != self.ptype() {
31            vortex_panic!(
32                "Attempted to get slice of type {} from array of type {}",
33                T::PTYPE,
34                self.ptype()
35            )
36        }
37
38        let byte_buffer = self
39            .buffer
40            .as_host_opt()
41            .vortex_expect("as_slice must be called on host buffer");
42        let raw_slice = byte_buffer.as_ptr();
43
44        // SAFETY: alignment of Buffer is checked on construction
45        unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::<T>()) }
46    }
47
48    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
49        if self.ptype() == ptype {
50            return self.clone();
51        }
52
53        assert_eq!(
54            self.ptype().byte_width(),
55            ptype.byte_width(),
56            "can't reinterpret cast between integers of two different widths"
57        );
58
59        PrimitiveArray::from_buffer_handle(
60            self.buffer_handle().clone(),
61            ptype,
62            self.validity().clone(),
63        )
64    }
65
66    /// Narrow the array to the smallest possible integer type that can represent all values.
67    pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
68        if !self.ptype().is_int() {
69            return Ok(self.clone());
70        }
71
72        let mut ctx = LEGACY_SESSION.create_execution_ctx();
73        let Some(min_max) = min_max(&self.clone().into_array(), &mut ctx)? else {
74            return Ok(PrimitiveArray::new(
75                Buffer::<u8>::zeroed(self.len()),
76                self.validity.clone(),
77            ));
78        };
79
80        // If we can't cast to i64, then leave the array as its original type.
81        // It's too big to downcast anyway.
82        let Ok(min) = min_max
83            .min
84            .cast(&PType::I64.into())
85            .and_then(|s| i64::try_from(&s))
86        else {
87            return Ok(self.clone());
88        };
89        let Ok(max) = min_max
90            .max
91            .cast(&PType::I64.into())
92            .and_then(|s| i64::try_from(&s))
93        else {
94            return Ok(self.clone());
95        };
96
97        if min < 0 || max < 0 {
98            // Signed
99            if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
100                return Ok(self
101                    .clone()
102                    .into_array()
103                    .cast(DType::Primitive(PType::I8, self.dtype().nullability()))?
104                    .to_primitive());
105            }
106
107            if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
108                return Ok(self
109                    .clone()
110                    .into_array()
111                    .cast(DType::Primitive(PType::I16, self.dtype().nullability()))?
112                    .to_primitive());
113            }
114
115            if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
116                return Ok(self
117                    .clone()
118                    .into_array()
119                    .cast(DType::Primitive(PType::I32, self.dtype().nullability()))?
120                    .to_primitive());
121            }
122        } else {
123            // Unsigned
124            if max <= u8::MAX as i64 {
125                return Ok(self
126                    .clone()
127                    .into_array()
128                    .cast(DType::Primitive(PType::U8, self.dtype().nullability()))?
129                    .to_primitive());
130            }
131
132            if max <= u16::MAX as i64 {
133                return Ok(self
134                    .clone()
135                    .into_array()
136                    .cast(DType::Primitive(PType::U16, self.dtype().nullability()))?
137                    .to_primitive());
138            }
139
140            if max <= u32::MAX as i64 {
141                return Ok(self
142                    .clone()
143                    .into_array()
144                    .cast(DType::Primitive(PType::U32, self.dtype().nullability()))?
145                    .to_primitive());
146            }
147        }
148
149        Ok(self.clone())
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use rstest::rstest;
156    use vortex_buffer::Buffer;
157    use vortex_buffer::buffer;
158
159    use crate::arrays::PrimitiveArray;
160    use crate::dtype::DType;
161    use crate::dtype::Nullability;
162    use crate::dtype::PType;
163    use crate::validity::Validity;
164
165    #[test]
166    fn test_downcast_all_invalid() {
167        let array = PrimitiveArray::new(
168            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
169            Validity::AllInvalid,
170        );
171
172        let result = array.narrow().unwrap();
173        assert_eq!(
174            result.dtype(),
175            &DType::Primitive(PType::U8, Nullability::Nullable)
176        );
177        assert!(matches!(result.validity, Validity::AllInvalid));
178    }
179
180    #[rstest]
181    #[case(vec![0_i64, 127], PType::U8)]
182    #[case(vec![-128_i64, 127], PType::I8)]
183    #[case(vec![-129_i64, 127], PType::I16)]
184    #[case(vec![-128_i64, 128], PType::I16)]
185    #[case(vec![-32768_i64, 32767], PType::I16)]
186    #[case(vec![-32769_i64, 32767], PType::I32)]
187    #[case(vec![-32768_i64, 32768], PType::I32)]
188    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
189    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
190        let array = PrimitiveArray::from_iter(values);
191        let result = array.narrow().unwrap();
192        assert_eq!(result.ptype(), expected_ptype);
193    }
194
195    #[rstest]
196    #[case(vec![0_u64, 255], PType::U8)]
197    #[case(vec![0_u64, 256], PType::U16)]
198    #[case(vec![0_u64, 65535], PType::U16)]
199    #[case(vec![0_u64, 65536], PType::U32)]
200    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
201    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
202        let array = PrimitiveArray::from_iter(values);
203        let result = array.narrow().unwrap();
204        assert_eq!(result.ptype(), expected_ptype);
205    }
206
207    #[test]
208    fn test_downcast_keeps_original_if_too_large() {
209        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
210        let result = array.narrow().unwrap();
211        assert_eq!(result.ptype(), PType::U64);
212    }
213
214    #[test]
215    fn test_downcast_preserves_nullability() {
216        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
217        let result = array.narrow().unwrap();
218        assert_eq!(
219            result.dtype(),
220            &DType::Primitive(PType::U8, Nullability::Nullable)
221        );
222        // Check that validity is preserved (the array should still have nullable values)
223        assert!(matches!(&result.validity, Validity::Array(_)));
224    }
225
226    #[test]
227    fn test_downcast_preserves_values() {
228        let values = vec![-100_i16, 0, 100];
229        let array = PrimitiveArray::from_iter(values);
230        let result = array.narrow().unwrap();
231
232        assert_eq!(result.ptype(), PType::I8);
233        // Check that the values were properly downscaled
234        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
235        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
236    }
237
238    #[test]
239    fn test_downcast_with_mixed_signs_chooses_signed() {
240        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
241        let result = array.narrow().unwrap();
242        assert_eq!(result.ptype(), PType::I16);
243    }
244
245    #[test]
246    fn test_downcast_floats() {
247        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
248        let result = array.narrow().unwrap();
249        // Floats should remain unchanged since they can't be downscaled to integers
250        assert_eq!(result.ptype(), PType::F32);
251    }
252
253    #[test]
254    fn test_downcast_empty_array() {
255        let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
256        let result = array.narrow().unwrap();
257        let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
258        let result2 = array2.narrow().unwrap();
259        // Empty arrays should not have their validity changed
260        assert!(matches!(result.validity, Validity::AllInvalid));
261        assert!(matches!(result2.validity, Validity::NonNullable));
262    }
263}