vortex_array/arrays/primitive/array/
cast.rs1use vortex_buffer::Buffer;
5use vortex_dtype::DType;
6use vortex_dtype::NativePType;
7use vortex_dtype::PType;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10
11use crate::ToCanonical;
12use crate::arrays::PrimitiveArray;
13use crate::compute::cast;
14use crate::compute::min_max;
15use crate::vtable::ValidityHelper;
16
17impl PrimitiveArray {
18 pub fn as_slice<T: NativePType>(&self) -> &[T] {
22 if T::PTYPE != self.ptype() {
23 vortex_panic!(
24 "Attempted to get slice of type {} from array of type {}",
25 T::PTYPE,
26 self.ptype()
27 )
28 }
29 let raw_slice = self.byte_buffer().as_ptr();
30 unsafe {
32 std::slice::from_raw_parts(raw_slice.cast(), self.byte_buffer().len() / size_of::<T>())
33 }
34 }
35
36 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
37 if self.ptype() == ptype {
38 return self.clone();
39 }
40
41 assert_eq!(
42 self.ptype().byte_width(),
43 ptype.byte_width(),
44 "can't reinterpret cast between integers of two different widths"
45 );
46
47 PrimitiveArray::from_byte_buffer(self.byte_buffer().clone(), ptype, self.validity().clone())
48 }
49
50 pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
52 if !self.ptype().is_int() {
53 return Ok(self.clone());
54 }
55
56 let Some(min_max) = min_max(self.as_ref())? else {
57 return Ok(PrimitiveArray::new(
58 Buffer::<u8>::zeroed(self.len()),
59 self.validity.clone(),
60 ));
61 };
62
63 let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else {
66 return Ok(self.clone());
67 };
68 let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else {
69 return Ok(self.clone());
70 };
71
72 if min < 0 || max < 0 {
73 if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
75 return Ok(cast(
76 self.as_ref(),
77 &DType::Primitive(PType::I8, self.dtype().nullability()),
78 )?
79 .to_primitive());
80 }
81
82 if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
83 return Ok(cast(
84 self.as_ref(),
85 &DType::Primitive(PType::I16, self.dtype().nullability()),
86 )?
87 .to_primitive());
88 }
89
90 if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
91 return Ok(cast(
92 self.as_ref(),
93 &DType::Primitive(PType::I32, self.dtype().nullability()),
94 )?
95 .to_primitive());
96 }
97 } else {
98 if max <= u8::MAX as i64 {
100 return Ok(cast(
101 self.as_ref(),
102 &DType::Primitive(PType::U8, self.dtype().nullability()),
103 )?
104 .to_primitive());
105 }
106
107 if max <= u16::MAX as i64 {
108 return Ok(cast(
109 self.as_ref(),
110 &DType::Primitive(PType::U16, self.dtype().nullability()),
111 )?
112 .to_primitive());
113 }
114
115 if max <= u32::MAX as i64 {
116 return Ok(cast(
117 self.as_ref(),
118 &DType::Primitive(PType::U32, self.dtype().nullability()),
119 )?
120 .to_primitive());
121 }
122 }
123
124 Ok(self.clone())
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use rstest::rstest;
131 use vortex_buffer::Buffer;
132 use vortex_buffer::buffer;
133 use vortex_dtype::DType;
134 use vortex_dtype::Nullability;
135 use vortex_dtype::PType;
136
137 use crate::arrays::PrimitiveArray;
138 use crate::validity::Validity;
139
140 #[test]
141 fn test_downcast_all_invalid() {
142 let array = PrimitiveArray::new(
143 buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
144 Validity::AllInvalid,
145 );
146
147 let result = array.narrow().unwrap();
148 assert_eq!(
149 result.dtype(),
150 &DType::Primitive(PType::U8, Nullability::Nullable)
151 );
152 assert_eq!(result.validity, Validity::AllInvalid);
153 }
154
155 #[rstest]
156 #[case(vec![0_i64, 127], PType::U8)]
157 #[case(vec![-128_i64, 127], PType::I8)]
158 #[case(vec![-129_i64, 127], PType::I16)]
159 #[case(vec![-128_i64, 128], PType::I16)]
160 #[case(vec![-32768_i64, 32767], PType::I16)]
161 #[case(vec![-32769_i64, 32767], PType::I32)]
162 #[case(vec![-32768_i64, 32768], PType::I32)]
163 #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
164 fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
165 let array = PrimitiveArray::from_iter(values);
166 let result = array.narrow().unwrap();
167 assert_eq!(result.ptype(), expected_ptype);
168 }
169
170 #[rstest]
171 #[case(vec![0_u64, 255], PType::U8)]
172 #[case(vec![0_u64, 256], PType::U16)]
173 #[case(vec![0_u64, 65535], PType::U16)]
174 #[case(vec![0_u64, 65536], PType::U32)]
175 #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
176 fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
177 let array = PrimitiveArray::from_iter(values);
178 let result = array.narrow().unwrap();
179 assert_eq!(result.ptype(), expected_ptype);
180 }
181
182 #[test]
183 fn test_downcast_keeps_original_if_too_large() {
184 let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
185 let result = array.narrow().unwrap();
186 assert_eq!(result.ptype(), PType::U64);
187 }
188
189 #[test]
190 fn test_downcast_preserves_nullability() {
191 let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
192 let result = array.narrow().unwrap();
193 assert_eq!(
194 result.dtype(),
195 &DType::Primitive(PType::U8, Nullability::Nullable)
196 );
197 assert!(matches!(&result.validity, Validity::Array(_)));
199 }
200
201 #[test]
202 fn test_downcast_preserves_values() {
203 let values = vec![-100_i16, 0, 100];
204 let array = PrimitiveArray::from_iter(values);
205 let result = array.narrow().unwrap();
206
207 assert_eq!(result.ptype(), PType::I8);
208 let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
210 assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
211 }
212
213 #[test]
214 fn test_downcast_with_mixed_signs_chooses_signed() {
215 let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
216 let result = array.narrow().unwrap();
217 assert_eq!(result.ptype(), PType::I16);
218 }
219
220 #[test]
221 fn test_downcast_floats() {
222 let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
223 let result = array.narrow().unwrap();
224 assert_eq!(result.ptype(), PType::F32);
226 }
227
228 #[test]
229 fn test_downcast_empty_array() {
230 let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
231 let result = array.narrow().unwrap();
232 let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
233 let result2 = array2.narrow().unwrap();
234 assert_eq!(result.validity, Validity::AllInvalid);
236 assert_eq!(result2.validity, Validity::NonNullable);
237 }
238}