vortex_scalar/
primitive.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::{Debug, Display, Formatter};
7use std::ops::{Add, Sub};
8
9use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
10use vortex_dtype::half::f16;
11use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
12use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
13
14use crate::pvalue::{CoercePValue, PValue};
15use crate::{InnerScalarValue, Scalar, ScalarValue};
16
17/// A scalar value representing a primitive type.
18///
19/// This type provides a view into a primitive scalar value of any primitive type
20/// (integers, floats) with various bit widths.
21#[derive(Debug, Clone, Copy, Hash)]
22pub struct PrimitiveScalar<'a> {
23    dtype: &'a DType,
24    ptype: PType,
25    pvalue: Option<PValue>,
26}
27
28impl Display for PrimitiveScalar<'_> {
29    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30        match self.pvalue {
31            None => write!(f, "null"),
32            Some(pv) => write!(f, "{pv}"),
33        }
34    }
35}
36
37impl PartialEq for PrimitiveScalar<'_> {
38    fn eq(&self, other: &Self) -> bool {
39        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
40    }
41}
42
43impl Eq for PrimitiveScalar<'_> {}
44
45/// Ord is not implemented since it's undefined for different PTypes
46impl PartialOrd for PrimitiveScalar<'_> {
47    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48        if !self.dtype.eq_ignore_nullability(other.dtype) {
49            return None;
50        }
51        self.pvalue.partial_cmp(&other.pvalue)
52    }
53}
54
55impl<'a> PrimitiveScalar<'a> {
56    /// Creates a new primitive scalar from a data type and scalar value.
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if the data type is not a primitive type or if the value
61    /// cannot be converted to the expected primitive type.
62    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
63        let ptype = PType::try_from(dtype)?;
64
65        // Read the serialized value into the correct PValue.
66        // The serialized form may come back over the wire as e.g. any integer type.
67        let pvalue = match_each_native_ptype!(ptype, |T| {
68            value
69                .as_pvalue()?
70                .map(|pv| VortexResult::Ok(PValue::from(<T>::coerce(pv)?)))
71                .transpose()?
72        });
73
74        Ok(Self {
75            dtype,
76            ptype,
77            pvalue,
78        })
79    }
80
81    /// Returns the data type of this primitive scalar.
82    #[inline]
83    pub fn dtype(&self) -> &'a DType {
84        self.dtype
85    }
86
87    /// Returns the primitive type of this scalar.
88    #[inline]
89    pub fn ptype(&self) -> PType {
90        self.ptype
91    }
92
93    /// Returns the primitive value, or None if null.
94    #[inline]
95    pub fn pvalue(&self) -> Option<PValue> {
96        self.pvalue
97    }
98
99    /// Returns the value as a specific native primitive type.
100    ///
101    /// Returns `None` if the scalar is null, otherwise returns `Some(value)` where
102    /// value is the underlying primitive value cast to the requested type `T`.
103    ///
104    /// # Panics
105    ///
106    /// Panics if the primitive type of this scalar does not match the requested type.
107    pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
108        assert_eq!(
109            self.ptype,
110            T::PTYPE,
111            "Attempting to read {} scalar as {}",
112            self.ptype,
113            T::PTYPE
114        );
115
116        self.pvalue.map(|pv| pv.as_primitive::<T>())
117    }
118
119    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
120        let ptype = PType::try_from(dtype)?;
121        let pvalue = self
122            .pvalue
123            .vortex_expect("nullness handled in Scalar::cast");
124        Ok(match_each_native_ptype!(ptype, |Q| {
125            Scalar::primitive(
126                pvalue
127                    .as_primitive_opt::<Q>()
128                    .ok_or_else(|| vortex_err!("Cannot cast {} to {}", self.ptype, dtype))?,
129                dtype.nullability(),
130            )
131        }))
132    }
133
134    /// Attempts to extract the primitive value as the given type.
135    ///
136    /// # Errors
137    ///
138    /// Panics if the cast fails due to overflow or type incompatibility. See also
139    /// `as_opt` for the checked version that does not panic.
140    ///
141    /// # Examples
142    ///
143    /// ```should_panic
144    /// # use vortex_dtype::{DType, PType};
145    /// # use vortex_scalar::Scalar;
146    /// let wide = Scalar::primitive(1000i32, false.into());
147    ///
148    /// // This succeeds
149    /// let narrow = wide.as_primitive().as_::<i16>();
150    /// assert_eq!(narrow, Some(1000i16));
151    ///
152    /// // This also succeeds
153    /// let null = Scalar::null(DType::Primitive(PType::I16, true.into()));
154    /// assert_eq!(null.as_primitive().as_::<i8>(), None);
155    ///
156    /// // This will panic, because 1000 does not fit in i8
157    /// wide.as_primitive().as_::<i8>();
158    /// ```
159    pub fn as_<T: FromPrimitiveOrF16>(&self) -> Option<T> {
160        self.as_opt::<T>().unwrap_or_else(|| {
161            vortex_panic!(
162                "cast {} to {}: value out of range",
163                self.ptype,
164                type_name::<T>()
165            )
166        })
167    }
168
169    /// Returns the inner value cast to the desired type.
170    ///
171    /// If the cast fails, `None` is returned. If the scalar represents a null, `Some(None)`
172    /// is returned. Otherwise, `Some(Some(T))` is returned for a successful non-null conversion.
173    ///
174    ///
175    /// # Examples
176    ///
177    /// ```
178    /// # use vortex_dtype::{DType, PType};
179    /// # use vortex_scalar::Scalar;
180    ///
181    /// // Non-null values
182    /// let scalar = Scalar::primitive(100i32, false.into());
183    /// let primitive = scalar.as_primitive();
184    /// assert_eq!(primitive.as_opt::<i8>(), Some(Some(100i8)));
185    ///
186    /// // Null value
187    /// let scalar = Scalar::null(DType::Primitive(PType::I32, true.into()));
188    /// let primitive = scalar.as_primitive();
189    /// assert_eq!(primitive.as_opt::<i8>(), Some(None));
190    ///
191    /// // Failed conversion: 1000 cannot fit in an i8
192    /// let scalar = Scalar::primitive(1000i32, false.into());
193    /// let primitive = scalar.as_primitive();
194    /// assert_eq!(primitive.as_opt::<i8>(), None);
195    /// ```
196    pub fn as_opt<T: FromPrimitiveOrF16>(&self) -> Option<Option<T>> {
197        if let Some(pv) = self.pvalue {
198            match pv {
199                PValue::U8(v) => T::from_u8(v),
200                PValue::U16(v) => T::from_u16(v),
201                PValue::U32(v) => T::from_u32(v),
202                PValue::U64(v) => T::from_u64(v),
203                PValue::I8(v) => T::from_i8(v),
204                PValue::I16(v) => T::from_i16(v),
205                PValue::I32(v) => T::from_i32(v),
206                PValue::I64(v) => T::from_i64(v),
207                PValue::F16(v) => T::from_f16(v),
208                PValue::F32(v) => T::from_f32(v),
209                PValue::F64(v) => T::from_f64(v),
210            }
211            .map(Some)
212        } else {
213            Some(None)
214        }
215    }
216}
217
218/// A trait for types that can be created from primitive values, including f16.
219///
220/// This extends the `FromPrimitive` trait to also support conversion from f16 values.
221pub trait FromPrimitiveOrF16: FromPrimitive {
222    /// Converts an f16 value to this type, returning None if the conversion fails.
223    fn from_f16(v: f16) -> Option<Self>;
224}
225
226macro_rules! from_primitive_or_f16_for_non_floating_point {
227    ($T:ty) => {
228        impl FromPrimitiveOrF16 for $T {
229            fn from_f16(_: f16) -> Option<Self> {
230                None
231            }
232        }
233    };
234}
235
236from_primitive_or_f16_for_non_floating_point!(usize);
237from_primitive_or_f16_for_non_floating_point!(u8);
238from_primitive_or_f16_for_non_floating_point!(u16);
239from_primitive_or_f16_for_non_floating_point!(u32);
240from_primitive_or_f16_for_non_floating_point!(u64);
241from_primitive_or_f16_for_non_floating_point!(i8);
242from_primitive_or_f16_for_non_floating_point!(i16);
243from_primitive_or_f16_for_non_floating_point!(i32);
244from_primitive_or_f16_for_non_floating_point!(i64);
245
246impl FromPrimitiveOrF16 for f16 {
247    fn from_f16(v: f16) -> Option<Self> {
248        Some(v)
249    }
250}
251
252impl FromPrimitiveOrF16 for f32 {
253    fn from_f16(v: f16) -> Option<Self> {
254        Some(v.to_f32())
255    }
256}
257
258impl FromPrimitiveOrF16 for f64 {
259    fn from_f16(v: f16) -> Option<Self> {
260        Some(v.to_f64())
261    }
262}
263
264impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
265    type Error = VortexError;
266
267    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
268        Self::try_new(value.dtype(), value.value())
269    }
270}
271
272impl Sub for PrimitiveScalar<'_> {
273    type Output = Self;
274
275    fn sub(self, rhs: Self) -> Self::Output {
276        self.checked_sub(&rhs)
277            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
278    }
279}
280
281impl CheckedSub for PrimitiveScalar<'_> {
282    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
283        self.checked_binary_numeric(rhs, NumericOperator::Sub)
284    }
285}
286
287impl Add for PrimitiveScalar<'_> {
288    type Output = Self;
289
290    fn add(self, rhs: Self) -> Self::Output {
291        self.checked_add(&rhs)
292            .vortex_expect("PrimitiveScalar add: overflow or underflow")
293    }
294}
295
296impl CheckedAdd for PrimitiveScalar<'_> {
297    fn checked_add(&self, rhs: &Self) -> Option<Self> {
298        self.checked_binary_numeric(rhs, NumericOperator::Add)
299    }
300}
301
302impl Scalar {
303    /// Creates a new primitive scalar from a native value.
304    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
305        Self::primitive_value(value.into(), T::PTYPE, nullability)
306    }
307
308    /// Create a PrimitiveScalar from a PValue.
309    ///
310    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
311    /// for a primitive type.
312    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
313        Self::new(
314            DType::Primitive(ptype, nullability),
315            ScalarValue(InnerScalarValue::Primitive(value)),
316        )
317    }
318
319    /// Reinterprets the bytes of this scalar as a different primitive type.
320    ///
321    /// # Panics
322    ///
323    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
324    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
325        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
326            vortex_panic!(
327                e,
328                "Failed to reinterpret cast {} to {}",
329                self.dtype(),
330                ptype
331            )
332        });
333        if primitive.ptype() == ptype {
334            return self.clone();
335        }
336
337        assert_eq!(
338            primitive.ptype().byte_width(),
339            ptype.byte_width(),
340            "can't reinterpret cast between integers of two different widths"
341        );
342
343        Scalar::new(
344            DType::Primitive(ptype, self.dtype().nullability()),
345            primitive
346                .pvalue
347                .map(|p| p.reinterpret_cast(ptype))
348                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
349                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
350        )
351    }
352}
353
354macro_rules! primitive_scalar {
355    ($T:ty) => {
356        impl TryFrom<&Scalar> for $T {
357            type Error = VortexError;
358
359            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
360                <Option<$T>>::try_from(value)?
361                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
362            }
363        }
364
365        impl TryFrom<Scalar> for $T {
366            type Error = VortexError;
367
368            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
369                <$T>::try_from(&value)
370            }
371        }
372
373        impl TryFrom<&Scalar> for Option<$T> {
374            type Error = VortexError;
375
376            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
377                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
378            }
379        }
380
381        impl TryFrom<Scalar> for Option<$T> {
382            type Error = VortexError;
383
384            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
385                <Option<$T>>::try_from(&value)
386            }
387        }
388
389        impl From<$T> for Scalar {
390            fn from(value: $T) -> Self {
391                Scalar::new(
392                    DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
393                    ScalarValue(InnerScalarValue::Primitive(value.into())),
394                )
395            }
396        }
397
398        impl From<$T> for ScalarValue {
399            fn from(value: $T) -> Self {
400                ScalarValue(InnerScalarValue::Primitive(value.into()))
401            }
402        }
403    };
404}
405
406primitive_scalar!(u8);
407primitive_scalar!(u16);
408primitive_scalar!(u32);
409primitive_scalar!(u64);
410primitive_scalar!(i8);
411primitive_scalar!(i16);
412primitive_scalar!(i32);
413primitive_scalar!(i64);
414primitive_scalar!(f16);
415primitive_scalar!(f32);
416primitive_scalar!(f64);
417
418/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
419impl TryFrom<&Scalar> for usize {
420    type Error = VortexError;
421
422    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
423        let prim = PrimitiveScalar::try_from(value)?
424            .as_::<u64>()
425            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
426        Ok(usize::try_from(prim)?)
427    }
428}
429
430impl TryFrom<&Scalar> for Option<usize> {
431    type Error = VortexError;
432
433    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
434        Ok(PrimitiveScalar::try_from(value)?
435            .as_::<u64>()
436            .map(usize::try_from)
437            .transpose()?)
438    }
439}
440
441/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
442impl From<usize> for Scalar {
443    fn from(value: usize) -> Self {
444        Scalar::primitive(value as u64, Nullability::NonNullable)
445    }
446}
447
448impl From<PValue> for ScalarValue {
449    fn from(value: PValue) -> Self {
450        ScalarValue(InnerScalarValue::Primitive(value))
451    }
452}
453
454/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
455impl From<usize> for ScalarValue {
456    fn from(value: usize) -> Self {
457        ScalarValue(InnerScalarValue::Primitive((value as u64).into()))
458    }
459}
460
461#[derive(Debug, Clone, Copy, PartialEq, Eq)]
462/// Binary element-wise operations on two arrays or two scalars.
463pub enum NumericOperator {
464    /// Binary element-wise addition of two arrays or of two scalars.
465    ///
466    /// Errs at runtime if the sum would overflow or underflow.
467    Add,
468    /// Binary element-wise subtraction of two arrays or of two scalars.
469    Sub,
470    /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`.
471    RSub,
472    /// Binary element-wise multiplication of two arrays or of two scalars.
473    Mul,
474    /// Binary element-wise division of two arrays or of two scalars.
475    Div,
476    /// Same as [NumericOperator::Div] but with the parameters flipped: `right / left`.
477    RDiv,
478    // Missing from arrow-rs:
479    // Min,
480    // Max,
481    // Pow,
482}
483
484impl Display for NumericOperator {
485    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
486        Debug::fmt(self, f)
487    }
488}
489
490impl NumericOperator {
491    /// Returns the operator with swapped operands (e.g., Sub becomes RSub).
492    pub fn swap(self) -> Self {
493        match self {
494            NumericOperator::Add => NumericOperator::Add,
495            NumericOperator::Sub => NumericOperator::RSub,
496            NumericOperator::RSub => NumericOperator::Sub,
497            NumericOperator::Mul => NumericOperator::Mul,
498            NumericOperator::Div => NumericOperator::RDiv,
499            NumericOperator::RDiv => NumericOperator::Div,
500        }
501    }
502}
503
504impl<'a> PrimitiveScalar<'a> {
505    /// Apply the (checked) operator to self and other using SQL-style null semantics.
506    ///
507    /// If the operation overflows, Ok(None) is returned.
508    ///
509    /// If the types are incompatible (ignoring nullability), an error is returned.
510    ///
511    /// If either value is null, the result is null.
512    pub fn checked_binary_numeric(
513        &self,
514        other: &PrimitiveScalar<'a>,
515        op: NumericOperator,
516    ) -> Option<PrimitiveScalar<'a>> {
517        if !self.dtype().eq_ignore_nullability(other.dtype()) {
518            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
519        }
520        let result_dtype = if self.dtype().is_nullable() {
521            self.dtype()
522        } else {
523            other.dtype()
524        };
525        let ptype = self.ptype();
526
527        match_each_native_ptype!(
528            self.ptype(),
529            integral: |P| {
530                self.checked_integral_numeric_operator::<P>(other, result_dtype, ptype, op)
531            },
532            floating: |P| {
533                let lhs = self.typed_value::<P>();
534                let rhs = other.typed_value::<P>();
535                let value_or_null = match (lhs, rhs) {
536                    (_, None) | (None, _) => None,
537                    (Some(lhs), Some(rhs)) => match op {
538                        NumericOperator::Add => Some(lhs + rhs),
539                        NumericOperator::Sub => Some(lhs - rhs),
540                        NumericOperator::RSub => Some(rhs - lhs),
541                        NumericOperator::Mul => Some(lhs * rhs),
542                        NumericOperator::Div => Some(lhs / rhs),
543                        NumericOperator::RDiv => Some(rhs / lhs),
544                    }
545                };
546                Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
547            }
548        )
549    }
550
551    fn checked_integral_numeric_operator<
552        P: NativePType
553            + TryFrom<PValue, Error = VortexError>
554            + CheckedSub
555            + CheckedAdd
556            + CheckedMul
557            + CheckedDiv,
558    >(
559        &self,
560        other: &PrimitiveScalar<'a>,
561        result_dtype: &'a DType,
562        ptype: PType,
563        op: NumericOperator,
564    ) -> Option<PrimitiveScalar<'a>>
565    where
566        PValue: From<P>,
567    {
568        let lhs = self.typed_value::<P>();
569        let rhs = other.typed_value::<P>();
570        let value_or_null_or_overflow = match (lhs, rhs) {
571            (_, None) | (None, _) => Some(None),
572            (Some(lhs), Some(rhs)) => match op {
573                NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
574                NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
575                NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
576                NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
577                NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
578                NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
579            },
580        };
581
582        value_or_null_or_overflow.map(|value_or_null| Self {
583            dtype: result_dtype,
584            ptype,
585            pvalue: value_or_null.map(PValue::from),
586        })
587    }
588}
589
590#[cfg(test)]
591#[allow(clippy::cast_possible_truncation)]
592mod tests {
593    use num_traits::CheckedSub;
594    use rstest::rstest;
595    use vortex_dtype::{DType, Nullability, PType};
596
597    use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
598
599    #[test]
600    fn test_integer_subtract() {
601        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
602        let p_scalar1 = PrimitiveScalar::try_new(
603            &dtype,
604            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
605        )
606        .unwrap();
607        let p_scalar2 = PrimitiveScalar::try_new(
608            &dtype,
609            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
610        )
611        .unwrap();
612        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
613        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
614        assert_eq!(value_or_null_or_type_error.unwrap(), 1);
615
616        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap(), 1);
617    }
618
619    #[test]
620    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
621    #[allow(clippy::assertions_on_constants)]
622    fn test_integer_subtract_overflow() {
623        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
624        let p_scalar1 = PrimitiveScalar::try_new(
625            &dtype,
626            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
627        )
628        .unwrap();
629        let p_scalar2 = PrimitiveScalar::try_new(
630            &dtype,
631            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
632        )
633        .unwrap();
634        let _ = p_scalar1 - p_scalar2;
635    }
636
637    #[test]
638    fn test_float_subtract() {
639        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
640        let p_scalar1 = PrimitiveScalar::try_new(
641            &dtype,
642            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
643        )
644        .unwrap();
645        let p_scalar2 = PrimitiveScalar::try_new(
646            &dtype,
647            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
648        )
649        .unwrap();
650        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
651        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
652        assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32);
653
654        assert_eq!((p_scalar1 - p_scalar2).as_::<f32>().unwrap(), 0.99f32);
655    }
656
657    #[test]
658    fn test_primitive_scalar_equality() {
659        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
660        let scalar1 = PrimitiveScalar::try_new(
661            &dtype,
662            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
663        )
664        .unwrap();
665        let scalar2 = PrimitiveScalar::try_new(
666            &dtype,
667            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
668        )
669        .unwrap();
670        let scalar3 = PrimitiveScalar::try_new(
671            &dtype,
672            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))),
673        )
674        .unwrap();
675
676        assert_eq!(scalar1, scalar2);
677        assert_ne!(scalar1, scalar3);
678    }
679
680    #[test]
681    fn test_primitive_scalar_partial_ord() {
682        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
683        let scalar1 = PrimitiveScalar::try_new(
684            &dtype,
685            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
686        )
687        .unwrap();
688        let scalar2 = PrimitiveScalar::try_new(
689            &dtype,
690            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
691        )
692        .unwrap();
693
694        assert!(scalar1 < scalar2);
695        assert!(scalar2 > scalar1);
696        assert_eq!(
697            scalar1.partial_cmp(&scalar1),
698            Some(std::cmp::Ordering::Equal)
699        );
700    }
701
702    #[test]
703    fn test_primitive_scalar_null_handling() {
704        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
705        let null_scalar =
706            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
707
708        assert_eq!(null_scalar.pvalue(), None);
709        assert_eq!(null_scalar.typed_value::<i32>(), None);
710    }
711
712    #[test]
713    fn test_typed_value_correct_type() {
714        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
715        let scalar = PrimitiveScalar::try_new(
716            &dtype,
717            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
718        )
719        .unwrap();
720
721        assert_eq!(scalar.typed_value::<f64>(), Some(3.5));
722    }
723
724    #[test]
725    #[should_panic(expected = "Attempting to read")]
726    fn test_typed_value_wrong_type() {
727        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
728        let scalar = PrimitiveScalar::try_new(
729            &dtype,
730            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
731        )
732        .unwrap();
733
734        let _ = scalar.typed_value::<i32>();
735    }
736
737    #[rstest]
738    #[case(PType::I8, 127i32, PType::I16, true)]
739    #[case(PType::I8, 127i32, PType::I32, true)]
740    #[case(PType::I8, 127i32, PType::I64, true)]
741    #[case(PType::U8, 255i32, PType::U16, true)]
742    #[case(PType::U8, 255i32, PType::U32, true)]
743    #[case(PType::I32, 42i32, PType::F32, true)]
744    #[case(PType::I32, 42i32, PType::F64, true)]
745    // Overflow cases
746    #[case(PType::I32, 300i32, PType::U8, false)]
747    #[case(PType::I32, -1i32, PType::U32, false)]
748    #[case(PType::I32, 256i32, PType::I8, false)]
749    #[case(PType::U16, 65535i32, PType::I8, false)]
750    fn test_primitive_cast(
751        #[case] source_type: PType,
752        #[case] source_value: i32,
753        #[case] target_type: PType,
754        #[case] should_succeed: bool,
755    ) {
756        let source_pvalue = match source_type {
757            PType::I8 => PValue::I8(source_value as i8),
758            PType::U8 => PValue::U8(source_value as u8),
759            PType::U16 => PValue::U16(source_value as u16),
760            PType::I32 => PValue::I32(source_value),
761            _ => unreachable!("Test case uses unexpected source type"),
762        };
763
764        let dtype = DType::Primitive(source_type, Nullability::NonNullable);
765        let scalar = PrimitiveScalar::try_new(
766            &dtype,
767            &ScalarValue(InnerScalarValue::Primitive(source_pvalue)),
768        )
769        .unwrap();
770
771        let target_dtype = DType::Primitive(target_type, Nullability::NonNullable);
772        let result = scalar.cast(&target_dtype);
773
774        if should_succeed {
775            assert!(
776                result.is_ok(),
777                "Cast from {:?} to {:?} should succeed",
778                source_type,
779                target_type
780            );
781        } else {
782            assert!(
783                result.is_err(),
784                "Cast from {:?} to {:?} should fail due to overflow",
785                source_type,
786                target_type
787            );
788        }
789    }
790
791    #[test]
792    fn test_as_conversion_success() {
793        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
794        let scalar = PrimitiveScalar::try_new(
795            &dtype,
796            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
797        )
798        .unwrap();
799
800        assert_eq!(scalar.as_::<i64>(), Some(42i64));
801        assert_eq!(scalar.as_::<f64>(), Some(42.0));
802    }
803
804    #[test]
805    fn test_as_conversion_overflow() {
806        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
807        let scalar = PrimitiveScalar::try_new(
808            &dtype,
809            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))),
810        )
811        .unwrap();
812
813        // Converting -1 to u32 should fail
814        let result = scalar.as_opt::<u32>();
815        assert!(result.is_none());
816    }
817
818    #[test]
819    fn test_as_conversion_null() {
820        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
821        let scalar =
822            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
823
824        assert_eq!(scalar.as_::<i32>(), None);
825        assert_eq!(scalar.as_::<f64>(), None);
826    }
827
828    #[test]
829    fn test_numeric_operator_swap() {
830        use crate::primitive::NumericOperator;
831
832        assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add);
833        assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub);
834        assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub);
835        assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul);
836        assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv);
837        assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div);
838    }
839
840    #[test]
841    fn test_checked_binary_numeric_add() {
842        use crate::primitive::NumericOperator;
843
844        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
845        let scalar1 = PrimitiveScalar::try_new(
846            &dtype,
847            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
848        )
849        .unwrap();
850        let scalar2 = PrimitiveScalar::try_new(
851            &dtype,
852            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
853        )
854        .unwrap();
855
856        let result = scalar1
857            .checked_binary_numeric(&scalar2, NumericOperator::Add)
858            .unwrap();
859        assert_eq!(result.typed_value::<i32>(), Some(30));
860    }
861
862    #[test]
863    fn test_checked_binary_numeric_overflow() {
864        use crate::primitive::NumericOperator;
865
866        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
867        let scalar1 = PrimitiveScalar::try_new(
868            &dtype,
869            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
870        )
871        .unwrap();
872        let scalar2 = PrimitiveScalar::try_new(
873            &dtype,
874            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))),
875        )
876        .unwrap();
877
878        // Add should overflow and return None
879        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add);
880        assert!(result.is_none());
881    }
882
883    #[test]
884    fn test_checked_binary_numeric_with_null() {
885        use crate::primitive::NumericOperator;
886
887        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
888        let scalar1 = PrimitiveScalar::try_new(
889            &dtype,
890            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
891        )
892        .unwrap();
893        let null_scalar =
894            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
895
896        // Operation with null should return null
897        let result = scalar1
898            .checked_binary_numeric(&null_scalar, NumericOperator::Add)
899            .unwrap();
900        assert_eq!(result.pvalue(), None);
901    }
902
903    #[test]
904    fn test_checked_binary_numeric_mul() {
905        use crate::primitive::NumericOperator;
906
907        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
908        let scalar1 = PrimitiveScalar::try_new(
909            &dtype,
910            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
911        )
912        .unwrap();
913        let scalar2 = PrimitiveScalar::try_new(
914            &dtype,
915            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))),
916        )
917        .unwrap();
918
919        let result = scalar1
920            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
921            .unwrap();
922        assert_eq!(result.typed_value::<i32>(), Some(30));
923    }
924
925    #[test]
926    fn test_checked_binary_numeric_div() {
927        use crate::primitive::NumericOperator;
928
929        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
930        let scalar1 = PrimitiveScalar::try_new(
931            &dtype,
932            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
933        )
934        .unwrap();
935        let scalar2 = PrimitiveScalar::try_new(
936            &dtype,
937            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
938        )
939        .unwrap();
940
941        let result = scalar1
942            .checked_binary_numeric(&scalar2, NumericOperator::Div)
943            .unwrap();
944        assert_eq!(result.typed_value::<i32>(), Some(5));
945    }
946
947    #[test]
948    fn test_checked_binary_numeric_rdiv() {
949        use crate::primitive::NumericOperator;
950
951        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
952        let scalar1 = PrimitiveScalar::try_new(
953            &dtype,
954            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
955        )
956        .unwrap();
957        let scalar2 = PrimitiveScalar::try_new(
958            &dtype,
959            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
960        )
961        .unwrap();
962
963        // RDiv means right / left, so 20 / 4 = 5
964        let result = scalar1
965            .checked_binary_numeric(&scalar2, NumericOperator::RDiv)
966            .unwrap();
967        assert_eq!(result.typed_value::<i32>(), Some(5));
968    }
969
970    #[test]
971    fn test_checked_binary_numeric_div_by_zero() {
972        use crate::primitive::NumericOperator;
973
974        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
975        let scalar1 = PrimitiveScalar::try_new(
976            &dtype,
977            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
978        )
979        .unwrap();
980        let scalar2 = PrimitiveScalar::try_new(
981            &dtype,
982            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))),
983        )
984        .unwrap();
985
986        // Division by zero should return None for integers
987        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div);
988        assert!(result.is_none());
989    }
990
991    #[test]
992    fn test_checked_binary_numeric_float_ops() {
993        use crate::primitive::NumericOperator;
994
995        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
996        let scalar1 = PrimitiveScalar::try_new(
997            &dtype,
998            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
999        )
1000        .unwrap();
1001        let scalar2 = PrimitiveScalar::try_new(
1002            &dtype,
1003            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))),
1004        )
1005        .unwrap();
1006
1007        // Test all operations with floats
1008        let add_result = scalar1
1009            .checked_binary_numeric(&scalar2, NumericOperator::Add)
1010            .unwrap();
1011        assert_eq!(add_result.typed_value::<f32>(), Some(12.5));
1012
1013        let sub_result = scalar1
1014            .checked_binary_numeric(&scalar2, NumericOperator::Sub)
1015            .unwrap();
1016        assert_eq!(sub_result.typed_value::<f32>(), Some(7.5));
1017
1018        let mul_result = scalar1
1019            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
1020            .unwrap();
1021        assert_eq!(mul_result.typed_value::<f32>(), Some(25.0));
1022
1023        let div_result = scalar1
1024            .checked_binary_numeric(&scalar2, NumericOperator::Div)
1025            .unwrap();
1026        assert_eq!(div_result.typed_value::<f32>(), Some(4.0));
1027    }
1028
1029    #[test]
1030    fn test_from_primitive_or_f16() {
1031        use vortex_dtype::half::f16;
1032
1033        use crate::primitive::FromPrimitiveOrF16;
1034
1035        // Test f16 to f32 conversion
1036        let f16_val = f16::from_f32(3.5);
1037        assert!(f32::from_f16(f16_val).is_some());
1038
1039        // Test f16 to f64 conversion
1040        assert!(f64::from_f16(f16_val).is_some());
1041
1042        // Test f16 to integer conversion (should fail)
1043        assert!(i32::from_f16(f16_val).is_none());
1044        assert!(u32::from_f16(f16_val).is_none());
1045    }
1046
1047    #[test]
1048    fn test_partial_ord_different_types() {
1049        let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable);
1050        let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable);
1051
1052        let scalar1 = PrimitiveScalar::try_new(
1053            &dtype1,
1054            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
1055        )
1056        .unwrap();
1057        let scalar2 = PrimitiveScalar::try_new(
1058            &dtype2,
1059            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1060        )
1061        .unwrap();
1062
1063        // Different types should not be comparable
1064        assert_eq!(scalar1.partial_cmp(&scalar2), None);
1065    }
1066
1067    #[test]
1068    fn test_scalar_value_from_usize() {
1069        let value: ScalarValue = 42usize.into();
1070        assert!(matches!(
1071            value.0,
1072            InnerScalarValue::Primitive(PValue::U64(42))
1073        ));
1074    }
1075
1076    #[test]
1077    fn test_getters() {
1078        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
1079        let scalar = PrimitiveScalar::try_new(
1080            &dtype,
1081            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
1082        )
1083        .unwrap();
1084
1085        assert_eq!(scalar.dtype(), &dtype);
1086        assert_eq!(scalar.ptype(), PType::I32);
1087        assert_eq!(scalar.pvalue(), Some(PValue::I32(42)));
1088    }
1089}