1use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::ops::Add;
10use std::ops::Sub;
11
12use num_traits::CheckedAdd;
13use num_traits::CheckedDiv;
14use num_traits::CheckedMul;
15use num_traits::CheckedSub;
16use vortex_dtype::DType;
17use vortex_dtype::FromPrimitiveOrF16;
18use vortex_dtype::NativePType;
19use vortex_dtype::Nullability;
20use vortex_dtype::PType;
21use vortex_dtype::half::f16;
22use vortex_dtype::match_each_native_ptype;
23use vortex_error::VortexError;
24use vortex_error::VortexExpect;
25use vortex_error::VortexResult;
26use vortex_error::vortex_err;
27use vortex_error::vortex_panic;
28
29use crate::InnerScalarValue;
30use crate::Scalar;
31use crate::ScalarValue;
32use crate::pvalue::CoercePValue;
33use crate::pvalue::PValue;
34
35#[derive(Debug, Clone, Copy, Hash)]
40pub struct PrimitiveScalar<'a> {
41 dtype: &'a DType,
42 ptype: PType,
43 pvalue: Option<PValue>,
44}
45
46impl Display for PrimitiveScalar<'_> {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 match self.pvalue {
49 None => write!(f, "null"),
50 Some(pv) => write!(f, "{pv}"),
51 }
52 }
53}
54
55impl PartialEq for PrimitiveScalar<'_> {
56 fn eq(&self, other: &Self) -> bool {
57 self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
58 }
59}
60
61impl Eq for PrimitiveScalar<'_> {}
62
63impl PartialOrd for PrimitiveScalar<'_> {
65 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
66 if !self.dtype.eq_ignore_nullability(other.dtype) {
67 return None;
68 }
69 self.pvalue.partial_cmp(&other.pvalue)
70 }
71}
72
73impl<'a> PrimitiveScalar<'a> {
74 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
81 let ptype = PType::try_from(dtype)?;
82
83 let pvalue = match_each_native_ptype!(ptype, |T| {
86 value
87 .as_pvalue()?
88 .map(|pv| VortexResult::Ok(PValue::from(<T>::coerce(pv)?)))
89 .transpose()?
90 });
91
92 Ok(Self {
93 dtype,
94 ptype,
95 pvalue,
96 })
97 }
98
99 #[inline]
101 pub fn dtype(&self) -> &'a DType {
102 self.dtype
103 }
104
105 #[inline]
107 pub fn ptype(&self) -> PType {
108 self.ptype
109 }
110
111 #[inline]
113 pub fn pvalue(&self) -> Option<PValue> {
114 self.pvalue
115 }
116
117 pub fn typed_value<T: NativePType>(&self) -> Option<T> {
126 assert_eq!(
127 self.ptype,
128 T::PTYPE,
129 "Attempting to read {} scalar as {}",
130 self.ptype,
131 T::PTYPE
132 );
133
134 self.pvalue.map(|pv| pv.cast::<T>())
135 }
136
137 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
138 let ptype = PType::try_from(dtype)?;
139 let pvalue = self
140 .pvalue
141 .vortex_expect("nullness handled in Scalar::cast");
142 Ok(match_each_native_ptype!(ptype, |Q| {
143 Scalar::primitive(
144 pvalue
145 .cast_opt::<Q>()
146 .ok_or_else(|| vortex_err!("Cannot cast {} to {}", self.ptype, dtype))?,
147 dtype.nullability(),
148 )
149 }))
150 }
151
152 pub fn is_nan(&self) -> bool {
154 self.pvalue.as_ref().is_some_and(|p| p.is_nan())
155 }
156
157 pub fn as_<T: FromPrimitiveOrF16>(&self) -> Option<T> {
183 self.as_opt::<T>().unwrap_or_else(|| {
184 vortex_panic!(
185 "cast {} to {}: value out of range",
186 self.ptype,
187 type_name::<T>()
188 )
189 })
190 }
191
192 pub fn as_opt<T: FromPrimitiveOrF16>(&self) -> Option<Option<T>> {
220 if let Some(pv) = self.pvalue {
221 match pv {
222 PValue::U8(v) => T::from_u8(v),
223 PValue::U16(v) => T::from_u16(v),
224 PValue::U32(v) => T::from_u32(v),
225 PValue::U64(v) => T::from_u64(v),
226 PValue::I8(v) => T::from_i8(v),
227 PValue::I16(v) => T::from_i16(v),
228 PValue::I32(v) => T::from_i32(v),
229 PValue::I64(v) => T::from_i64(v),
230 PValue::F16(v) => T::from_f16(v),
231 PValue::F32(v) => T::from_f32(v),
232 PValue::F64(v) => T::from_f64(v),
233 }
234 .map(Some)
235 } else {
236 Some(None)
237 }
238 }
239}
240
241impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
242 type Error = VortexError;
243
244 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
245 Self::try_new(value.dtype(), value.value())
246 }
247}
248
249impl Sub for PrimitiveScalar<'_> {
250 type Output = Self;
251
252 fn sub(self, rhs: Self) -> Self::Output {
253 self.checked_sub(&rhs)
254 .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
255 }
256}
257
258impl CheckedSub for PrimitiveScalar<'_> {
259 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
260 self.checked_binary_numeric(rhs, NumericOperator::Sub)
261 }
262}
263
264impl Add for PrimitiveScalar<'_> {
265 type Output = Self;
266
267 fn add(self, rhs: Self) -> Self::Output {
268 self.checked_add(&rhs)
269 .vortex_expect("PrimitiveScalar add: overflow or underflow")
270 }
271}
272
273impl CheckedAdd for PrimitiveScalar<'_> {
274 fn checked_add(&self, rhs: &Self) -> Option<Self> {
275 self.checked_binary_numeric(rhs, NumericOperator::Add)
276 }
277}
278
279impl Scalar {
280 pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
282 Self::primitive_value(value.into(), T::PTYPE, nullability)
283 }
284
285 pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
290 Self::new(
291 DType::Primitive(ptype, nullability),
292 ScalarValue(InnerScalarValue::Primitive(value)),
293 )
294 }
295
296 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
302 let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
303 vortex_panic!(
304 e,
305 "Failed to reinterpret cast {} to {}",
306 self.dtype(),
307 ptype
308 )
309 });
310 if primitive.ptype() == ptype {
311 return self.clone();
312 }
313
314 assert_eq!(
315 primitive.ptype().byte_width(),
316 ptype.byte_width(),
317 "can't reinterpret cast between integers of two different widths"
318 );
319
320 Scalar::new(
321 DType::Primitive(ptype, self.dtype().nullability()),
322 primitive
323 .pvalue
324 .map(|p| p.reinterpret_cast(ptype))
325 .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
326 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
327 )
328 }
329}
330
331macro_rules! primitive_scalar {
332 ($T:ty) => {
333 impl TryFrom<&Scalar> for $T {
334 type Error = VortexError;
335
336 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
337 <Option<$T>>::try_from(value)?
338 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
339 }
340 }
341
342 impl TryFrom<Scalar> for $T {
343 type Error = VortexError;
344
345 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
346 <$T>::try_from(&value)
347 }
348 }
349
350 impl TryFrom<&Scalar> for Option<$T> {
351 type Error = VortexError;
352
353 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
354 Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
355 }
356 }
357
358 impl TryFrom<Scalar> for Option<$T> {
359 type Error = VortexError;
360
361 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
362 <Option<$T>>::try_from(&value)
363 }
364 }
365
366 impl From<$T> for Scalar {
367 fn from(value: $T) -> Self {
368 Scalar::new(
369 DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
370 ScalarValue(InnerScalarValue::Primitive(value.into())),
371 )
372 }
373 }
374
375 impl From<$T> for ScalarValue {
376 fn from(value: $T) -> Self {
377 ScalarValue(InnerScalarValue::Primitive(value.into()))
378 }
379 }
380 };
381}
382
383primitive_scalar!(u8);
384primitive_scalar!(u16);
385primitive_scalar!(u32);
386primitive_scalar!(u64);
387primitive_scalar!(i8);
388primitive_scalar!(i16);
389primitive_scalar!(i32);
390primitive_scalar!(i64);
391primitive_scalar!(f16);
392primitive_scalar!(f32);
393primitive_scalar!(f64);
394
395impl TryFrom<&Scalar> for usize {
397 type Error = VortexError;
398
399 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
400 let prim = PrimitiveScalar::try_from(value)?
401 .as_::<u64>()
402 .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
403 Ok(usize::try_from(prim)?)
404 }
405}
406
407impl TryFrom<&Scalar> for Option<usize> {
408 type Error = VortexError;
409
410 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
411 Ok(PrimitiveScalar::try_from(value)?
412 .as_::<u64>()
413 .map(usize::try_from)
414 .transpose()?)
415 }
416}
417
418impl From<usize> for Scalar {
420 fn from(value: usize) -> Self {
421 Scalar::primitive(value as u64, Nullability::NonNullable)
422 }
423}
424
425impl From<PValue> for ScalarValue {
426 fn from(value: PValue) -> Self {
427 ScalarValue(InnerScalarValue::Primitive(value))
428 }
429}
430
431impl From<usize> for ScalarValue {
433 fn from(value: usize) -> Self {
434 ScalarValue(InnerScalarValue::Primitive((value as u64).into()))
435 }
436}
437
438#[derive(Debug, Clone, Copy, PartialEq, Eq)]
439pub enum NumericOperator {
441 Add,
445 Sub,
447 RSub,
449 Mul,
451 Div,
453 RDiv,
455 }
460
461impl Display for NumericOperator {
462 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
463 Debug::fmt(self, f)
464 }
465}
466
467impl NumericOperator {
468 pub fn swap(self) -> Self {
470 match self {
471 NumericOperator::Add => NumericOperator::Add,
472 NumericOperator::Sub => NumericOperator::RSub,
473 NumericOperator::RSub => NumericOperator::Sub,
474 NumericOperator::Mul => NumericOperator::Mul,
475 NumericOperator::Div => NumericOperator::RDiv,
476 NumericOperator::RDiv => NumericOperator::Div,
477 }
478 }
479}
480
481impl<'a> PrimitiveScalar<'a> {
482 pub fn checked_binary_numeric(
490 &self,
491 other: &PrimitiveScalar<'a>,
492 op: NumericOperator,
493 ) -> Option<PrimitiveScalar<'a>> {
494 if !self.dtype().eq_ignore_nullability(other.dtype()) {
495 vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
496 }
497 let result_dtype = if self.dtype().is_nullable() {
498 self.dtype()
499 } else {
500 other.dtype()
501 };
502 let ptype = self.ptype();
503
504 match_each_native_ptype!(
505 self.ptype(),
506 integral: |P| {
507 self.checked_integral_numeric_operator::<P>(other, result_dtype, ptype, op)
508 },
509 floating: |P| {
510 let lhs = self.typed_value::<P>();
511 let rhs = other.typed_value::<P>();
512 let value_or_null = match (lhs, rhs) {
513 (_, None) | (None, _) => None,
514 (Some(lhs), Some(rhs)) => match op {
515 NumericOperator::Add => Some(lhs + rhs),
516 NumericOperator::Sub => Some(lhs - rhs),
517 NumericOperator::RSub => Some(rhs - lhs),
518 NumericOperator::Mul => Some(lhs * rhs),
519 NumericOperator::Div => Some(lhs / rhs),
520 NumericOperator::RDiv => Some(rhs / lhs),
521 }
522 };
523 Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
524 }
525 )
526 }
527
528 fn checked_integral_numeric_operator<
529 P: NativePType
530 + TryFrom<PValue, Error = VortexError>
531 + CheckedSub
532 + CheckedAdd
533 + CheckedMul
534 + CheckedDiv,
535 >(
536 &self,
537 other: &PrimitiveScalar<'a>,
538 result_dtype: &'a DType,
539 ptype: PType,
540 op: NumericOperator,
541 ) -> Option<PrimitiveScalar<'a>>
542 where
543 PValue: From<P>,
544 {
545 let lhs = self.typed_value::<P>();
546 let rhs = other.typed_value::<P>();
547 let value_or_null_or_overflow = match (lhs, rhs) {
548 (_, None) | (None, _) => Some(None),
549 (Some(lhs), Some(rhs)) => match op {
550 NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
551 NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
552 NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
553 NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
554 NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
555 NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
556 },
557 };
558
559 value_or_null_or_overflow.map(|value_or_null| Self {
560 dtype: result_dtype,
561 ptype,
562 pvalue: value_or_null.map(PValue::from),
563 })
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use num_traits::CheckedSub;
570 use rstest::rstest;
571 use vortex_dtype::DType;
572 use vortex_dtype::Nullability;
573 use vortex_dtype::PType;
574 use vortex_error::VortexExpect;
575
576 use crate::InnerScalarValue;
577 use crate::PValue;
578 use crate::PrimitiveScalar;
579 use crate::ScalarValue;
580
581 #[test]
582 fn test_integer_subtract() {
583 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
584 let p_scalar1 = PrimitiveScalar::try_new(
585 &dtype,
586 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
587 )
588 .unwrap();
589 let p_scalar2 = PrimitiveScalar::try_new(
590 &dtype,
591 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
592 )
593 .unwrap();
594 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
595 let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
596 assert_eq!(value_or_null_or_type_error.unwrap(), 1);
597
598 assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap(), 1);
599 }
600
601 #[test]
602 #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
603 fn test_integer_subtract_overflow() {
604 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
605 let p_scalar1 = PrimitiveScalar::try_new(
606 &dtype,
607 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
608 )
609 .unwrap();
610 let p_scalar2 = PrimitiveScalar::try_new(
611 &dtype,
612 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
613 )
614 .unwrap();
615 let _ = p_scalar1 - p_scalar2;
616 }
617
618 #[test]
619 fn test_float_subtract() {
620 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
621 let p_scalar1 = PrimitiveScalar::try_new(
622 &dtype,
623 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
624 )
625 .unwrap();
626 let p_scalar2 = PrimitiveScalar::try_new(
627 &dtype,
628 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
629 )
630 .unwrap();
631 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
632 let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
633 assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32);
634
635 assert_eq!((p_scalar1 - p_scalar2).as_::<f32>().unwrap(), 0.99f32);
636 }
637
638 #[test]
639 fn test_primitive_scalar_equality() {
640 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
641 let scalar1 = PrimitiveScalar::try_new(
642 &dtype,
643 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
644 )
645 .unwrap();
646 let scalar2 = PrimitiveScalar::try_new(
647 &dtype,
648 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
649 )
650 .unwrap();
651 let scalar3 = PrimitiveScalar::try_new(
652 &dtype,
653 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))),
654 )
655 .unwrap();
656
657 assert_eq!(scalar1, scalar2);
658 assert_ne!(scalar1, scalar3);
659 }
660
661 #[test]
662 fn test_primitive_scalar_partial_ord() {
663 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
664 let scalar1 = PrimitiveScalar::try_new(
665 &dtype,
666 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
667 )
668 .unwrap();
669 let scalar2 = PrimitiveScalar::try_new(
670 &dtype,
671 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
672 )
673 .unwrap();
674
675 assert!(scalar1 < scalar2);
676 assert!(scalar2 > scalar1);
677 assert_eq!(
678 scalar1.partial_cmp(&scalar1),
679 Some(std::cmp::Ordering::Equal)
680 );
681 }
682
683 #[test]
684 fn test_primitive_scalar_null_handling() {
685 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
686 let null_scalar =
687 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
688
689 assert_eq!(null_scalar.pvalue(), None);
690 assert_eq!(null_scalar.typed_value::<i32>(), None);
691 }
692
693 #[test]
694 fn test_typed_value_correct_type() {
695 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
696 let scalar = PrimitiveScalar::try_new(
697 &dtype,
698 &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
699 )
700 .unwrap();
701
702 assert_eq!(scalar.typed_value::<f64>(), Some(3.5));
703 }
704
705 #[test]
706 #[should_panic(expected = "Attempting to read")]
707 fn test_typed_value_wrong_type() {
708 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
709 let scalar = PrimitiveScalar::try_new(
710 &dtype,
711 &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
712 )
713 .unwrap();
714
715 let _ = scalar.typed_value::<i32>();
716 }
717
718 #[rstest]
719 #[case(PType::I8, 127i32, PType::I16, true)]
720 #[case(PType::I8, 127i32, PType::I32, true)]
721 #[case(PType::I8, 127i32, PType::I64, true)]
722 #[case(PType::U8, 255i32, PType::U16, true)]
723 #[case(PType::U8, 255i32, PType::U32, true)]
724 #[case(PType::I32, 42i32, PType::F32, true)]
725 #[case(PType::I32, 42i32, PType::F64, true)]
726 #[case(PType::I32, 300i32, PType::U8, false)]
728 #[case(PType::I32, -1i32, PType::U32, false)]
729 #[case(PType::I32, 256i32, PType::I8, false)]
730 #[case(PType::U16, 65535i32, PType::I8, false)]
731 fn test_primitive_cast(
732 #[case] source_type: PType,
733 #[case] source_value: i32,
734 #[case] target_type: PType,
735 #[case] should_succeed: bool,
736 ) {
737 let source_pvalue = match source_type {
738 PType::I8 => PValue::I8(i8::try_from(source_value).vortex_expect("cannot cast")),
739 PType::U8 => PValue::U8(u8::try_from(source_value).vortex_expect("cannot cast")),
740 PType::U16 => PValue::U16(u16::try_from(source_value).vortex_expect("cannot cast")),
741 PType::I32 => PValue::I32(source_value),
742 _ => unreachable!("Test case uses unexpected source type"),
743 };
744
745 let dtype = DType::Primitive(source_type, Nullability::NonNullable);
746 let scalar = PrimitiveScalar::try_new(
747 &dtype,
748 &ScalarValue(InnerScalarValue::Primitive(source_pvalue)),
749 )
750 .unwrap();
751
752 let target_dtype = DType::Primitive(target_type, Nullability::NonNullable);
753 let result = scalar.cast(&target_dtype);
754
755 if should_succeed {
756 assert!(
757 result.is_ok(),
758 "Cast from {:?} to {:?} should succeed",
759 source_type,
760 target_type
761 );
762 } else {
763 assert!(
764 result.is_err(),
765 "Cast from {:?} to {:?} should fail due to overflow",
766 source_type,
767 target_type
768 );
769 }
770 }
771
772 #[test]
773 fn test_as_conversion_success() {
774 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
775 let scalar = PrimitiveScalar::try_new(
776 &dtype,
777 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
778 )
779 .unwrap();
780
781 assert_eq!(scalar.as_::<i64>(), Some(42i64));
782 assert_eq!(scalar.as_::<f64>(), Some(42.0));
783 }
784
785 #[test]
786 fn test_as_conversion_overflow() {
787 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
788 let scalar = PrimitiveScalar::try_new(
789 &dtype,
790 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))),
791 )
792 .unwrap();
793
794 let result = scalar.as_opt::<u32>();
796 assert!(result.is_none());
797 }
798
799 #[test]
800 fn test_as_conversion_null() {
801 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
802 let scalar =
803 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
804
805 assert_eq!(scalar.as_::<i32>(), None);
806 assert_eq!(scalar.as_::<f64>(), None);
807 }
808
809 #[test]
810 fn test_numeric_operator_swap() {
811 use crate::primitive::NumericOperator;
812
813 assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add);
814 assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub);
815 assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub);
816 assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul);
817 assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv);
818 assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div);
819 }
820
821 #[test]
822 fn test_checked_binary_numeric_add() {
823 use crate::primitive::NumericOperator;
824
825 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
826 let scalar1 = PrimitiveScalar::try_new(
827 &dtype,
828 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
829 )
830 .unwrap();
831 let scalar2 = PrimitiveScalar::try_new(
832 &dtype,
833 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
834 )
835 .unwrap();
836
837 let result = scalar1
838 .checked_binary_numeric(&scalar2, NumericOperator::Add)
839 .unwrap();
840 assert_eq!(result.typed_value::<i32>(), Some(30));
841 }
842
843 #[test]
844 fn test_checked_binary_numeric_overflow() {
845 use crate::primitive::NumericOperator;
846
847 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
848 let scalar1 = PrimitiveScalar::try_new(
849 &dtype,
850 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
851 )
852 .unwrap();
853 let scalar2 = PrimitiveScalar::try_new(
854 &dtype,
855 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))),
856 )
857 .unwrap();
858
859 let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add);
861 assert!(result.is_none());
862 }
863
864 #[test]
865 fn test_checked_binary_numeric_with_null() {
866 use crate::primitive::NumericOperator;
867
868 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
869 let scalar1 = PrimitiveScalar::try_new(
870 &dtype,
871 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
872 )
873 .unwrap();
874 let null_scalar =
875 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
876
877 let result = scalar1
879 .checked_binary_numeric(&null_scalar, NumericOperator::Add)
880 .unwrap();
881 assert_eq!(result.pvalue(), None);
882 }
883
884 #[test]
885 fn test_checked_binary_numeric_mul() {
886 use crate::primitive::NumericOperator;
887
888 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
889 let scalar1 = PrimitiveScalar::try_new(
890 &dtype,
891 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
892 )
893 .unwrap();
894 let scalar2 = PrimitiveScalar::try_new(
895 &dtype,
896 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))),
897 )
898 .unwrap();
899
900 let result = scalar1
901 .checked_binary_numeric(&scalar2, NumericOperator::Mul)
902 .unwrap();
903 assert_eq!(result.typed_value::<i32>(), Some(30));
904 }
905
906 #[test]
907 fn test_checked_binary_numeric_div() {
908 use crate::primitive::NumericOperator;
909
910 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
911 let scalar1 = PrimitiveScalar::try_new(
912 &dtype,
913 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
914 )
915 .unwrap();
916 let scalar2 = PrimitiveScalar::try_new(
917 &dtype,
918 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
919 )
920 .unwrap();
921
922 let result = scalar1
923 .checked_binary_numeric(&scalar2, NumericOperator::Div)
924 .unwrap();
925 assert_eq!(result.typed_value::<i32>(), Some(5));
926 }
927
928 #[test]
929 fn test_checked_binary_numeric_rdiv() {
930 use crate::primitive::NumericOperator;
931
932 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
933 let scalar1 = PrimitiveScalar::try_new(
934 &dtype,
935 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
936 )
937 .unwrap();
938 let scalar2 = PrimitiveScalar::try_new(
939 &dtype,
940 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
941 )
942 .unwrap();
943
944 let result = scalar1
946 .checked_binary_numeric(&scalar2, NumericOperator::RDiv)
947 .unwrap();
948 assert_eq!(result.typed_value::<i32>(), Some(5));
949 }
950
951 #[test]
952 fn test_checked_binary_numeric_div_by_zero() {
953 use crate::primitive::NumericOperator;
954
955 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
956 let scalar1 = PrimitiveScalar::try_new(
957 &dtype,
958 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
959 )
960 .unwrap();
961 let scalar2 = PrimitiveScalar::try_new(
962 &dtype,
963 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))),
964 )
965 .unwrap();
966
967 let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div);
969 assert!(result.is_none());
970 }
971
972 #[test]
973 fn test_checked_binary_numeric_float_ops() {
974 use crate::primitive::NumericOperator;
975
976 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
977 let scalar1 = PrimitiveScalar::try_new(
978 &dtype,
979 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
980 )
981 .unwrap();
982 let scalar2 = PrimitiveScalar::try_new(
983 &dtype,
984 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))),
985 )
986 .unwrap();
987
988 let add_result = scalar1
990 .checked_binary_numeric(&scalar2, NumericOperator::Add)
991 .unwrap();
992 assert_eq!(add_result.typed_value::<f32>(), Some(12.5));
993
994 let sub_result = scalar1
995 .checked_binary_numeric(&scalar2, NumericOperator::Sub)
996 .unwrap();
997 assert_eq!(sub_result.typed_value::<f32>(), Some(7.5));
998
999 let mul_result = scalar1
1000 .checked_binary_numeric(&scalar2, NumericOperator::Mul)
1001 .unwrap();
1002 assert_eq!(mul_result.typed_value::<f32>(), Some(25.0));
1003
1004 let div_result = scalar1
1005 .checked_binary_numeric(&scalar2, NumericOperator::Div)
1006 .unwrap();
1007 assert_eq!(div_result.typed_value::<f32>(), Some(4.0));
1008 }
1009
1010 #[test]
1011 fn test_from_primitive_or_f16() {
1012 use vortex_dtype::half::f16;
1013
1014 use crate::primitive::FromPrimitiveOrF16;
1015
1016 let f16_val = f16::from_f32(3.5);
1018 assert!(f32::from_f16(f16_val).is_some());
1019
1020 assert!(f64::from_f16(f16_val).is_some());
1022
1023 assert!(i32::try_from(PValue::from(f16_val)).is_err());
1025 assert!(u32::try_from(PValue::from(f16_val)).is_err());
1026 }
1027
1028 #[test]
1029 fn test_partial_ord_different_types() {
1030 let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable);
1031 let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable);
1032
1033 let scalar1 = PrimitiveScalar::try_new(
1034 &dtype1,
1035 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
1036 )
1037 .unwrap();
1038 let scalar2 = PrimitiveScalar::try_new(
1039 &dtype2,
1040 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1041 )
1042 .unwrap();
1043
1044 assert_eq!(scalar1.partial_cmp(&scalar2), None);
1046 }
1047
1048 #[test]
1049 fn test_scalar_value_from_usize() {
1050 let value: ScalarValue = 42usize.into();
1051 assert!(matches!(
1052 value.0,
1053 InnerScalarValue::Primitive(PValue::U64(42))
1054 ));
1055 }
1056
1057 #[test]
1058 fn test_getters() {
1059 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
1060 let scalar = PrimitiveScalar::try_new(
1061 &dtype,
1062 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
1063 )
1064 .unwrap();
1065
1066 assert_eq!(scalar.dtype(), &dtype);
1067 assert_eq!(scalar.ptype(), PType::I32);
1068 assert_eq!(scalar.pvalue(), Some(PValue::I32(42)));
1069 }
1070}