vortex_scalar/
primitive.rs

1use std::any::type_name;
2use std::cmp::Ordering;
3use std::fmt::{Debug, Display};
4use std::ops::{Add, Sub};
5
6use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{
10    VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
11};
12
13use crate::pvalue::PValue;
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16#[derive(Debug, Clone, Copy, Hash)]
17pub struct PrimitiveScalar<'a> {
18    dtype: &'a DType,
19    ptype: PType,
20    pvalue: Option<PValue>,
21}
22
23impl PartialEq for PrimitiveScalar<'_> {
24    fn eq(&self, other: &Self) -> bool {
25        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
26    }
27}
28
29impl Eq for PrimitiveScalar<'_> {}
30
31/// Ord is not implemented since it's undefined for different PTypes
32impl PartialOrd for PrimitiveScalar<'_> {
33    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
34        if !self.dtype.eq_ignore_nullability(other.dtype) {
35            return None;
36        }
37        self.pvalue.partial_cmp(&other.pvalue)
38    }
39}
40
41impl<'a> PrimitiveScalar<'a> {
42    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
43        let ptype = PType::try_from(dtype)?;
44
45        // Read the serialized value into the correct PValue.
46        // The serialized form may come back over the wire as e.g. any integer type.
47        let pvalue = match_each_native_ptype!(ptype, |$T| {
48            if let Some(pvalue) = value.as_pvalue()? {
49                Some(PValue::from(<$T>::try_from(pvalue)?))
50            } else {
51                None
52            }
53        });
54
55        Ok(Self {
56            dtype,
57            ptype,
58            pvalue,
59        })
60    }
61
62    #[inline]
63    pub fn dtype(&self) -> &'a DType {
64        self.dtype
65    }
66
67    #[inline]
68    pub fn ptype(&self) -> PType {
69        self.ptype
70    }
71
72    #[inline]
73    pub fn pvalue(&self) -> Option<PValue> {
74        self.pvalue
75    }
76
77    pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
78        assert_eq!(
79            self.ptype,
80            T::PTYPE,
81            "Attempting to read {} scalar as {}",
82            self.ptype,
83            T::PTYPE
84        );
85
86        self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
87    }
88
89    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
90        let ptype = PType::try_from(dtype)?;
91        let pvalue = self
92            .pvalue
93            .vortex_expect("nullness handled in Scalar::cast");
94        Ok(match_each_native_ptype!(ptype, |$Q| {
95            Scalar::primitive(
96                pvalue
97                    .as_primitive::<$Q>()
98                    .map_err(|err| vortex_err!("Can't cast {} scalar {} to {} (cause: {})", self.ptype, pvalue, dtype, err))?,
99                dtype.nullability()
100            )
101        }))
102    }
103
104    /// Attempt to extract the primitive value as the given type.
105    /// Fails on a bad cast.
106    pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
107        match self.pvalue {
108            None => Ok(None),
109            Some(pv) => Ok(Some(match pv {
110                PValue::U8(v) => T::from_u8(v)
111                    .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
112                PValue::U16(v) => T::from_u16(v)
113                    .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
114                PValue::U32(v) => T::from_u32(v)
115                    .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
116                PValue::U64(v) => T::from_u64(v)
117                    .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
118                PValue::I8(v) => T::from_i8(v)
119                    .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
120                PValue::I16(v) => T::from_i16(v)
121                    .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
122                PValue::I32(v) => T::from_i32(v)
123                    .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
124                PValue::I64(v) => T::from_i64(v)
125                    .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
126                PValue::F16(v) => T::from_f16(v)
127                    .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
128                PValue::F32(v) => T::from_f32(v)
129                    .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
130                PValue::F64(v) => T::from_f64(v)
131                    .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
132            }?)),
133        }
134    }
135}
136
137pub trait FromPrimitiveOrF16: FromPrimitive {
138    fn from_f16(v: f16) -> Option<Self>;
139}
140
141macro_rules! from_primitive_or_f16_for_non_floating_point {
142    ($T:ty) => {
143        impl FromPrimitiveOrF16 for $T {
144            fn from_f16(_: f16) -> Option<Self> {
145                None
146            }
147        }
148    };
149}
150
151from_primitive_or_f16_for_non_floating_point!(usize);
152from_primitive_or_f16_for_non_floating_point!(u8);
153from_primitive_or_f16_for_non_floating_point!(u16);
154from_primitive_or_f16_for_non_floating_point!(u32);
155from_primitive_or_f16_for_non_floating_point!(u64);
156from_primitive_or_f16_for_non_floating_point!(i8);
157from_primitive_or_f16_for_non_floating_point!(i16);
158from_primitive_or_f16_for_non_floating_point!(i32);
159from_primitive_or_f16_for_non_floating_point!(i64);
160
161impl FromPrimitiveOrF16 for f16 {
162    fn from_f16(v: f16) -> Option<Self> {
163        Some(v)
164    }
165}
166
167impl FromPrimitiveOrF16 for f32 {
168    fn from_f16(v: f16) -> Option<Self> {
169        Some(v.to_f32())
170    }
171}
172
173impl FromPrimitiveOrF16 for f64 {
174    fn from_f16(v: f16) -> Option<Self> {
175        Some(v.to_f64())
176    }
177}
178
179impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
180    type Error = VortexError;
181
182    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
183        Self::try_new(value.dtype(), value.value())
184    }
185}
186
187impl Sub for PrimitiveScalar<'_> {
188    type Output = Self;
189
190    fn sub(self, rhs: Self) -> Self::Output {
191        self.checked_sub(&rhs)
192            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
193    }
194}
195
196impl CheckedSub for PrimitiveScalar<'_> {
197    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
198        self.checked_binary_numeric(rhs, BinaryNumericOperator::Sub)
199    }
200}
201
202impl Add for PrimitiveScalar<'_> {
203    type Output = Self;
204
205    fn add(self, rhs: Self) -> Self::Output {
206        self.checked_add(&rhs)
207            .vortex_expect("PrimitiveScalar add: overflow or underflow")
208    }
209}
210
211impl CheckedAdd for PrimitiveScalar<'_> {
212    fn checked_add(&self, rhs: &Self) -> Option<Self> {
213        self.checked_binary_numeric(rhs, BinaryNumericOperator::Add)
214    }
215}
216
217impl Scalar {
218    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
219        Self::primitive_value(value.into(), T::PTYPE, nullability)
220    }
221
222    /// Create a PrimitiveScalar from a PValue.
223    ///
224    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
225    /// for a primitive type.
226    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
227        Self {
228            dtype: DType::Primitive(ptype, nullability),
229            value: ScalarValue(InnerScalarValue::Primitive(value)),
230        }
231    }
232
233    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
234        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
235            vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
236        });
237        if primitive.ptype() == ptype {
238            return self.clone();
239        }
240
241        assert_eq!(
242            primitive.ptype().byte_width(),
243            ptype.byte_width(),
244            "can't reinterpret cast between integers of two different widths"
245        );
246
247        Scalar::new(
248            DType::Primitive(ptype, self.dtype.nullability()),
249            primitive
250                .pvalue
251                .map(|p| p.reinterpret_cast(ptype))
252                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
253                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
254        )
255    }
256}
257
258macro_rules! primitive_scalar {
259    ($T:ty) => {
260        impl TryFrom<&Scalar> for $T {
261            type Error = VortexError;
262
263            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
264                <Option<$T>>::try_from(value)?
265                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
266            }
267        }
268
269        impl TryFrom<Scalar> for $T {
270            type Error = VortexError;
271
272            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
273                <$T>::try_from(&value)
274            }
275        }
276
277        impl TryFrom<&Scalar> for Option<$T> {
278            type Error = VortexError;
279
280            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
281                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
282            }
283        }
284
285        impl TryFrom<Scalar> for Option<$T> {
286            type Error = VortexError;
287
288            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
289                <Option<$T>>::try_from(&value)
290            }
291        }
292
293        impl From<$T> for Scalar {
294            fn from(value: $T) -> Self {
295                Scalar {
296                    dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
297                    value: ScalarValue(InnerScalarValue::Primitive(value.into())),
298                }
299            }
300        }
301    };
302}
303
304primitive_scalar!(u8);
305primitive_scalar!(u16);
306primitive_scalar!(u32);
307primitive_scalar!(u64);
308primitive_scalar!(i8);
309primitive_scalar!(i16);
310primitive_scalar!(i32);
311primitive_scalar!(i64);
312primitive_scalar!(f16);
313primitive_scalar!(f32);
314primitive_scalar!(f64);
315
316/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
317impl TryFrom<&Scalar> for usize {
318    type Error = VortexError;
319
320    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
321        let prim = PrimitiveScalar::try_from(value)?
322            .as_::<u64>()?
323            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
324        Ok(usize::try_from(prim)?)
325    }
326}
327
328/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
329impl From<usize> for Scalar {
330    fn from(value: usize) -> Self {
331        Scalar::primitive(value as u64, Nullability::NonNullable)
332    }
333}
334
335#[derive(Debug, Clone, Copy, PartialEq, Eq)]
336/// Binary element-wise operations on two arrays or two scalars.
337pub enum BinaryNumericOperator {
338    /// Binary element-wise addition of two arrays or of two scalars.
339    Add,
340    /// Binary element-wise subtraction of two arrays or of two scalars.
341    Sub,
342    /// Same as [BinaryNumericOperator::Sub] but with the parameters flipped: `right - left`.
343    RSub,
344    /// Binary element-wise multiplication of two arrays or of two scalars.
345    Mul,
346    /// Binary element-wise division of two arrays or of two scalars.
347    Div,
348    /// Same as [BinaryNumericOperator::Div] but with the parameters flipped: `right - left`.
349    RDiv,
350    // Missing from arrow-rs:
351    // Min,
352    // Max,
353    // Pow,
354}
355
356impl Display for BinaryNumericOperator {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        Debug::fmt(self, f)
359    }
360}
361
362impl BinaryNumericOperator {
363    pub fn swap(self) -> Self {
364        match self {
365            BinaryNumericOperator::Add => BinaryNumericOperator::Add,
366            BinaryNumericOperator::Sub => BinaryNumericOperator::RSub,
367            BinaryNumericOperator::RSub => BinaryNumericOperator::Sub,
368            BinaryNumericOperator::Mul => BinaryNumericOperator::Mul,
369            BinaryNumericOperator::Div => BinaryNumericOperator::RDiv,
370            BinaryNumericOperator::RDiv => BinaryNumericOperator::Div,
371        }
372    }
373}
374
375impl<'a> PrimitiveScalar<'a> {
376    /// Apply the (checked) operator to self and other using SQL-style null semantics.
377    ///
378    /// If the operation overflows, Ok(None) is returned.
379    ///
380    /// If the types are incompatible (ignoring nullability), an error is returned.
381    ///
382    /// If either value is null, the result is null.
383    pub fn checked_binary_numeric(
384        &self,
385        other: &PrimitiveScalar<'a>,
386        op: BinaryNumericOperator,
387    ) -> Option<PrimitiveScalar<'a>> {
388        if !self.dtype().eq_ignore_nullability(other.dtype()) {
389            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
390        }
391        let result_dtype = if self.dtype().is_nullable() {
392            self.dtype()
393        } else {
394            other.dtype()
395        };
396        let ptype = self.ptype();
397
398        match_each_native_ptype!(
399            self.ptype(),
400            integral: |$P| {
401                self.checked_integeral_numeric_operator::<$P>(other, result_dtype, ptype, op)
402            }
403            floating_point: |$P| {
404                let lhs = self.typed_value::<$P>();
405                let rhs = other.typed_value::<$P>();
406                let value_or_null = match (lhs, rhs) {
407                    (_, None) | (None, _) => None,
408                    (Some(lhs), Some(rhs)) => match op {
409                        BinaryNumericOperator::Add => Some(lhs + rhs),
410                        BinaryNumericOperator::Sub => Some(lhs - rhs),
411                        BinaryNumericOperator::RSub => Some(rhs - lhs),
412                        BinaryNumericOperator::Mul => Some(lhs * rhs),
413                        BinaryNumericOperator::Div => Some(lhs / rhs),
414                        BinaryNumericOperator::RDiv => Some(rhs / lhs),
415                    }
416                };
417                Some(Self { dtype: result_dtype, ptype: ptype, pvalue: value_or_null.map(PValue::from) })
418            }
419        )
420    }
421
422    fn checked_integeral_numeric_operator<
423        P: NativePType
424            + TryFrom<PValue, Error = VortexError>
425            + CheckedSub
426            + CheckedAdd
427            + CheckedMul
428            + CheckedDiv,
429    >(
430        &self,
431        other: &PrimitiveScalar<'a>,
432        result_dtype: &'a DType,
433        ptype: PType,
434        op: BinaryNumericOperator,
435    ) -> Option<PrimitiveScalar<'a>>
436    where
437        PValue: From<P>,
438    {
439        let lhs = self.typed_value::<P>();
440        let rhs = other.typed_value::<P>();
441        let value_or_null_or_overflow = match (lhs, rhs) {
442            (_, None) | (None, _) => Some(None),
443            (Some(lhs), Some(rhs)) => match op {
444                BinaryNumericOperator::Add => lhs.checked_add(&rhs).map(Some),
445                BinaryNumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
446                BinaryNumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
447                BinaryNumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
448                BinaryNumericOperator::Div => lhs.checked_div(&rhs).map(Some),
449                BinaryNumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
450            },
451        };
452
453        value_or_null_or_overflow.map(|value_or_null| Self {
454            dtype: result_dtype,
455            ptype,
456            pvalue: value_or_null.map(PValue::from),
457        })
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use num_traits::CheckedSub;
464    use vortex_dtype::{DType, Nullability, PType};
465
466    use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
467
468    #[test]
469    fn test_integer_subtract() {
470        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
471        let p_scalar1 = PrimitiveScalar::try_new(
472            &dtype,
473            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
474        )
475        .unwrap();
476        let p_scalar2 = PrimitiveScalar::try_new(
477            &dtype,
478            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
479        )
480        .unwrap();
481        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
482        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
483        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
484
485        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
486    }
487
488    #[test]
489    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
490    #[allow(clippy::assertions_on_constants)]
491    fn test_integer_subtract_overflow() {
492        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
493        let p_scalar1 = PrimitiveScalar::try_new(
494            &dtype,
495            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
496        )
497        .unwrap();
498        let p_scalar2 = PrimitiveScalar::try_new(
499            &dtype,
500            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
501        )
502        .unwrap();
503        let _ = p_scalar1 - p_scalar2;
504    }
505
506    #[test]
507    fn test_float_subtract() {
508        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
509        let p_scalar1 = PrimitiveScalar::try_new(
510            &dtype,
511            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
512        )
513        .unwrap();
514        let p_scalar2 = PrimitiveScalar::try_new(
515            &dtype,
516            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
517        )
518        .unwrap();
519        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
520        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
521        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
522
523        assert_eq!(
524            (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
525            0.99f32
526        );
527    }
528}