vortex_array/arrays/primitive/array/
cast.rs1use vortex_buffer::Buffer;
5use vortex_dtype::DType;
6use vortex_dtype::NativePType;
7use vortex_dtype::PType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_panic;
11
12use crate::ToCanonical;
13use crate::arrays::PrimitiveArray;
14use crate::builtins::ArrayBuiltins;
15use crate::compute::min_max;
16use crate::vtable::ValidityHelper;
17
18impl PrimitiveArray {
19 pub fn as_slice<T: NativePType>(&self) -> &[T] {
27 if T::PTYPE != self.ptype() {
28 vortex_panic!(
29 "Attempted to get slice of type {} from array of type {}",
30 T::PTYPE,
31 self.ptype()
32 )
33 }
34
35 let byte_buffer = self
36 .buffer
37 .as_host_opt()
38 .vortex_expect("as_slice must be called on host buffer");
39 let raw_slice = byte_buffer.as_ptr();
40
41 unsafe { std::slice::from_raw_parts(raw_slice.cast(), byte_buffer.len() / size_of::<T>()) }
43 }
44
45 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
46 if self.ptype() == ptype {
47 return self.clone();
48 }
49
50 assert_eq!(
51 self.ptype().byte_width(),
52 ptype.byte_width(),
53 "can't reinterpret cast between integers of two different widths"
54 );
55
56 PrimitiveArray::from_buffer_handle(
57 self.buffer_handle().clone(),
58 ptype,
59 self.validity().clone(),
60 )
61 }
62
63 pub fn narrow(&self) -> VortexResult<PrimitiveArray> {
65 if !self.ptype().is_int() {
66 return Ok(self.clone());
67 }
68
69 let Some(min_max) = min_max(self.as_ref())? else {
70 return Ok(PrimitiveArray::new(
71 Buffer::<u8>::zeroed(self.len()),
72 self.validity.clone(),
73 ));
74 };
75
76 let Ok(min) = min_max
79 .min
80 .cast(&PType::I64.into())
81 .and_then(|s| i64::try_from(&s))
82 else {
83 return Ok(self.clone());
84 };
85 let Ok(max) = min_max
86 .max
87 .cast(&PType::I64.into())
88 .and_then(|s| i64::try_from(&s))
89 else {
90 return Ok(self.clone());
91 };
92
93 if min < 0 || max < 0 {
94 if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
96 return Ok(self
97 .to_array()
98 .cast(DType::Primitive(PType::I8, self.dtype().nullability()))?
99 .to_primitive());
100 }
101
102 if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
103 return Ok(self
104 .to_array()
105 .cast(DType::Primitive(PType::I16, self.dtype().nullability()))?
106 .to_primitive());
107 }
108
109 if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
110 return Ok(self
111 .to_array()
112 .cast(DType::Primitive(PType::I32, self.dtype().nullability()))?
113 .to_primitive());
114 }
115 } else {
116 if max <= u8::MAX as i64 {
118 return Ok(self
119 .to_array()
120 .cast(DType::Primitive(PType::U8, self.dtype().nullability()))?
121 .to_primitive());
122 }
123
124 if max <= u16::MAX as i64 {
125 return Ok(self
126 .to_array()
127 .cast(DType::Primitive(PType::U16, self.dtype().nullability()))?
128 .to_primitive());
129 }
130
131 if max <= u32::MAX as i64 {
132 return Ok(self
133 .to_array()
134 .cast(DType::Primitive(PType::U32, self.dtype().nullability()))?
135 .to_primitive());
136 }
137 }
138
139 Ok(self.clone())
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use rstest::rstest;
146 use vortex_buffer::Buffer;
147 use vortex_buffer::buffer;
148 use vortex_dtype::DType;
149 use vortex_dtype::Nullability;
150 use vortex_dtype::PType;
151
152 use crate::arrays::PrimitiveArray;
153 use crate::validity::Validity;
154
155 #[test]
156 fn test_downcast_all_invalid() {
157 let array = PrimitiveArray::new(
158 buffer![0_u32, 0, 0, 0, 0, 0, 0, 0, 0, 0],
159 Validity::AllInvalid,
160 );
161
162 let result = array.narrow().unwrap();
163 assert_eq!(
164 result.dtype(),
165 &DType::Primitive(PType::U8, Nullability::Nullable)
166 );
167 assert_eq!(result.validity, Validity::AllInvalid);
168 }
169
170 #[rstest]
171 #[case(vec![0_i64, 127], PType::U8)]
172 #[case(vec![-128_i64, 127], PType::I8)]
173 #[case(vec![-129_i64, 127], PType::I16)]
174 #[case(vec![-128_i64, 128], PType::I16)]
175 #[case(vec![-32768_i64, 32767], PType::I16)]
176 #[case(vec![-32769_i64, 32767], PType::I32)]
177 #[case(vec![-32768_i64, 32768], PType::I32)]
178 #[case(vec![i32::MIN as i64, i32::MAX as i64], PType::I32)]
179 fn test_downcast_signed(#[case] values: Vec<i64>, #[case] expected_ptype: PType) {
180 let array = PrimitiveArray::from_iter(values);
181 let result = array.narrow().unwrap();
182 assert_eq!(result.ptype(), expected_ptype);
183 }
184
185 #[rstest]
186 #[case(vec![0_u64, 255], PType::U8)]
187 #[case(vec![0_u64, 256], PType::U16)]
188 #[case(vec![0_u64, 65535], PType::U16)]
189 #[case(vec![0_u64, 65536], PType::U32)]
190 #[case(vec![0_u64, u32::MAX as u64], PType::U32)]
191 fn test_downcast_unsigned(#[case] values: Vec<u64>, #[case] expected_ptype: PType) {
192 let array = PrimitiveArray::from_iter(values);
193 let result = array.narrow().unwrap();
194 assert_eq!(result.ptype(), expected_ptype);
195 }
196
197 #[test]
198 fn test_downcast_keeps_original_if_too_large() {
199 let array = PrimitiveArray::from_iter(vec![0_u64, u64::MAX]);
200 let result = array.narrow().unwrap();
201 assert_eq!(result.ptype(), PType::U64);
202 }
203
204 #[test]
205 fn test_downcast_preserves_nullability() {
206 let array = PrimitiveArray::from_option_iter([Some(0_i32), None, Some(127)]);
207 let result = array.narrow().unwrap();
208 assert_eq!(
209 result.dtype(),
210 &DType::Primitive(PType::U8, Nullability::Nullable)
211 );
212 assert!(matches!(&result.validity, Validity::Array(_)));
214 }
215
216 #[test]
217 fn test_downcast_preserves_values() {
218 let values = vec![-100_i16, 0, 100];
219 let array = PrimitiveArray::from_iter(values);
220 let result = array.narrow().unwrap();
221
222 assert_eq!(result.ptype(), PType::I8);
223 let downscaled_values: Vec<i8> = result.as_slice::<i8>().to_vec();
225 assert_eq!(downscaled_values, vec![-100_i8, 0, 100]);
226 }
227
228 #[test]
229 fn test_downcast_with_mixed_signs_chooses_signed() {
230 let array = PrimitiveArray::from_iter(vec![-1_i32, 200]);
231 let result = array.narrow().unwrap();
232 assert_eq!(result.ptype(), PType::I16);
233 }
234
235 #[test]
236 fn test_downcast_floats() {
237 let array = PrimitiveArray::from_iter(vec![1.0_f32, 2.0, 3.0]);
238 let result = array.narrow().unwrap();
239 assert_eq!(result.ptype(), PType::F32);
241 }
242
243 #[test]
244 fn test_downcast_empty_array() {
245 let array = PrimitiveArray::new(Buffer::<i32>::empty(), Validity::AllInvalid);
246 let result = array.narrow().unwrap();
247 let array2 = PrimitiveArray::new(Buffer::<i64>::empty(), Validity::NonNullable);
248 let result2 = array2.narrow().unwrap();
249 assert_eq!(result.validity, Validity::AllInvalid);
251 assert_eq!(result2.validity, Validity::NonNullable);
252 }
253}