vortex_scalar/
pvalue.rs

1use core::fmt::Display;
2use std::cmp::Ordering;
3use std::hash::{Hash, Hasher};
4use std::mem;
5
6use num_traits::NumCast;
7use paste::paste;
8use vortex_dtype::half::f16;
9use vortex_dtype::{NativePType, PType, ToBytes};
10use vortex_error::{VortexError, VortexExpect, vortex_err};
11
12#[derive(Debug, Clone, Copy)]
13pub enum PValue {
14    U8(u8),
15    U16(u16),
16    U32(u32),
17    U64(u64),
18    I8(i8),
19    I16(i16),
20    I32(i32),
21    I64(i64),
22    F16(f16),
23    F32(f32),
24    F64(f64),
25}
26
27impl PartialEq for PValue {
28    fn eq(&self, other: &Self) -> bool {
29        match (self, other) {
30            (Self::U8(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
31            (Self::U16(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
32            (Self::U32(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
33            (Self::U64(s), o) => o.as_u64().vortex_expect("upcast") == *s,
34            (Self::I8(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
35            (Self::I16(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
36            (Self::I32(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
37            (Self::I64(s), o) => o.as_i64().vortex_expect("upcast") == *s,
38            (Self::F16(s), Self::F16(o)) => s.is_eq(*o),
39            (Self::F32(s), Self::F32(o)) => s.is_eq(*o),
40            (Self::F64(s), Self::F64(o)) => s.is_eq(*o),
41            (..) => false,
42        }
43    }
44}
45
46impl Eq for PValue {}
47
48impl PartialOrd for PValue {
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        match (self, other) {
51            (Self::U8(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
52            (Self::U16(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
53            (Self::U32(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
54            (Self::U64(s), o) => Some((*s).cmp(&o.as_u64().vortex_expect("upcast"))),
55            (Self::I8(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
56            (Self::I16(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
57            (Self::I32(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
58            (Self::I64(s), o) => Some((*s).cmp(&o.as_i64().vortex_expect("upcast"))),
59            (Self::F16(s), Self::F16(o)) => Some(s.total_compare(*o)),
60            (Self::F32(s), Self::F32(o)) => Some(s.total_compare(*o)),
61            (Self::F64(s), Self::F64(o)) => Some(s.total_compare(*o)),
62            (..) => None,
63        }
64    }
65}
66
67impl Hash for PValue {
68    fn hash<H: Hasher>(&self, state: &mut H) {
69        self.to_le_bytes().hash(state);
70    }
71}
72
73impl ToBytes for PValue {
74    fn to_le_bytes(&self) -> &[u8] {
75        match self {
76            PValue::U8(v) => v.to_le_bytes(),
77            PValue::U16(v) => v.to_le_bytes(),
78            PValue::U32(v) => v.to_le_bytes(),
79            PValue::U64(v) => v.to_le_bytes(),
80            PValue::I8(v) => v.to_le_bytes(),
81            PValue::I16(v) => v.to_le_bytes(),
82            PValue::I32(v) => v.to_le_bytes(),
83            PValue::I64(v) => v.to_le_bytes(),
84            PValue::F16(v) => v.to_le_bytes(),
85            PValue::F32(v) => v.to_le_bytes(),
86            PValue::F64(v) => v.to_le_bytes(),
87        }
88    }
89}
90
91macro_rules! as_primitive {
92    ($T:ty, $PT:tt) => {
93        paste! {
94            #[doc = "Access PValue as `" $T "`, returning `None` if conversion is unsuccessful"]
95            pub fn [<as_ $T>](self) -> Option<$T> {
96                match self {
97                    PValue::U8(v) => <$T as NumCast>::from(v),
98                    PValue::U16(v) => <$T as NumCast>::from(v),
99                    PValue::U32(v) => <$T as NumCast>::from(v),
100                    PValue::U64(v) => <$T as NumCast>::from(v),
101                    PValue::I8(v) => <$T as NumCast>::from(v),
102                    PValue::I16(v) => <$T as NumCast>::from(v),
103                    PValue::I32(v) => <$T as NumCast>::from(v),
104                    PValue::I64(v) => <$T as NumCast>::from(v),
105                    PValue::F16(v) => <$T as NumCast>::from(v),
106                    PValue::F32(v) => <$T as NumCast>::from(v),
107                    PValue::F64(v) => <$T as NumCast>::from(v),
108                }
109            }
110        }
111    };
112}
113
114impl PValue {
115    pub fn ptype(&self) -> PType {
116        match self {
117            Self::U8(_) => PType::U8,
118            Self::U16(_) => PType::U16,
119            Self::U32(_) => PType::U32,
120            Self::U64(_) => PType::U64,
121            Self::I8(_) => PType::I8,
122            Self::I16(_) => PType::I16,
123            Self::I32(_) => PType::I32,
124            Self::I64(_) => PType::I64,
125            Self::F16(_) => PType::F16,
126            Self::F32(_) => PType::F32,
127            Self::F64(_) => PType::F64,
128        }
129    }
130
131    pub fn is_instance_of(&self, ptype: &PType) -> bool {
132        &self.ptype() == ptype
133    }
134
135    #[inline]
136    pub fn as_primitive<T: NativePType + TryFrom<PValue, Error = VortexError>>(
137        &self,
138    ) -> Result<T, VortexError> {
139        T::try_from(*self)
140    }
141
142    #[allow(clippy::transmute_int_to_float, clippy::transmute_float_to_int)]
143    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
144        if ptype == self.ptype() {
145            return *self;
146        }
147
148        assert_eq!(
149            ptype.byte_width(),
150            self.ptype().byte_width(),
151            "Cannot reinterpret cast between types of different widths"
152        );
153
154        match self {
155            PValue::U8(v) => unsafe { mem::transmute::<u8, i8>(*v) }.into(),
156            PValue::U16(v) => match ptype {
157                PType::I16 => unsafe { mem::transmute::<u16, i16>(*v) }.into(),
158                PType::F16 => unsafe { mem::transmute::<u16, f16>(*v) }.into(),
159                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
160            },
161            PValue::U32(v) => match ptype {
162                PType::I32 => unsafe { mem::transmute::<u32, i32>(*v) }.into(),
163                PType::F32 => unsafe { mem::transmute::<u32, f32>(*v) }.into(),
164                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
165            },
166            PValue::U64(v) => match ptype {
167                PType::I64 => unsafe { mem::transmute::<u64, i64>(*v) }.into(),
168                PType::F64 => unsafe { mem::transmute::<u64, f64>(*v) }.into(),
169                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
170            },
171            PValue::I8(v) => unsafe { mem::transmute::<i8, u8>(*v) }.into(),
172            PValue::I16(v) => match ptype {
173                PType::U16 => unsafe { mem::transmute::<i16, u16>(*v) }.into(),
174                PType::F16 => unsafe { mem::transmute::<i16, f16>(*v) }.into(),
175                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
176            },
177            PValue::I32(v) => match ptype {
178                PType::U32 => unsafe { mem::transmute::<i32, u32>(*v) }.into(),
179                PType::F32 => unsafe { mem::transmute::<i32, f32>(*v) }.into(),
180                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
181            },
182            PValue::I64(v) => match ptype {
183                PType::U64 => unsafe { mem::transmute::<i64, u64>(*v) }.into(),
184                PType::F64 => unsafe { mem::transmute::<i64, f64>(*v) }.into(),
185                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
186            },
187            PValue::F16(v) => match ptype {
188                PType::U16 => unsafe { mem::transmute::<f16, u16>(*v) }.into(),
189                PType::I16 => unsafe { mem::transmute::<f16, i16>(*v) }.into(),
190                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
191            },
192            PValue::F32(v) => match ptype {
193                PType::U32 => unsafe { mem::transmute::<f32, u32>(*v) }.into(),
194                PType::I32 => unsafe { mem::transmute::<f32, i32>(*v) }.into(),
195                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
196            },
197            PValue::F64(v) => match ptype {
198                PType::U64 => unsafe { mem::transmute::<f64, u64>(*v) }.into(),
199                PType::I64 => unsafe { mem::transmute::<f64, i64>(*v) }.into(),
200                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
201            },
202        }
203    }
204
205    as_primitive!(i8, I8);
206    as_primitive!(i16, I16);
207    as_primitive!(i32, I32);
208    as_primitive!(i64, I64);
209    as_primitive!(u8, U8);
210    as_primitive!(u16, U16);
211    as_primitive!(u32, U32);
212    as_primitive!(u64, U64);
213    as_primitive!(f16, F16);
214    as_primitive!(f32, F32);
215    as_primitive!(f64, F64);
216}
217
218macro_rules! int_pvalue {
219    ($T:ty, $PT:tt) => {
220        impl TryFrom<PValue> for $T {
221            type Error = VortexError;
222
223            fn try_from(value: PValue) -> Result<Self, Self::Error> {
224                match value {
225                    PValue::U8(v) => <$T as NumCast>::from(v),
226                    PValue::U16(v) => <$T as NumCast>::from(v),
227                    PValue::U32(v) => <$T as NumCast>::from(v),
228                    PValue::U64(v) => <$T as NumCast>::from(v),
229                    PValue::I8(v) => <$T as NumCast>::from(v),
230                    PValue::I16(v) => <$T as NumCast>::from(v),
231                    PValue::I32(v) => <$T as NumCast>::from(v),
232                    PValue::I64(v) => <$T as NumCast>::from(v),
233                    _ => None,
234                }
235                .ok_or_else(|| {
236                    vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT)
237                })
238            }
239        }
240    };
241}
242
243int_pvalue!(u8, U8);
244int_pvalue!(u16, U16);
245int_pvalue!(u32, U32);
246int_pvalue!(u64, U64);
247int_pvalue!(usize, U64);
248int_pvalue!(i8, I8);
249int_pvalue!(i16, I16);
250int_pvalue!(i32, I32);
251int_pvalue!(i64, I64);
252
253impl TryFrom<PValue> for f64 {
254    type Error = VortexError;
255
256    fn try_from(value: PValue) -> Result<Self, Self::Error> {
257        // We serialize f64 as u64, but this can also sometimes be narrowed down to u8 if e.g. == 0
258        match value {
259            PValue::U8(u) => Some(Self::from_bits(u as u64)),
260            PValue::U16(u) => Some(Self::from_bits(u as u64)),
261            PValue::U32(u) => Some(Self::from_bits(u as u64)),
262            PValue::U64(u) => Some(Self::from_bits(u)),
263            PValue::F16(f) => <Self as NumCast>::from(f),
264            PValue::F32(f) => <Self as NumCast>::from(f),
265            PValue::F64(f) => <Self as NumCast>::from(f),
266            _ => None,
267        }
268        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F64))
269    }
270}
271
272impl TryFrom<PValue> for f32 {
273    type Error = VortexError;
274
275    #[allow(clippy::cast_possible_truncation)]
276    fn try_from(value: PValue) -> Result<Self, Self::Error> {
277        // We serialize f32 as u32, but this can also sometimes be narrowed down to u8 if e.g. == 0
278        match value {
279            PValue::U8(u) => Some(Self::from_bits(u as u32)),
280            PValue::U16(u) => Some(Self::from_bits(u as u32)),
281            PValue::U32(u) => Some(Self::from_bits(u)),
282            // We assume that the value was created from a valid f16 and only changed in serialization
283            PValue::U64(u) => <Self as NumCast>::from(Self::from_bits(u as u32)),
284            PValue::F16(f) => <Self as NumCast>::from(f),
285            PValue::F32(f) => <Self as NumCast>::from(f),
286            PValue::F64(f) => <Self as NumCast>::from(f),
287            _ => None,
288        }
289        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F32))
290    }
291}
292
293impl TryFrom<PValue> for f16 {
294    type Error = VortexError;
295
296    #[allow(clippy::cast_possible_truncation)]
297    fn try_from(value: PValue) -> Result<Self, Self::Error> {
298        // We serialize f16 as u16, but this can also sometimes be narrowed down to u8 if e.g. == 0
299        match value {
300            PValue::U8(u) => Some(Self::from_bits(u as u16)),
301            PValue::U16(u) => Some(Self::from_bits(u)),
302            // We assume that the value was created from a valid f16 and only changed in serialization
303            PValue::U32(u) => Some(Self::from_bits(u as u16)),
304            PValue::U64(u) => Some(Self::from_bits(u as u16)),
305            PValue::F16(u) => Some(u),
306            PValue::F32(f) => <Self as NumCast>::from(f),
307            PValue::F64(f) => <Self as NumCast>::from(f),
308            _ => None,
309        }
310        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F16))
311    }
312}
313
314macro_rules! impl_pvalue {
315    ($T:ty, $PT:tt) => {
316        impl From<$T> for PValue {
317            fn from(value: $T) -> Self {
318                PValue::$PT(value)
319            }
320        }
321    };
322}
323
324impl_pvalue!(u8, U8);
325impl_pvalue!(u16, U16);
326impl_pvalue!(u32, U32);
327impl_pvalue!(u64, U64);
328impl_pvalue!(i8, I8);
329impl_pvalue!(i16, I16);
330impl_pvalue!(i32, I32);
331impl_pvalue!(i64, I64);
332impl_pvalue!(f16, F16);
333impl_pvalue!(f32, F32);
334impl_pvalue!(f64, F64);
335
336impl From<usize> for PValue {
337    fn from(value: usize) -> PValue {
338        PValue::U64(value as u64)
339    }
340}
341
342impl Display for PValue {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        match self {
345            Self::U8(v) => write!(f, "{}u8", v),
346            Self::U16(v) => write!(f, "{}u16", v),
347            Self::U32(v) => write!(f, "{}u32", v),
348            Self::U64(v) => write!(f, "{}u64", v),
349            Self::I8(v) => write!(f, "{}i8", v),
350            Self::I16(v) => write!(f, "{}i16", v),
351            Self::I32(v) => write!(f, "{}i32", v),
352            Self::I64(v) => write!(f, "{}i64", v),
353            Self::F16(v) => write!(f, "{}f16", v),
354            Self::F32(v) => write!(f, "{}f32", v),
355            Self::F64(v) => write!(f, "{}f64", v),
356        }
357    }
358}
359
360#[cfg(test)]
361mod test {
362    use std::cmp::Ordering;
363
364    use vortex_dtype::PType;
365    use vortex_dtype::half::f16;
366
367    use crate::PValue;
368
369    #[test]
370    pub fn test_is_instance_of() {
371        assert!(PValue::U8(10).is_instance_of(&PType::U8));
372        assert!(!PValue::U8(10).is_instance_of(&PType::U16));
373        assert!(!PValue::U8(10).is_instance_of(&PType::I8));
374        assert!(!PValue::U8(10).is_instance_of(&PType::F16));
375
376        assert!(PValue::I8(10).is_instance_of(&PType::I8));
377        assert!(!PValue::I8(10).is_instance_of(&PType::I16));
378        assert!(!PValue::I8(10).is_instance_of(&PType::U8));
379        assert!(!PValue::I8(10).is_instance_of(&PType::F16));
380
381        assert!(PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F16));
382        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F32));
383        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::U16));
384        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::I16));
385    }
386
387    #[test]
388    fn test_compare_different_types() {
389        assert_eq!(
390            PValue::I8(4).partial_cmp(&PValue::I8(5)),
391            Some(Ordering::Less)
392        );
393        assert_eq!(
394            PValue::I8(4).partial_cmp(&PValue::I64(5)),
395            Some(Ordering::Less)
396        );
397    }
398}