vortex_array/arrays/primitive/array/
cast.rs1use vortex_buffer::Buffer;
5use vortex_dtype::{DType, NativePType, PType};
6use vortex_error::{VortexResult, vortex_panic};
7
8use crate::ToCanonical;
9use crate::arrays::PrimitiveArray;
10use crate::compute::{cast, min_max};
11use crate::vtable::ValidityHelper;
12
13impl PrimitiveArray {
14 pub fn as_slice<T: NativePType>(&self) -> &[T] {
18 if T::PTYPE != self.ptype() {
19 vortex_panic!(
20 "Attempted to get slice of type {} from array of type {}",
21 T::PTYPE,
22 self.ptype()
23 )
24 }
25 let raw_slice = self.byte_buffer().as_ptr();
26 unsafe {
28 std::slice::from_raw_parts(raw_slice.cast(), self.byte_buffer().len() / size_of::<T>())
29 }
30 }
31
32 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
33 if self.ptype() == ptype {
34 return self.clone();
35 }
36
37 assert_eq!(
38 self.ptype().byte_width(),
39 ptype.byte_width(),
40 "can't reinterpret cast between integers of two different widths"
41 );
42
43 PrimitiveArray::from_byte_buffer(self.byte_buffer().clone(), ptype, self.validity().clone())
44 }
45
46 pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
48 if !self.ptype().is_int() {
49 return Ok(self.clone());
50 }
51
52 let Some(min_max) = min_max(self.as_ref())? else {
53 return Ok(PrimitiveArray::new(
54 Buffer::<u8>::zeroed(self.len()),
55 self.validity.clone(),
56 ));
57 };
58
59 let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
62 return Ok(self.clone());
63 };
64 let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
65 return Ok(self.clone());
66 };
67
68 if min < 0 || max < 0 {
69 if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
71 return Ok(cast(
72 self.as_ref(),
73 &DType::Primitive(PType::I8, self.dtype().nullability()),
74 )?
75 .to_primitive());
76 }
77
78 if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
79 return Ok(cast(
80 self.as_ref(),
81 &DType::Primitive(PType::I16, self.dtype().nullability()),
82 )?
83 .to_primitive());
84 }
85
86 if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
87 return Ok(cast(
88 self.as_ref(),
89 &DType::Primitive(PType::I32, self.dtype().nullability()),
90 )?
91 .to_primitive());
92 }
93 } else {
94 if max <= u8::MAX as i64 {
96 return Ok(cast(
97 self.as_ref(),
98 &DType::Primitive(PType::U8, self.dtype().nullability()),
99 )?
100 .to_primitive());
101 }
102
103 if max <= u16::MAX as i64 {
104 return Ok(cast(
105 self.as_ref(),
106 &DType::Primitive(PType::U16, self.dtype().nullability()),
107 )?
108 .to_primitive());
109 }
110
111 if max <= u32::MAX as i64 {
112 return Ok(cast(
113 self.as_ref(),
114 &DType::Primitive(PType::U32, self.dtype().nullability()),
115 )?
116 .to_primitive());
117 }
118 }
119
120 Ok(self.clone())
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use rstest::rstest;
127 use vortex_buffer::{Buffer, buffer};
128 use vortex_dtype::{DType, Nullability, PType};
129
130 use crate::arrays::PrimitiveArray;
131 use crate::validity::Validity;
132
133 #[test]
134 fn test_downcast_all_invalid() {
135 let array = PrimitiveArray::new(
136 buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
137 Validity::AllInvalid,
138 );
139
140 let result = array.narrow().unwrap();
141 assert_eq!(
142 result.dtype(),
143 &DType::Primitive(PType::U8, Nullability::Nullable)
144 );
145 assert_eq!(result.validity, Validity::AllInvalid);
146 }
147
148 #[rstest]
149 #[case(vec![0_i64, 127], PType::U8)]
150 #[case(vec![-128_i64, 127], PType::I8)]
151 #[case(vec![-129_i64, 127], PType::I16)]
152 #[case(vec![-128_i64, 128], PType::I16)]
153 #[case(vec![-32768_i64, 32767], PType::I16)]
154 #[case(vec![-32769_i64, 32767], PType::I32)]
155 #[case(vec![-32768_i64, 32768], PType::I32)]
156 #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
157 fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
158 let array = PrimitiveArray::from_iter(values);
159 let result = array.narrow().unwrap();
160 assert_eq!(result.ptype(), expected_ptype);
161 }
162
163 #[rstest]
164 #[case(vec![0_u64, 255], PType::U8)]
165 #[case(vec![0_u64, 256], PType::U16)]
166 #[case(vec![0_u64, 65535], PType::U16)]
167 #[case(vec![0_u64, 65536], PType::U32)]
168 #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
169 fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
170 let array = PrimitiveArray::from_iter(values);
171 let result = array.narrow().unwrap();
172 assert_eq!(result.ptype(), expected_ptype);
173 }
174
175 #[test]
176 fn test_downcast_keeps_original_if_too_large() {
177 let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
178 let result = array.narrow().unwrap();
179 assert_eq!(result.ptype(), PType::U64);
180 }
181
182 #[test]
183 fn test_downcast_preserves_nullability() {
184 let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
185 let result = array.narrow().unwrap();
186 assert_eq!(
187 result.dtype(),
188 &DType::Primitive(PType::U8, Nullability::Nullable)
189 );
190 assert!(matches!(&result.validity, Validity::Array(_)));
192 }
193
194 #[test]
195 fn test_downcast_preserves_values() {
196 let values = vec![-100_i16, 0, 100];
197 let array = PrimitiveArray::from_iter(values);
198 let result = array.narrow().unwrap();
199
200 assert_eq!(result.ptype(), PType::I8);
201 let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
203 assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
204 }
205
206 #[test]
207 fn test_downcast_with_mixed_signs_chooses_signed() {
208 let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
209 let result = array.narrow().unwrap();
210 assert_eq!(result.ptype(), PType::I16);
211 }
212
213 #[test]
214 fn test_downcast_floats() {
215 let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
216 let result = array.narrow().unwrap();
217 assert_eq!(result.ptype(), PType::F32);
219 }
220
221 #[test]
222 fn test_downcast_empty_array() {
223 let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
224 let result = array.narrow().unwrap();
225 let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
226 let result2 = array2.narrow().unwrap();
227 assert_eq!(result.validity, Validity::AllInvalid);
229 assert_eq!(result2.validity, Validity::NonNullable);
230 }
231}