vortex_array/arrays/primitive/
downcast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::{DType, PType};
6use vortex_error::VortexResult;
7
8use crate::ToCanonical;
9use crate::arrays::PrimitiveArray;
10use crate::compute::{cast, min_max};
11use crate::validity::Validity;
12
13impl PrimitiveArray {
14    pub fn downcast(&self) -> VortexResult<PrimitiveArray> {
15        if !self.ptype().is_int() {
16            return Ok(self.clone());
17        }
18
19        let Some(min_max) = min_max(self.as_ref())? else {
20            return Ok(PrimitiveArray::new(
21                Buffer::<u8>::zeroed(self.len()),
22                Validity::AllInvalid,
23            ));
24        };
25
26        // If we can't cast to i64, then leave the array as its original type.
27        // It's too big to downcast anyway.
28        let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
29            return Ok(self.clone());
30        };
31        let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
32            return Ok(self.clone());
33        };
34
35        if min < 0 || max < 0 {
36            // Signed
37            if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
38                return cast(
39                    self.as_ref(),
40                    &DType::Primitive(PType::I8, self.dtype().nullability()),
41                )?
42                .to_primitive();
43            }
44
45            if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
46                return cast(
47                    self.as_ref(),
48                    &DType::Primitive(PType::I16, self.dtype().nullability()),
49                )?
50                .to_primitive();
51            }
52
53            if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
54                return cast(
55                    self.as_ref(),
56                    &DType::Primitive(PType::I32, self.dtype().nullability()),
57                )?
58                .to_primitive();
59            }
60        } else {
61            // Unsigned
62            if max <= u8::MAX as i64 {
63                return cast(
64                    self.as_ref(),
65                    &DType::Primitive(PType::U8, self.dtype().nullability()),
66                )?
67                .to_primitive();
68            }
69
70            if max <= u16::MAX as i64 {
71                return cast(
72                    self.as_ref(),
73                    &DType::Primitive(PType::U16, self.dtype().nullability()),
74                )?
75                .to_primitive();
76            }
77
78            if max <= u32::MAX as i64 {
79                return cast(
80                    self.as_ref(),
81                    &DType::Primitive(PType::U32, self.dtype().nullability()),
82                )?
83                .to_primitive();
84            }
85        }
86
87        Ok(self.clone())
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use rstest::rstest;
94    use vortex_buffer::buffer;
95    use vortex_dtype::{DType, Nullability, PType};
96
97    use crate::arrays::PrimitiveArray;
98    use crate::validity::Validity;
99
100    #[test]
101    fn test_downcast_all_invalid() {
102        let array = PrimitiveArray::new(
103            buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
104            Validity::AllInvalid,
105        );
106
107        let result = array.downcast().unwrap();
108        assert_eq!(
109            result.dtype(),
110            &DType::Primitive(PType::U8, Nullability::Nullable)
111        );
112        assert_eq!(result.validity, Validity::AllInvalid);
113    }
114
115    #[rstest]
116    #[case(vec![0_i64, 127], PType::U8)]
117    #[case(vec![-128_i64, 127], PType::I8)]
118    #[case(vec![-129_i64, 127], PType::I16)]
119    #[case(vec![-128_i64, 128], PType::I16)]
120    #[case(vec![-32768_i64, 32767], PType::I16)]
121    #[case(vec![-32769_i64, 32767], PType::I32)]
122    #[case(vec![-32768_i64, 32768], PType::I32)]
123    #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
124    fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
125        let array = PrimitiveArray::from_iter(values);
126        let result = array.downcast().unwrap();
127        assert_eq!(result.ptype(), expected_ptype);
128    }
129
130    #[rstest]
131    #[case(vec![0_u64, 255], PType::U8)]
132    #[case(vec![0_u64, 256], PType::U16)]
133    #[case(vec![0_u64, 65535], PType::U16)]
134    #[case(vec![0_u64, 65536], PType::U32)]
135    #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
136    fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
137        let array = PrimitiveArray::from_iter(values);
138        let result = array.downcast().unwrap();
139        assert_eq!(result.ptype(), expected_ptype);
140    }
141
142    #[test]
143    fn test_downcast_keeps_original_if_too_large() {
144        let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
145        let result = array.downcast().unwrap();
146        assert_eq!(result.ptype(), PType::U64);
147    }
148
149    #[test]
150    fn test_downcast_preserves_nullability() {
151        let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
152        let result = array.downcast().unwrap();
153        assert_eq!(
154            result.dtype(),
155            &DType::Primitive(PType::U8, Nullability::Nullable)
156        );
157        // Check that validity is preserved (the array should still have nullable values)
158        assert!(matches!(&result.validity, Validity::Array(_)));
159    }
160
161    #[test]
162    fn test_downcast_preserves_values() {
163        let values = vec![-100_i16, 0, 100];
164        let array = PrimitiveArray::from_iter(values);
165        let result = array.downcast().unwrap();
166
167        assert_eq!(result.ptype(), PType::I8);
168        // Check that the values were properly downscaled
169        let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
170        assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
171    }
172
173    #[test]
174    fn test_downcast_with_mixed_signs_chooses_signed() {
175        let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
176        let result = array.downcast().unwrap();
177        assert_eq!(result.ptype(), PType::I16);
178    }
179
180    #[test]
181    fn test_downcast_floats() {
182        let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
183        let result = array.downcast().unwrap();
184        // Floats should remain unchanged since they can't be downscaled to integers
185        assert_eq!(result.ptype(), PType::F32);
186    }
187
188    #[test]
189    fn test_downcast_empty_array() {
190        let array = PrimitiveArray::from_iter(Vec::<i32>::new());
191        let result = array.downcast().unwrap();
192        // Empty arrays should get all invalid buffer
193        assert_eq!(result.validity, Validity::AllInvalid);
194    }
195}