vortex_scalar/
primitive.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::{Debug, Display, Formatter};
7use std::ops::{Add, Sub};
8
9use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
10use vortex_dtype::half::f16;
11use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
12use vortex_error::{
13    VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
14};
15
16use crate::pvalue::PValue;
17use crate::{InnerScalarValue, Scalar, ScalarValue};
18
19/// A scalar value representing a primitive type.
20///
21/// This type provides a view into a primitive scalar value of any primitive type
22/// (integers, floats) with various bit widths.
23#[derive(Debug, Clone, Copy, Hash)]
24pub struct PrimitiveScalar<'a> {
25    dtype: &'a DType,
26    ptype: PType,
27    pvalue: Option<PValue>,
28}
29
30impl Display for PrimitiveScalar<'_> {
31    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32        match self.pvalue {
33            None => write!(f, "null"),
34            Some(pv) => write!(f, "{pv}"),
35        }
36    }
37}
38
39impl PartialEq for PrimitiveScalar<'_> {
40    fn eq(&self, other: &Self) -> bool {
41        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
42    }
43}
44
45impl Eq for PrimitiveScalar<'_> {}
46
47/// Ord is not implemented since it's undefined for different PTypes
48impl PartialOrd for PrimitiveScalar<'_> {
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        if !self.dtype.eq_ignore_nullability(other.dtype) {
51            return None;
52        }
53        self.pvalue.partial_cmp(&other.pvalue)
54    }
55}
56
57impl<'a> PrimitiveScalar<'a> {
58    /// Creates a new primitive scalar from a data type and scalar value.
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if the data type is not a primitive type or if the value
63    /// cannot be converted to the expected primitive type.
64    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
65        let ptype = PType::try_from(dtype)?;
66
67        // Read the serialized value into the correct PValue.
68        // The serialized form may come back over the wire as e.g. any integer type.
69        let pvalue = match_each_native_ptype!(ptype, |T| {
70            if let Some(pvalue) = value.as_pvalue()? {
71                Some(PValue::from(<T>::try_from(pvalue)?))
72            } else {
73                None
74            }
75        });
76
77        Ok(Self {
78            dtype,
79            ptype,
80            pvalue,
81        })
82    }
83
84    /// Returns the data type of this primitive scalar.
85    #[inline]
86    pub fn dtype(&self) -> &'a DType {
87        self.dtype
88    }
89
90    /// Returns the primitive type of this scalar.
91    #[inline]
92    pub fn ptype(&self) -> PType {
93        self.ptype
94    }
95
96    /// Returns the primitive value, or None if null.
97    #[inline]
98    pub fn pvalue(&self) -> Option<PValue> {
99        self.pvalue
100    }
101
102    /// Returns the value as a specific native primitive type.
103    ///
104    /// # Panics
105    ///
106    /// Panics if the primitive type of this scalar does not match the requested type.
107    pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
108        assert_eq!(
109            self.ptype,
110            T::PTYPE,
111            "Attempting to read {} scalar as {}",
112            self.ptype,
113            T::PTYPE
114        );
115
116        self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
117    }
118
119    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
120        let ptype = PType::try_from(dtype)?;
121        let pvalue = self
122            .pvalue
123            .vortex_expect("nullness handled in Scalar::cast");
124        Ok(match_each_native_ptype!(ptype, |Q| {
125            Scalar::primitive(
126                pvalue.as_primitive::<Q>().map_err(|err| {
127                    vortex_err!(
128                        "Can't cast {} scalar {} to {} (cause: {})",
129                        self.ptype,
130                        pvalue,
131                        dtype,
132                        err
133                    )
134                })?,
135                dtype.nullability(),
136            )
137        }))
138    }
139
140    /// Attempts to extract the primitive value as the given type.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if the cast fails due to overflow or type incompatibility.
145    pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
146        match self.pvalue {
147            None => Ok(None),
148            Some(pv) => Ok(Some(match pv {
149                PValue::U8(v) => T::from_u8(v)
150                    .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
151                PValue::U16(v) => T::from_u16(v)
152                    .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
153                PValue::U32(v) => T::from_u32(v)
154                    .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
155                PValue::U64(v) => T::from_u64(v)
156                    .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
157                PValue::I8(v) => T::from_i8(v)
158                    .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
159                PValue::I16(v) => T::from_i16(v)
160                    .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
161                PValue::I32(v) => T::from_i32(v)
162                    .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
163                PValue::I64(v) => T::from_i64(v)
164                    .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
165                PValue::F16(v) => T::from_f16(v)
166                    .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
167                PValue::F32(v) => T::from_f32(v)
168                    .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
169                PValue::F64(v) => T::from_f64(v)
170                    .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
171            }?)),
172        }
173    }
174}
175
176/// A trait for types that can be created from primitive values, including f16.
177///
178/// This extends the `FromPrimitive` trait to also support conversion from f16 values.
179pub trait FromPrimitiveOrF16: FromPrimitive {
180    /// Converts an f16 value to this type, returning None if the conversion fails.
181    fn from_f16(v: f16) -> Option<Self>;
182}
183
184macro_rules! from_primitive_or_f16_for_non_floating_point {
185    ($T:ty) => {
186        impl FromPrimitiveOrF16 for $T {
187            fn from_f16(_: f16) -> Option<Self> {
188                None
189            }
190        }
191    };
192}
193
194from_primitive_or_f16_for_non_floating_point!(usize);
195from_primitive_or_f16_for_non_floating_point!(u8);
196from_primitive_or_f16_for_non_floating_point!(u16);
197from_primitive_or_f16_for_non_floating_point!(u32);
198from_primitive_or_f16_for_non_floating_point!(u64);
199from_primitive_or_f16_for_non_floating_point!(i8);
200from_primitive_or_f16_for_non_floating_point!(i16);
201from_primitive_or_f16_for_non_floating_point!(i32);
202from_primitive_or_f16_for_non_floating_point!(i64);
203
204impl FromPrimitiveOrF16 for f16 {
205    fn from_f16(v: f16) -> Option<Self> {
206        Some(v)
207    }
208}
209
210impl FromPrimitiveOrF16 for f32 {
211    fn from_f16(v: f16) -> Option<Self> {
212        Some(v.to_f32())
213    }
214}
215
216impl FromPrimitiveOrF16 for f64 {
217    fn from_f16(v: f16) -> Option<Self> {
218        Some(v.to_f64())
219    }
220}
221
222impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
223    type Error = VortexError;
224
225    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
226        Self::try_new(value.dtype(), value.value())
227    }
228}
229
230impl Sub for PrimitiveScalar<'_> {
231    type Output = Self;
232
233    fn sub(self, rhs: Self) -> Self::Output {
234        self.checked_sub(&rhs)
235            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
236    }
237}
238
239impl CheckedSub for PrimitiveScalar<'_> {
240    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
241        self.checked_binary_numeric(rhs, NumericOperator::Sub)
242    }
243}
244
245impl Add for PrimitiveScalar<'_> {
246    type Output = Self;
247
248    fn add(self, rhs: Self) -> Self::Output {
249        self.checked_add(&rhs)
250            .vortex_expect("PrimitiveScalar add: overflow or underflow")
251    }
252}
253
254impl CheckedAdd for PrimitiveScalar<'_> {
255    fn checked_add(&self, rhs: &Self) -> Option<Self> {
256        self.checked_binary_numeric(rhs, NumericOperator::Add)
257    }
258}
259
260impl Scalar {
261    /// Creates a new primitive scalar from a native value.
262    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
263        Self::primitive_value(value.into(), T::PTYPE, nullability)
264    }
265
266    /// Create a PrimitiveScalar from a PValue.
267    ///
268    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
269    /// for a primitive type.
270    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
271        Self {
272            dtype: DType::Primitive(ptype, nullability),
273            value: ScalarValue(InnerScalarValue::Primitive(value)),
274        }
275    }
276
277    /// Reinterprets the bytes of this scalar as a different primitive type.
278    ///
279    /// # Panics
280    ///
281    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
282    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
283        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
284            vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
285        });
286        if primitive.ptype() == ptype {
287            return self.clone();
288        }
289
290        assert_eq!(
291            primitive.ptype().byte_width(),
292            ptype.byte_width(),
293            "can't reinterpret cast between integers of two different widths"
294        );
295
296        Scalar::new(
297            DType::Primitive(ptype, self.dtype.nullability()),
298            primitive
299                .pvalue
300                .map(|p| p.reinterpret_cast(ptype))
301                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
302                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
303        )
304    }
305}
306
307macro_rules! primitive_scalar {
308    ($T:ty) => {
309        impl TryFrom<&Scalar> for $T {
310            type Error = VortexError;
311
312            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
313                <Option<$T>>::try_from(value)?
314                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
315            }
316        }
317
318        impl TryFrom<Scalar> for $T {
319            type Error = VortexError;
320
321            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
322                <$T>::try_from(&value)
323            }
324        }
325
326        impl TryFrom<&Scalar> for Option<$T> {
327            type Error = VortexError;
328
329            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
330                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
331            }
332        }
333
334        impl TryFrom<Scalar> for Option<$T> {
335            type Error = VortexError;
336
337            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
338                <Option<$T>>::try_from(&value)
339            }
340        }
341
342        impl From<$T> for Scalar {
343            fn from(value: $T) -> Self {
344                Scalar {
345                    dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
346                    value: ScalarValue(InnerScalarValue::Primitive(value.into())),
347                }
348            }
349        }
350    };
351}
352
353primitive_scalar!(u8);
354primitive_scalar!(u16);
355primitive_scalar!(u32);
356primitive_scalar!(u64);
357primitive_scalar!(i8);
358primitive_scalar!(i16);
359primitive_scalar!(i32);
360primitive_scalar!(i64);
361primitive_scalar!(f16);
362primitive_scalar!(f32);
363primitive_scalar!(f64);
364
365/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
366impl TryFrom<&Scalar> for usize {
367    type Error = VortexError;
368
369    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
370        let prim = PrimitiveScalar::try_from(value)?
371            .as_::<u64>()?
372            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
373        Ok(usize::try_from(prim)?)
374    }
375}
376
377impl TryFrom<&Scalar> for Option<usize> {
378    type Error = VortexError;
379
380    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
381        Ok(PrimitiveScalar::try_from(value)?
382            .as_::<u64>()?
383            .map(usize::try_from)
384            .transpose()?)
385    }
386}
387
388/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
389impl From<usize> for Scalar {
390    fn from(value: usize) -> Self {
391        Scalar::primitive(value as u64, Nullability::NonNullable)
392    }
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Eq)]
396/// Binary element-wise operations on two arrays or two scalars.
397pub enum NumericOperator {
398    /// Binary element-wise addition of two arrays or of two scalars.
399    ///
400    /// Errs at runtime if the sum would overflow or underflow.
401    Add,
402    /// Binary element-wise subtraction of two arrays or of two scalars.
403    Sub,
404    /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`.
405    RSub,
406    /// Binary element-wise multiplication of two arrays or of two scalars.
407    Mul,
408    /// Binary element-wise division of two arrays or of two scalars.
409    Div,
410    /// Same as [NumericOperator::Div] but with the parameters flipped: `right - left`.
411    RDiv,
412    // Missing from arrow-rs:
413    // Min,
414    // Max,
415    // Pow,
416}
417
418impl Display for NumericOperator {
419    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
420        Debug::fmt(self, f)
421    }
422}
423
424impl NumericOperator {
425    /// Returns the operator with swapped operands (e.g., Sub becomes RSub).
426    pub fn swap(self) -> Self {
427        match self {
428            NumericOperator::Add => NumericOperator::Add,
429            NumericOperator::Sub => NumericOperator::RSub,
430            NumericOperator::RSub => NumericOperator::Sub,
431            NumericOperator::Mul => NumericOperator::Mul,
432            NumericOperator::Div => NumericOperator::RDiv,
433            NumericOperator::RDiv => NumericOperator::Div,
434        }
435    }
436}
437
438impl<'a> PrimitiveScalar<'a> {
439    /// Apply the (checked) operator to self and other using SQL-style null semantics.
440    ///
441    /// If the operation overflows, Ok(None) is returned.
442    ///
443    /// If the types are incompatible (ignoring nullability), an error is returned.
444    ///
445    /// If either value is null, the result is null.
446    pub fn checked_binary_numeric(
447        &self,
448        other: &PrimitiveScalar<'a>,
449        op: NumericOperator,
450    ) -> Option<PrimitiveScalar<'a>> {
451        if !self.dtype().eq_ignore_nullability(other.dtype()) {
452            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
453        }
454        let result_dtype = if self.dtype().is_nullable() {
455            self.dtype()
456        } else {
457            other.dtype()
458        };
459        let ptype = self.ptype();
460
461        match_each_native_ptype!(
462            self.ptype(),
463            integral: |P| {
464                self.checked_integeral_numeric_operator::<P>(other, result_dtype, ptype, op)
465            },
466            floating: |P| {
467                let lhs = self.typed_value::<P>();
468                let rhs = other.typed_value::<P>();
469                let value_or_null = match (lhs, rhs) {
470                    (_, None) | (None, _) => None,
471                    (Some(lhs), Some(rhs)) => match op {
472                        NumericOperator::Add => Some(lhs + rhs),
473                        NumericOperator::Sub => Some(lhs - rhs),
474                        NumericOperator::RSub => Some(rhs - lhs),
475                        NumericOperator::Mul => Some(lhs * rhs),
476                        NumericOperator::Div => Some(lhs / rhs),
477                        NumericOperator::RDiv => Some(rhs / lhs),
478                    }
479                };
480                Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
481            }
482        )
483    }
484
485    fn checked_integeral_numeric_operator<
486        P: NativePType
487            + TryFrom<PValue, Error = VortexError>
488            + CheckedSub
489            + CheckedAdd
490            + CheckedMul
491            + CheckedDiv,
492    >(
493        &self,
494        other: &PrimitiveScalar<'a>,
495        result_dtype: &'a DType,
496        ptype: PType,
497        op: NumericOperator,
498    ) -> Option<PrimitiveScalar<'a>>
499    where
500        PValue: From<P>,
501    {
502        let lhs = self.typed_value::<P>();
503        let rhs = other.typed_value::<P>();
504        let value_or_null_or_overflow = match (lhs, rhs) {
505            (_, None) | (None, _) => Some(None),
506            (Some(lhs), Some(rhs)) => match op {
507                NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
508                NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
509                NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
510                NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
511                NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
512                NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
513            },
514        };
515
516        value_or_null_or_overflow.map(|value_or_null| Self {
517            dtype: result_dtype,
518            ptype,
519            pvalue: value_or_null.map(PValue::from),
520        })
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use num_traits::CheckedSub;
527    use vortex_dtype::{DType, Nullability, PType};
528
529    use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
530
531    #[test]
532    fn test_integer_subtract() {
533        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
534        let p_scalar1 = PrimitiveScalar::try_new(
535            &dtype,
536            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
537        )
538        .unwrap();
539        let p_scalar2 = PrimitiveScalar::try_new(
540            &dtype,
541            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
542        )
543        .unwrap();
544        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
545        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
546        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
547
548        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
549    }
550
551    #[test]
552    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
553    #[allow(clippy::assertions_on_constants)]
554    fn test_integer_subtract_overflow() {
555        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
556        let p_scalar1 = PrimitiveScalar::try_new(
557            &dtype,
558            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
559        )
560        .unwrap();
561        let p_scalar2 = PrimitiveScalar::try_new(
562            &dtype,
563            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
564        )
565        .unwrap();
566        let _ = p_scalar1 - p_scalar2;
567    }
568
569    #[test]
570    fn test_float_subtract() {
571        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
572        let p_scalar1 = PrimitiveScalar::try_new(
573            &dtype,
574            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
575        )
576        .unwrap();
577        let p_scalar2 = PrimitiveScalar::try_new(
578            &dtype,
579            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
580        )
581        .unwrap();
582        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
583        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
584        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
585
586        assert_eq!(
587            (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
588            0.99f32
589        );
590    }
591}