vortex_compute/cast/
pvector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::NumCast;
5use vortex_dtype::DType;
6use vortex_dtype::NativePType;
7use vortex_dtype::match_each_native_ptype;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11use vortex_vector::Scalar;
12use vortex_vector::ScalarOps;
13use vortex_vector::Vector;
14use vortex_vector::VectorOps;
15use vortex_vector::primitive::PScalar;
16use vortex_vector::primitive::PVector;
17use vortex_vector::primitive::PrimitiveScalar;
18use vortex_vector::primitive::PrimitiveVector;
19use vortex_vector::primitive::cast;
20
21use crate::cast::Cast;
22use crate::cast::try_cast_scalar_common;
23use crate::cast::try_cast_vector_common;
24
25impl<T: NativePType> Cast for PVector<T> {
26    type Output = Vector;
27
28    /// Cast a primitive vector to a different primitive type.
29    fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
30        if let Some(result) = try_cast_vector_common(self, target_dtype)? {
31            return Ok(result);
32        }
33
34        match target_dtype {
35            // We have the same `PType` and we have compatible nullability.
36            DType::Primitive(target_ptype, n)
37                if *target_ptype == T::PTYPE && (n.is_nullable() || self.validity().all_true()) =>
38            {
39                Ok(self.clone().into())
40            }
41            // We can possibly convert to the target `PType` and we have compatible nullability.
42            DType::Primitive(target_ptype, n) if n.is_nullable() || self.validity().all_true() => {
43                match_each_native_ptype!(*target_ptype, |Dst| {
44                    let result = cast::cast_pvector::<T, Dst>(self)?;
45                    Ok(PrimitiveVector::from(result).into())
46                })
47            }
48            _ => {
49                vortex_bail!("Cannot cast PVector<{}> to {}", T::PTYPE, target_dtype);
50            }
51        }
52    }
53}
54
55impl<T: NativePType> Cast for PScalar<T> {
56    type Output = Scalar;
57
58    /// Cast a primitive scalar to a different primitive type.
59    fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
60        if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
61            return Ok(result);
62        }
63
64        match target_dtype {
65            // We have the same `PType` and we have compatible nullability.
66            DType::Primitive(target_ptype, n)
67                if *target_ptype == T::PTYPE && (n.is_nullable() || self.is_valid()) =>
68            {
69                Ok(self.clone().into())
70            }
71            // We can possibly convert to the target `PType` and we have compatible nullability.
72            DType::Primitive(target_ptype, n) if n.is_nullable() || self.is_valid() => {
73                match_each_native_ptype!(*target_ptype, |Dst| {
74                    let result = match self.value() {
75                        None => PScalar::null(),
76                        Some(v) => {
77                            let converted = <Dst as NumCast>::from(v).ok_or_else(|| {
78                                vortex_err!(ComputeError: "Failed to cast {} to {:?}", v, Dst::PTYPE)
79                            })?;
80                            PScalar::new(Some(converted))
81                        }
82                    };
83                    Ok(PrimitiveScalar::from(result).into())
84                })
85            }
86            _ => {
87                vortex_bail!("Cannot cast PScalar<{}> to {}", T::PTYPE, target_dtype);
88            }
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use rstest::rstest;
96    use vortex_buffer::BitBuffer;
97    use vortex_buffer::buffer;
98    use vortex_dtype::DType;
99    use vortex_dtype::Nullability;
100    use vortex_dtype::PType;
101    use vortex_dtype::PTypeDowncast;
102    use vortex_error::VortexError;
103    use vortex_mask::Mask;
104    use vortex_vector::ScalarOps;
105    use vortex_vector::VectorOps;
106    use vortex_vector::primitive::PScalar;
107    use vortex_vector::primitive::PVector;
108
109    use crate::cast::Cast;
110
111    #[rstest]
112    #[case(PType::U8)]
113    #[case(PType::U16)]
114    #[case(PType::U32)]
115    #[case(PType::U64)]
116    #[case(PType::I8)]
117    #[case(PType::I16)]
118    #[case(PType::I32)]
119    #[case(PType::I64)]
120    #[case(PType::F32)]
121    #[case(PType::F64)]
122    fn cast_u32_to_ptype(#[case] target: PType) {
123        // Use values that fit in all target types (including i8: -128..127).
124        let vec: PVector<u32> = buffer![0u32, 10, 100].into();
125        let result = vec.cast(&target.into()).unwrap();
126        assert!(result.as_primitive().validity().all_true());
127        assert_eq!(result.len(), 3);
128    }
129
130    #[test]
131    fn cast_various_types_to_f64() {
132        // Test casting from various primitive types to f64.
133        let u8_vec: PVector<u8> = buffer![0u8, 1, 2, 3, 255].into();
134        assert!(u8_vec.cast(&PType::F64.into()).is_ok());
135
136        let u16_vec: PVector<u16> = buffer![0u16, 100, 1000].into();
137        assert!(u16_vec.cast(&PType::F64.into()).is_ok());
138
139        let u32_vec: PVector<u32> = buffer![0u32, 100, 1000, 1000000].into();
140        assert!(u32_vec.cast(&PType::F64.into()).is_ok());
141
142        let i8_vec: PVector<i8> = buffer![0i8, -1, 1, 127].into();
143        assert!(i8_vec.cast(&PType::F64.into()).is_ok());
144
145        let i32_vec: PVector<i32> = buffer![-1000000i32, -1, 0, 1, 1000000].into();
146        assert!(i32_vec.cast(&PType::F64.into()).is_ok());
147
148        let f32_vec: PVector<f32> = buffer![0.0f32, 1.5, -2.5, 100.0].into();
149        assert!(f32_vec.cast(&PType::F64.into()).is_ok());
150    }
151
152    #[test]
153    fn cast_u32_u8() {
154        let vec: PVector<u32> = buffer![0u32, 10, 200].into();
155
156        // Cast from u32 to u8.
157        let result = vec.cast(&PType::U8.into()).unwrap();
158        let p = result.into_primitive().into_u8();
159        assert_eq!(p.as_ref(), &[0u8, 10, 200]);
160        assert!(p.validity().all_true());
161    }
162
163    #[test]
164    fn cast_u32_f32() {
165        let vec: PVector<u32> = buffer![0u32, 10, 200].into();
166        let result = vec.cast(&PType::F32.into()).unwrap();
167        let p = result.into_primitive().into_f32();
168        assert_eq!(p.as_ref(), &[0.0f32, 10., 200.]);
169    }
170
171    #[test]
172    fn cast_i32_u32_overflow() {
173        let vec: PVector<i32> = buffer![-1i32].into();
174        let error = vec.cast(&PType::U32.into()).err().unwrap();
175        let VortexError::ComputeError(s, _) = error else {
176            unreachable!()
177        };
178        assert_eq!(s.to_string(), "Failed to cast -1 to U32");
179    }
180
181    #[test]
182    fn cast_with_invalid_nulls() {
183        // Create a vector with an invalid value at position 0 (which would overflow).
184        let vec: PVector<i32> = PVector::new(
185            buffer![-1i32, 0, 10],
186            Mask::from(BitBuffer::from(vec![false, true, true])),
187        );
188
189        // Cast to nullable u32 should succeed because the invalid value is masked.
190        let result = vec
191            .cast(&DType::Primitive(PType::U32, Nullability::Nullable))
192            .unwrap();
193        let p = result.into_primitive().into_u32();
194        assert_eq!(p.as_ref(), &[0u32, 0, 10]);
195        assert_eq!(
196            *p.validity(),
197            Mask::from(BitBuffer::from(vec![false, true, true]))
198        );
199    }
200
201    #[test]
202    fn cast_all_null_vector() {
203        let vec: PVector<i32> = PVector::new(buffer![-1i32, -2, -3], Mask::new_false(3));
204
205        // Cast to nullable u32 should succeed because all values are masked.
206        let result = vec
207            .cast(&DType::Primitive(PType::U32, Nullability::Nullable))
208            .unwrap();
209        let p = result.into_primitive().into_u32();
210        assert_eq!(p.as_ref(), &[0u32, 0, 0]);
211        assert!(p.validity().all_false());
212    }
213
214    #[rstest]
215    #[case(42i32, PType::U32)]
216    #[case(0i32, PType::U8)]
217    #[case(255i32, PType::U8)]
218    #[case(100i32, PType::F64)]
219    fn cast_scalar_valid(#[case] value: i32, #[case] target: PType) {
220        let scalar: PScalar<i32> = PScalar::new(Some(value));
221        let result = scalar.cast(&target.into()).unwrap();
222        assert!(result.as_primitive().is_valid());
223    }
224
225    #[test]
226    fn cast_scalar_i32_u32_overflow() {
227        let scalar: PScalar<i32> = PScalar::new(Some(-1));
228        let error = scalar.cast(&PType::U32.into()).err().unwrap();
229        let VortexError::ComputeError(s, _) = error else {
230            unreachable!()
231        };
232        assert_eq!(s.to_string(), "Failed to cast -1 to U32");
233    }
234
235    #[test]
236    fn cast_scalar_null() {
237        let scalar: PScalar<i32> = PScalar::null();
238        let result = scalar
239            .cast(&DType::Primitive(PType::U32, Nullability::Nullable))
240            .unwrap();
241        let p = result.into_primitive().into_u32();
242        assert_eq!(p.value(), None);
243    }
244
245    #[test]
246    fn cast_scalar_u32_f64() {
247        let scalar: PScalar<u32> = PScalar::new(Some(12345));
248        let result = scalar.cast(&PType::F64.into()).unwrap();
249        let p = result.into_primitive().into_f64();
250        assert_eq!(p.value(), Some(12345.0f64));
251    }
252}