vortex_scalar/
pvalue.rs

1use core::fmt::Display;
2use std::cmp::Ordering;
3use std::hash::{Hash, Hasher};
4
5use num_traits::NumCast;
6use paste::paste;
7use vortex_dtype::half::f16;
8use vortex_dtype::{NativePType, PType, ToBytes};
9use vortex_error::{VortexError, VortexExpect, vortex_err};
10
11#[derive(Debug, Clone, Copy)]
12pub enum PValue {
13    U8(u8),
14    U16(u16),
15    U32(u32),
16    U64(u64),
17    I8(i8),
18    I16(i16),
19    I32(i32),
20    I64(i64),
21    F16(f16),
22    F32(f32),
23    F64(f64),
24}
25
26impl PartialEq for PValue {
27    fn eq(&self, other: &Self) -> bool {
28        match (self, other) {
29            (Self::U8(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
30            (Self::U16(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
31            (Self::U32(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
32            (Self::U64(s), o) => o.as_u64().vortex_expect("upcast") == *s,
33            (Self::I8(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
34            (Self::I16(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
35            (Self::I32(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
36            (Self::I64(s), o) => o.as_i64().vortex_expect("upcast") == *s,
37            (Self::F16(s), Self::F16(o)) => s.is_eq(*o),
38            (Self::F32(s), Self::F32(o)) => s.is_eq(*o),
39            (Self::F64(s), Self::F64(o)) => s.is_eq(*o),
40            (..) => false,
41        }
42    }
43}
44
45impl Eq for PValue {}
46
47impl PartialOrd for PValue {
48    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
49        match (self, other) {
50            (Self::U8(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
51            (Self::U16(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
52            (Self::U32(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
53            (Self::U64(s), o) => Some((*s).cmp(&o.as_u64().vortex_expect("upcast"))),
54            (Self::I8(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
55            (Self::I16(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
56            (Self::I32(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
57            (Self::I64(s), o) => Some((*s).cmp(&o.as_i64().vortex_expect("upcast"))),
58            (Self::F16(s), Self::F16(o)) => Some(s.total_compare(*o)),
59            (Self::F32(s), Self::F32(o)) => Some(s.total_compare(*o)),
60            (Self::F64(s), Self::F64(o)) => Some(s.total_compare(*o)),
61            (..) => None,
62        }
63    }
64}
65
66impl Hash for PValue {
67    fn hash<H: Hasher>(&self, state: &mut H) {
68        match self {
69            PValue::U8(_) | PValue::U16(_) | PValue::U32(_) | PValue::U64(_) => {
70                self.as_u64().vortex_expect("upcast").hash(state)
71            }
72            PValue::I8(_) | PValue::I16(_) | PValue::I32(_) | PValue::I64(_) => {
73                self.as_i64().vortex_expect("upcast").hash(state)
74            }
75            PValue::F16(v) => v.to_le_bytes().hash(state),
76            PValue::F32(v) => v.to_le_bytes().hash(state),
77            PValue::F64(v) => v.to_le_bytes().hash(state),
78        }
79    }
80}
81
82impl ToBytes for PValue {
83    fn to_le_bytes(&self) -> &[u8] {
84        match self {
85            PValue::U8(v) => v.to_le_bytes(),
86            PValue::U16(v) => v.to_le_bytes(),
87            PValue::U32(v) => v.to_le_bytes(),
88            PValue::U64(v) => v.to_le_bytes(),
89            PValue::I8(v) => v.to_le_bytes(),
90            PValue::I16(v) => v.to_le_bytes(),
91            PValue::I32(v) => v.to_le_bytes(),
92            PValue::I64(v) => v.to_le_bytes(),
93            PValue::F16(v) => v.to_le_bytes(),
94            PValue::F32(v) => v.to_le_bytes(),
95            PValue::F64(v) => v.to_le_bytes(),
96        }
97    }
98}
99
100macro_rules! as_primitive {
101    ($T:ty, $PT:tt) => {
102        paste! {
103            #[doc = "Access PValue as `" $T "`, returning `None` if conversion is unsuccessful"]
104            pub fn [<as_ $T>](self) -> Option<$T> {
105                match self {
106                    PValue::U8(v) => <$T as NumCast>::from(v),
107                    PValue::U16(v) => <$T as NumCast>::from(v),
108                    PValue::U32(v) => <$T as NumCast>::from(v),
109                    PValue::U64(v) => <$T as NumCast>::from(v),
110                    PValue::I8(v) => <$T as NumCast>::from(v),
111                    PValue::I16(v) => <$T as NumCast>::from(v),
112                    PValue::I32(v) => <$T as NumCast>::from(v),
113                    PValue::I64(v) => <$T as NumCast>::from(v),
114                    PValue::F16(v) => <$T as NumCast>::from(v),
115                    PValue::F32(v) => <$T as NumCast>::from(v),
116                    PValue::F64(v) => <$T as NumCast>::from(v),
117                }
118            }
119        }
120    };
121}
122
123impl PValue {
124    pub fn zero(ptype: PType) -> PValue {
125        match ptype {
126            PType::U8 => PValue::U8(0),
127            PType::U16 => PValue::U16(0),
128            PType::U32 => PValue::U32(0),
129            PType::U64 => PValue::U64(0),
130            PType::I8 => PValue::I8(0),
131            PType::I16 => PValue::I16(0),
132            PType::I32 => PValue::I32(0),
133            PType::I64 => PValue::I64(0),
134            PType::F16 => PValue::F16(f16::from_f32(0.0)),
135            PType::F32 => PValue::F32(0.0),
136            PType::F64 => PValue::F64(0.0),
137        }
138    }
139
140    pub fn ptype(&self) -> PType {
141        match self {
142            Self::U8(_) => PType::U8,
143            Self::U16(_) => PType::U16,
144            Self::U32(_) => PType::U32,
145            Self::U64(_) => PType::U64,
146            Self::I8(_) => PType::I8,
147            Self::I16(_) => PType::I16,
148            Self::I32(_) => PType::I32,
149            Self::I64(_) => PType::I64,
150            Self::F16(_) => PType::F16,
151            Self::F32(_) => PType::F32,
152            Self::F64(_) => PType::F64,
153        }
154    }
155
156    pub fn is_instance_of(&self, ptype: &PType) -> bool {
157        &self.ptype() == ptype
158    }
159
160    #[inline]
161    pub fn as_primitive<T: NativePType + TryFrom<PValue, Error = VortexError>>(
162        &self,
163    ) -> Result<T, VortexError> {
164        T::try_from(*self)
165    }
166
167    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
168        if ptype == self.ptype() {
169            return *self;
170        }
171
172        assert_eq!(
173            ptype.byte_width(),
174            self.ptype().byte_width(),
175            "Cannot reinterpret cast between types of different widths"
176        );
177
178        match self {
179            PValue::U8(v) => u8::cast_signed(*v).into(),
180            PValue::U16(v) => match ptype {
181                PType::I16 => u16::cast_signed(*v).into(),
182                PType::F16 => f16::from_bits(*v).into(),
183                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
184            },
185            PValue::U32(v) => match ptype {
186                PType::I32 => u32::cast_signed(*v).into(),
187                PType::F32 => f32::from_bits(*v).into(),
188                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
189            },
190            PValue::U64(v) => match ptype {
191                PType::I64 => u64::cast_signed(*v).into(),
192                PType::F64 => f64::from_bits(*v).into(),
193                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
194            },
195            PValue::I8(v) => i8::cast_unsigned(*v).into(),
196            PValue::I16(v) => match ptype {
197                PType::U16 => i16::cast_unsigned(*v).into(),
198                PType::F16 => f16::from_bits(v.cast_unsigned()).into(),
199                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
200            },
201            PValue::I32(v) => match ptype {
202                PType::U32 => i32::cast_unsigned(*v).into(),
203                PType::F32 => f32::from_bits(i32::cast_unsigned(*v)).into(),
204                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
205            },
206            PValue::I64(v) => match ptype {
207                PType::U64 => i64::cast_unsigned(*v).into(),
208                PType::F64 => f64::from_bits(i64::cast_unsigned(*v)).into(),
209                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
210            },
211            PValue::F16(v) => match ptype {
212                PType::U16 => v.to_bits().into(),
213                PType::I16 => v.to_bits().cast_signed().into(),
214                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
215            },
216            PValue::F32(v) => match ptype {
217                PType::U32 => f32::to_bits(*v).into(),
218                PType::I32 => f32::to_bits(*v).cast_signed().into(),
219                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
220            },
221            PValue::F64(v) => match ptype {
222                PType::U64 => f64::to_bits(*v).into(),
223                PType::I64 => f64::to_bits(*v).cast_signed().into(),
224                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
225            },
226        }
227    }
228
229    as_primitive!(i8, I8);
230    as_primitive!(i16, I16);
231    as_primitive!(i32, I32);
232    as_primitive!(i64, I64);
233    as_primitive!(u8, U8);
234    as_primitive!(u16, U16);
235    as_primitive!(u32, U32);
236    as_primitive!(u64, U64);
237    as_primitive!(f16, F16);
238    as_primitive!(f32, F32);
239    as_primitive!(f64, F64);
240}
241
242macro_rules! int_pvalue {
243    ($T:ty, $PT:tt) => {
244        impl TryFrom<PValue> for $T {
245            type Error = VortexError;
246
247            fn try_from(value: PValue) -> Result<Self, Self::Error> {
248                match value {
249                    PValue::U8(v) => <$T as NumCast>::from(v),
250                    PValue::U16(v) => <$T as NumCast>::from(v),
251                    PValue::U32(v) => <$T as NumCast>::from(v),
252                    PValue::U64(v) => <$T as NumCast>::from(v),
253                    PValue::I8(v) => <$T as NumCast>::from(v),
254                    PValue::I16(v) => <$T as NumCast>::from(v),
255                    PValue::I32(v) => <$T as NumCast>::from(v),
256                    PValue::I64(v) => <$T as NumCast>::from(v),
257                    _ => None,
258                }
259                .ok_or_else(|| {
260                    vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT)
261                })
262            }
263        }
264    };
265}
266
267int_pvalue!(u8, U8);
268int_pvalue!(u16, U16);
269int_pvalue!(u32, U32);
270int_pvalue!(u64, U64);
271int_pvalue!(usize, U64);
272int_pvalue!(i8, I8);
273int_pvalue!(i16, I16);
274int_pvalue!(i32, I32);
275int_pvalue!(i64, I64);
276
277impl TryFrom<PValue> for f64 {
278    type Error = VortexError;
279
280    fn try_from(value: PValue) -> Result<Self, Self::Error> {
281        // We serialize f64 as u64, but this can also sometimes be narrowed down to u8 if e.g. == 0
282        match value {
283            PValue::U8(u) => Some(Self::from_bits(u as u64)),
284            PValue::U16(u) => Some(Self::from_bits(u as u64)),
285            PValue::U32(u) => Some(Self::from_bits(u as u64)),
286            PValue::U64(u) => Some(Self::from_bits(u)),
287            PValue::F16(f) => <Self as NumCast>::from(f),
288            PValue::F32(f) => <Self as NumCast>::from(f),
289            PValue::F64(f) => <Self as NumCast>::from(f),
290            _ => None,
291        }
292        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F64))
293    }
294}
295
296impl TryFrom<PValue> for f32 {
297    type Error = VortexError;
298
299    #[allow(clippy::cast_possible_truncation)]
300    fn try_from(value: PValue) -> Result<Self, Self::Error> {
301        // We serialize f32 as u32, but this can also sometimes be narrowed down to u8 if e.g. == 0
302        match value {
303            PValue::U8(u) => Some(Self::from_bits(u as u32)),
304            PValue::U16(u) => Some(Self::from_bits(u as u32)),
305            PValue::U32(u) => Some(Self::from_bits(u)),
306            // We assume that the value was created from a valid f16 and only changed in serialization
307            PValue::U64(u) => <Self as NumCast>::from(Self::from_bits(u as u32)),
308            PValue::F16(f) => <Self as NumCast>::from(f),
309            PValue::F32(f) => <Self as NumCast>::from(f),
310            PValue::F64(f) => <Self as NumCast>::from(f),
311            _ => None,
312        }
313        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F32))
314    }
315}
316
317impl TryFrom<PValue> for f16 {
318    type Error = VortexError;
319
320    #[allow(clippy::cast_possible_truncation)]
321    fn try_from(value: PValue) -> Result<Self, Self::Error> {
322        // We serialize f16 as u16, but this can also sometimes be narrowed down to u8 if e.g. == 0
323        match value {
324            PValue::U8(u) => Some(Self::from_bits(u as u16)),
325            PValue::U16(u) => Some(Self::from_bits(u)),
326            // We assume that the value was created from a valid f16 and only changed in serialization
327            PValue::U32(u) => Some(Self::from_bits(u as u16)),
328            PValue::U64(u) => Some(Self::from_bits(u as u16)),
329            PValue::F16(u) => Some(u),
330            PValue::F32(f) => <Self as NumCast>::from(f),
331            PValue::F64(f) => <Self as NumCast>::from(f),
332            _ => None,
333        }
334        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F16))
335    }
336}
337
338macro_rules! impl_pvalue {
339    ($T:ty, $PT:tt) => {
340        impl From<$T> for PValue {
341            fn from(value: $T) -> Self {
342                PValue::$PT(value)
343            }
344        }
345    };
346}
347
348impl_pvalue!(u8, U8);
349impl_pvalue!(u16, U16);
350impl_pvalue!(u32, U32);
351impl_pvalue!(u64, U64);
352impl_pvalue!(i8, I8);
353impl_pvalue!(i16, I16);
354impl_pvalue!(i32, I32);
355impl_pvalue!(i64, I64);
356impl_pvalue!(f16, F16);
357impl_pvalue!(f32, F32);
358impl_pvalue!(f64, F64);
359
360impl From<usize> for PValue {
361    #[inline]
362    fn from(value: usize) -> PValue {
363        PValue::U64(value as u64)
364    }
365}
366
367impl Display for PValue {
368    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369        match self {
370            Self::U8(v) => write!(f, "{v}u8"),
371            Self::U16(v) => write!(f, "{v}u16"),
372            Self::U32(v) => write!(f, "{v}u32"),
373            Self::U64(v) => write!(f, "{v}u64"),
374            Self::I8(v) => write!(f, "{v}i8"),
375            Self::I16(v) => write!(f, "{v}i16"),
376            Self::I32(v) => write!(f, "{v}i32"),
377            Self::I64(v) => write!(f, "{v}i64"),
378            Self::F16(v) => write!(f, "{v}f16"),
379            Self::F32(v) => write!(f, "{v}f32"),
380            Self::F64(v) => write!(f, "{v}f64"),
381        }
382    }
383}
384
385#[cfg(test)]
386#[allow(clippy::disallowed_types)]
387mod test {
388    use std::cmp::Ordering;
389    use std::collections::HashSet;
390
391    use vortex_dtype::PType;
392    use vortex_dtype::half::f16;
393
394    use crate::PValue;
395
396    #[test]
397    pub fn test_is_instance_of() {
398        assert!(PValue::U8(10).is_instance_of(&PType::U8));
399        assert!(!PValue::U8(10).is_instance_of(&PType::U16));
400        assert!(!PValue::U8(10).is_instance_of(&PType::I8));
401        assert!(!PValue::U8(10).is_instance_of(&PType::F16));
402
403        assert!(PValue::I8(10).is_instance_of(&PType::I8));
404        assert!(!PValue::I8(10).is_instance_of(&PType::I16));
405        assert!(!PValue::I8(10).is_instance_of(&PType::U8));
406        assert!(!PValue::I8(10).is_instance_of(&PType::F16));
407
408        assert!(PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F16));
409        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F32));
410        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::U16));
411        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::I16));
412    }
413
414    #[test]
415    fn test_compare_different_types() {
416        assert_eq!(
417            PValue::I8(4).partial_cmp(&PValue::I8(5)),
418            Some(Ordering::Less)
419        );
420        assert_eq!(
421            PValue::I8(4).partial_cmp(&PValue::I64(5)),
422            Some(Ordering::Less)
423        );
424    }
425
426    #[test]
427    fn test_hash() {
428        let set = HashSet::from([
429            PValue::U8(1),
430            PValue::U16(1),
431            PValue::U32(1),
432            PValue::U64(1),
433            PValue::I8(1),
434            PValue::I16(1),
435            PValue::I32(1),
436            PValue::I64(1),
437            PValue::I8(-1),
438            PValue::I16(-1),
439            PValue::I32(-1),
440            PValue::I64(-1),
441        ]);
442        assert_eq!(set.len(), 2);
443    }
444}