vortex_scalar/
primitive.rs

1use std::any::type_name;
2use std::cmp::Ordering;
3use std::fmt::{Debug, Display, Formatter};
4use std::ops::{Add, Sub};
5
6use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{
10    VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
11};
12
13use crate::pvalue::PValue;
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16#[derive(Debug, Clone, Copy, Hash)]
17pub struct PrimitiveScalar<'a> {
18    dtype: &'a DType,
19    ptype: PType,
20    pvalue: Option<PValue>,
21}
22
23impl Display for PrimitiveScalar<'_> {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        match self.pvalue {
26            None => write!(f, "null"),
27            Some(pv) => write!(f, "{}", pv),
28        }
29    }
30}
31
32impl PartialEq for PrimitiveScalar<'_> {
33    fn eq(&self, other: &Self) -> bool {
34        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
35    }
36}
37
38impl Eq for PrimitiveScalar<'_> {}
39
40/// Ord is not implemented since it's undefined for different PTypes
41impl PartialOrd for PrimitiveScalar<'_> {
42    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43        if !self.dtype.eq_ignore_nullability(other.dtype) {
44            return None;
45        }
46        self.pvalue.partial_cmp(&other.pvalue)
47    }
48}
49
50impl<'a> PrimitiveScalar<'a> {
51    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
52        let ptype = PType::try_from(dtype)?;
53
54        // Read the serialized value into the correct PValue.
55        // The serialized form may come back over the wire as e.g. any integer type.
56        let pvalue = match_each_native_ptype!(ptype, |$T| {
57            if let Some(pvalue) = value.as_pvalue()? {
58                Some(PValue::from(<$T>::try_from(pvalue)?))
59            } else {
60                None
61            }
62        });
63
64        Ok(Self {
65            dtype,
66            ptype,
67            pvalue,
68        })
69    }
70
71    #[inline]
72    pub fn dtype(&self) -> &'a DType {
73        self.dtype
74    }
75
76    #[inline]
77    pub fn ptype(&self) -> PType {
78        self.ptype
79    }
80
81    #[inline]
82    pub fn pvalue(&self) -> Option<PValue> {
83        self.pvalue
84    }
85
86    pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
87        assert_eq!(
88            self.ptype,
89            T::PTYPE,
90            "Attempting to read {} scalar as {}",
91            self.ptype,
92            T::PTYPE
93        );
94
95        self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
96    }
97
98    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
99        let ptype = PType::try_from(dtype)?;
100        let pvalue = self
101            .pvalue
102            .vortex_expect("nullness handled in Scalar::cast");
103        Ok(match_each_native_ptype!(ptype, |$Q| {
104            Scalar::primitive(
105                pvalue
106                    .as_primitive::<$Q>()
107                    .map_err(|err| vortex_err!("Can't cast {} scalar {} to {} (cause: {})", self.ptype, pvalue, dtype, err))?,
108                dtype.nullability()
109            )
110        }))
111    }
112
113    /// Attempt to extract the primitive value as the given type.
114    /// Fails on a bad cast.
115    pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
116        match self.pvalue {
117            None => Ok(None),
118            Some(pv) => Ok(Some(match pv {
119                PValue::U8(v) => T::from_u8(v)
120                    .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
121                PValue::U16(v) => T::from_u16(v)
122                    .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
123                PValue::U32(v) => T::from_u32(v)
124                    .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
125                PValue::U64(v) => T::from_u64(v)
126                    .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
127                PValue::I8(v) => T::from_i8(v)
128                    .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
129                PValue::I16(v) => T::from_i16(v)
130                    .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
131                PValue::I32(v) => T::from_i32(v)
132                    .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
133                PValue::I64(v) => T::from_i64(v)
134                    .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
135                PValue::F16(v) => T::from_f16(v)
136                    .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
137                PValue::F32(v) => T::from_f32(v)
138                    .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
139                PValue::F64(v) => T::from_f64(v)
140                    .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
141            }?)),
142        }
143    }
144}
145
146pub trait FromPrimitiveOrF16: FromPrimitive {
147    fn from_f16(v: f16) -> Option<Self>;
148}
149
150macro_rules! from_primitive_or_f16_for_non_floating_point {
151    ($T:ty) => {
152        impl FromPrimitiveOrF16 for $T {
153            fn from_f16(_: f16) -> Option<Self> {
154                None
155            }
156        }
157    };
158}
159
160from_primitive_or_f16_for_non_floating_point!(usize);
161from_primitive_or_f16_for_non_floating_point!(u8);
162from_primitive_or_f16_for_non_floating_point!(u16);
163from_primitive_or_f16_for_non_floating_point!(u32);
164from_primitive_or_f16_for_non_floating_point!(u64);
165from_primitive_or_f16_for_non_floating_point!(i8);
166from_primitive_or_f16_for_non_floating_point!(i16);
167from_primitive_or_f16_for_non_floating_point!(i32);
168from_primitive_or_f16_for_non_floating_point!(i64);
169
170impl FromPrimitiveOrF16 for f16 {
171    fn from_f16(v: f16) -> Option<Self> {
172        Some(v)
173    }
174}
175
176impl FromPrimitiveOrF16 for f32 {
177    fn from_f16(v: f16) -> Option<Self> {
178        Some(v.to_f32())
179    }
180}
181
182impl FromPrimitiveOrF16 for f64 {
183    fn from_f16(v: f16) -> Option<Self> {
184        Some(v.to_f64())
185    }
186}
187
188impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
189    type Error = VortexError;
190
191    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
192        Self::try_new(value.dtype(), value.value())
193    }
194}
195
196impl Sub for PrimitiveScalar<'_> {
197    type Output = Self;
198
199    fn sub(self, rhs: Self) -> Self::Output {
200        self.checked_sub(&rhs)
201            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
202    }
203}
204
205impl CheckedSub for PrimitiveScalar<'_> {
206    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
207        self.checked_binary_numeric(rhs, BinaryNumericOperator::Sub)
208    }
209}
210
211impl Add for PrimitiveScalar<'_> {
212    type Output = Self;
213
214    fn add(self, rhs: Self) -> Self::Output {
215        self.checked_add(&rhs)
216            .vortex_expect("PrimitiveScalar add: overflow or underflow")
217    }
218}
219
220impl CheckedAdd for PrimitiveScalar<'_> {
221    fn checked_add(&self, rhs: &Self) -> Option<Self> {
222        self.checked_binary_numeric(rhs, BinaryNumericOperator::Add)
223    }
224}
225
226impl Scalar {
227    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
228        Self::primitive_value(value.into(), T::PTYPE, nullability)
229    }
230
231    /// Create a PrimitiveScalar from a PValue.
232    ///
233    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
234    /// for a primitive type.
235    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
236        Self {
237            dtype: DType::Primitive(ptype, nullability),
238            value: ScalarValue(InnerScalarValue::Primitive(value)),
239        }
240    }
241
242    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
243        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
244            vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
245        });
246        if primitive.ptype() == ptype {
247            return self.clone();
248        }
249
250        assert_eq!(
251            primitive.ptype().byte_width(),
252            ptype.byte_width(),
253            "can't reinterpret cast between integers of two different widths"
254        );
255
256        Scalar::new(
257            DType::Primitive(ptype, self.dtype.nullability()),
258            primitive
259                .pvalue
260                .map(|p| p.reinterpret_cast(ptype))
261                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
262                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
263        )
264    }
265}
266
267macro_rules! primitive_scalar {
268    ($T:ty) => {
269        impl TryFrom<&Scalar> for $T {
270            type Error = VortexError;
271
272            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
273                <Option<$T>>::try_from(value)?
274                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
275            }
276        }
277
278        impl TryFrom<Scalar> for $T {
279            type Error = VortexError;
280
281            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
282                <$T>::try_from(&value)
283            }
284        }
285
286        impl TryFrom<&Scalar> for Option<$T> {
287            type Error = VortexError;
288
289            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
290                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
291            }
292        }
293
294        impl TryFrom<Scalar> for Option<$T> {
295            type Error = VortexError;
296
297            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
298                <Option<$T>>::try_from(&value)
299            }
300        }
301
302        impl From<$T> for Scalar {
303            fn from(value: $T) -> Self {
304                Scalar {
305                    dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
306                    value: ScalarValue(InnerScalarValue::Primitive(value.into())),
307                }
308            }
309        }
310    };
311}
312
313primitive_scalar!(u8);
314primitive_scalar!(u16);
315primitive_scalar!(u32);
316primitive_scalar!(u64);
317primitive_scalar!(i8);
318primitive_scalar!(i16);
319primitive_scalar!(i32);
320primitive_scalar!(i64);
321primitive_scalar!(f16);
322primitive_scalar!(f32);
323primitive_scalar!(f64);
324
325/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
326impl TryFrom<&Scalar> for usize {
327    type Error = VortexError;
328
329    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
330        let prim = PrimitiveScalar::try_from(value)?
331            .as_::<u64>()?
332            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
333        Ok(usize::try_from(prim)?)
334    }
335}
336
337/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
338impl From<usize> for Scalar {
339    fn from(value: usize) -> Self {
340        Scalar::primitive(value as u64, Nullability::NonNullable)
341    }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq)]
345/// Binary element-wise operations on two arrays or two scalars.
346pub enum BinaryNumericOperator {
347    /// Binary element-wise addition of two arrays or of two scalars.
348    Add,
349    /// Binary element-wise subtraction of two arrays or of two scalars.
350    Sub,
351    /// Same as [BinaryNumericOperator::Sub] but with the parameters flipped: `right - left`.
352    RSub,
353    /// Binary element-wise multiplication of two arrays or of two scalars.
354    Mul,
355    /// Binary element-wise division of two arrays or of two scalars.
356    Div,
357    /// Same as [BinaryNumericOperator::Div] but with the parameters flipped: `right - left`.
358    RDiv,
359    // Missing from arrow-rs:
360    // Min,
361    // Max,
362    // Pow,
363}
364
365impl Display for BinaryNumericOperator {
366    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
367        Debug::fmt(self, f)
368    }
369}
370
371impl BinaryNumericOperator {
372    pub fn swap(self) -> Self {
373        match self {
374            BinaryNumericOperator::Add => BinaryNumericOperator::Add,
375            BinaryNumericOperator::Sub => BinaryNumericOperator::RSub,
376            BinaryNumericOperator::RSub => BinaryNumericOperator::Sub,
377            BinaryNumericOperator::Mul => BinaryNumericOperator::Mul,
378            BinaryNumericOperator::Div => BinaryNumericOperator::RDiv,
379            BinaryNumericOperator::RDiv => BinaryNumericOperator::Div,
380        }
381    }
382}
383
384impl<'a> PrimitiveScalar<'a> {
385    /// Apply the (checked) operator to self and other using SQL-style null semantics.
386    ///
387    /// If the operation overflows, Ok(None) is returned.
388    ///
389    /// If the types are incompatible (ignoring nullability), an error is returned.
390    ///
391    /// If either value is null, the result is null.
392    pub fn checked_binary_numeric(
393        &self,
394        other: &PrimitiveScalar<'a>,
395        op: BinaryNumericOperator,
396    ) -> Option<PrimitiveScalar<'a>> {
397        if !self.dtype().eq_ignore_nullability(other.dtype()) {
398            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
399        }
400        let result_dtype = if self.dtype().is_nullable() {
401            self.dtype()
402        } else {
403            other.dtype()
404        };
405        let ptype = self.ptype();
406
407        match_each_native_ptype!(
408            self.ptype(),
409            integral: |$P| {
410                self.checked_integeral_numeric_operator::<$P>(other, result_dtype, ptype, op)
411            }
412            floating_point: |$P| {
413                let lhs = self.typed_value::<$P>();
414                let rhs = other.typed_value::<$P>();
415                let value_or_null = match (lhs, rhs) {
416                    (_, None) | (None, _) => None,
417                    (Some(lhs), Some(rhs)) => match op {
418                        BinaryNumericOperator::Add => Some(lhs + rhs),
419                        BinaryNumericOperator::Sub => Some(lhs - rhs),
420                        BinaryNumericOperator::RSub => Some(rhs - lhs),
421                        BinaryNumericOperator::Mul => Some(lhs * rhs),
422                        BinaryNumericOperator::Div => Some(lhs / rhs),
423                        BinaryNumericOperator::RDiv => Some(rhs / lhs),
424                    }
425                };
426                Some(Self { dtype: result_dtype, ptype: ptype, pvalue: value_or_null.map(PValue::from) })
427            }
428        )
429    }
430
431    fn checked_integeral_numeric_operator<
432        P: NativePType
433            + TryFrom<PValue, Error = VortexError>
434            + CheckedSub
435            + CheckedAdd
436            + CheckedMul
437            + CheckedDiv,
438    >(
439        &self,
440        other: &PrimitiveScalar<'a>,
441        result_dtype: &'a DType,
442        ptype: PType,
443        op: BinaryNumericOperator,
444    ) -> Option<PrimitiveScalar<'a>>
445    where
446        PValue: From<P>,
447    {
448        let lhs = self.typed_value::<P>();
449        let rhs = other.typed_value::<P>();
450        let value_or_null_or_overflow = match (lhs, rhs) {
451            (_, None) | (None, _) => Some(None),
452            (Some(lhs), Some(rhs)) => match op {
453                BinaryNumericOperator::Add => lhs.checked_add(&rhs).map(Some),
454                BinaryNumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
455                BinaryNumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
456                BinaryNumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
457                BinaryNumericOperator::Div => lhs.checked_div(&rhs).map(Some),
458                BinaryNumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
459            },
460        };
461
462        value_or_null_or_overflow.map(|value_or_null| Self {
463            dtype: result_dtype,
464            ptype,
465            pvalue: value_or_null.map(PValue::from),
466        })
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use num_traits::CheckedSub;
473    use vortex_dtype::{DType, Nullability, PType};
474
475    use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
476
477    #[test]
478    fn test_integer_subtract() {
479        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
480        let p_scalar1 = PrimitiveScalar::try_new(
481            &dtype,
482            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
483        )
484        .unwrap();
485        let p_scalar2 = PrimitiveScalar::try_new(
486            &dtype,
487            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
488        )
489        .unwrap();
490        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
491        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
492        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
493
494        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
495    }
496
497    #[test]
498    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
499    #[allow(clippy::assertions_on_constants)]
500    fn test_integer_subtract_overflow() {
501        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
502        let p_scalar1 = PrimitiveScalar::try_new(
503            &dtype,
504            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
505        )
506        .unwrap();
507        let p_scalar2 = PrimitiveScalar::try_new(
508            &dtype,
509            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
510        )
511        .unwrap();
512        let _ = p_scalar1 - p_scalar2;
513    }
514
515    #[test]
516    fn test_float_subtract() {
517        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
518        let p_scalar1 = PrimitiveScalar::try_new(
519            &dtype,
520            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
521        )
522        .unwrap();
523        let p_scalar2 = PrimitiveScalar::try_new(
524            &dtype,
525            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
526        )
527        .unwrap();
528        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
529        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
530        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
531
532        assert_eq!(
533            (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
534            0.99f32
535        );
536    }
537}