vortex_array/arrays/primitive/
downcast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::{DType, PType};
6use vortex_error::VortexResult;
7
8use crate::ToCanonical;
9use crate::arrays::PrimitiveArray;
10use crate::compute::{cast, min_max};
11
12impl PrimitiveArray {
13    pub fn downcast(&self) -> VortexResult<PrimitiveArray> {
14        if !self.ptype().is_int() {
15            return Ok(self.clone());
16        }
17
18        let Some(min_max) = min_max(self.as_ref())? else {
19            return Ok(PrimitiveArray::new(
20                Buffer::<u8>::zeroed(self.len()),
21                self.validity.clone(),
22            ));
23        };
24
25        // If we can't cast to i64, then leave the array as its original type.
26        // It's too big to downcast anyway.
27        let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
28            return Ok(self.clone());
29        };
30        let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
31            return Ok(self.clone());
32        };
33
34        if min < 0 || max < 0 {
35            // Signed
36            if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
37                return Ok(cast(
38                    self.as_ref(),
39                    &DType::Primitive(PType::I8, self.dtype().nullability()),
40                )?
41                .to_primitive());
42            }
43
44            if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
45                return Ok(cast(
46                    self.as_ref(),
47                    &DType::Primitive(PType::I16, self.dtype().nullability()),
48                )?
49                .to_primitive());
50            }
51
52            if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
53                return Ok(cast(
54                    self.as_ref(),
55                    &DType::Primitive(PType::I32, self.dtype().nullability()),
56                )?
57                .to_primitive());
58            }
59        } else {
60            // Unsigned
61            if max <= u8::MAX as i64 {
62                return Ok(cast(
63                    self.as_ref(),
64                    &DType::Primitive(PType::U8, self.dtype().nullability()),
65                )?
66                .to_primitive());
67            }
68
69            if max <= u16::MAX as i64 {
70                return Ok(cast(
71                    self.as_ref(),
72                    &DType::Primitive(PType::U16, self.dtype().nullability()),
73                )?
74                .to_primitive());
75            }
76
77            if max <= u32::MAX as i64 {
78                return Ok(cast(
79                    self.as_ref(),
80                    &DType::Primitive(PType::U32, self.dtype().nullability()),
81                )?
82                .to_primitive());
83            }
84        }
85
86        Ok(self.clone())
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use rstest::rstest;
93    use vortex_buffer::{Buffer, buffer};
94    use vortex_dtype::{DType, Nullability, PType};
95
96    use crate::arrays::PrimitiveArray;
97    use crate::validity::Validity;
98
99    #[test]
100    fn test_downcast_all_invalid() {
101        let array = PrimitiveArray::new(
102            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
103            Validity::AllInvalid,
104        );
105
106        let result = array.downcast().unwrap();
107        assert_eq!(
108            result.dtype(),
109            &DType::Primitive(PType::U8, Nullability::Nullable)
110        );
111        assert_eq!(result.validity, Validity::AllInvalid);
112    }
113
114    #[rstest]
115    #[case(vec![0_i64, 127], PType::U8)]
116    #[case(vec![-128_i64, 127], PType::I8)]
117    #[case(vec![-129_i64, 127], PType::I16)]
118    #[case(vec![-128_i64, 128], PType::I16)]
119    #[case(vec![-32768_i64, 32767], PType::I16)]
120    #[case(vec![-32769_i64, 32767], PType::I32)]
121    #[case(vec![-32768_i64, 32768], PType::I32)]
122    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
123    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
124        let array = PrimitiveArray::from_iter(values);
125        let result = array.downcast().unwrap();
126        assert_eq!(result.ptype(), expected_ptype);
127    }
128
129    #[rstest]
130    #[case(vec![0_u64, 255], PType::U8)]
131    #[case(vec![0_u64, 256], PType::U16)]
132    #[case(vec![0_u64, 65535], PType::U16)]
133    #[case(vec![0_u64, 65536], PType::U32)]
134    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
135    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
136        let array = PrimitiveArray::from_iter(values);
137        let result = array.downcast().unwrap();
138        assert_eq!(result.ptype(), expected_ptype);
139    }
140
141    #[test]
142    fn test_downcast_keeps_original_if_too_large() {
143        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
144        let result = array.downcast().unwrap();
145        assert_eq!(result.ptype(), PType::U64);
146    }
147
148    #[test]
149    fn test_downcast_preserves_nullability() {
150        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
151        let result = array.downcast().unwrap();
152        assert_eq!(
153            result.dtype(),
154            &DType::Primitive(PType::U8, Nullability::Nullable)
155        );
156        // Check that validity is preserved (the array should still have nullable values)
157        assert!(matches!(&result.validity, Validity::Array(_)));
158    }
159
160    #[test]
161    fn test_downcast_preserves_values() {
162        let values = vec![-100_i16, 0, 100];
163        let array = PrimitiveArray::from_iter(values);
164        let result = array.downcast().unwrap();
165
166        assert_eq!(result.ptype(), PType::I8);
167        // Check that the values were properly downscaled
168        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
169        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
170    }
171
172    #[test]
173    fn test_downcast_with_mixed_signs_chooses_signed() {
174        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
175        let result = array.downcast().unwrap();
176        assert_eq!(result.ptype(), PType::I16);
177    }
178
179    #[test]
180    fn test_downcast_floats() {
181        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
182        let result = array.downcast().unwrap();
183        // Floats should remain unchanged since they can't be downscaled to integers
184        assert_eq!(result.ptype(), PType::F32);
185    }
186
187    #[test]
188    fn test_downcast_empty_array() {
189        let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
190        let result = array.downcast().unwrap();
191        let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
192        let result2 = array2.downcast().unwrap();
193        // Empty arrays should not have their validity changed
194        assert_eq!(result.validity, Validity::AllInvalid);
195        assert_eq!(result2.validity, Validity::NonNullable);
196    }
197}