vortex_compute/cast/
pvector.rs1use num_traits::NumCast;
5use vortex_dtype::DType;
6use vortex_dtype::NativePType;
7use vortex_dtype::match_each_native_ptype;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11use vortex_vector::Scalar;
12use vortex_vector::ScalarOps;
13use vortex_vector::Vector;
14use vortex_vector::VectorOps;
15use vortex_vector::primitive::PScalar;
16use vortex_vector::primitive::PVector;
17use vortex_vector::primitive::PrimitiveScalar;
18use vortex_vector::primitive::PrimitiveVector;
19use vortex_vector::primitive::cast;
20
21use crate::cast::Cast;
22use crate::cast::try_cast_scalar_common;
23use crate::cast::try_cast_vector_common;
24
25impl<T: NativePType> Cast for PVector<T> {
26 type Output = Vector;
27
28 fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
30 if let Some(result) = try_cast_vector_common(self, target_dtype)? {
31 return Ok(result);
32 }
33
34 match target_dtype {
35 DType::Primitive(target_ptype, n)
37 if *target_ptype == T::PTYPE && (n.is_nullable() || self.validity().all_true()) =>
38 {
39 Ok(self.clone().into())
40 }
41 DType::Primitive(target_ptype, n) if n.is_nullable() || self.validity().all_true() => {
43 match_each_native_ptype!(*target_ptype, |Dst| {
44 let result = cast::cast_pvector::<T, Dst>(self)?;
45 Ok(PrimitiveVector::from(result).into())
46 })
47 }
48 _ => {
49 vortex_bail!("Cannot cast PVector<{}> to {}", T::PTYPE, target_dtype);
50 }
51 }
52 }
53}
54
55impl<T: NativePType> Cast for PScalar<T> {
56 type Output = Scalar;
57
58 fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
60 if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
61 return Ok(result);
62 }
63
64 match target_dtype {
65 DType::Primitive(target_ptype, n)
67 if *target_ptype == T::PTYPE && (n.is_nullable() || self.is_valid()) =>
68 {
69 Ok(self.clone().into())
70 }
71 DType::Primitive(target_ptype, n) if n.is_nullable() || self.is_valid() => {
73 match_each_native_ptype!(*target_ptype, |Dst| {
74 let result = match self.value() {
75 None => PScalar::null(),
76 Some(v) => {
77 let converted = <Dst as NumCast>::from(v).ok_or_else(|| {
78 vortex_err!(ComputeError: "Failed to cast {} to {:?}", v, Dst::PTYPE)
79 })?;
80 PScalar::new(Some(converted))
81 }
82 };
83 Ok(PrimitiveScalar::from(result).into())
84 })
85 }
86 _ => {
87 vortex_bail!("Cannot cast PScalar<{}> to {}", T::PTYPE, target_dtype);
88 }
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use rstest::rstest;
96 use vortex_buffer::BitBuffer;
97 use vortex_buffer::buffer;
98 use vortex_dtype::DType;
99 use vortex_dtype::Nullability;
100 use vortex_dtype::PType;
101 use vortex_dtype::PTypeDowncast;
102 use vortex_error::VortexError;
103 use vortex_mask::Mask;
104 use vortex_vector::ScalarOps;
105 use vortex_vector::VectorOps;
106 use vortex_vector::primitive::PScalar;
107 use vortex_vector::primitive::PVector;
108
109 use crate::cast::Cast;
110
111 #[rstest]
112 #[case(PType::U8)]
113 #[case(PType::U16)]
114 #[case(PType::U32)]
115 #[case(PType::U64)]
116 #[case(PType::I8)]
117 #[case(PType::I16)]
118 #[case(PType::I32)]
119 #[case(PType::I64)]
120 #[case(PType::F32)]
121 #[case(PType::F64)]
122 fn cast_u32_to_ptype(#[case] target: PType) {
123 let vec: PVector<u32> = buffer![0u32, 10, 100].into();
125 let result = vec.cast(&target.into()).unwrap();
126 assert!(result.as_primitive().validity().all_true());
127 assert_eq!(result.len(), 3);
128 }
129
130 #[test]
131 fn cast_various_types_to_f64() {
132 let u8_vec: PVector<u8> = buffer![0u8, 1, 2, 3, 255].into();
134 assert!(u8_vec.cast(&PType::F64.into()).is_ok());
135
136 let u16_vec: PVector<u16> = buffer![0u16, 100, 1000].into();
137 assert!(u16_vec.cast(&PType::F64.into()).is_ok());
138
139 let u32_vec: PVector<u32> = buffer![0u32, 100, 1000, 1000000].into();
140 assert!(u32_vec.cast(&PType::F64.into()).is_ok());
141
142 let i8_vec: PVector<i8> = buffer![0i8, -1, 1, 127].into();
143 assert!(i8_vec.cast(&PType::F64.into()).is_ok());
144
145 let i32_vec: PVector<i32> = buffer![-1000000i32, -1, 0, 1, 1000000].into();
146 assert!(i32_vec.cast(&PType::F64.into()).is_ok());
147
148 let f32_vec: PVector<f32> = buffer![0.0f32, 1.5, -2.5, 100.0].into();
149 assert!(f32_vec.cast(&PType::F64.into()).is_ok());
150 }
151
152 #[test]
153 fn cast_u32_u8() {
154 let vec: PVector<u32> = buffer![0u32, 10, 200].into();
155
156 let result = vec.cast(&PType::U8.into()).unwrap();
158 let p = result.into_primitive().into_u8();
159 assert_eq!(p.as_ref(), &[0u8, 10, 200]);
160 assert!(p.validity().all_true());
161 }
162
163 #[test]
164 fn cast_u32_f32() {
165 let vec: PVector<u32> = buffer![0u32, 10, 200].into();
166 let result = vec.cast(&PType::F32.into()).unwrap();
167 let p = result.into_primitive().into_f32();
168 assert_eq!(p.as_ref(), &[0.0f32, 10., 200.]);
169 }
170
171 #[test]
172 fn cast_i32_u32_overflow() {
173 let vec: PVector<i32> = buffer![-1i32].into();
174 let error = vec.cast(&PType::U32.into()).err().unwrap();
175 let VortexError::ComputeError(s, _) = error else {
176 unreachable!()
177 };
178 assert_eq!(s.to_string(), "Failed to cast -1 to U32");
179 }
180
181 #[test]
182 fn cast_with_invalid_nulls() {
183 let vec: PVector<i32> = PVector::new(
185 buffer![-1i32, 0, 10],
186 Mask::from(BitBuffer::from(vec![false, true, true])),
187 );
188
189 let result = vec
191 .cast(&DType::Primitive(PType::U32, Nullability::Nullable))
192 .unwrap();
193 let p = result.into_primitive().into_u32();
194 assert_eq!(p.as_ref(), &[0u32, 0, 10]);
195 assert_eq!(
196 *p.validity(),
197 Mask::from(BitBuffer::from(vec![false, true, true]))
198 );
199 }
200
201 #[test]
202 fn cast_all_null_vector() {
203 let vec: PVector<i32> = PVector::new(buffer![-1i32, -2, -3], Mask::new_false(3));
204
205 let result = vec
207 .cast(&DType::Primitive(PType::U32, Nullability::Nullable))
208 .unwrap();
209 let p = result.into_primitive().into_u32();
210 assert_eq!(p.as_ref(), &[0u32, 0, 0]);
211 assert!(p.validity().all_false());
212 }
213
214 #[rstest]
215 #[case(42i32, PType::U32)]
216 #[case(0i32, PType::U8)]
217 #[case(255i32, PType::U8)]
218 #[case(100i32, PType::F64)]
219 fn cast_scalar_valid(#[case] value: i32, #[case] target: PType) {
220 let scalar: PScalar<i32> = PScalar::new(Some(value));
221 let result = scalar.cast(&target.into()).unwrap();
222 assert!(result.as_primitive().is_valid());
223 }
224
225 #[test]
226 fn cast_scalar_i32_u32_overflow() {
227 let scalar: PScalar<i32> = PScalar::new(Some(-1));
228 let error = scalar.cast(&PType::U32.into()).err().unwrap();
229 let VortexError::ComputeError(s, _) = error else {
230 unreachable!()
231 };
232 assert_eq!(s.to_string(), "Failed to cast -1 to U32");
233 }
234
235 #[test]
236 fn cast_scalar_null() {
237 let scalar: PScalar<i32> = PScalar::null();
238 let result = scalar
239 .cast(&DType::Primitive(PType::U32, Nullability::Nullable))
240 .unwrap();
241 let p = result.into_primitive().into_u32();
242 assert_eq!(p.value(), None);
243 }
244
245 #[test]
246 fn cast_scalar_u32_f64() {
247 let scalar: PScalar<u32> = PScalar::new(Some(12345));
248 let result = scalar.cast(&PType::F64.into()).unwrap();
249 let p = result.into_primitive().into_f64();
250 assert_eq!(p.value(), Some(12345.0f64));
251 }
252}