vortex_array/arrays/primitive/
downcast.rs1use vortex_buffer::Buffer;
5use vortex_dtype::{DType, PType};
6use vortex_error::VortexResult;
7
8use crate::ToCanonical;
9use crate::arrays::PrimitiveArray;
10use crate::compute::{cast, min_max};
11
12impl PrimitiveArray {
13 pub fn downcast(&self) -> VortexResult<PrimitiveArray> {
14 if !self.ptype().is_int() {
15 return Ok(self.clone());
16 }
17
18 let Some(min_max) = min_max(self.as_ref())? else {
19 return Ok(PrimitiveArray::new(
20 Buffer::<u8>::zeroed(self.len()),
21 self.validity.clone(),
22 ));
23 };
24
25 let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
28 return Ok(self.clone());
29 };
30 let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
31 return Ok(self.clone());
32 };
33
34 if min < 0 || max < 0 {
35 if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
37 return Ok(cast(
38 self.as_ref(),
39 &DType::Primitive(PType::I8, self.dtype().nullability()),
40 )?
41 .to_primitive());
42 }
43
44 if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
45 return Ok(cast(
46 self.as_ref(),
47 &DType::Primitive(PType::I16, self.dtype().nullability()),
48 )?
49 .to_primitive());
50 }
51
52 if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
53 return Ok(cast(
54 self.as_ref(),
55 &DType::Primitive(PType::I32, self.dtype().nullability()),
56 )?
57 .to_primitive());
58 }
59 } else {
60 if max <= u8::MAX as i64 {
62 return Ok(cast(
63 self.as_ref(),
64 &DType::Primitive(PType::U8, self.dtype().nullability()),
65 )?
66 .to_primitive());
67 }
68
69 if max <= u16::MAX as i64 {
70 return Ok(cast(
71 self.as_ref(),
72 &DType::Primitive(PType::U16, self.dtype().nullability()),
73 )?
74 .to_primitive());
75 }
76
77 if max <= u32::MAX as i64 {
78 return Ok(cast(
79 self.as_ref(),
80 &DType::Primitive(PType::U32, self.dtype().nullability()),
81 )?
82 .to_primitive());
83 }
84 }
85
86 Ok(self.clone())
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use rstest::rstest;
93 use vortex_buffer::{Buffer, buffer};
94 use vortex_dtype::{DType, Nullability, PType};
95
96 use crate::arrays::PrimitiveArray;
97 use crate::validity::Validity;
98
99 #[test]
100 fn test_downcast_all_invalid() {
101 let array = PrimitiveArray::new(
102 buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
103 Validity::AllInvalid,
104 );
105
106 let result = array.downcast().unwrap();
107 assert_eq!(
108 result.dtype(),
109 &DType::Primitive(PType::U8, Nullability::Nullable)
110 );
111 assert_eq!(result.validity, Validity::AllInvalid);
112 }
113
114 #[rstest]
115 #[case(vec![0_i64, 127], PType::U8)]
116 #[case(vec![-128_i64, 127], PType::I8)]
117 #[case(vec![-129_i64, 127], PType::I16)]
118 #[case(vec![-128_i64, 128], PType::I16)]
119 #[case(vec![-32768_i64, 32767], PType::I16)]
120 #[case(vec![-32769_i64, 32767], PType::I32)]
121 #[case(vec![-32768_i64, 32768], PType::I32)]
122 #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
123 fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
124 let array = PrimitiveArray::from_iter(values);
125 let result = array.downcast().unwrap();
126 assert_eq!(result.ptype(), expected_ptype);
127 }
128
129 #[rstest]
130 #[case(vec![0_u64, 255], PType::U8)]
131 #[case(vec![0_u64, 256], PType::U16)]
132 #[case(vec![0_u64, 65535], PType::U16)]
133 #[case(vec![0_u64, 65536], PType::U32)]
134 #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
135 fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
136 let array = PrimitiveArray::from_iter(values);
137 let result = array.downcast().unwrap();
138 assert_eq!(result.ptype(), expected_ptype);
139 }
140
141 #[test]
142 fn test_downcast_keeps_original_if_too_large() {
143 let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
144 let result = array.downcast().unwrap();
145 assert_eq!(result.ptype(), PType::U64);
146 }
147
148 #[test]
149 fn test_downcast_preserves_nullability() {
150 let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
151 let result = array.downcast().unwrap();
152 assert_eq!(
153 result.dtype(),
154 &DType::Primitive(PType::U8, Nullability::Nullable)
155 );
156 assert!(matches!(&result.validity, Validity::Array(_)));
158 }
159
160 #[test]
161 fn test_downcast_preserves_values() {
162 let values = vec![-100_i16, 0, 100];
163 let array = PrimitiveArray::from_iter(values);
164 let result = array.downcast().unwrap();
165
166 assert_eq!(result.ptype(), PType::I8);
167 let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
169 assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
170 }
171
172 #[test]
173 fn test_downcast_with_mixed_signs_chooses_signed() {
174 let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
175 let result = array.downcast().unwrap();
176 assert_eq!(result.ptype(), PType::I16);
177 }
178
179 #[test]
180 fn test_downcast_floats() {
181 let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
182 let result = array.downcast().unwrap();
183 assert_eq!(result.ptype(), PType::F32);
185 }
186
187 #[test]
188 fn test_downcast_empty_array() {
189 let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
190 let result = array.downcast().unwrap();
191 let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
192 let result2 = array2.downcast().unwrap();
193 assert_eq!(result.validity, Validity::AllInvalid);
195 assert_eq!(result2.validity, Validity::NonNullable);
196 }
197}