vortex_array/arrays/primitive/array/
cast.rs1use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_panic;
8
9use crate::IntoArray;
10use crate::ToCanonical;
11use crate::arrays::PrimitiveArray;
12use crate::builtins::ArrayBuiltins;
13use crate::compute::min_max;
14use crate::dtype::DType;
15use crate::dtype::NativePType;
16use crate::dtype::PType;
17use crate::vtable::ValidityHelper;
18
19impl PrimitiveArray {
20 pub fn as_slice<T: NativePType>(&self) -> &[T] {
28 if T::PTYPE != self.ptype() {
29 vortex_panic!(
30 "Attempted to get slice of type {} from array of type {}",
31 T::PTYPE,
32 self.ptype()
33 )
34 }
35
36 let byte_buffer = self
37 .buffer
38 .as_host_opt()
39 .vortex_expect("as_slice must be called on host buffer");
40 let raw_slice = byte_buffer.as_ptr();
41
42 unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::<T>()) }
44 }
45
46 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
47 if self.ptype() == ptype {
48 return self.clone();
49 }
50
51 assert_eq!(
52 self.ptype().byte_width(),
53 ptype.byte_width(),
54 "can't reinterpret cast between integers of two different widths"
55 );
56
57 PrimitiveArray::from_buffer_handle(
58 self.buffer_handle().clone(),
59 ptype,
60 self.validity().clone(),
61 )
62 }
63
64 pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
66 if !self.ptype().is_int() {
67 return Ok(self.clone());
68 }
69
70 let Some(min_max) = min_max(&self.clone().into_array())? else {
71 return Ok(PrimitiveArray::new(
72 Buffer::<u8>::zeroed(self.len()),
73 self.validity.clone(),
74 ));
75 };
76
77 let Ok(min) = min_max
80 .min
81 .cast(&PType::I64.into())
82 .and_then(|s| i64::try_from(&s))
83 else {
84 return Ok(self.clone());
85 };
86 let Ok(max) = min_max
87 .max
88 .cast(&PType::I64.into())
89 .and_then(|s| i64::try_from(&s))
90 else {
91 return Ok(self.clone());
92 };
93
94 if min < 0 || max < 0 {
95 if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
97 return Ok(self
98 .clone()
99 .into_array()
100 .cast(DType::Primitive(PType::I8, self.dtype().nullability()))?
101 .to_primitive());
102 }
103
104 if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
105 return Ok(self
106 .clone()
107 .into_array()
108 .cast(DType::Primitive(PType::I16, self.dtype().nullability()))?
109 .to_primitive());
110 }
111
112 if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
113 return Ok(self
114 .clone()
115 .into_array()
116 .cast(DType::Primitive(PType::I32, self.dtype().nullability()))?
117 .to_primitive());
118 }
119 } else {
120 if max <= u8::MAX as i64 {
122 return Ok(self
123 .clone()
124 .into_array()
125 .cast(DType::Primitive(PType::U8, self.dtype().nullability()))?
126 .to_primitive());
127 }
128
129 if max <= u16::MAX as i64 {
130 return Ok(self
131 .clone()
132 .into_array()
133 .cast(DType::Primitive(PType::U16, self.dtype().nullability()))?
134 .to_primitive());
135 }
136
137 if max <= u32::MAX as i64 {
138 return Ok(self
139 .clone()
140 .into_array()
141 .cast(DType::Primitive(PType::U32, self.dtype().nullability()))?
142 .to_primitive());
143 }
144 }
145
146 Ok(self.clone())
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use rstest::rstest;
153 use vortex_buffer::Buffer;
154 use vortex_buffer::buffer;
155
156 use crate::arrays::PrimitiveArray;
157 use crate::dtype::DType;
158 use crate::dtype::Nullability;
159 use crate::dtype::PType;
160 use crate::validity::Validity;
161
162 #[test]
163 fn test_downcast_all_invalid() {
164 let array = PrimitiveArray::new(
165 buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
166 Validity::AllInvalid,
167 );
168
169 let result = array.narrow().unwrap();
170 assert_eq!(
171 result.dtype(),
172 &DType::Primitive(PType::U8, Nullability::Nullable)
173 );
174 assert_eq!(result.validity, Validity::AllInvalid);
175 }
176
177 #[rstest]
178 #[case(vec![0_i64, 127], PType::U8)]
179 #[case(vec![-128_i64, 127], PType::I8)]
180 #[case(vec![-129_i64, 127], PType::I16)]
181 #[case(vec![-128_i64, 128], PType::I16)]
182 #[case(vec![-32768_i64, 32767], PType::I16)]
183 #[case(vec![-32769_i64, 32767], PType::I32)]
184 #[case(vec![-32768_i64, 32768], PType::I32)]
185 #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
186 fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
187 let array = PrimitiveArray::from_iter(values);
188 let result = array.narrow().unwrap();
189 assert_eq!(result.ptype(), expected_ptype);
190 }
191
192 #[rstest]
193 #[case(vec![0_u64, 255], PType::U8)]
194 #[case(vec![0_u64, 256], PType::U16)]
195 #[case(vec![0_u64, 65535], PType::U16)]
196 #[case(vec![0_u64, 65536], PType::U32)]
197 #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
198 fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
199 let array = PrimitiveArray::from_iter(values);
200 let result = array.narrow().unwrap();
201 assert_eq!(result.ptype(), expected_ptype);
202 }
203
204 #[test]
205 fn test_downcast_keeps_original_if_too_large() {
206 let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
207 let result = array.narrow().unwrap();
208 assert_eq!(result.ptype(), PType::U64);
209 }
210
211 #[test]
212 fn test_downcast_preserves_nullability() {
213 let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
214 let result = array.narrow().unwrap();
215 assert_eq!(
216 result.dtype(),
217 &DType::Primitive(PType::U8, Nullability::Nullable)
218 );
219 assert!(matches!(&result.validity, Validity::Array(_)));
221 }
222
223 #[test]
224 fn test_downcast_preserves_values() {
225 let values = vec![-100_i16, 0, 100];
226 let array = PrimitiveArray::from_iter(values);
227 let result = array.narrow().unwrap();
228
229 assert_eq!(result.ptype(), PType::I8);
230 let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
232 assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
233 }
234
235 #[test]
236 fn test_downcast_with_mixed_signs_chooses_signed() {
237 let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
238 let result = array.narrow().unwrap();
239 assert_eq!(result.ptype(), PType::I16);
240 }
241
242 #[test]
243 fn test_downcast_floats() {
244 let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
245 let result = array.narrow().unwrap();
246 assert_eq!(result.ptype(), PType::F32);
248 }
249
250 #[test]
251 fn test_downcast_empty_array() {
252 let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
253 let result = array.narrow().unwrap();
254 let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
255 let result2 = array2.narrow().unwrap();
256 assert_eq!(result.validity, Validity::AllInvalid);
258 assert_eq!(result2.validity, Validity::NonNullable);
259 }
260}