vortex_scalar/
primitive.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::{Debug, Display, Formatter};
7use std::ops::{Add, Sub};
8
9use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub};
10use vortex_dtype::half::f16;
11use vortex_dtype::{
12    DType, FromPrimitiveOrF16, NativePType, Nullability, PType, match_each_native_ptype,
13};
14use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
15
16use crate::pvalue::{CoercePValue, PValue};
17use crate::{InnerScalarValue, Scalar, ScalarValue};
18
19/// A scalar value representing a primitive type.
20///
21/// This type provides a view into a primitive scalar value of any primitive type
22/// (integers, floats) with various bit widths.
23#[derive(Debug, Clone, Copy, Hash)]
24pub struct PrimitiveScalar<'a> {
25    dtype: &'a DType,
26    ptype: PType,
27    pvalue: Option<PValue>,
28}
29
30impl Display for PrimitiveScalar<'_> {
31    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32        match self.pvalue {
33            None => write!(f, "null"),
34            Some(pv) => write!(f, "{pv}"),
35        }
36    }
37}
38
39impl PartialEq for PrimitiveScalar<'_> {
40    fn eq(&self, other: &Self) -> bool {
41        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
42    }
43}
44
45impl Eq for PrimitiveScalar<'_> {}
46
47/// Ord is not implemented since it's undefined for different PTypes
48impl PartialOrd for PrimitiveScalar<'_> {
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        if !self.dtype.eq_ignore_nullability(other.dtype) {
51            return None;
52        }
53        self.pvalue.partial_cmp(&other.pvalue)
54    }
55}
56
57impl<'a> PrimitiveScalar<'a> {
58    /// Creates a new primitive scalar from a data type and scalar value.
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if the data type is not a primitive type or if the value
63    /// cannot be converted to the expected primitive type.
64    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
65        let ptype = PType::try_from(dtype)?;
66
67        // Read the serialized value into the correct PValue.
68        // The serialized form may come back over the wire as e.g. any integer type.
69        let pvalue = match_each_native_ptype!(ptype, |T| {
70            value
71                .as_pvalue()?
72                .map(|pv| VortexResult::Ok(PValue::from(<T>::coerce(pv)?)))
73                .transpose()?
74        });
75
76        Ok(Self {
77            dtype,
78            ptype,
79            pvalue,
80        })
81    }
82
83    /// Returns the data type of this primitive scalar.
84    #[inline]
85    pub fn dtype(&self) -> &'a DType {
86        self.dtype
87    }
88
89    /// Returns the primitive type of this scalar.
90    #[inline]
91    pub fn ptype(&self) -> PType {
92        self.ptype
93    }
94
95    /// Returns the primitive value, or None if null.
96    #[inline]
97    pub fn pvalue(&self) -> Option<PValue> {
98        self.pvalue
99    }
100
101    /// Returns the value as a specific native primitive type.
102    ///
103    /// Returns `None` if the scalar is null, otherwise returns `Some(value)` where
104    /// value is the underlying primitive value cast to the requested type `T`.
105    ///
106    /// # Panics
107    ///
108    /// Panics if the primitive type of this scalar does not match the requested type.
109    pub fn typed_value<T: NativePType>(&self) -> Option<T> {
110        assert_eq!(
111            self.ptype,
112            T::PTYPE,
113            "Attempting to read {} scalar as {}",
114            self.ptype,
115            T::PTYPE
116        );
117
118        self.pvalue.map(|pv| pv.cast::<T>())
119    }
120
121    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
122        let ptype = PType::try_from(dtype)?;
123        let pvalue = self
124            .pvalue
125            .vortex_expect("nullness handled in Scalar::cast");
126        Ok(match_each_native_ptype!(ptype, |Q| {
127            Scalar::primitive(
128                pvalue
129                    .cast_opt::<Q>()
130                    .ok_or_else(|| vortex_err!("Cannot cast {} to {}", self.ptype, dtype))?,
131                dtype.nullability(),
132            )
133        }))
134    }
135
136    /// Returns true if the scalar is nan.
137    pub fn is_nan(&self) -> bool {
138        self.pvalue.as_ref().is_some_and(|p| p.is_nan())
139    }
140
141    /// Attempts to extract the primitive value as the given type.
142    ///
143    /// # Errors
144    ///
145    /// Panics if the cast fails due to overflow or type incompatibility. See also
146    /// `as_opt` for the checked version that does not panic.
147    ///
148    /// # Examples
149    ///
150    /// ```should_panic
151    /// # use vortex_dtype::{DType, PType};
152    /// # use vortex_scalar::Scalar;
153    /// let wide = Scalar::primitive(1000i32, false.into());
154    ///
155    /// // This succeeds
156    /// let narrow = wide.as_primitive().as_::<i16>();
157    /// assert_eq!(narrow, Some(1000i16));
158    ///
159    /// // This also succeeds
160    /// let null = Scalar::null(DType::Primitive(PType::I16, true.into()));
161    /// assert_eq!(null.as_primitive().as_::<i8>(), None);
162    ///
163    /// // This will panic, because 1000 does not fit in i8
164    /// wide.as_primitive().as_::<i8>();
165    /// ```
166    pub fn as_<T: FromPrimitiveOrF16>(&self) -> Option<T> {
167        self.as_opt::<T>().unwrap_or_else(|| {
168            vortex_panic!(
169                "cast {} to {}: value out of range",
170                self.ptype,
171                type_name::<T>()
172            )
173        })
174    }
175
176    /// Returns the inner value cast to the desired type.
177    ///
178    /// If the cast fails, `None` is returned. If the scalar represents a null, `Some(None)`
179    /// is returned. Otherwise, `Some(Some(T))` is returned for a successful non-null conversion.
180    ///
181    ///
182    /// # Examples
183    ///
184    /// ```
185    /// # use vortex_dtype::{DType, PType};
186    /// # use vortex_scalar::Scalar;
187    ///
188    /// // Non-null values
189    /// let scalar = Scalar::primitive(100i32, false.into());
190    /// let primitive = scalar.as_primitive();
191    /// assert_eq!(primitive.as_opt::<i8>(), Some(Some(100i8)));
192    ///
193    /// // Null value
194    /// let scalar = Scalar::null(DType::Primitive(PType::I32, true.into()));
195    /// let primitive = scalar.as_primitive();
196    /// assert_eq!(primitive.as_opt::<i8>(), Some(None));
197    ///
198    /// // Failed conversion: 1000 cannot fit in an i8
199    /// let scalar = Scalar::primitive(1000i32, false.into());
200    /// let primitive = scalar.as_primitive();
201    /// assert_eq!(primitive.as_opt::<i8>(), None);
202    /// ```
203    pub fn as_opt<T: FromPrimitiveOrF16>(&self) -> Option<Option<T>> {
204        if let Some(pv) = self.pvalue {
205            match pv {
206                PValue::U8(v) => T::from_u8(v),
207                PValue::U16(v) => T::from_u16(v),
208                PValue::U32(v) => T::from_u32(v),
209                PValue::U64(v) => T::from_u64(v),
210                PValue::I8(v) => T::from_i8(v),
211                PValue::I16(v) => T::from_i16(v),
212                PValue::I32(v) => T::from_i32(v),
213                PValue::I64(v) => T::from_i64(v),
214                PValue::F16(v) => T::from_f16(v),
215                PValue::F32(v) => T::from_f32(v),
216                PValue::F64(v) => T::from_f64(v),
217            }
218            .map(Some)
219        } else {
220            Some(None)
221        }
222    }
223}
224
225impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
226    type Error = VortexError;
227
228    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
229        Self::try_new(value.dtype(), value.value())
230    }
231}
232
233impl Sub for PrimitiveScalar<'_> {
234    type Output = Self;
235
236    fn sub(self, rhs: Self) -> Self::Output {
237        self.checked_sub(&rhs)
238            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
239    }
240}
241
242impl CheckedSub for PrimitiveScalar<'_> {
243    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
244        self.checked_binary_numeric(rhs, NumericOperator::Sub)
245    }
246}
247
248impl Add for PrimitiveScalar<'_> {
249    type Output = Self;
250
251    fn add(self, rhs: Self) -> Self::Output {
252        self.checked_add(&rhs)
253            .vortex_expect("PrimitiveScalar add: overflow or underflow")
254    }
255}
256
257impl CheckedAdd for PrimitiveScalar<'_> {
258    fn checked_add(&self, rhs: &Self) -> Option<Self> {
259        self.checked_binary_numeric(rhs, NumericOperator::Add)
260    }
261}
262
263impl Scalar {
264    /// Creates a new primitive scalar from a native value.
265    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
266        Self::primitive_value(value.into(), T::PTYPE, nullability)
267    }
268
269    /// Create a PrimitiveScalar from a PValue.
270    ///
271    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
272    /// for a primitive type.
273    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
274        Self::new(
275            DType::Primitive(ptype, nullability),
276            ScalarValue(InnerScalarValue::Primitive(value)),
277        )
278    }
279
280    /// Reinterprets the bytes of this scalar as a different primitive type.
281    ///
282    /// # Panics
283    ///
284    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
285    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
286        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
287            vortex_panic!(
288                e,
289                "Failed to reinterpret cast {} to {}",
290                self.dtype(),
291                ptype
292            )
293        });
294        if primitive.ptype() == ptype {
295            return self.clone();
296        }
297
298        assert_eq!(
299            primitive.ptype().byte_width(),
300            ptype.byte_width(),
301            "can't reinterpret cast between integers of two different widths"
302        );
303
304        Scalar::new(
305            DType::Primitive(ptype, self.dtype().nullability()),
306            primitive
307                .pvalue
308                .map(|p| p.reinterpret_cast(ptype))
309                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
310                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
311        )
312    }
313}
314
315macro_rules! primitive_scalar {
316    ($T:ty) => {
317        impl TryFrom<&Scalar> for $T {
318            type Error = VortexError;
319
320            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
321                <Option<$T>>::try_from(value)?
322                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
323            }
324        }
325
326        impl TryFrom<Scalar> for $T {
327            type Error = VortexError;
328
329            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
330                <$T>::try_from(&value)
331            }
332        }
333
334        impl TryFrom<&Scalar> for Option<$T> {
335            type Error = VortexError;
336
337            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
338                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
339            }
340        }
341
342        impl TryFrom<Scalar> for Option<$T> {
343            type Error = VortexError;
344
345            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
346                <Option<$T>>::try_from(&value)
347            }
348        }
349
350        impl From<$T> for Scalar {
351            fn from(value: $T) -> Self {
352                Scalar::new(
353                    DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
354                    ScalarValue(InnerScalarValue::Primitive(value.into())),
355                )
356            }
357        }
358
359        impl From<$T> for ScalarValue {
360            fn from(value: $T) -> Self {
361                ScalarValue(InnerScalarValue::Primitive(value.into()))
362            }
363        }
364    };
365}
366
367primitive_scalar!(u8);
368primitive_scalar!(u16);
369primitive_scalar!(u32);
370primitive_scalar!(u64);
371primitive_scalar!(i8);
372primitive_scalar!(i16);
373primitive_scalar!(i32);
374primitive_scalar!(i64);
375primitive_scalar!(f16);
376primitive_scalar!(f32);
377primitive_scalar!(f64);
378
379/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
380impl TryFrom<&Scalar> for usize {
381    type Error = VortexError;
382
383    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
384        let prim = PrimitiveScalar::try_from(value)?
385            .as_::<u64>()
386            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
387        Ok(usize::try_from(prim)?)
388    }
389}
390
391impl TryFrom<&Scalar> for Option<usize> {
392    type Error = VortexError;
393
394    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
395        Ok(PrimitiveScalar::try_from(value)?
396            .as_::<u64>()
397            .map(usize::try_from)
398            .transpose()?)
399    }
400}
401
402/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
403impl From<usize> for Scalar {
404    fn from(value: usize) -> Self {
405        Scalar::primitive(value as u64, Nullability::NonNullable)
406    }
407}
408
409impl From<PValue> for ScalarValue {
410    fn from(value: PValue) -> Self {
411        ScalarValue(InnerScalarValue::Primitive(value))
412    }
413}
414
415/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
416impl From<usize> for ScalarValue {
417    fn from(value: usize) -> Self {
418        ScalarValue(InnerScalarValue::Primitive((value as u64).into()))
419    }
420}
421
422#[derive(Debug, Clone, Copy, PartialEq, Eq)]
423/// Binary element-wise operations on two arrays or two scalars.
424pub enum NumericOperator {
425    /// Binary element-wise addition of two arrays or of two scalars.
426    ///
427    /// Errs at runtime if the sum would overflow or underflow.
428    Add,
429    /// Binary element-wise subtraction of two arrays or of two scalars.
430    Sub,
431    /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`.
432    RSub,
433    /// Binary element-wise multiplication of two arrays or of two scalars.
434    Mul,
435    /// Binary element-wise division of two arrays or of two scalars.
436    Div,
437    /// Same as [NumericOperator::Div] but with the parameters flipped: `right / left`.
438    RDiv,
439    // Missing from arrow-rs:
440    // Min,
441    // Max,
442    // Pow,
443}
444
445impl Display for NumericOperator {
446    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
447        Debug::fmt(self, f)
448    }
449}
450
451impl NumericOperator {
452    /// Returns the operator with swapped operands (e.g., Sub becomes RSub).
453    pub fn swap(self) -> Self {
454        match self {
455            NumericOperator::Add => NumericOperator::Add,
456            NumericOperator::Sub => NumericOperator::RSub,
457            NumericOperator::RSub => NumericOperator::Sub,
458            NumericOperator::Mul => NumericOperator::Mul,
459            NumericOperator::Div => NumericOperator::RDiv,
460            NumericOperator::RDiv => NumericOperator::Div,
461        }
462    }
463}
464
465impl<'a> PrimitiveScalar<'a> {
466    /// Apply the (checked) operator to self and other using SQL-style null semantics.
467    ///
468    /// If the operation overflows, Ok(None) is returned.
469    ///
470    /// If the types are incompatible (ignoring nullability), an error is returned.
471    ///
472    /// If either value is null, the result is null.
473    pub fn checked_binary_numeric(
474        &self,
475        other: &PrimitiveScalar<'a>,
476        op: NumericOperator,
477    ) -> Option<PrimitiveScalar<'a>> {
478        if !self.dtype().eq_ignore_nullability(other.dtype()) {
479            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
480        }
481        let result_dtype = if self.dtype().is_nullable() {
482            self.dtype()
483        } else {
484            other.dtype()
485        };
486        let ptype = self.ptype();
487
488        match_each_native_ptype!(
489            self.ptype(),
490            integral: |P| {
491                self.checked_integral_numeric_operator::<P>(other, result_dtype, ptype, op)
492            },
493            floating: |P| {
494                let lhs = self.typed_value::<P>();
495                let rhs = other.typed_value::<P>();
496                let value_or_null = match (lhs, rhs) {
497                    (_, None) | (None, _) => None,
498                    (Some(lhs), Some(rhs)) => match op {
499                        NumericOperator::Add => Some(lhs + rhs),
500                        NumericOperator::Sub => Some(lhs - rhs),
501                        NumericOperator::RSub => Some(rhs - lhs),
502                        NumericOperator::Mul => Some(lhs * rhs),
503                        NumericOperator::Div => Some(lhs / rhs),
504                        NumericOperator::RDiv => Some(rhs / lhs),
505                    }
506                };
507                Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
508            }
509        )
510    }
511
512    fn checked_integral_numeric_operator<
513        P: NativePType
514            + TryFrom<PValue, Error = VortexError>
515            + CheckedSub
516            + CheckedAdd
517            + CheckedMul
518            + CheckedDiv,
519    >(
520        &self,
521        other: &PrimitiveScalar<'a>,
522        result_dtype: &'a DType,
523        ptype: PType,
524        op: NumericOperator,
525    ) -> Option<PrimitiveScalar<'a>>
526    where
527        PValue: From<P>,
528    {
529        let lhs = self.typed_value::<P>();
530        let rhs = other.typed_value::<P>();
531        let value_or_null_or_overflow = match (lhs, rhs) {
532            (_, None) | (None, _) => Some(None),
533            (Some(lhs), Some(rhs)) => match op {
534                NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
535                NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
536                NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
537                NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
538                NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
539                NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
540            },
541        };
542
543        value_or_null_or_overflow.map(|value_or_null| Self {
544            dtype: result_dtype,
545            ptype,
546            pvalue: value_or_null.map(PValue::from),
547        })
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use num_traits::CheckedSub;
554    use rstest::rstest;
555    use vortex_dtype::{DType, Nullability, PType};
556    use vortex_error::VortexExpect;
557
558    use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
559
560    #[test]
561    fn test_integer_subtract() {
562        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
563        let p_scalar1 = PrimitiveScalar::try_new(
564            &dtype,
565            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
566        )
567        .unwrap();
568        let p_scalar2 = PrimitiveScalar::try_new(
569            &dtype,
570            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
571        )
572        .unwrap();
573        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
574        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
575        assert_eq!(value_or_null_or_type_error.unwrap(), 1);
576
577        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap(), 1);
578    }
579
580    #[test]
581    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
582    fn test_integer_subtract_overflow() {
583        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
584        let p_scalar1 = PrimitiveScalar::try_new(
585            &dtype,
586            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
587        )
588        .unwrap();
589        let p_scalar2 = PrimitiveScalar::try_new(
590            &dtype,
591            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
592        )
593        .unwrap();
594        let _ = p_scalar1 - p_scalar2;
595    }
596
597    #[test]
598    fn test_float_subtract() {
599        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
600        let p_scalar1 = PrimitiveScalar::try_new(
601            &dtype,
602            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
603        )
604        .unwrap();
605        let p_scalar2 = PrimitiveScalar::try_new(
606            &dtype,
607            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
608        )
609        .unwrap();
610        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
611        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
612        assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32);
613
614        assert_eq!((p_scalar1 - p_scalar2).as_::<f32>().unwrap(), 0.99f32);
615    }
616
617    #[test]
618    fn test_primitive_scalar_equality() {
619        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
620        let scalar1 = PrimitiveScalar::try_new(
621            &dtype,
622            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
623        )
624        .unwrap();
625        let scalar2 = PrimitiveScalar::try_new(
626            &dtype,
627            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
628        )
629        .unwrap();
630        let scalar3 = PrimitiveScalar::try_new(
631            &dtype,
632            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))),
633        )
634        .unwrap();
635
636        assert_eq!(scalar1, scalar2);
637        assert_ne!(scalar1, scalar3);
638    }
639
640    #[test]
641    fn test_primitive_scalar_partial_ord() {
642        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
643        let scalar1 = PrimitiveScalar::try_new(
644            &dtype,
645            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
646        )
647        .unwrap();
648        let scalar2 = PrimitiveScalar::try_new(
649            &dtype,
650            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
651        )
652        .unwrap();
653
654        assert!(scalar1 < scalar2);
655        assert!(scalar2 > scalar1);
656        assert_eq!(
657            scalar1.partial_cmp(&scalar1),
658            Some(std::cmp::Ordering::Equal)
659        );
660    }
661
662    #[test]
663    fn test_primitive_scalar_null_handling() {
664        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
665        let null_scalar =
666            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
667
668        assert_eq!(null_scalar.pvalue(), None);
669        assert_eq!(null_scalar.typed_value::<i32>(), None);
670    }
671
672    #[test]
673    fn test_typed_value_correct_type() {
674        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
675        let scalar = PrimitiveScalar::try_new(
676            &dtype,
677            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
678        )
679        .unwrap();
680
681        assert_eq!(scalar.typed_value::<f64>(), Some(3.5));
682    }
683
684    #[test]
685    #[should_panic(expected = "Attempting to read")]
686    fn test_typed_value_wrong_type() {
687        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
688        let scalar = PrimitiveScalar::try_new(
689            &dtype,
690            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
691        )
692        .unwrap();
693
694        let _ = scalar.typed_value::<i32>();
695    }
696
697    #[rstest]
698    #[case(PType::I8, 127i32, PType::I16, true)]
699    #[case(PType::I8, 127i32, PType::I32, true)]
700    #[case(PType::I8, 127i32, PType::I64, true)]
701    #[case(PType::U8, 255i32, PType::U16, true)]
702    #[case(PType::U8, 255i32, PType::U32, true)]
703    #[case(PType::I32, 42i32, PType::F32, true)]
704    #[case(PType::I32, 42i32, PType::F64, true)]
705    // Overflow cases
706    #[case(PType::I32, 300i32, PType::U8, false)]
707    #[case(PType::I32, -1i32, PType::U32, false)]
708    #[case(PType::I32, 256i32, PType::I8, false)]
709    #[case(PType::U16, 65535i32, PType::I8, false)]
710    fn test_primitive_cast(
711        #[case] source_type: PType,
712        #[case] source_value: i32,
713        #[case] target_type: PType,
714        #[case] should_succeed: bool,
715    ) {
716        let source_pvalue = match source_type {
717            PType::I8 => PValue::I8(i8::try_from(source_value).vortex_expect("cannot cast")),
718            PType::U8 => PValue::U8(u8::try_from(source_value).vortex_expect("cannot cast")),
719            PType::U16 => PValue::U16(u16::try_from(source_value).vortex_expect("cannot cast")),
720            PType::I32 => PValue::I32(source_value),
721            _ => unreachable!("Test case uses unexpected source type"),
722        };
723
724        let dtype = DType::Primitive(source_type, Nullability::NonNullable);
725        let scalar = PrimitiveScalar::try_new(
726            &dtype,
727            &ScalarValue(InnerScalarValue::Primitive(source_pvalue)),
728        )
729        .unwrap();
730
731        let target_dtype = DType::Primitive(target_type, Nullability::NonNullable);
732        let result = scalar.cast(&target_dtype);
733
734        if should_succeed {
735            assert!(
736                result.is_ok(),
737                "Cast from {:?} to {:?} should succeed",
738                source_type,
739                target_type
740            );
741        } else {
742            assert!(
743                result.is_err(),
744                "Cast from {:?} to {:?} should fail due to overflow",
745                source_type,
746                target_type
747            );
748        }
749    }
750
751    #[test]
752    fn test_as_conversion_success() {
753        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
754        let scalar = PrimitiveScalar::try_new(
755            &dtype,
756            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
757        )
758        .unwrap();
759
760        assert_eq!(scalar.as_::<i64>(), Some(42i64));
761        assert_eq!(scalar.as_::<f64>(), Some(42.0));
762    }
763
764    #[test]
765    fn test_as_conversion_overflow() {
766        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
767        let scalar = PrimitiveScalar::try_new(
768            &dtype,
769            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))),
770        )
771        .unwrap();
772
773        // Converting -1 to u32 should fail
774        let result = scalar.as_opt::<u32>();
775        assert!(result.is_none());
776    }
777
778    #[test]
779    fn test_as_conversion_null() {
780        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
781        let scalar =
782            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
783
784        assert_eq!(scalar.as_::<i32>(), None);
785        assert_eq!(scalar.as_::<f64>(), None);
786    }
787
788    #[test]
789    fn test_numeric_operator_swap() {
790        use crate::primitive::NumericOperator;
791
792        assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add);
793        assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub);
794        assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub);
795        assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul);
796        assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv);
797        assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div);
798    }
799
800    #[test]
801    fn test_checked_binary_numeric_add() {
802        use crate::primitive::NumericOperator;
803
804        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
805        let scalar1 = PrimitiveScalar::try_new(
806            &dtype,
807            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
808        )
809        .unwrap();
810        let scalar2 = PrimitiveScalar::try_new(
811            &dtype,
812            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
813        )
814        .unwrap();
815
816        let result = scalar1
817            .checked_binary_numeric(&scalar2, NumericOperator::Add)
818            .unwrap();
819        assert_eq!(result.typed_value::<i32>(), Some(30));
820    }
821
822    #[test]
823    fn test_checked_binary_numeric_overflow() {
824        use crate::primitive::NumericOperator;
825
826        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
827        let scalar1 = PrimitiveScalar::try_new(
828            &dtype,
829            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
830        )
831        .unwrap();
832        let scalar2 = PrimitiveScalar::try_new(
833            &dtype,
834            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))),
835        )
836        .unwrap();
837
838        // Add should overflow and return None
839        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add);
840        assert!(result.is_none());
841    }
842
843    #[test]
844    fn test_checked_binary_numeric_with_null() {
845        use crate::primitive::NumericOperator;
846
847        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
848        let scalar1 = PrimitiveScalar::try_new(
849            &dtype,
850            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
851        )
852        .unwrap();
853        let null_scalar =
854            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
855
856        // Operation with null should return null
857        let result = scalar1
858            .checked_binary_numeric(&null_scalar, NumericOperator::Add)
859            .unwrap();
860        assert_eq!(result.pvalue(), None);
861    }
862
863    #[test]
864    fn test_checked_binary_numeric_mul() {
865        use crate::primitive::NumericOperator;
866
867        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
868        let scalar1 = PrimitiveScalar::try_new(
869            &dtype,
870            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
871        )
872        .unwrap();
873        let scalar2 = PrimitiveScalar::try_new(
874            &dtype,
875            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))),
876        )
877        .unwrap();
878
879        let result = scalar1
880            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
881            .unwrap();
882        assert_eq!(result.typed_value::<i32>(), Some(30));
883    }
884
885    #[test]
886    fn test_checked_binary_numeric_div() {
887        use crate::primitive::NumericOperator;
888
889        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
890        let scalar1 = PrimitiveScalar::try_new(
891            &dtype,
892            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
893        )
894        .unwrap();
895        let scalar2 = PrimitiveScalar::try_new(
896            &dtype,
897            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
898        )
899        .unwrap();
900
901        let result = scalar1
902            .checked_binary_numeric(&scalar2, NumericOperator::Div)
903            .unwrap();
904        assert_eq!(result.typed_value::<i32>(), Some(5));
905    }
906
907    #[test]
908    fn test_checked_binary_numeric_rdiv() {
909        use crate::primitive::NumericOperator;
910
911        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
912        let scalar1 = PrimitiveScalar::try_new(
913            &dtype,
914            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
915        )
916        .unwrap();
917        let scalar2 = PrimitiveScalar::try_new(
918            &dtype,
919            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
920        )
921        .unwrap();
922
923        // RDiv means right / left, so 20 / 4 = 5
924        let result = scalar1
925            .checked_binary_numeric(&scalar2, NumericOperator::RDiv)
926            .unwrap();
927        assert_eq!(result.typed_value::<i32>(), Some(5));
928    }
929
930    #[test]
931    fn test_checked_binary_numeric_div_by_zero() {
932        use crate::primitive::NumericOperator;
933
934        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
935        let scalar1 = PrimitiveScalar::try_new(
936            &dtype,
937            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
938        )
939        .unwrap();
940        let scalar2 = PrimitiveScalar::try_new(
941            &dtype,
942            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))),
943        )
944        .unwrap();
945
946        // Division by zero should return None for integers
947        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div);
948        assert!(result.is_none());
949    }
950
951    #[test]
952    fn test_checked_binary_numeric_float_ops() {
953        use crate::primitive::NumericOperator;
954
955        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
956        let scalar1 = PrimitiveScalar::try_new(
957            &dtype,
958            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
959        )
960        .unwrap();
961        let scalar2 = PrimitiveScalar::try_new(
962            &dtype,
963            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))),
964        )
965        .unwrap();
966
967        // Test all operations with floats
968        let add_result = scalar1
969            .checked_binary_numeric(&scalar2, NumericOperator::Add)
970            .unwrap();
971        assert_eq!(add_result.typed_value::<f32>(), Some(12.5));
972
973        let sub_result = scalar1
974            .checked_binary_numeric(&scalar2, NumericOperator::Sub)
975            .unwrap();
976        assert_eq!(sub_result.typed_value::<f32>(), Some(7.5));
977
978        let mul_result = scalar1
979            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
980            .unwrap();
981        assert_eq!(mul_result.typed_value::<f32>(), Some(25.0));
982
983        let div_result = scalar1
984            .checked_binary_numeric(&scalar2, NumericOperator::Div)
985            .unwrap();
986        assert_eq!(div_result.typed_value::<f32>(), Some(4.0));
987    }
988
989    #[test]
990    fn test_from_primitive_or_f16() {
991        use vortex_dtype::half::f16;
992
993        use crate::primitive::FromPrimitiveOrF16;
994
995        // Test f16 to f32 conversion
996        let f16_val = f16::from_f32(3.5);
997        assert!(f32::from_f16(f16_val).is_some());
998
999        // Test f16 to f64 conversion
1000        assert!(f64::from_f16(f16_val).is_some());
1001
1002        // Test PValue::F16(f16) to integer conversion (should fail)
1003        assert!(i32::try_from(PValue::from(f16_val)).is_err());
1004        assert!(u32::try_from(PValue::from(f16_val)).is_err());
1005    }
1006
1007    #[test]
1008    fn test_partial_ord_different_types() {
1009        let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable);
1010        let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable);
1011
1012        let scalar1 = PrimitiveScalar::try_new(
1013            &dtype1,
1014            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
1015        )
1016        .unwrap();
1017        let scalar2 = PrimitiveScalar::try_new(
1018            &dtype2,
1019            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1020        )
1021        .unwrap();
1022
1023        // Different types should not be comparable
1024        assert_eq!(scalar1.partial_cmp(&scalar2), None);
1025    }
1026
1027    #[test]
1028    fn test_scalar_value_from_usize() {
1029        let value: ScalarValue = 42usize.into();
1030        assert!(matches!(
1031            value.0,
1032            InnerScalarValue::Primitive(PValue::U64(42))
1033        ));
1034    }
1035
1036    #[test]
1037    fn test_getters() {
1038        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
1039        let scalar = PrimitiveScalar::try_new(
1040            &dtype,
1041            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
1042        )
1043        .unwrap();
1044
1045        assert_eq!(scalar.dtype(), &dtype);
1046        assert_eq!(scalar.ptype(), PType::I32);
1047        assert_eq!(scalar.pvalue(), Some(PValue::I32(42)));
1048    }
1049}