vortex_scalar/
pvalue.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use core::fmt::Display;
5use std::cmp::Ordering;
6use std::hash::{Hash, Hasher};
7
8use num_traits::NumCast;
9use paste::paste;
10use vortex_dtype::half::f16;
11use vortex_dtype::{NativePType, PType, ToBytes};
12use vortex_error::{VortexError, VortexExpect, vortex_err};
13
14/// A primitive value that can represent any primitive type supported by Vortex.
15///
16/// `PValue` is used to store primitive scalar values in a type-erased manner,
17/// supporting all primitive types (integers, floats) with various bit widths.
18#[derive(Debug, Clone, Copy)]
19pub enum PValue {
20    /// Unsigned 8-bit integer.
21    U8(u8),
22    /// Unsigned 16-bit integer.
23    U16(u16),
24    /// Unsigned 32-bit integer.
25    U32(u32),
26    /// Unsigned 64-bit integer.
27    U64(u64),
28    /// Signed 8-bit integer.
29    I8(i8),
30    /// Signed 16-bit integer.
31    I16(i16),
32    /// Signed 32-bit integer.
33    I32(i32),
34    /// Signed 64-bit integer.
35    I64(i64),
36    /// 16-bit floating point.
37    F16(f16),
38    /// 32-bit floating point.
39    F32(f32),
40    /// 64-bit floating point.
41    F64(f64),
42}
43
44impl PartialEq for PValue {
45    fn eq(&self, other: &Self) -> bool {
46        match (self, other) {
47            (Self::U8(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
48            (Self::U16(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
49            (Self::U32(s), o) => o.as_u64().vortex_expect("upcast") == *s as u64,
50            (Self::U64(s), o) => o.as_u64().vortex_expect("upcast") == *s,
51            (Self::I8(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
52            (Self::I16(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
53            (Self::I32(s), o) => o.as_i64().vortex_expect("upcast") == *s as i64,
54            (Self::I64(s), o) => o.as_i64().vortex_expect("upcast") == *s,
55            (Self::F16(s), Self::F16(o)) => s.is_eq(*o),
56            (Self::F32(s), Self::F32(o)) => s.is_eq(*o),
57            (Self::F64(s), Self::F64(o)) => s.is_eq(*o),
58            (..) => false,
59        }
60    }
61}
62
63impl Eq for PValue {}
64
65impl PartialOrd for PValue {
66    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
67        match (self, other) {
68            (Self::U8(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
69            (Self::U16(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
70            (Self::U32(s), o) => Some((*s as u64).cmp(&o.as_u64().vortex_expect("upcast"))),
71            (Self::U64(s), o) => Some((*s).cmp(&o.as_u64().vortex_expect("upcast"))),
72            (Self::I8(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
73            (Self::I16(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
74            (Self::I32(s), o) => Some((*s as i64).cmp(&o.as_i64().vortex_expect("upcast"))),
75            (Self::I64(s), o) => Some((*s).cmp(&o.as_i64().vortex_expect("upcast"))),
76            (Self::F16(s), Self::F16(o)) => Some(s.total_compare(*o)),
77            (Self::F32(s), Self::F32(o)) => Some(s.total_compare(*o)),
78            (Self::F64(s), Self::F64(o)) => Some(s.total_compare(*o)),
79            (..) => None,
80        }
81    }
82}
83
84impl Hash for PValue {
85    fn hash<H: Hasher>(&self, state: &mut H) {
86        match self {
87            PValue::U8(_) | PValue::U16(_) | PValue::U32(_) | PValue::U64(_) => {
88                self.as_u64().vortex_expect("upcast").hash(state)
89            }
90            PValue::I8(_) | PValue::I16(_) | PValue::I32(_) | PValue::I64(_) => {
91                self.as_i64().vortex_expect("upcast").hash(state)
92            }
93            PValue::F16(v) => v.to_le_bytes().hash(state),
94            PValue::F32(v) => v.to_le_bytes().hash(state),
95            PValue::F64(v) => v.to_le_bytes().hash(state),
96        }
97    }
98}
99
100impl ToBytes for PValue {
101    fn to_le_bytes(&self) -> &[u8] {
102        match self {
103            PValue::U8(v) => v.to_le_bytes(),
104            PValue::U16(v) => v.to_le_bytes(),
105            PValue::U32(v) => v.to_le_bytes(),
106            PValue::U64(v) => v.to_le_bytes(),
107            PValue::I8(v) => v.to_le_bytes(),
108            PValue::I16(v) => v.to_le_bytes(),
109            PValue::I32(v) => v.to_le_bytes(),
110            PValue::I64(v) => v.to_le_bytes(),
111            PValue::F16(v) => v.to_le_bytes(),
112            PValue::F32(v) => v.to_le_bytes(),
113            PValue::F64(v) => v.to_le_bytes(),
114        }
115    }
116}
117
118macro_rules! as_primitive {
119    ($T:ty, $PT:tt) => {
120        paste! {
121            #[doc = "Access PValue as `" $T "`, returning `None` if conversion is unsuccessful"]
122            pub fn [<as_ $T>](self) -> Option<$T> {
123                match self {
124                    PValue::U8(v) => <$T as NumCast>::from(v),
125                    PValue::U16(v) => <$T as NumCast>::from(v),
126                    PValue::U32(v) => <$T as NumCast>::from(v),
127                    PValue::U64(v) => <$T as NumCast>::from(v),
128                    PValue::I8(v) => <$T as NumCast>::from(v),
129                    PValue::I16(v) => <$T as NumCast>::from(v),
130                    PValue::I32(v) => <$T as NumCast>::from(v),
131                    PValue::I64(v) => <$T as NumCast>::from(v),
132                    PValue::F16(v) => <$T as NumCast>::from(v),
133                    PValue::F32(v) => <$T as NumCast>::from(v),
134                    PValue::F64(v) => <$T as NumCast>::from(v),
135                }
136            }
137        }
138    };
139}
140
141impl PValue {
142    /// Creates a zero value for the given primitive type.
143    pub fn zero(ptype: PType) -> PValue {
144        match ptype {
145            PType::U8 => PValue::U8(0),
146            PType::U16 => PValue::U16(0),
147            PType::U32 => PValue::U32(0),
148            PType::U64 => PValue::U64(0),
149            PType::I8 => PValue::I8(0),
150            PType::I16 => PValue::I16(0),
151            PType::I32 => PValue::I32(0),
152            PType::I64 => PValue::I64(0),
153            PType::F16 => PValue::F16(f16::from_f32(0.0)),
154            PType::F32 => PValue::F32(0.0),
155            PType::F64 => PValue::F64(0.0),
156        }
157    }
158
159    /// Returns the primitive type of this value.
160    pub fn ptype(&self) -> PType {
161        match self {
162            Self::U8(_) => PType::U8,
163            Self::U16(_) => PType::U16,
164            Self::U32(_) => PType::U32,
165            Self::U64(_) => PType::U64,
166            Self::I8(_) => PType::I8,
167            Self::I16(_) => PType::I16,
168            Self::I32(_) => PType::I32,
169            Self::I64(_) => PType::I64,
170            Self::F16(_) => PType::F16,
171            Self::F32(_) => PType::F32,
172            Self::F64(_) => PType::F64,
173        }
174    }
175
176    /// Returns true if this value is of the given primitive type.
177    pub fn is_instance_of(&self, ptype: &PType) -> bool {
178        &self.ptype() == ptype
179    }
180
181    /// Converts this value to a specific native primitive type.
182    ///
183    /// Returns an error if the conversion is not supported or would overflow.
184    #[inline]
185    pub fn as_primitive<T: NativePType + TryFrom<PValue, Error = VortexError>>(
186        &self,
187    ) -> Result<T, VortexError> {
188        T::try_from(*self)
189    }
190
191    /// Reinterprets the bits of this value as a different primitive type.
192    ///
193    /// This performs a bitwise cast between types of the same width.
194    ///
195    /// # Panics
196    ///
197    /// Panics if the target type has a different byte width than this value.
198    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
199        if ptype == self.ptype() {
200            return *self;
201        }
202
203        assert_eq!(
204            ptype.byte_width(),
205            self.ptype().byte_width(),
206            "Cannot reinterpret cast between types of different widths"
207        );
208
209        match self {
210            PValue::U8(v) => u8::cast_signed(*v).into(),
211            PValue::U16(v) => match ptype {
212                PType::I16 => u16::cast_signed(*v).into(),
213                PType::F16 => f16::from_bits(*v).into(),
214                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
215            },
216            PValue::U32(v) => match ptype {
217                PType::I32 => u32::cast_signed(*v).into(),
218                PType::F32 => f32::from_bits(*v).into(),
219                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
220            },
221            PValue::U64(v) => match ptype {
222                PType::I64 => u64::cast_signed(*v).into(),
223                PType::F64 => f64::from_bits(*v).into(),
224                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
225            },
226            PValue::I8(v) => i8::cast_unsigned(*v).into(),
227            PValue::I16(v) => match ptype {
228                PType::U16 => i16::cast_unsigned(*v).into(),
229                PType::F16 => f16::from_bits(v.cast_unsigned()).into(),
230                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
231            },
232            PValue::I32(v) => match ptype {
233                PType::U32 => i32::cast_unsigned(*v).into(),
234                PType::F32 => f32::from_bits(i32::cast_unsigned(*v)).into(),
235                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
236            },
237            PValue::I64(v) => match ptype {
238                PType::U64 => i64::cast_unsigned(*v).into(),
239                PType::F64 => f64::from_bits(i64::cast_unsigned(*v)).into(),
240                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
241            },
242            PValue::F16(v) => match ptype {
243                PType::U16 => v.to_bits().into(),
244                PType::I16 => v.to_bits().cast_signed().into(),
245                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
246            },
247            PValue::F32(v) => match ptype {
248                PType::U32 => f32::to_bits(*v).into(),
249                PType::I32 => f32::to_bits(*v).cast_signed().into(),
250                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
251            },
252            PValue::F64(v) => match ptype {
253                PType::U64 => f64::to_bits(*v).into(),
254                PType::I64 => f64::to_bits(*v).cast_signed().into(),
255                _ => unreachable!("Only same width type are allowed to be reinterpreted"),
256            },
257        }
258    }
259
260    as_primitive!(i8, I8);
261    as_primitive!(i16, I16);
262    as_primitive!(i32, I32);
263    as_primitive!(i64, I64);
264    as_primitive!(u8, U8);
265    as_primitive!(u16, U16);
266    as_primitive!(u32, U32);
267    as_primitive!(u64, U64);
268    as_primitive!(f16, F16);
269    as_primitive!(f32, F32);
270    as_primitive!(f64, F64);
271}
272
273macro_rules! int_pvalue {
274    ($T:ty, $PT:tt) => {
275        impl TryFrom<PValue> for $T {
276            type Error = VortexError;
277
278            fn try_from(value: PValue) -> Result<Self, Self::Error> {
279                match value {
280                    PValue::U8(v) => <$T as NumCast>::from(v),
281                    PValue::U16(v) => <$T as NumCast>::from(v),
282                    PValue::U32(v) => <$T as NumCast>::from(v),
283                    PValue::U64(v) => <$T as NumCast>::from(v),
284                    PValue::I8(v) => <$T as NumCast>::from(v),
285                    PValue::I16(v) => <$T as NumCast>::from(v),
286                    PValue::I32(v) => <$T as NumCast>::from(v),
287                    PValue::I64(v) => <$T as NumCast>::from(v),
288                    _ => None,
289                }
290                .ok_or_else(|| {
291                    vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT)
292                })
293            }
294        }
295    };
296}
297
298int_pvalue!(u8, U8);
299int_pvalue!(u16, U16);
300int_pvalue!(u32, U32);
301int_pvalue!(u64, U64);
302int_pvalue!(usize, U64);
303int_pvalue!(i8, I8);
304int_pvalue!(i16, I16);
305int_pvalue!(i32, I32);
306int_pvalue!(i64, I64);
307
308impl TryFrom<PValue> for f64 {
309    type Error = VortexError;
310
311    fn try_from(value: PValue) -> Result<Self, Self::Error> {
312        // We serialize f64 as u64, but this can also sometimes be narrowed down to u8 if e.g. == 0
313        match value {
314            PValue::U8(u) => Some(Self::from_bits(u as u64)),
315            PValue::U16(u) => Some(Self::from_bits(u as u64)),
316            PValue::U32(u) => Some(Self::from_bits(u as u64)),
317            PValue::U64(u) => Some(Self::from_bits(u)),
318            PValue::F16(f) => <Self as NumCast>::from(f),
319            PValue::F32(f) => <Self as NumCast>::from(f),
320            PValue::F64(f) => <Self as NumCast>::from(f),
321            _ => None,
322        }
323        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F64))
324    }
325}
326
327impl TryFrom<PValue> for f32 {
328    type Error = VortexError;
329
330    #[allow(clippy::cast_possible_truncation)]
331    fn try_from(value: PValue) -> Result<Self, Self::Error> {
332        // We serialize f32 as u32, but this can also sometimes be narrowed down to u8 if e.g. == 0
333        match value {
334            PValue::U8(u) => Some(Self::from_bits(u as u32)),
335            PValue::U16(u) => Some(Self::from_bits(u as u32)),
336            PValue::U32(u) => Some(Self::from_bits(u)),
337            // We assume that the value was created from a valid f16 and only changed in serialization
338            PValue::U64(u) => <Self as NumCast>::from(Self::from_bits(u as u32)),
339            PValue::F16(f) => <Self as NumCast>::from(f),
340            PValue::F32(f) => <Self as NumCast>::from(f),
341            PValue::F64(f) => <Self as NumCast>::from(f),
342            _ => None,
343        }
344        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F32))
345    }
346}
347
348impl TryFrom<PValue> for f16 {
349    type Error = VortexError;
350
351    #[allow(clippy::cast_possible_truncation)]
352    fn try_from(value: PValue) -> Result<Self, Self::Error> {
353        // We serialize f16 as u16, but this can also sometimes be narrowed down to u8 if e.g. == 0
354        match value {
355            PValue::U8(u) => Some(Self::from_bits(u as u16)),
356            PValue::U16(u) => Some(Self::from_bits(u)),
357            // We assume that the value was created from a valid f16 and only changed in serialization
358            PValue::U32(u) => Some(Self::from_bits(u as u16)),
359            PValue::U64(u) => Some(Self::from_bits(u as u16)),
360            PValue::F16(u) => Some(u),
361            PValue::F32(f) => <Self as NumCast>::from(f),
362            PValue::F64(f) => <Self as NumCast>::from(f),
363            _ => None,
364        }
365        .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F16))
366    }
367}
368
369macro_rules! impl_pvalue {
370    ($T:ty, $PT:tt) => {
371        impl From<$T> for PValue {
372            fn from(value: $T) -> Self {
373                PValue::$PT(value)
374            }
375        }
376    };
377}
378
379impl_pvalue!(u8, U8);
380impl_pvalue!(u16, U16);
381impl_pvalue!(u32, U32);
382impl_pvalue!(u64, U64);
383impl_pvalue!(i8, I8);
384impl_pvalue!(i16, I16);
385impl_pvalue!(i32, I32);
386impl_pvalue!(i64, I64);
387impl_pvalue!(f16, F16);
388impl_pvalue!(f32, F32);
389impl_pvalue!(f64, F64);
390
391impl From<usize> for PValue {
392    #[inline]
393    fn from(value: usize) -> PValue {
394        PValue::U64(value as u64)
395    }
396}
397
398impl Display for PValue {
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        match self {
401            Self::U8(v) => write!(f, "{v}u8"),
402            Self::U16(v) => write!(f, "{v}u16"),
403            Self::U32(v) => write!(f, "{v}u32"),
404            Self::U64(v) => write!(f, "{v}u64"),
405            Self::I8(v) => write!(f, "{v}i8"),
406            Self::I16(v) => write!(f, "{v}i16"),
407            Self::I32(v) => write!(f, "{v}i32"),
408            Self::I64(v) => write!(f, "{v}i64"),
409            Self::F16(v) => write!(f, "{v}f16"),
410            Self::F32(v) => write!(f, "{v}f32"),
411            Self::F64(v) => write!(f, "{v}f64"),
412        }
413    }
414}
415
416#[cfg(test)]
417#[allow(clippy::disallowed_types)]
418mod test {
419    use std::cmp::Ordering;
420    use std::collections::HashSet;
421
422    use vortex_dtype::PType;
423    use vortex_dtype::half::f16;
424
425    use crate::PValue;
426
427    #[test]
428    pub fn test_is_instance_of() {
429        assert!(PValue::U8(10).is_instance_of(&PType::U8));
430        assert!(!PValue::U8(10).is_instance_of(&PType::U16));
431        assert!(!PValue::U8(10).is_instance_of(&PType::I8));
432        assert!(!PValue::U8(10).is_instance_of(&PType::F16));
433
434        assert!(PValue::I8(10).is_instance_of(&PType::I8));
435        assert!(!PValue::I8(10).is_instance_of(&PType::I16));
436        assert!(!PValue::I8(10).is_instance_of(&PType::U8));
437        assert!(!PValue::I8(10).is_instance_of(&PType::F16));
438
439        assert!(PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F16));
440        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F32));
441        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::U16));
442        assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::I16));
443    }
444
445    #[test]
446    fn test_compare_different_types() {
447        assert_eq!(
448            PValue::I8(4).partial_cmp(&PValue::I8(5)),
449            Some(Ordering::Less)
450        );
451        assert_eq!(
452            PValue::I8(4).partial_cmp(&PValue::I64(5)),
453            Some(Ordering::Less)
454        );
455    }
456
457    #[test]
458    fn test_hash() {
459        let set = HashSet::from([
460            PValue::U8(1),
461            PValue::U16(1),
462            PValue::U32(1),
463            PValue::U64(1),
464            PValue::I8(1),
465            PValue::I16(1),
466            PValue::I32(1),
467            PValue::I64(1),
468            PValue::I8(-1),
469            PValue::I16(-1),
470            PValue::I32(-1),
471            PValue::I64(-1),
472        ]);
473        assert_eq!(set.len(), 2);
474    }
475}