vortex_scalar/
primitive.rs

1use std::any::type_name;
2use std::cmp::Ordering;
3use std::fmt::{Debug, Display, Formatter};
4use std::ops::{Add, Sub};
5
6use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{
10    VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
11};
12
13use crate::pvalue::PValue;
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16#[derive(Debug, Clone, Copy, Hash)]
17pub struct PrimitiveScalar<'a> {
18    dtype: &'a DType,
19    ptype: PType,
20    pvalue: Option<PValue>,
21}
22
23impl Display for PrimitiveScalar<'_> {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        match self.pvalue {
26            None => write!(f, "null"),
27            Some(pv) => write!(f, "{pv}"),
28        }
29    }
30}
31
32impl PartialEq for PrimitiveScalar<'_> {
33    fn eq(&self, other: &Self) -> bool {
34        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
35    }
36}
37
38impl Eq for PrimitiveScalar<'_> {}
39
40/// Ord is not implemented since it's undefined for different PTypes
41impl PartialOrd for PrimitiveScalar<'_> {
42    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43        if !self.dtype.eq_ignore_nullability(other.dtype) {
44            return None;
45        }
46        self.pvalue.partial_cmp(&other.pvalue)
47    }
48}
49
50impl<'a> PrimitiveScalar<'a> {
51    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
52        let ptype = PType::try_from(dtype)?;
53
54        // Read the serialized value into the correct PValue.
55        // The serialized form may come back over the wire as e.g. any integer type.
56        let pvalue = match_each_native_ptype!(ptype, |T| {
57            if let Some(pvalue) = value.as_pvalue()? {
58                Some(PValue::from(<T>::try_from(pvalue)?))
59            } else {
60                None
61            }
62        });
63
64        Ok(Self {
65            dtype,
66            ptype,
67            pvalue,
68        })
69    }
70
71    #[inline]
72    pub fn dtype(&self) -> &'a DType {
73        self.dtype
74    }
75
76    #[inline]
77    pub fn ptype(&self) -> PType {
78        self.ptype
79    }
80
81    #[inline]
82    pub fn pvalue(&self) -> Option<PValue> {
83        self.pvalue
84    }
85
86    pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
87        assert_eq!(
88            self.ptype,
89            T::PTYPE,
90            "Attempting to read {} scalar as {}",
91            self.ptype,
92            T::PTYPE
93        );
94
95        self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
96    }
97
98    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
99        let ptype = PType::try_from(dtype)?;
100        let pvalue = self
101            .pvalue
102            .vortex_expect("nullness handled in Scalar::cast");
103        Ok(match_each_native_ptype!(ptype, |Q| {
104            Scalar::primitive(
105                pvalue.as_primitive::<Q>().map_err(|err| {
106                    vortex_err!(
107                        "Can't cast {} scalar {} to {} (cause: {})",
108                        self.ptype,
109                        pvalue,
110                        dtype,
111                        err
112                    )
113                })?,
114                dtype.nullability(),
115            )
116        }))
117    }
118
119    /// Attempt to extract the primitive value as the given type.
120    /// Fails on a bad cast.
121    pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
122        match self.pvalue {
123            None => Ok(None),
124            Some(pv) => Ok(Some(match pv {
125                PValue::U8(v) => T::from_u8(v)
126                    .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
127                PValue::U16(v) => T::from_u16(v)
128                    .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
129                PValue::U32(v) => T::from_u32(v)
130                    .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
131                PValue::U64(v) => T::from_u64(v)
132                    .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
133                PValue::I8(v) => T::from_i8(v)
134                    .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
135                PValue::I16(v) => T::from_i16(v)
136                    .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
137                PValue::I32(v) => T::from_i32(v)
138                    .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
139                PValue::I64(v) => T::from_i64(v)
140                    .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
141                PValue::F16(v) => T::from_f16(v)
142                    .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
143                PValue::F32(v) => T::from_f32(v)
144                    .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
145                PValue::F64(v) => T::from_f64(v)
146                    .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
147            }?)),
148        }
149    }
150}
151
152pub trait FromPrimitiveOrF16: FromPrimitive {
153    fn from_f16(v: f16) -> Option<Self>;
154}
155
156macro_rules! from_primitive_or_f16_for_non_floating_point {
157    ($T:ty) => {
158        impl FromPrimitiveOrF16 for $T {
159            fn from_f16(_: f16) -> Option<Self> {
160                None
161            }
162        }
163    };
164}
165
166from_primitive_or_f16_for_non_floating_point!(usize);
167from_primitive_or_f16_for_non_floating_point!(u8);
168from_primitive_or_f16_for_non_floating_point!(u16);
169from_primitive_or_f16_for_non_floating_point!(u32);
170from_primitive_or_f16_for_non_floating_point!(u64);
171from_primitive_or_f16_for_non_floating_point!(i8);
172from_primitive_or_f16_for_non_floating_point!(i16);
173from_primitive_or_f16_for_non_floating_point!(i32);
174from_primitive_or_f16_for_non_floating_point!(i64);
175
176impl FromPrimitiveOrF16 for f16 {
177    fn from_f16(v: f16) -> Option<Self> {
178        Some(v)
179    }
180}
181
182impl FromPrimitiveOrF16 for f32 {
183    fn from_f16(v: f16) -> Option<Self> {
184        Some(v.to_f32())
185    }
186}
187
188impl FromPrimitiveOrF16 for f64 {
189    fn from_f16(v: f16) -> Option<Self> {
190        Some(v.to_f64())
191    }
192}
193
194impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
195    type Error = VortexError;
196
197    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
198        Self::try_new(value.dtype(), value.value())
199    }
200}
201
202impl Sub for PrimitiveScalar<'_> {
203    type Output = Self;
204
205    fn sub(self, rhs: Self) -> Self::Output {
206        self.checked_sub(&rhs)
207            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
208    }
209}
210
211impl CheckedSub for PrimitiveScalar<'_> {
212    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
213        self.checked_binary_numeric(rhs, NumericOperator::Sub)
214    }
215}
216
217impl Add for PrimitiveScalar<'_> {
218    type Output = Self;
219
220    fn add(self, rhs: Self) -> Self::Output {
221        self.checked_add(&rhs)
222            .vortex_expect("PrimitiveScalar add: overflow or underflow")
223    }
224}
225
226impl CheckedAdd for PrimitiveScalar<'_> {
227    fn checked_add(&self, rhs: &Self) -> Option<Self> {
228        self.checked_binary_numeric(rhs, NumericOperator::Add)
229    }
230}
231
232impl Scalar {
233    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
234        Self::primitive_value(value.into(), T::PTYPE, nullability)
235    }
236
237    /// Create a PrimitiveScalar from a PValue.
238    ///
239    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
240    /// for a primitive type.
241    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
242        Self {
243            dtype: DType::Primitive(ptype, nullability),
244            value: ScalarValue(InnerScalarValue::Primitive(value)),
245        }
246    }
247
248    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
249        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
250            vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
251        });
252        if primitive.ptype() == ptype {
253            return self.clone();
254        }
255
256        assert_eq!(
257            primitive.ptype().byte_width(),
258            ptype.byte_width(),
259            "can't reinterpret cast between integers of two different widths"
260        );
261
262        Scalar::new(
263            DType::Primitive(ptype, self.dtype.nullability()),
264            primitive
265                .pvalue
266                .map(|p| p.reinterpret_cast(ptype))
267                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
268                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
269        )
270    }
271}
272
273macro_rules! primitive_scalar {
274    ($T:ty) => {
275        impl TryFrom<&Scalar> for $T {
276            type Error = VortexError;
277
278            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
279                <Option<$T>>::try_from(value)?
280                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
281            }
282        }
283
284        impl TryFrom<Scalar> for $T {
285            type Error = VortexError;
286
287            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
288                <$T>::try_from(&value)
289            }
290        }
291
292        impl TryFrom<&Scalar> for Option<$T> {
293            type Error = VortexError;
294
295            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
296                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
297            }
298        }
299
300        impl TryFrom<Scalar> for Option<$T> {
301            type Error = VortexError;
302
303            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
304                <Option<$T>>::try_from(&value)
305            }
306        }
307
308        impl From<$T> for Scalar {
309            fn from(value: $T) -> Self {
310                Scalar {
311                    dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
312                    value: ScalarValue(InnerScalarValue::Primitive(value.into())),
313                }
314            }
315        }
316    };
317}
318
319primitive_scalar!(u8);
320primitive_scalar!(u16);
321primitive_scalar!(u32);
322primitive_scalar!(u64);
323primitive_scalar!(i8);
324primitive_scalar!(i16);
325primitive_scalar!(i32);
326primitive_scalar!(i64);
327primitive_scalar!(f16);
328primitive_scalar!(f32);
329primitive_scalar!(f64);
330
331/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
332impl TryFrom<&Scalar> for usize {
333    type Error = VortexError;
334
335    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
336        let prim = PrimitiveScalar::try_from(value)?
337            .as_::<u64>()?
338            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
339        Ok(usize::try_from(prim)?)
340    }
341}
342
343impl TryFrom<&Scalar> for Option<usize> {
344    type Error = VortexError;
345
346    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
347        Ok(PrimitiveScalar::try_from(value)?
348            .as_::<u64>()?
349            .map(usize::try_from)
350            .transpose()?)
351    }
352}
353
354/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
355impl From<usize> for Scalar {
356    fn from(value: usize) -> Self {
357        Scalar::primitive(value as u64, Nullability::NonNullable)
358    }
359}
360
361#[derive(Debug, Clone, Copy, PartialEq, Eq)]
362/// Binary element-wise operations on two arrays or two scalars.
363pub enum NumericOperator {
364    /// Binary element-wise addition of two arrays or of two scalars.
365    Add,
366    /// Binary element-wise subtraction of two arrays or of two scalars.
367    Sub,
368    /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`.
369    RSub,
370    /// Binary element-wise multiplication of two arrays or of two scalars.
371    Mul,
372    /// Binary element-wise division of two arrays or of two scalars.
373    Div,
374    /// Same as [NumericOperator::Div] but with the parameters flipped: `right - left`.
375    RDiv,
376    // Missing from arrow-rs:
377    // Min,
378    // Max,
379    // Pow,
380}
381
382impl Display for NumericOperator {
383    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
384        Debug::fmt(self, f)
385    }
386}
387
388impl NumericOperator {
389    pub fn swap(self) -> Self {
390        match self {
391            NumericOperator::Add => NumericOperator::Add,
392            NumericOperator::Sub => NumericOperator::RSub,
393            NumericOperator::RSub => NumericOperator::Sub,
394            NumericOperator::Mul => NumericOperator::Mul,
395            NumericOperator::Div => NumericOperator::RDiv,
396            NumericOperator::RDiv => NumericOperator::Div,
397        }
398    }
399}
400
401impl<'a> PrimitiveScalar<'a> {
402    /// Apply the (checked) operator to self and other using SQL-style null semantics.
403    ///
404    /// If the operation overflows, Ok(None) is returned.
405    ///
406    /// If the types are incompatible (ignoring nullability), an error is returned.
407    ///
408    /// If either value is null, the result is null.
409    pub fn checked_binary_numeric(
410        &self,
411        other: &PrimitiveScalar<'a>,
412        op: NumericOperator,
413    ) -> Option<PrimitiveScalar<'a>> {
414        if !self.dtype().eq_ignore_nullability(other.dtype()) {
415            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
416        }
417        let result_dtype = if self.dtype().is_nullable() {
418            self.dtype()
419        } else {
420            other.dtype()
421        };
422        let ptype = self.ptype();
423
424        match_each_native_ptype!(
425            self.ptype(),
426            integral: |P| {
427                self.checked_integeral_numeric_operator::<P>(other, result_dtype, ptype, op)
428            },
429            floating_point: |P| {
430                let lhs = self.typed_value::<P>();
431                let rhs = other.typed_value::<P>();
432                let value_or_null = match (lhs, rhs) {
433                    (_, None) | (None, _) => None,
434                    (Some(lhs), Some(rhs)) => match op {
435                        NumericOperator::Add => Some(lhs + rhs),
436                        NumericOperator::Sub => Some(lhs - rhs),
437                        NumericOperator::RSub => Some(rhs - lhs),
438                        NumericOperator::Mul => Some(lhs * rhs),
439                        NumericOperator::Div => Some(lhs / rhs),
440                        NumericOperator::RDiv => Some(rhs / lhs),
441                    }
442                };
443                Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
444            }
445        )
446    }
447
448    fn checked_integeral_numeric_operator<
449        P: NativePType
450            + TryFrom<PValue, Error = VortexError>
451            + CheckedSub
452            + CheckedAdd
453            + CheckedMul
454            + CheckedDiv,
455    >(
456        &self,
457        other: &PrimitiveScalar<'a>,
458        result_dtype: &'a DType,
459        ptype: PType,
460        op: NumericOperator,
461    ) -> Option<PrimitiveScalar<'a>>
462    where
463        PValue: From<P>,
464    {
465        let lhs = self.typed_value::<P>();
466        let rhs = other.typed_value::<P>();
467        let value_or_null_or_overflow = match (lhs, rhs) {
468            (_, None) | (None, _) => Some(None),
469            (Some(lhs), Some(rhs)) => match op {
470                NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
471                NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
472                NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
473                NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
474                NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
475                NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
476            },
477        };
478
479        value_or_null_or_overflow.map(|value_or_null| Self {
480            dtype: result_dtype,
481            ptype,
482            pvalue: value_or_null.map(PValue::from),
483        })
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use num_traits::CheckedSub;
490    use vortex_dtype::{DType, Nullability, PType};
491
492    use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
493
494    #[test]
495    fn test_integer_subtract() {
496        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
497        let p_scalar1 = PrimitiveScalar::try_new(
498            &dtype,
499            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
500        )
501        .unwrap();
502        let p_scalar2 = PrimitiveScalar::try_new(
503            &dtype,
504            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
505        )
506        .unwrap();
507        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
508        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
509        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
510
511        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
512    }
513
514    #[test]
515    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
516    #[allow(clippy::assertions_on_constants)]
517    fn test_integer_subtract_overflow() {
518        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
519        let p_scalar1 = PrimitiveScalar::try_new(
520            &dtype,
521            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
522        )
523        .unwrap();
524        let p_scalar2 = PrimitiveScalar::try_new(
525            &dtype,
526            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
527        )
528        .unwrap();
529        let _ = p_scalar1 - p_scalar2;
530    }
531
532    #[test]
533    fn test_float_subtract() {
534        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
535        let p_scalar1 = PrimitiveScalar::try_new(
536            &dtype,
537            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
538        )
539        .unwrap();
540        let p_scalar2 = PrimitiveScalar::try_new(
541            &dtype,
542            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
543        )
544        .unwrap();
545        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
546        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
547        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
548
549        assert_eq!(
550            (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
551            0.99f32
552        );
553    }
554}