vortex_array/arrays/primitive/
downcast.rs1use vortex_buffer::Buffer;
5use vortex_dtype::{DType, PType};
6use vortex_error::VortexResult;
7
8use crate::ToCanonical;
9use crate::arrays::PrimitiveArray;
10use crate::compute::{cast, min_max};
11use crate::validity::Validity;
12
13impl PrimitiveArray {
14 pub fn downcast(&self) -> VortexResult<PrimitiveArray> {
15 if !self.ptype().is_int() {
16 return Ok(self.clone());
17 }
18
19 let Some(min_max) = min_max(self.as_ref())? else {
20 return Ok(PrimitiveArray::new(
21 Buffer::<u8>::zeroed(self.len()),
22 Validity::AllInvalid,
23 ));
24 };
25
26 let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
29 return Ok(self.clone());
30 };
31 let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
32 return Ok(self.clone());
33 };
34
35 if min < 0 || max < 0 {
36 if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
38 return cast(
39 self.as_ref(),
40 &DType::Primitive(PType::I8, self.dtype().nullability()),
41 )?
42 .to_primitive();
43 }
44
45 if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
46 return cast(
47 self.as_ref(),
48 &DType::Primitive(PType::I16, self.dtype().nullability()),
49 )?
50 .to_primitive();
51 }
52
53 if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
54 return cast(
55 self.as_ref(),
56 &DType::Primitive(PType::I32, self.dtype().nullability()),
57 )?
58 .to_primitive();
59 }
60 } else {
61 if max <= u8::MAX as i64 {
63 return cast(
64 self.as_ref(),
65 &DType::Primitive(PType::U8, self.dtype().nullability()),
66 )?
67 .to_primitive();
68 }
69
70 if max <= u16::MAX as i64 {
71 return cast(
72 self.as_ref(),
73 &DType::Primitive(PType::U16, self.dtype().nullability()),
74 )?
75 .to_primitive();
76 }
77
78 if max <= u32::MAX as i64 {
79 return cast(
80 self.as_ref(),
81 &DType::Primitive(PType::U32, self.dtype().nullability()),
82 )?
83 .to_primitive();
84 }
85 }
86
87 Ok(self.clone())
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use rstest::rstest;
94 use vortex_buffer::buffer;
95 use vortex_dtype::{DType, Nullability, PType};
96
97 use crate::arrays::PrimitiveArray;
98 use crate::validity::Validity;
99
100 #[test]
101 fn test_downcast_all_invalid() {
102 let array = PrimitiveArray::new(
103 buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
104 Validity::AllInvalid,
105 );
106
107 let result = array.downcast().unwrap();
108 assert_eq!(
109 result.dtype(),
110 &DType::Primitive(PType::U8, Nullability::Nullable)
111 );
112 assert_eq!(result.validity, Validity::AllInvalid);
113 }
114
115 #[rstest]
116 #[case(vec![0_i64, 127], PType::U8)]
117 #[case(vec![-128_i64, 127], PType::I8)]
118 #[case(vec![-129_i64, 127], PType::I16)]
119 #[case(vec![-128_i64, 128], PType::I16)]
120 #[case(vec![-32768_i64, 32767], PType::I16)]
121 #[case(vec![-32769_i64, 32767], PType::I32)]
122 #[case(vec![-32768_i64, 32768], PType::I32)]
123 #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
124 fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
125 let array = PrimitiveArray::from_iter(values);
126 let result = array.downcast().unwrap();
127 assert_eq!(result.ptype(), expected_ptype);
128 }
129
130 #[rstest]
131 #[case(vec![0_u64, 255], PType::U8)]
132 #[case(vec![0_u64, 256], PType::U16)]
133 #[case(vec![0_u64, 65535], PType::U16)]
134 #[case(vec![0_u64, 65536], PType::U32)]
135 #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
136 fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
137 let array = PrimitiveArray::from_iter(values);
138 let result = array.downcast().unwrap();
139 assert_eq!(result.ptype(), expected_ptype);
140 }
141
142 #[test]
143 fn test_downcast_keeps_original_if_too_large() {
144 let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
145 let result = array.downcast().unwrap();
146 assert_eq!(result.ptype(), PType::U64);
147 }
148
149 #[test]
150 fn test_downcast_preserves_nullability() {
151 let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
152 let result = array.downcast().unwrap();
153 assert_eq!(
154 result.dtype(),
155 &DType::Primitive(PType::U8, Nullability::Nullable)
156 );
157 assert!(matches!(&result.validity, Validity::Array(_)));
159 }
160
161 #[test]
162 fn test_downcast_preserves_values() {
163 let values = vec![-100_i16, 0, 100];
164 let array = PrimitiveArray::from_iter(values);
165 let result = array.downcast().unwrap();
166
167 assert_eq!(result.ptype(), PType::I8);
168 let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
170 assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
171 }
172
173 #[test]
174 fn test_downcast_with_mixed_signs_chooses_signed() {
175 let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
176 let result = array.downcast().unwrap();
177 assert_eq!(result.ptype(), PType::I16);
178 }
179
180 #[test]
181 fn test_downcast_floats() {
182 let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
183 let result = array.downcast().unwrap();
184 assert_eq!(result.ptype(), PType::F32);
186 }
187
188 #[test]
189 fn test_downcast_empty_array() {
190 let array = PrimitiveArray::from_iter(Vec::<i32>::new());
191 let result = array.downcast().unwrap();
192 assert_eq!(result.validity, Validity::AllInvalid);
194 }
195}