vortex_scalar/
primitive.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::{Debug, Display, Formatter};
7use std::ops::{Add, Sub};
8
9use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
10use vortex_dtype::half::f16;
11use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
12use vortex_error::{
13    VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
14};
15
16use crate::pvalue::{CoercePValue, PValue};
17use crate::{InnerScalarValue, Scalar, ScalarValue};
18
19/// A scalar value representing a primitive type.
20///
21/// This type provides a view into a primitive scalar value of any primitive type
22/// (integers, floats) with various bit widths.
23#[derive(Debug, Clone, Copy, Hash)]
24pub struct PrimitiveScalar<'a> {
25    dtype: &'a DType,
26    ptype: PType,
27    pvalue: Option<PValue>,
28}
29
30impl Display for PrimitiveScalar<'_> {
31    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32        match self.pvalue {
33            None => write!(f, "null"),
34            Some(pv) => write!(f, "{pv}"),
35        }
36    }
37}
38
39impl PartialEq for PrimitiveScalar<'_> {
40    fn eq(&self, other: &Self) -> bool {
41        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
42    }
43}
44
45impl Eq for PrimitiveScalar<'_> {}
46
47/// Ord is not implemented since it's undefined for different PTypes
48impl PartialOrd for PrimitiveScalar<'_> {
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        if !self.dtype.eq_ignore_nullability(other.dtype) {
51            return None;
52        }
53        self.pvalue.partial_cmp(&other.pvalue)
54    }
55}
56
57impl<'a> PrimitiveScalar<'a> {
58    /// Creates a new primitive scalar from a data type and scalar value.
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if the data type is not a primitive type or if the value
63    /// cannot be converted to the expected primitive type.
64    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
65        let ptype = PType::try_from(dtype)?;
66
67        // Read the serialized value into the correct PValue.
68        // The serialized form may come back over the wire as e.g. any integer type.
69        let pvalue = match_each_native_ptype!(ptype, |T| {
70            value
71                .as_pvalue()?
72                .map(|pv| VortexResult::Ok(PValue::from(<T>::coerce(pv)?)))
73                .transpose()?
74        });
75
76        Ok(Self {
77            dtype,
78            ptype,
79            pvalue,
80        })
81    }
82
83    /// Returns the data type of this primitive scalar.
84    #[inline]
85    pub fn dtype(&self) -> &'a DType {
86        self.dtype
87    }
88
89    /// Returns the primitive type of this scalar.
90    #[inline]
91    pub fn ptype(&self) -> PType {
92        self.ptype
93    }
94
95    /// Returns the primitive value, or None if null.
96    #[inline]
97    pub fn pvalue(&self) -> Option<PValue> {
98        self.pvalue
99    }
100
101    /// Returns the value as a specific native primitive type.
102    ///
103    /// Returns `None` if the scalar is null, otherwise returns `Some(value)` where
104    /// value is the underlying primitive value cast to the requested type `T`.
105    ///
106    /// # Panics
107    ///
108    /// Panics if the primitive type of this scalar does not match the requested type.
109    /// For example, attempting to read an i32 scalar as f32 will panic.
110    pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
111        assert_eq!(
112            self.ptype,
113            T::PTYPE,
114            "Attempting to read {} scalar as {}",
115            self.ptype,
116            T::PTYPE
117        );
118
119        self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
120    }
121
122    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
123        let ptype = PType::try_from(dtype)?;
124        let pvalue = self
125            .pvalue
126            .vortex_expect("nullness handled in Scalar::cast");
127        Ok(match_each_native_ptype!(ptype, |Q| {
128            Scalar::primitive(
129                pvalue.as_primitive::<Q>().map_err(|err| {
130                    vortex_err!("Cannot cast {} to {}: {}", self.ptype, dtype, err)
131                })?,
132                dtype.nullability(),
133            )
134        }))
135    }
136
137    /// Attempts to extract the primitive value as the given type.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if the cast fails due to overflow or type incompatibility.
142    pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
143        match self.pvalue {
144            None => Ok(None),
145            Some(pv) => Ok(Some(match pv {
146                PValue::U8(v) => T::from_u8(v).ok_or_else(|| {
147                    vortex_err!("Cannot cast u8 to {}: value out of range", type_name::<T>())
148                }),
149                PValue::U16(v) => T::from_u16(v).ok_or_else(|| {
150                    vortex_err!(
151                        "Cannot cast u16 to {}: value out of range",
152                        type_name::<T>()
153                    )
154                }),
155                PValue::U32(v) => T::from_u32(v).ok_or_else(|| {
156                    vortex_err!(
157                        "Cannot cast u32 to {}: value out of range",
158                        type_name::<T>()
159                    )
160                }),
161                PValue::U64(v) => T::from_u64(v).ok_or_else(|| {
162                    vortex_err!(
163                        "Cannot cast u64 to {}: value out of range",
164                        type_name::<T>()
165                    )
166                }),
167                PValue::I8(v) => T::from_i8(v).ok_or_else(|| {
168                    vortex_err!("Cannot cast i8 to {}: value out of range", type_name::<T>())
169                }),
170                PValue::I16(v) => T::from_i16(v).ok_or_else(|| {
171                    vortex_err!(
172                        "Cannot cast i16 to {}: value out of range",
173                        type_name::<T>()
174                    )
175                }),
176                PValue::I32(v) => T::from_i32(v).ok_or_else(|| {
177                    vortex_err!(
178                        "Cannot cast i32 to {}: value out of range",
179                        type_name::<T>()
180                    )
181                }),
182                PValue::I64(v) => T::from_i64(v).ok_or_else(|| {
183                    vortex_err!(
184                        "Cannot cast i64 to {}: value out of range",
185                        type_name::<T>()
186                    )
187                }),
188                PValue::F16(v) => T::from_f16(v).ok_or_else(|| {
189                    vortex_err!(
190                        "Cannot cast f16 to {}: value out of range",
191                        type_name::<T>()
192                    )
193                }),
194                PValue::F32(v) => T::from_f32(v).ok_or_else(|| {
195                    vortex_err!(
196                        "Cannot cast f32 to {}: value out of range",
197                        type_name::<T>()
198                    )
199                }),
200                PValue::F64(v) => T::from_f64(v).ok_or_else(|| {
201                    vortex_err!(
202                        "Cannot cast f64 to {}: value out of range",
203                        type_name::<T>()
204                    )
205                }),
206            }?)),
207        }
208    }
209}
210
211/// A trait for types that can be created from primitive values, including f16.
212///
213/// This extends the `FromPrimitive` trait to also support conversion from f16 values.
214pub trait FromPrimitiveOrF16: FromPrimitive {
215    /// Converts an f16 value to this type, returning None if the conversion fails.
216    fn from_f16(v: f16) -> Option<Self>;
217}
218
219macro_rules! from_primitive_or_f16_for_non_floating_point {
220    ($T:ty) => {
221        impl FromPrimitiveOrF16 for $T {
222            fn from_f16(_: f16) -> Option<Self> {
223                None
224            }
225        }
226    };
227}
228
229from_primitive_or_f16_for_non_floating_point!(usize);
230from_primitive_or_f16_for_non_floating_point!(u8);
231from_primitive_or_f16_for_non_floating_point!(u16);
232from_primitive_or_f16_for_non_floating_point!(u32);
233from_primitive_or_f16_for_non_floating_point!(u64);
234from_primitive_or_f16_for_non_floating_point!(i8);
235from_primitive_or_f16_for_non_floating_point!(i16);
236from_primitive_or_f16_for_non_floating_point!(i32);
237from_primitive_or_f16_for_non_floating_point!(i64);
238
239impl FromPrimitiveOrF16 for f16 {
240    fn from_f16(v: f16) -> Option<Self> {
241        Some(v)
242    }
243}
244
245impl FromPrimitiveOrF16 for f32 {
246    fn from_f16(v: f16) -> Option<Self> {
247        Some(v.to_f32())
248    }
249}
250
251impl FromPrimitiveOrF16 for f64 {
252    fn from_f16(v: f16) -> Option<Self> {
253        Some(v.to_f64())
254    }
255}
256
257impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
258    type Error = VortexError;
259
260    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
261        Self::try_new(value.dtype(), value.value())
262    }
263}
264
265impl Sub for PrimitiveScalar<'_> {
266    type Output = Self;
267
268    fn sub(self, rhs: Self) -> Self::Output {
269        self.checked_sub(&rhs)
270            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
271    }
272}
273
274impl CheckedSub for PrimitiveScalar<'_> {
275    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
276        self.checked_binary_numeric(rhs, NumericOperator::Sub)
277    }
278}
279
280impl Add for PrimitiveScalar<'_> {
281    type Output = Self;
282
283    fn add(self, rhs: Self) -> Self::Output {
284        self.checked_add(&rhs)
285            .vortex_expect("PrimitiveScalar add: overflow or underflow")
286    }
287}
288
289impl CheckedAdd for PrimitiveScalar<'_> {
290    fn checked_add(&self, rhs: &Self) -> Option<Self> {
291        self.checked_binary_numeric(rhs, NumericOperator::Add)
292    }
293}
294
295impl Scalar {
296    /// Creates a new primitive scalar from a native value.
297    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
298        Self::primitive_value(value.into(), T::PTYPE, nullability)
299    }
300
301    /// Create a PrimitiveScalar from a PValue.
302    ///
303    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
304    /// for a primitive type.
305    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
306        Self::new(
307            DType::Primitive(ptype, nullability),
308            ScalarValue(InnerScalarValue::Primitive(value)),
309        )
310    }
311
312    /// Reinterprets the bytes of this scalar as a different primitive type.
313    ///
314    /// # Panics
315    ///
316    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
317    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
318        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
319            vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
320        });
321        if primitive.ptype() == ptype {
322            return self.clone();
323        }
324
325        assert_eq!(
326            primitive.ptype().byte_width(),
327            ptype.byte_width(),
328            "can't reinterpret cast between integers of two different widths"
329        );
330
331        Scalar::new(
332            DType::Primitive(ptype, self.dtype.nullability()),
333            primitive
334                .pvalue
335                .map(|p| p.reinterpret_cast(ptype))
336                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
337                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
338        )
339    }
340}
341
342macro_rules! primitive_scalar {
343    ($T:ty) => {
344        impl TryFrom<&Scalar> for $T {
345            type Error = VortexError;
346
347            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
348                <Option<$T>>::try_from(value)?
349                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
350            }
351        }
352
353        impl TryFrom<Scalar> for $T {
354            type Error = VortexError;
355
356            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
357                <$T>::try_from(&value)
358            }
359        }
360
361        impl TryFrom<&Scalar> for Option<$T> {
362            type Error = VortexError;
363
364            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
365                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
366            }
367        }
368
369        impl TryFrom<Scalar> for Option<$T> {
370            type Error = VortexError;
371
372            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
373                <Option<$T>>::try_from(&value)
374            }
375        }
376
377        impl From<$T> for Scalar {
378            fn from(value: $T) -> Self {
379                Scalar::new(
380                    DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
381                    ScalarValue(InnerScalarValue::Primitive(value.into())),
382                )
383            }
384        }
385    };
386}
387
388primitive_scalar!(u8);
389primitive_scalar!(u16);
390primitive_scalar!(u32);
391primitive_scalar!(u64);
392primitive_scalar!(i8);
393primitive_scalar!(i16);
394primitive_scalar!(i32);
395primitive_scalar!(i64);
396primitive_scalar!(f16);
397primitive_scalar!(f32);
398primitive_scalar!(f64);
399
400/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
401impl TryFrom<&Scalar> for usize {
402    type Error = VortexError;
403
404    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
405        let prim = PrimitiveScalar::try_from(value)?
406            .as_::<u64>()?
407            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
408        Ok(usize::try_from(prim)?)
409    }
410}
411
412impl TryFrom<&Scalar> for Option<usize> {
413    type Error = VortexError;
414
415    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
416        Ok(PrimitiveScalar::try_from(value)?
417            .as_::<u64>()?
418            .map(usize::try_from)
419            .transpose()?)
420    }
421}
422
423/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
424impl From<usize> for Scalar {
425    fn from(value: usize) -> Self {
426        Scalar::primitive(value as u64, Nullability::NonNullable)
427    }
428}
429
430macro_rules! primitive_scalar {
431    ($T:ty) => {
432        impl From<$T> for ScalarValue {
433            fn from(value: $T) -> Self {
434                ScalarValue(InnerScalarValue::Primitive(value.into()))
435            }
436        }
437    };
438}
439
440primitive_scalar!(u8);
441primitive_scalar!(u16);
442primitive_scalar!(u32);
443primitive_scalar!(u64);
444primitive_scalar!(i8);
445primitive_scalar!(i16);
446primitive_scalar!(i32);
447primitive_scalar!(i64);
448primitive_scalar!(f16);
449primitive_scalar!(f32);
450primitive_scalar!(f64);
451
452impl From<PValue> for ScalarValue {
453    fn from(value: PValue) -> Self {
454        ScalarValue(InnerScalarValue::Primitive(value))
455    }
456}
457
458/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
459impl From<usize> for ScalarValue {
460    fn from(value: usize) -> Self {
461        ScalarValue(InnerScalarValue::Primitive((value as u64).into()))
462    }
463}
464
465#[derive(Debug, Clone, Copy, PartialEq, Eq)]
466/// Binary element-wise operations on two arrays or two scalars.
467pub enum NumericOperator {
468    /// Binary element-wise addition of two arrays or of two scalars.
469    ///
470    /// Errs at runtime if the sum would overflow or underflow.
471    Add,
472    /// Binary element-wise subtraction of two arrays or of two scalars.
473    Sub,
474    /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`.
475    RSub,
476    /// Binary element-wise multiplication of two arrays or of two scalars.
477    Mul,
478    /// Binary element-wise division of two arrays or of two scalars.
479    Div,
480    /// Same as [NumericOperator::Div] but with the parameters flipped: `right / left`.
481    RDiv,
482    // Missing from arrow-rs:
483    // Min,
484    // Max,
485    // Pow,
486}
487
488impl Display for NumericOperator {
489    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
490        Debug::fmt(self, f)
491    }
492}
493
494impl NumericOperator {
495    /// Returns the operator with swapped operands (e.g., Sub becomes RSub).
496    pub fn swap(self) -> Self {
497        match self {
498            NumericOperator::Add => NumericOperator::Add,
499            NumericOperator::Sub => NumericOperator::RSub,
500            NumericOperator::RSub => NumericOperator::Sub,
501            NumericOperator::Mul => NumericOperator::Mul,
502            NumericOperator::Div => NumericOperator::RDiv,
503            NumericOperator::RDiv => NumericOperator::Div,
504        }
505    }
506}
507
508impl<'a> PrimitiveScalar<'a> {
509    /// Apply the (checked) operator to self and other using SQL-style null semantics.
510    ///
511    /// If the operation overflows, Ok(None) is returned.
512    ///
513    /// If the types are incompatible (ignoring nullability), an error is returned.
514    ///
515    /// If either value is null, the result is null.
516    pub fn checked_binary_numeric(
517        &self,
518        other: &PrimitiveScalar<'a>,
519        op: NumericOperator,
520    ) -> Option<PrimitiveScalar<'a>> {
521        if !self.dtype().eq_ignore_nullability(other.dtype()) {
522            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
523        }
524        let result_dtype = if self.dtype().is_nullable() {
525            self.dtype()
526        } else {
527            other.dtype()
528        };
529        let ptype = self.ptype();
530
531        match_each_native_ptype!(
532            self.ptype(),
533            integral: |P| {
534                self.checked_integeral_numeric_operator::<P>(other, result_dtype, ptype, op)
535            },
536            floating: |P| {
537                let lhs = self.typed_value::<P>();
538                let rhs = other.typed_value::<P>();
539                let value_or_null = match (lhs, rhs) {
540                    (_, None) | (None, _) => None,
541                    (Some(lhs), Some(rhs)) => match op {
542                        NumericOperator::Add => Some(lhs + rhs),
543                        NumericOperator::Sub => Some(lhs - rhs),
544                        NumericOperator::RSub => Some(rhs - lhs),
545                        NumericOperator::Mul => Some(lhs * rhs),
546                        NumericOperator::Div => Some(lhs / rhs),
547                        NumericOperator::RDiv => Some(rhs / lhs),
548                    }
549                };
550                Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
551            }
552        )
553    }
554
555    fn checked_integeral_numeric_operator<
556        P: NativePType
557            + TryFrom<PValue, Error = VortexError>
558            + CheckedSub
559            + CheckedAdd
560            + CheckedMul
561            + CheckedDiv,
562    >(
563        &self,
564        other: &PrimitiveScalar<'a>,
565        result_dtype: &'a DType,
566        ptype: PType,
567        op: NumericOperator,
568    ) -> Option<PrimitiveScalar<'a>>
569    where
570        PValue: From<P>,
571    {
572        let lhs = self.typed_value::<P>();
573        let rhs = other.typed_value::<P>();
574        let value_or_null_or_overflow = match (lhs, rhs) {
575            (_, None) | (None, _) => Some(None),
576            (Some(lhs), Some(rhs)) => match op {
577                NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
578                NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
579                NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
580                NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
581                NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
582                NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
583            },
584        };
585
586        value_or_null_or_overflow.map(|value_or_null| Self {
587            dtype: result_dtype,
588            ptype,
589            pvalue: value_or_null.map(PValue::from),
590        })
591    }
592}
593
594#[cfg(test)]
595#[allow(clippy::cast_possible_truncation)]
596mod tests {
597    use num_traits::CheckedSub;
598    use rstest::rstest;
599    use vortex_dtype::{DType, Nullability, PType};
600
601    use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
602
603    #[test]
604    fn test_integer_subtract() {
605        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
606        let p_scalar1 = PrimitiveScalar::try_new(
607            &dtype,
608            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
609        )
610        .unwrap();
611        let p_scalar2 = PrimitiveScalar::try_new(
612            &dtype,
613            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
614        )
615        .unwrap();
616        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
617        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
618        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
619
620        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
621    }
622
623    #[test]
624    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
625    #[allow(clippy::assertions_on_constants)]
626    fn test_integer_subtract_overflow() {
627        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
628        let p_scalar1 = PrimitiveScalar::try_new(
629            &dtype,
630            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
631        )
632        .unwrap();
633        let p_scalar2 = PrimitiveScalar::try_new(
634            &dtype,
635            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
636        )
637        .unwrap();
638        let _ = p_scalar1 - p_scalar2;
639    }
640
641    #[test]
642    fn test_float_subtract() {
643        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
644        let p_scalar1 = PrimitiveScalar::try_new(
645            &dtype,
646            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
647        )
648        .unwrap();
649        let p_scalar2 = PrimitiveScalar::try_new(
650            &dtype,
651            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
652        )
653        .unwrap();
654        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
655        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
656        assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
657
658        assert_eq!(
659            (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
660            0.99f32
661        );
662    }
663
664    #[test]
665    fn test_primitive_scalar_equality() {
666        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
667        let scalar1 = PrimitiveScalar::try_new(
668            &dtype,
669            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
670        )
671        .unwrap();
672        let scalar2 = PrimitiveScalar::try_new(
673            &dtype,
674            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
675        )
676        .unwrap();
677        let scalar3 = PrimitiveScalar::try_new(
678            &dtype,
679            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))),
680        )
681        .unwrap();
682
683        assert_eq!(scalar1, scalar2);
684        assert_ne!(scalar1, scalar3);
685    }
686
687    #[test]
688    fn test_primitive_scalar_partial_ord() {
689        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
690        let scalar1 = PrimitiveScalar::try_new(
691            &dtype,
692            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
693        )
694        .unwrap();
695        let scalar2 = PrimitiveScalar::try_new(
696            &dtype,
697            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
698        )
699        .unwrap();
700
701        assert!(scalar1 < scalar2);
702        assert!(scalar2 > scalar1);
703        assert_eq!(
704            scalar1.partial_cmp(&scalar1),
705            Some(std::cmp::Ordering::Equal)
706        );
707    }
708
709    #[test]
710    fn test_primitive_scalar_null_handling() {
711        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
712        let null_scalar =
713            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
714
715        assert_eq!(null_scalar.pvalue(), None);
716        assert_eq!(null_scalar.typed_value::<i32>(), None);
717    }
718
719    #[test]
720    fn test_typed_value_correct_type() {
721        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
722        let scalar = PrimitiveScalar::try_new(
723            &dtype,
724            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
725        )
726        .unwrap();
727
728        assert_eq!(scalar.typed_value::<f64>(), Some(3.5));
729    }
730
731    #[test]
732    #[should_panic(expected = "Attempting to read")]
733    fn test_typed_value_wrong_type() {
734        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
735        let scalar = PrimitiveScalar::try_new(
736            &dtype,
737            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
738        )
739        .unwrap();
740
741        let _ = scalar.typed_value::<i32>();
742    }
743
744    #[rstest]
745    #[case(PType::I8, 127i32, PType::I16, true)]
746    #[case(PType::I8, 127i32, PType::I32, true)]
747    #[case(PType::I8, 127i32, PType::I64, true)]
748    #[case(PType::U8, 255i32, PType::U16, true)]
749    #[case(PType::U8, 255i32, PType::U32, true)]
750    #[case(PType::I32, 42i32, PType::F32, true)]
751    #[case(PType::I32, 42i32, PType::F64, true)]
752    // Overflow cases
753    #[case(PType::I32, 300i32, PType::U8, false)]
754    #[case(PType::I32, -1i32, PType::U32, false)]
755    #[case(PType::I32, 256i32, PType::I8, false)]
756    #[case(PType::U16, 65535i32, PType::I8, false)]
757    fn test_primitive_cast(
758        #[case] source_type: PType,
759        #[case] source_value: i32,
760        #[case] target_type: PType,
761        #[case] should_succeed: bool,
762    ) {
763        let source_pvalue = match source_type {
764            PType::I8 => PValue::I8(source_value as i8),
765            PType::U8 => PValue::U8(source_value as u8),
766            PType::U16 => PValue::U16(source_value as u16),
767            PType::I32 => PValue::I32(source_value),
768            _ => unreachable!("Test case uses unexpected source type"),
769        };
770
771        let dtype = DType::Primitive(source_type, Nullability::NonNullable);
772        let scalar = PrimitiveScalar::try_new(
773            &dtype,
774            &ScalarValue(InnerScalarValue::Primitive(source_pvalue)),
775        )
776        .unwrap();
777
778        let target_dtype = DType::Primitive(target_type, Nullability::NonNullable);
779        let result = scalar.cast(&target_dtype);
780
781        if should_succeed {
782            assert!(
783                result.is_ok(),
784                "Cast from {:?} to {:?} should succeed",
785                source_type,
786                target_type
787            );
788        } else {
789            assert!(
790                result.is_err(),
791                "Cast from {:?} to {:?} should fail due to overflow",
792                source_type,
793                target_type
794            );
795        }
796    }
797
798    #[test]
799    fn test_as_conversion_success() {
800        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
801        let scalar = PrimitiveScalar::try_new(
802            &dtype,
803            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
804        )
805        .unwrap();
806
807        assert_eq!(scalar.as_::<i64>().unwrap(), Some(42i64));
808        assert_eq!(scalar.as_::<f64>().unwrap(), Some(42.0));
809    }
810
811    #[test]
812    fn test_as_conversion_overflow() {
813        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
814        let scalar = PrimitiveScalar::try_new(
815            &dtype,
816            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))),
817        )
818        .unwrap();
819
820        // Converting -1 to u32 should fail
821        let result = scalar.as_::<u32>();
822        assert!(result.is_err());
823    }
824
825    #[test]
826    fn test_as_conversion_null() {
827        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
828        let scalar =
829            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
830
831        assert_eq!(scalar.as_::<i32>().unwrap(), None);
832        assert_eq!(scalar.as_::<f64>().unwrap(), None);
833    }
834
835    #[test]
836    fn test_numeric_operator_swap() {
837        use crate::primitive::NumericOperator;
838
839        assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add);
840        assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub);
841        assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub);
842        assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul);
843        assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv);
844        assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div);
845    }
846
847    #[test]
848    fn test_checked_binary_numeric_add() {
849        use crate::primitive::NumericOperator;
850
851        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
852        let scalar1 = PrimitiveScalar::try_new(
853            &dtype,
854            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
855        )
856        .unwrap();
857        let scalar2 = PrimitiveScalar::try_new(
858            &dtype,
859            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
860        )
861        .unwrap();
862
863        let result = scalar1
864            .checked_binary_numeric(&scalar2, NumericOperator::Add)
865            .unwrap();
866        assert_eq!(result.typed_value::<i32>(), Some(30));
867    }
868
869    #[test]
870    fn test_checked_binary_numeric_overflow() {
871        use crate::primitive::NumericOperator;
872
873        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
874        let scalar1 = PrimitiveScalar::try_new(
875            &dtype,
876            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
877        )
878        .unwrap();
879        let scalar2 = PrimitiveScalar::try_new(
880            &dtype,
881            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))),
882        )
883        .unwrap();
884
885        // Add should overflow and return None
886        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add);
887        assert!(result.is_none());
888    }
889
890    #[test]
891    fn test_checked_binary_numeric_with_null() {
892        use crate::primitive::NumericOperator;
893
894        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
895        let scalar1 = PrimitiveScalar::try_new(
896            &dtype,
897            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
898        )
899        .unwrap();
900        let null_scalar =
901            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
902
903        // Operation with null should return null
904        let result = scalar1
905            .checked_binary_numeric(&null_scalar, NumericOperator::Add)
906            .unwrap();
907        assert_eq!(result.pvalue(), None);
908    }
909
910    #[test]
911    fn test_checked_binary_numeric_mul() {
912        use crate::primitive::NumericOperator;
913
914        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
915        let scalar1 = PrimitiveScalar::try_new(
916            &dtype,
917            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
918        )
919        .unwrap();
920        let scalar2 = PrimitiveScalar::try_new(
921            &dtype,
922            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))),
923        )
924        .unwrap();
925
926        let result = scalar1
927            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
928            .unwrap();
929        assert_eq!(result.typed_value::<i32>(), Some(30));
930    }
931
932    #[test]
933    fn test_checked_binary_numeric_div() {
934        use crate::primitive::NumericOperator;
935
936        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
937        let scalar1 = PrimitiveScalar::try_new(
938            &dtype,
939            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
940        )
941        .unwrap();
942        let scalar2 = PrimitiveScalar::try_new(
943            &dtype,
944            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
945        )
946        .unwrap();
947
948        let result = scalar1
949            .checked_binary_numeric(&scalar2, NumericOperator::Div)
950            .unwrap();
951        assert_eq!(result.typed_value::<i32>(), Some(5));
952    }
953
954    #[test]
955    fn test_checked_binary_numeric_rdiv() {
956        use crate::primitive::NumericOperator;
957
958        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
959        let scalar1 = PrimitiveScalar::try_new(
960            &dtype,
961            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
962        )
963        .unwrap();
964        let scalar2 = PrimitiveScalar::try_new(
965            &dtype,
966            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
967        )
968        .unwrap();
969
970        // RDiv means right / left, so 20 / 4 = 5
971        let result = scalar1
972            .checked_binary_numeric(&scalar2, NumericOperator::RDiv)
973            .unwrap();
974        assert_eq!(result.typed_value::<i32>(), Some(5));
975    }
976
977    #[test]
978    fn test_checked_binary_numeric_div_by_zero() {
979        use crate::primitive::NumericOperator;
980
981        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
982        let scalar1 = PrimitiveScalar::try_new(
983            &dtype,
984            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
985        )
986        .unwrap();
987        let scalar2 = PrimitiveScalar::try_new(
988            &dtype,
989            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))),
990        )
991        .unwrap();
992
993        // Division by zero should return None for integers
994        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div);
995        assert!(result.is_none());
996    }
997
998    #[test]
999    fn test_checked_binary_numeric_float_ops() {
1000        use crate::primitive::NumericOperator;
1001
1002        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
1003        let scalar1 = PrimitiveScalar::try_new(
1004            &dtype,
1005            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1006        )
1007        .unwrap();
1008        let scalar2 = PrimitiveScalar::try_new(
1009            &dtype,
1010            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))),
1011        )
1012        .unwrap();
1013
1014        // Test all operations with floats
1015        let add_result = scalar1
1016            .checked_binary_numeric(&scalar2, NumericOperator::Add)
1017            .unwrap();
1018        assert_eq!(add_result.typed_value::<f32>(), Some(12.5));
1019
1020        let sub_result = scalar1
1021            .checked_binary_numeric(&scalar2, NumericOperator::Sub)
1022            .unwrap();
1023        assert_eq!(sub_result.typed_value::<f32>(), Some(7.5));
1024
1025        let mul_result = scalar1
1026            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
1027            .unwrap();
1028        assert_eq!(mul_result.typed_value::<f32>(), Some(25.0));
1029
1030        let div_result = scalar1
1031            .checked_binary_numeric(&scalar2, NumericOperator::Div)
1032            .unwrap();
1033        assert_eq!(div_result.typed_value::<f32>(), Some(4.0));
1034    }
1035
1036    #[test]
1037    fn test_from_primitive_or_f16() {
1038        use vortex_dtype::half::f16;
1039
1040        use crate::primitive::FromPrimitiveOrF16;
1041
1042        // Test f16 to f32 conversion
1043        let f16_val = f16::from_f32(3.5);
1044        assert!(f32::from_f16(f16_val).is_some());
1045
1046        // Test f16 to f64 conversion
1047        assert!(f64::from_f16(f16_val).is_some());
1048
1049        // Test f16 to integer conversion (should fail)
1050        assert!(i32::from_f16(f16_val).is_none());
1051        assert!(u32::from_f16(f16_val).is_none());
1052    }
1053
1054    #[test]
1055    fn test_partial_ord_different_types() {
1056        let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable);
1057        let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable);
1058
1059        let scalar1 = PrimitiveScalar::try_new(
1060            &dtype1,
1061            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
1062        )
1063        .unwrap();
1064        let scalar2 = PrimitiveScalar::try_new(
1065            &dtype2,
1066            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1067        )
1068        .unwrap();
1069
1070        // Different types should not be comparable
1071        assert_eq!(scalar1.partial_cmp(&scalar2), None);
1072    }
1073
1074    #[test]
1075    fn test_scalar_value_from_usize() {
1076        let value: ScalarValue = 42usize.into();
1077        assert!(matches!(
1078            value.0,
1079            InnerScalarValue::Primitive(PValue::U64(42))
1080        ));
1081    }
1082
1083    #[test]
1084    fn test_getters() {
1085        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
1086        let scalar = PrimitiveScalar::try_new(
1087            &dtype,
1088            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
1089        )
1090        .unwrap();
1091
1092        assert_eq!(scalar.dtype(), &dtype);
1093        assert_eq!(scalar.ptype(), PType::I32);
1094        assert_eq!(scalar.pvalue(), Some(PValue::I32(42)));
1095    }
1096}