vortex_array/arrays/primitive/array/
cast.rs1use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_panic;
8
9use crate::IntoArray;
10use crate::LEGACY_SESSION;
11use crate::ToCanonical;
12use crate::VortexSessionExecute;
13use crate::aggregate_fn::fns::min_max::min_max;
14use crate::arrays::PrimitiveArray;
15use crate::builtins::ArrayBuiltins;
16use crate::dtype::DType;
17use crate::dtype::NativePType;
18use crate::dtype::PType;
19use crate::vtable::ValidityHelper;
20
21impl PrimitiveArray {
22 pub fn as_slice<T: NativePType>(&self) -> &[T] {
30 if T::PTYPE != self.ptype() {
31 vortex_panic!(
32 "Attempted to get slice of type {} from array of type {}",
33 T::PTYPE,
34 self.ptype()
35 )
36 }
37
38 let byte_buffer = self
39 .buffer
40 .as_host_opt()
41 .vortex_expect("as_slice must be called on host buffer");
42 let raw_slice = byte_buffer.as_ptr();
43
44 unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::<T>()) }
46 }
47
48 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
49 if self.ptype() == ptype {
50 return self.clone();
51 }
52
53 assert_eq!(
54 self.ptype().byte_width(),
55 ptype.byte_width(),
56 "can't reinterpret cast between integers of two different widths"
57 );
58
59 PrimitiveArray::from_buffer_handle(
60 self.buffer_handle().clone(),
61 ptype,
62 self.validity().clone(),
63 )
64 }
65
66 pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
68 if !self.ptype().is_int() {
69 return Ok(self.clone());
70 }
71
72 let mut ctx = LEGACY_SESSION.create_execution_ctx();
73 let Some(min_max) = min_max(&self.clone().into_array(), &mut ctx)? else {
74 return Ok(PrimitiveArray::new(
75 Buffer::<u8>::zeroed(self.len()),
76 self.validity.clone(),
77 ));
78 };
79
80 let Ok(min) = min_max
83 .min
84 .cast(&PType::I64.into())
85 .and_then(|s| i64::try_from(&s))
86 else {
87 return Ok(self.clone());
88 };
89 let Ok(max) = min_max
90 .max
91 .cast(&PType::I64.into())
92 .and_then(|s| i64::try_from(&s))
93 else {
94 return Ok(self.clone());
95 };
96
97 if min < 0 || max < 0 {
98 if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
100 return Ok(self
101 .clone()
102 .into_array()
103 .cast(DType::Primitive(PType::I8, self.dtype().nullability()))?
104 .to_primitive());
105 }
106
107 if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
108 return Ok(self
109 .clone()
110 .into_array()
111 .cast(DType::Primitive(PType::I16, self.dtype().nullability()))?
112 .to_primitive());
113 }
114
115 if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
116 return Ok(self
117 .clone()
118 .into_array()
119 .cast(DType::Primitive(PType::I32, self.dtype().nullability()))?
120 .to_primitive());
121 }
122 } else {
123 if max <= u8::MAX as i64 {
125 return Ok(self
126 .clone()
127 .into_array()
128 .cast(DType::Primitive(PType::U8, self.dtype().nullability()))?
129 .to_primitive());
130 }
131
132 if max <= u16::MAX as i64 {
133 return Ok(self
134 .clone()
135 .into_array()
136 .cast(DType::Primitive(PType::U16, self.dtype().nullability()))?
137 .to_primitive());
138 }
139
140 if max <= u32::MAX as i64 {
141 return Ok(self
142 .clone()
143 .into_array()
144 .cast(DType::Primitive(PType::U32, self.dtype().nullability()))?
145 .to_primitive());
146 }
147 }
148
149 Ok(self.clone())
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use rstest::rstest;
156 use vortex_buffer::Buffer;
157 use vortex_buffer::buffer;
158
159 use crate::arrays::PrimitiveArray;
160 use crate::dtype::DType;
161 use crate::dtype::Nullability;
162 use crate::dtype::PType;
163 use crate::validity::Validity;
164
165 #[test]
166 fn test_downcast_all_invalid() {
167 let array = PrimitiveArray::new(
168 buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
169 Validity::AllInvalid,
170 );
171
172 let result = array.narrow().unwrap();
173 assert_eq!(
174 result.dtype(),
175 &DType::Primitive(PType::U8, Nullability::Nullable)
176 );
177 assert!(matches!(result.validity, Validity::AllInvalid));
178 }
179
180 #[rstest]
181 #[case(vec![0_i64, 127], PType::U8)]
182 #[case(vec![-128_i64, 127], PType::I8)]
183 #[case(vec![-129_i64, 127], PType::I16)]
184 #[case(vec![-128_i64, 128], PType::I16)]
185 #[case(vec![-32768_i64, 32767], PType::I16)]
186 #[case(vec![-32769_i64, 32767], PType::I32)]
187 #[case(vec![-32768_i64, 32768], PType::I32)]
188 #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
189 fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
190 let array = PrimitiveArray::from_iter(values);
191 let result = array.narrow().unwrap();
192 assert_eq!(result.ptype(), expected_ptype);
193 }
194
195 #[rstest]
196 #[case(vec![0_u64, 255], PType::U8)]
197 #[case(vec![0_u64, 256], PType::U16)]
198 #[case(vec![0_u64, 65535], PType::U16)]
199 #[case(vec![0_u64, 65536], PType::U32)]
200 #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
201 fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
202 let array = PrimitiveArray::from_iter(values);
203 let result = array.narrow().unwrap();
204 assert_eq!(result.ptype(), expected_ptype);
205 }
206
207 #[test]
208 fn test_downcast_keeps_original_if_too_large() {
209 let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
210 let result = array.narrow().unwrap();
211 assert_eq!(result.ptype(), PType::U64);
212 }
213
214 #[test]
215 fn test_downcast_preserves_nullability() {
216 let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
217 let result = array.narrow().unwrap();
218 assert_eq!(
219 result.dtype(),
220 &DType::Primitive(PType::U8, Nullability::Nullable)
221 );
222 assert!(matches!(&result.validity, Validity::Array(_)));
224 }
225
226 #[test]
227 fn test_downcast_preserves_values() {
228 let values = vec![-100_i16, 0, 100];
229 let array = PrimitiveArray::from_iter(values);
230 let result = array.narrow().unwrap();
231
232 assert_eq!(result.ptype(), PType::I8);
233 let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
235 assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
236 }
237
238 #[test]
239 fn test_downcast_with_mixed_signs_chooses_signed() {
240 let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
241 let result = array.narrow().unwrap();
242 assert_eq!(result.ptype(), PType::I16);
243 }
244
245 #[test]
246 fn test_downcast_floats() {
247 let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
248 let result = array.narrow().unwrap();
249 assert_eq!(result.ptype(), PType::F32);
251 }
252
253 #[test]
254 fn test_downcast_empty_array() {
255 let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
256 let result = array.narrow().unwrap();
257 let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
258 let result2 = array2.narrow().unwrap();
259 assert!(matches!(result.validity, Validity::AllInvalid));
261 assert!(matches!(result2.validity, Validity::NonNullable));
262 }
263}