vortex_scalar/
primitive.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::ops::Add;
10use std::ops::Sub;
11
12use num_traits::CheckedAdd;
13use num_traits::CheckedDiv;
14use num_traits::CheckedMul;
15use num_traits::CheckedSub;
16use vortex_dtype::DType;
17use vortex_dtype::FromPrimitiveOrF16;
18use vortex_dtype::NativePType;
19use vortex_dtype::Nullability;
20use vortex_dtype::PType;
21use vortex_dtype::half::f16;
22use vortex_dtype::match_each_native_ptype;
23use vortex_error::VortexError;
24use vortex_error::VortexExpect;
25use vortex_error::VortexResult;
26use vortex_error::vortex_err;
27use vortex_error::vortex_panic;
28
29use crate::InnerScalarValue;
30use crate::Scalar;
31use crate::ScalarValue;
32use crate::pvalue::CoercePValue;
33use crate::pvalue::PValue;
34
35/// A scalar value representing a primitive type.
36///
37/// This type provides a view into a primitive scalar value of any primitive type
38/// (integers, floats) with various bit widths.
39#[derive(Debug, Clone, Copy, Hash)]
40pub struct PrimitiveScalar<'a> {
41    dtype: &'a DType,
42    ptype: PType,
43    pvalue: Option<PValue>,
44}
45
46impl Display for PrimitiveScalar<'_> {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        match self.pvalue {
49            None => write!(f, "null"),
50            Some(pv) => write!(f, "{pv}"),
51        }
52    }
53}
54
55impl PartialEq for PrimitiveScalar<'_> {
56    fn eq(&self, other: &Self) -> bool {
57        self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
58    }
59}
60
61impl Eq for PrimitiveScalar<'_> {}
62
63/// Ord is not implemented since it's undefined for different PTypes
64impl PartialOrd for PrimitiveScalar<'_> {
65    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
66        if !self.dtype.eq_ignore_nullability(other.dtype) {
67            return None;
68        }
69        self.pvalue.partial_cmp(&other.pvalue)
70    }
71}
72
73impl<'a> PrimitiveScalar<'a> {
74    /// Creates a new primitive scalar from a data type and scalar value.
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the data type is not a primitive type or if the value
79    /// cannot be converted to the expected primitive type.
80    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
81        let ptype = PType::try_from(dtype)?;
82
83        // Read the serialized value into the correct PValue.
84        // The serialized form may come back over the wire as e.g. any integer type.
85        let pvalue = match_each_native_ptype!(ptype, |T| {
86            value
87                .as_pvalue()?
88                .map(|pv| VortexResult::Ok(PValue::from(<T>::coerce(pv)?)))
89                .transpose()?
90        });
91
92        Ok(Self {
93            dtype,
94            ptype,
95            pvalue,
96        })
97    }
98
99    /// Returns the data type of this primitive scalar.
100    #[inline]
101    pub fn dtype(&self) -> &'a DType {
102        self.dtype
103    }
104
105    /// Returns the primitive type of this scalar.
106    #[inline]
107    pub fn ptype(&self) -> PType {
108        self.ptype
109    }
110
111    /// Returns the primitive value, or None if null.
112    #[inline]
113    pub fn pvalue(&self) -> Option<PValue> {
114        self.pvalue
115    }
116
117    /// Returns the value as a specific native primitive type.
118    ///
119    /// Returns `None` if the scalar is null, otherwise returns `Some(value)` where
120    /// value is the underlying primitive value cast to the requested type `T`.
121    ///
122    /// # Panics
123    ///
124    /// Panics if the primitive type of this scalar does not match the requested type.
125    pub fn typed_value<T: NativePType>(&self) -> Option<T> {
126        assert_eq!(
127            self.ptype,
128            T::PTYPE,
129            "Attempting to read {} scalar as {}",
130            self.ptype,
131            T::PTYPE
132        );
133
134        self.pvalue.map(|pv| pv.cast::<T>())
135    }
136
137    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
138        let ptype = PType::try_from(dtype)?;
139        let pvalue = self
140            .pvalue
141            .vortex_expect("nullness handled in Scalar::cast");
142        Ok(match_each_native_ptype!(ptype, |Q| {
143            Scalar::primitive(
144                pvalue
145                    .cast_opt::<Q>()
146                    .ok_or_else(|| vortex_err!("Cannot cast {} to {}", self.ptype, dtype))?,
147                dtype.nullability(),
148            )
149        }))
150    }
151
152    /// Returns true if the scalar is nan.
153    pub fn is_nan(&self) -> bool {
154        self.pvalue.as_ref().is_some_and(|p| p.is_nan())
155    }
156
157    /// Attempts to extract the primitive value as the given type.
158    ///
159    /// # Errors
160    ///
161    /// Panics if the cast fails due to overflow or type incompatibility. See also
162    /// `as_opt` for the checked version that does not panic.
163    ///
164    /// # Examples
165    ///
166    /// ```should_panic
167    /// # use vortex_dtype::{DType, PType};
168    /// # use vortex_scalar::Scalar;
169    /// let wide = Scalar::primitive(1000i32, false.into());
170    ///
171    /// // This succeeds
172    /// let narrow = wide.as_primitive().as_::<i16>();
173    /// assert_eq!(narrow, Some(1000i16));
174    ///
175    /// // This also succeeds
176    /// let null = Scalar::null(DType::Primitive(PType::I16, true.into()));
177    /// assert_eq!(null.as_primitive().as_::<i8>(), None);
178    ///
179    /// // This will panic, because 1000 does not fit in i8
180    /// wide.as_primitive().as_::<i8>();
181    /// ```
182    pub fn as_<T: FromPrimitiveOrF16>(&self) -> Option<T> {
183        self.as_opt::<T>().unwrap_or_else(|| {
184            vortex_panic!(
185                "cast {} to {}: value out of range",
186                self.ptype,
187                type_name::<T>()
188            )
189        })
190    }
191
192    /// Returns the inner value cast to the desired type.
193    ///
194    /// If the cast fails, `None` is returned. If the scalar represents a null, `Some(None)`
195    /// is returned. Otherwise, `Some(Some(T))` is returned for a successful non-null conversion.
196    ///
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// # use vortex_dtype::{DType, PType};
202    /// # use vortex_scalar::Scalar;
203    ///
204    /// // Non-null values
205    /// let scalar = Scalar::primitive(100i32, false.into());
206    /// let primitive = scalar.as_primitive();
207    /// assert_eq!(primitive.as_opt::<i8>(), Some(Some(100i8)));
208    ///
209    /// // Null value
210    /// let scalar = Scalar::null(DType::Primitive(PType::I32, true.into()));
211    /// let primitive = scalar.as_primitive();
212    /// assert_eq!(primitive.as_opt::<i8>(), Some(None));
213    ///
214    /// // Failed conversion: 1000 cannot fit in an i8
215    /// let scalar = Scalar::primitive(1000i32, false.into());
216    /// let primitive = scalar.as_primitive();
217    /// assert_eq!(primitive.as_opt::<i8>(), None);
218    /// ```
219    pub fn as_opt<T: FromPrimitiveOrF16>(&self) -> Option<Option<T>> {
220        if let Some(pv) = self.pvalue {
221            match pv {
222                PValue::U8(v) => T::from_u8(v),
223                PValue::U16(v) => T::from_u16(v),
224                PValue::U32(v) => T::from_u32(v),
225                PValue::U64(v) => T::from_u64(v),
226                PValue::I8(v) => T::from_i8(v),
227                PValue::I16(v) => T::from_i16(v),
228                PValue::I32(v) => T::from_i32(v),
229                PValue::I64(v) => T::from_i64(v),
230                PValue::F16(v) => T::from_f16(v),
231                PValue::F32(v) => T::from_f32(v),
232                PValue::F64(v) => T::from_f64(v),
233            }
234            .map(Some)
235        } else {
236            Some(None)
237        }
238    }
239}
240
241impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
242    type Error = VortexError;
243
244    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
245        Self::try_new(value.dtype(), value.value())
246    }
247}
248
249impl Sub for PrimitiveScalar<'_> {
250    type Output = Self;
251
252    fn sub(self, rhs: Self) -> Self::Output {
253        self.checked_sub(&rhs)
254            .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
255    }
256}
257
258impl CheckedSub for PrimitiveScalar<'_> {
259    fn checked_sub(&self, rhs: &Self) -> Option<Self> {
260        self.checked_binary_numeric(rhs, NumericOperator::Sub)
261    }
262}
263
264impl Add for PrimitiveScalar<'_> {
265    type Output = Self;
266
267    fn add(self, rhs: Self) -> Self::Output {
268        self.checked_add(&rhs)
269            .vortex_expect("PrimitiveScalar add: overflow or underflow")
270    }
271}
272
273impl CheckedAdd for PrimitiveScalar<'_> {
274    fn checked_add(&self, rhs: &Self) -> Option<Self> {
275        self.checked_binary_numeric(rhs, NumericOperator::Add)
276    }
277}
278
279impl Scalar {
280    /// Creates a new primitive scalar from a native value.
281    pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
282        Self::primitive_value(value.into(), T::PTYPE, nullability)
283    }
284
285    /// Create a PrimitiveScalar from a PValue.
286    ///
287    /// Note that an explicit PType is passed since any compatible PValue may be used as the value
288    /// for a primitive type.
289    pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
290        Self::new(
291            DType::Primitive(ptype, nullability),
292            ScalarValue(InnerScalarValue::Primitive(value)),
293        )
294    }
295
296    /// Reinterprets the bytes of this scalar as a different primitive type.
297    ///
298    /// # Panics
299    ///
300    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
301    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
302        let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
303            vortex_panic!(
304                e,
305                "Failed to reinterpret cast {} to {}",
306                self.dtype(),
307                ptype
308            )
309        });
310        if primitive.ptype() == ptype {
311            return self.clone();
312        }
313
314        assert_eq!(
315            primitive.ptype().byte_width(),
316            ptype.byte_width(),
317            "can't reinterpret cast between integers of two different widths"
318        );
319
320        Scalar::new(
321            DType::Primitive(ptype, self.dtype().nullability()),
322            primitive
323                .pvalue
324                .map(|p| p.reinterpret_cast(ptype))
325                .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
326                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
327        )
328    }
329}
330
331macro_rules! primitive_scalar {
332    ($T:ty) => {
333        impl TryFrom<&Scalar> for $T {
334            type Error = VortexError;
335
336            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
337                <Option<$T>>::try_from(value)?
338                    .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
339            }
340        }
341
342        impl TryFrom<Scalar> for $T {
343            type Error = VortexError;
344
345            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
346                <$T>::try_from(&value)
347            }
348        }
349
350        impl TryFrom<&Scalar> for Option<$T> {
351            type Error = VortexError;
352
353            fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
354                Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
355            }
356        }
357
358        impl TryFrom<Scalar> for Option<$T> {
359            type Error = VortexError;
360
361            fn try_from(value: Scalar) -> Result<Self, Self::Error> {
362                <Option<$T>>::try_from(&value)
363            }
364        }
365
366        impl From<$T> for Scalar {
367            fn from(value: $T) -> Self {
368                Scalar::new(
369                    DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
370                    ScalarValue(InnerScalarValue::Primitive(value.into())),
371                )
372            }
373        }
374
375        impl From<$T> for ScalarValue {
376            fn from(value: $T) -> Self {
377                ScalarValue(InnerScalarValue::Primitive(value.into()))
378            }
379        }
380    };
381}
382
383primitive_scalar!(u8);
384primitive_scalar!(u16);
385primitive_scalar!(u32);
386primitive_scalar!(u64);
387primitive_scalar!(i8);
388primitive_scalar!(i16);
389primitive_scalar!(i32);
390primitive_scalar!(i64);
391primitive_scalar!(f16);
392primitive_scalar!(f32);
393primitive_scalar!(f64);
394
395/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
396impl TryFrom<&Scalar> for usize {
397    type Error = VortexError;
398
399    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
400        let prim = PrimitiveScalar::try_from(value)?
401            .as_::<u64>()
402            .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
403        Ok(usize::try_from(prim)?)
404    }
405}
406
407impl TryFrom<&Scalar> for Option<usize> {
408    type Error = VortexError;
409
410    fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
411        Ok(PrimitiveScalar::try_from(value)?
412            .as_::<u64>()
413            .map(usize::try_from)
414            .transpose()?)
415    }
416}
417
418/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
419impl From<usize> for Scalar {
420    fn from(value: usize) -> Self {
421        Scalar::primitive(value as u64, Nullability::NonNullable)
422    }
423}
424
425impl From<PValue> for ScalarValue {
426    fn from(value: PValue) -> Self {
427        ScalarValue(InnerScalarValue::Primitive(value))
428    }
429}
430
431/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics.
432impl From<usize> for ScalarValue {
433    fn from(value: usize) -> Self {
434        ScalarValue(InnerScalarValue::Primitive((value as u64).into()))
435    }
436}
437
438#[derive(Debug, Clone, Copy, PartialEq, Eq)]
439/// Binary element-wise operations on two arrays or two scalars.
440pub enum NumericOperator {
441    /// Binary element-wise addition of two arrays or of two scalars.
442    ///
443    /// Errs at runtime if the sum would overflow or underflow.
444    Add,
445    /// Binary element-wise subtraction of two arrays or of two scalars.
446    Sub,
447    /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`.
448    RSub,
449    /// Binary element-wise multiplication of two arrays or of two scalars.
450    Mul,
451    /// Binary element-wise division of two arrays or of two scalars.
452    Div,
453    /// Same as [NumericOperator::Div] but with the parameters flipped: `right / left`.
454    RDiv,
455    // Missing from arrow-rs:
456    // Min,
457    // Max,
458    // Pow,
459}
460
461impl Display for NumericOperator {
462    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
463        Debug::fmt(self, f)
464    }
465}
466
467impl NumericOperator {
468    /// Returns the operator with swapped operands (e.g., Sub becomes RSub).
469    pub fn swap(self) -> Self {
470        match self {
471            NumericOperator::Add => NumericOperator::Add,
472            NumericOperator::Sub => NumericOperator::RSub,
473            NumericOperator::RSub => NumericOperator::Sub,
474            NumericOperator::Mul => NumericOperator::Mul,
475            NumericOperator::Div => NumericOperator::RDiv,
476            NumericOperator::RDiv => NumericOperator::Div,
477        }
478    }
479}
480
481impl<'a> PrimitiveScalar<'a> {
482    /// Apply the (checked) operator to self and other using SQL-style null semantics.
483    ///
484    /// If the operation overflows, Ok(None) is returned.
485    ///
486    /// If the types are incompatible (ignoring nullability), an error is returned.
487    ///
488    /// If either value is null, the result is null.
489    pub fn checked_binary_numeric(
490        &self,
491        other: &PrimitiveScalar<'a>,
492        op: NumericOperator,
493    ) -> Option<PrimitiveScalar<'a>> {
494        if !self.dtype().eq_ignore_nullability(other.dtype()) {
495            vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
496        }
497        let result_dtype = if self.dtype().is_nullable() {
498            self.dtype()
499        } else {
500            other.dtype()
501        };
502        let ptype = self.ptype();
503
504        match_each_native_ptype!(
505            self.ptype(),
506            integral: |P| {
507                self.checked_integral_numeric_operator::<P>(other, result_dtype, ptype, op)
508            },
509            floating: |P| {
510                let lhs = self.typed_value::<P>();
511                let rhs = other.typed_value::<P>();
512                let value_or_null = match (lhs, rhs) {
513                    (_, None) | (None, _) => None,
514                    (Some(lhs), Some(rhs)) => match op {
515                        NumericOperator::Add => Some(lhs + rhs),
516                        NumericOperator::Sub => Some(lhs - rhs),
517                        NumericOperator::RSub => Some(rhs - lhs),
518                        NumericOperator::Mul => Some(lhs * rhs),
519                        NumericOperator::Div => Some(lhs / rhs),
520                        NumericOperator::RDiv => Some(rhs / lhs),
521                    }
522                };
523                Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
524            }
525        )
526    }
527
528    fn checked_integral_numeric_operator<
529        P: NativePType
530            + TryFrom<PValue, Error = VortexError>
531            + CheckedSub
532            + CheckedAdd
533            + CheckedMul
534            + CheckedDiv,
535    >(
536        &self,
537        other: &PrimitiveScalar<'a>,
538        result_dtype: &'a DType,
539        ptype: PType,
540        op: NumericOperator,
541    ) -> Option<PrimitiveScalar<'a>>
542    where
543        PValue: From<P>,
544    {
545        let lhs = self.typed_value::<P>();
546        let rhs = other.typed_value::<P>();
547        let value_or_null_or_overflow = match (lhs, rhs) {
548            (_, None) | (None, _) => Some(None),
549            (Some(lhs), Some(rhs)) => match op {
550                NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
551                NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
552                NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
553                NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
554                NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
555                NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
556            },
557        };
558
559        value_or_null_or_overflow.map(|value_or_null| Self {
560            dtype: result_dtype,
561            ptype,
562            pvalue: value_or_null.map(PValue::from),
563        })
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use num_traits::CheckedSub;
570    use rstest::rstest;
571    use vortex_dtype::DType;
572    use vortex_dtype::Nullability;
573    use vortex_dtype::PType;
574    use vortex_error::VortexExpect;
575
576    use crate::InnerScalarValue;
577    use crate::PValue;
578    use crate::PrimitiveScalar;
579    use crate::ScalarValue;
580
581    #[test]
582    fn test_integer_subtract() {
583        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
584        let p_scalar1 = PrimitiveScalar::try_new(
585            &dtype,
586            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
587        )
588        .unwrap();
589        let p_scalar2 = PrimitiveScalar::try_new(
590            &dtype,
591            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
592        )
593        .unwrap();
594        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
595        let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
596        assert_eq!(value_or_null_or_type_error.unwrap(), 1);
597
598        assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap(), 1);
599    }
600
601    #[test]
602    #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
603    fn test_integer_subtract_overflow() {
604        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
605        let p_scalar1 = PrimitiveScalar::try_new(
606            &dtype,
607            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
608        )
609        .unwrap();
610        let p_scalar2 = PrimitiveScalar::try_new(
611            &dtype,
612            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
613        )
614        .unwrap();
615        let _ = p_scalar1 - p_scalar2;
616    }
617
618    #[test]
619    fn test_float_subtract() {
620        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
621        let p_scalar1 = PrimitiveScalar::try_new(
622            &dtype,
623            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
624        )
625        .unwrap();
626        let p_scalar2 = PrimitiveScalar::try_new(
627            &dtype,
628            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
629        )
630        .unwrap();
631        let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
632        let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
633        assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32);
634
635        assert_eq!((p_scalar1 - p_scalar2).as_::<f32>().unwrap(), 0.99f32);
636    }
637
638    #[test]
639    fn test_primitive_scalar_equality() {
640        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
641        let scalar1 = PrimitiveScalar::try_new(
642            &dtype,
643            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
644        )
645        .unwrap();
646        let scalar2 = PrimitiveScalar::try_new(
647            &dtype,
648            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
649        )
650        .unwrap();
651        let scalar3 = PrimitiveScalar::try_new(
652            &dtype,
653            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))),
654        )
655        .unwrap();
656
657        assert_eq!(scalar1, scalar2);
658        assert_ne!(scalar1, scalar3);
659    }
660
661    #[test]
662    fn test_primitive_scalar_partial_ord() {
663        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
664        let scalar1 = PrimitiveScalar::try_new(
665            &dtype,
666            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
667        )
668        .unwrap();
669        let scalar2 = PrimitiveScalar::try_new(
670            &dtype,
671            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
672        )
673        .unwrap();
674
675        assert!(scalar1 < scalar2);
676        assert!(scalar2 > scalar1);
677        assert_eq!(
678            scalar1.partial_cmp(&scalar1),
679            Some(std::cmp::Ordering::Equal)
680        );
681    }
682
683    #[test]
684    fn test_primitive_scalar_null_handling() {
685        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
686        let null_scalar =
687            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
688
689        assert_eq!(null_scalar.pvalue(), None);
690        assert_eq!(null_scalar.typed_value::<i32>(), None);
691    }
692
693    #[test]
694    fn test_typed_value_correct_type() {
695        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
696        let scalar = PrimitiveScalar::try_new(
697            &dtype,
698            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
699        )
700        .unwrap();
701
702        assert_eq!(scalar.typed_value::<f64>(), Some(3.5));
703    }
704
705    #[test]
706    #[should_panic(expected = "Attempting to read")]
707    fn test_typed_value_wrong_type() {
708        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
709        let scalar = PrimitiveScalar::try_new(
710            &dtype,
711            &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
712        )
713        .unwrap();
714
715        let _ = scalar.typed_value::<i32>();
716    }
717
718    #[rstest]
719    #[case(PType::I8, 127i32, PType::I16, true)]
720    #[case(PType::I8, 127i32, PType::I32, true)]
721    #[case(PType::I8, 127i32, PType::I64, true)]
722    #[case(PType::U8, 255i32, PType::U16, true)]
723    #[case(PType::U8, 255i32, PType::U32, true)]
724    #[case(PType::I32, 42i32, PType::F32, true)]
725    #[case(PType::I32, 42i32, PType::F64, true)]
726    // Overflow cases
727    #[case(PType::I32, 300i32, PType::U8, false)]
728    #[case(PType::I32, -1i32, PType::U32, false)]
729    #[case(PType::I32, 256i32, PType::I8, false)]
730    #[case(PType::U16, 65535i32, PType::I8, false)]
731    fn test_primitive_cast(
732        #[case] source_type: PType,
733        #[case] source_value: i32,
734        #[case] target_type: PType,
735        #[case] should_succeed: bool,
736    ) {
737        let source_pvalue = match source_type {
738            PType::I8 => PValue::I8(i8::try_from(source_value).vortex_expect("cannot cast")),
739            PType::U8 => PValue::U8(u8::try_from(source_value).vortex_expect("cannot cast")),
740            PType::U16 => PValue::U16(u16::try_from(source_value).vortex_expect("cannot cast")),
741            PType::I32 => PValue::I32(source_value),
742            _ => unreachable!("Test case uses unexpected source type"),
743        };
744
745        let dtype = DType::Primitive(source_type, Nullability::NonNullable);
746        let scalar = PrimitiveScalar::try_new(
747            &dtype,
748            &ScalarValue(InnerScalarValue::Primitive(source_pvalue)),
749        )
750        .unwrap();
751
752        let target_dtype = DType::Primitive(target_type, Nullability::NonNullable);
753        let result = scalar.cast(&target_dtype);
754
755        if should_succeed {
756            assert!(
757                result.is_ok(),
758                "Cast from {:?} to {:?} should succeed",
759                source_type,
760                target_type
761            );
762        } else {
763            assert!(
764                result.is_err(),
765                "Cast from {:?} to {:?} should fail due to overflow",
766                source_type,
767                target_type
768            );
769        }
770    }
771
772    #[test]
773    fn test_as_conversion_success() {
774        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
775        let scalar = PrimitiveScalar::try_new(
776            &dtype,
777            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
778        )
779        .unwrap();
780
781        assert_eq!(scalar.as_::<i64>(), Some(42i64));
782        assert_eq!(scalar.as_::<f64>(), Some(42.0));
783    }
784
785    #[test]
786    fn test_as_conversion_overflow() {
787        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
788        let scalar = PrimitiveScalar::try_new(
789            &dtype,
790            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))),
791        )
792        .unwrap();
793
794        // Converting -1 to u32 should fail
795        let result = scalar.as_opt::<u32>();
796        assert!(result.is_none());
797    }
798
799    #[test]
800    fn test_as_conversion_null() {
801        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
802        let scalar =
803            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
804
805        assert_eq!(scalar.as_::<i32>(), None);
806        assert_eq!(scalar.as_::<f64>(), None);
807    }
808
809    #[test]
810    fn test_numeric_operator_swap() {
811        use crate::primitive::NumericOperator;
812
813        assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add);
814        assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub);
815        assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub);
816        assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul);
817        assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv);
818        assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div);
819    }
820
821    #[test]
822    fn test_checked_binary_numeric_add() {
823        use crate::primitive::NumericOperator;
824
825        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
826        let scalar1 = PrimitiveScalar::try_new(
827            &dtype,
828            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
829        )
830        .unwrap();
831        let scalar2 = PrimitiveScalar::try_new(
832            &dtype,
833            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
834        )
835        .unwrap();
836
837        let result = scalar1
838            .checked_binary_numeric(&scalar2, NumericOperator::Add)
839            .unwrap();
840        assert_eq!(result.typed_value::<i32>(), Some(30));
841    }
842
843    #[test]
844    fn test_checked_binary_numeric_overflow() {
845        use crate::primitive::NumericOperator;
846
847        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
848        let scalar1 = PrimitiveScalar::try_new(
849            &dtype,
850            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
851        )
852        .unwrap();
853        let scalar2 = PrimitiveScalar::try_new(
854            &dtype,
855            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))),
856        )
857        .unwrap();
858
859        // Add should overflow and return None
860        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add);
861        assert!(result.is_none());
862    }
863
864    #[test]
865    fn test_checked_binary_numeric_with_null() {
866        use crate::primitive::NumericOperator;
867
868        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
869        let scalar1 = PrimitiveScalar::try_new(
870            &dtype,
871            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
872        )
873        .unwrap();
874        let null_scalar =
875            PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
876
877        // Operation with null should return null
878        let result = scalar1
879            .checked_binary_numeric(&null_scalar, NumericOperator::Add)
880            .unwrap();
881        assert_eq!(result.pvalue(), None);
882    }
883
884    #[test]
885    fn test_checked_binary_numeric_mul() {
886        use crate::primitive::NumericOperator;
887
888        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
889        let scalar1 = PrimitiveScalar::try_new(
890            &dtype,
891            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
892        )
893        .unwrap();
894        let scalar2 = PrimitiveScalar::try_new(
895            &dtype,
896            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))),
897        )
898        .unwrap();
899
900        let result = scalar1
901            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
902            .unwrap();
903        assert_eq!(result.typed_value::<i32>(), Some(30));
904    }
905
906    #[test]
907    fn test_checked_binary_numeric_div() {
908        use crate::primitive::NumericOperator;
909
910        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
911        let scalar1 = PrimitiveScalar::try_new(
912            &dtype,
913            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
914        )
915        .unwrap();
916        let scalar2 = PrimitiveScalar::try_new(
917            &dtype,
918            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
919        )
920        .unwrap();
921
922        let result = scalar1
923            .checked_binary_numeric(&scalar2, NumericOperator::Div)
924            .unwrap();
925        assert_eq!(result.typed_value::<i32>(), Some(5));
926    }
927
928    #[test]
929    fn test_checked_binary_numeric_rdiv() {
930        use crate::primitive::NumericOperator;
931
932        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
933        let scalar1 = PrimitiveScalar::try_new(
934            &dtype,
935            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
936        )
937        .unwrap();
938        let scalar2 = PrimitiveScalar::try_new(
939            &dtype,
940            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
941        )
942        .unwrap();
943
944        // RDiv means right / left, so 20 / 4 = 5
945        let result = scalar1
946            .checked_binary_numeric(&scalar2, NumericOperator::RDiv)
947            .unwrap();
948        assert_eq!(result.typed_value::<i32>(), Some(5));
949    }
950
951    #[test]
952    fn test_checked_binary_numeric_div_by_zero() {
953        use crate::primitive::NumericOperator;
954
955        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
956        let scalar1 = PrimitiveScalar::try_new(
957            &dtype,
958            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
959        )
960        .unwrap();
961        let scalar2 = PrimitiveScalar::try_new(
962            &dtype,
963            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))),
964        )
965        .unwrap();
966
967        // Division by zero should return None for integers
968        let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div);
969        assert!(result.is_none());
970    }
971
972    #[test]
973    fn test_checked_binary_numeric_float_ops() {
974        use crate::primitive::NumericOperator;
975
976        let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
977        let scalar1 = PrimitiveScalar::try_new(
978            &dtype,
979            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
980        )
981        .unwrap();
982        let scalar2 = PrimitiveScalar::try_new(
983            &dtype,
984            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))),
985        )
986        .unwrap();
987
988        // Test all operations with floats
989        let add_result = scalar1
990            .checked_binary_numeric(&scalar2, NumericOperator::Add)
991            .unwrap();
992        assert_eq!(add_result.typed_value::<f32>(), Some(12.5));
993
994        let sub_result = scalar1
995            .checked_binary_numeric(&scalar2, NumericOperator::Sub)
996            .unwrap();
997        assert_eq!(sub_result.typed_value::<f32>(), Some(7.5));
998
999        let mul_result = scalar1
1000            .checked_binary_numeric(&scalar2, NumericOperator::Mul)
1001            .unwrap();
1002        assert_eq!(mul_result.typed_value::<f32>(), Some(25.0));
1003
1004        let div_result = scalar1
1005            .checked_binary_numeric(&scalar2, NumericOperator::Div)
1006            .unwrap();
1007        assert_eq!(div_result.typed_value::<f32>(), Some(4.0));
1008    }
1009
1010    #[test]
1011    fn test_from_primitive_or_f16() {
1012        use vortex_dtype::half::f16;
1013
1014        use crate::primitive::FromPrimitiveOrF16;
1015
1016        // Test f16 to f32 conversion
1017        let f16_val = f16::from_f32(3.5);
1018        assert!(f32::from_f16(f16_val).is_some());
1019
1020        // Test f16 to f64 conversion
1021        assert!(f64::from_f16(f16_val).is_some());
1022
1023        // Test PValue::F16(f16) to integer conversion (should fail)
1024        assert!(i32::try_from(PValue::from(f16_val)).is_err());
1025        assert!(u32::try_from(PValue::from(f16_val)).is_err());
1026    }
1027
1028    #[test]
1029    fn test_partial_ord_different_types() {
1030        let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable);
1031        let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable);
1032
1033        let scalar1 = PrimitiveScalar::try_new(
1034            &dtype1,
1035            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
1036        )
1037        .unwrap();
1038        let scalar2 = PrimitiveScalar::try_new(
1039            &dtype2,
1040            &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1041        )
1042        .unwrap();
1043
1044        // Different types should not be comparable
1045        assert_eq!(scalar1.partial_cmp(&scalar2), None);
1046    }
1047
1048    #[test]
1049    fn test_scalar_value_from_usize() {
1050        let value: ScalarValue = 42usize.into();
1051        assert!(matches!(
1052            value.0,
1053            InnerScalarValue::Primitive(PValue::U64(42))
1054        ));
1055    }
1056
1057    #[test]
1058    fn test_getters() {
1059        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
1060        let scalar = PrimitiveScalar::try_new(
1061            &dtype,
1062            &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
1063        )
1064        .unwrap();
1065
1066        assert_eq!(scalar.dtype(), &dtype);
1067        assert_eq!(scalar.ptype(), PType::I32);
1068        assert_eq!(scalar.pvalue(), Some(PValue::I32(42)));
1069    }
1070}