1use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::{Debug, Display, Formatter};
7use std::ops::{Add, Sub};
8
9use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
10use vortex_dtype::half::f16;
11use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
12use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
13
14use crate::pvalue::{CoercePValue, PValue};
15use crate::{InnerScalarValue, Scalar, ScalarValue};
16
17#[derive(Debug, Clone, Copy, Hash)]
22pub struct PrimitiveScalar<'a> {
23 dtype: &'a DType,
24 ptype: PType,
25 pvalue: Option<PValue>,
26}
27
28impl Display for PrimitiveScalar<'_> {
29 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30 match self.pvalue {
31 None => write!(f, "null"),
32 Some(pv) => write!(f, "{pv}"),
33 }
34 }
35}
36
37impl PartialEq for PrimitiveScalar<'_> {
38 fn eq(&self, other: &Self) -> bool {
39 self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
40 }
41}
42
43impl Eq for PrimitiveScalar<'_> {}
44
45impl PartialOrd for PrimitiveScalar<'_> {
47 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48 if !self.dtype.eq_ignore_nullability(other.dtype) {
49 return None;
50 }
51 self.pvalue.partial_cmp(&other.pvalue)
52 }
53}
54
55impl<'a> PrimitiveScalar<'a> {
56 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
63 let ptype = PType::try_from(dtype)?;
64
65 let pvalue = match_each_native_ptype!(ptype, |T| {
68 value
69 .as_pvalue()?
70 .map(|pv| VortexResult::Ok(PValue::from(<T>::coerce(pv)?)))
71 .transpose()?
72 });
73
74 Ok(Self {
75 dtype,
76 ptype,
77 pvalue,
78 })
79 }
80
81 #[inline]
83 pub fn dtype(&self) -> &'a DType {
84 self.dtype
85 }
86
87 #[inline]
89 pub fn ptype(&self) -> PType {
90 self.ptype
91 }
92
93 #[inline]
95 pub fn pvalue(&self) -> Option<PValue> {
96 self.pvalue
97 }
98
99 pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
108 assert_eq!(
109 self.ptype,
110 T::PTYPE,
111 "Attempting to read {} scalar as {}",
112 self.ptype,
113 T::PTYPE
114 );
115
116 self.pvalue.map(|pv| pv.as_primitive::<T>())
117 }
118
119 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
120 let ptype = PType::try_from(dtype)?;
121 let pvalue = self
122 .pvalue
123 .vortex_expect("nullness handled in Scalar::cast");
124 Ok(match_each_native_ptype!(ptype, |Q| {
125 Scalar::primitive(
126 pvalue
127 .as_primitive_opt::<Q>()
128 .ok_or_else(|| vortex_err!("Cannot cast {} to {}", self.ptype, dtype))?,
129 dtype.nullability(),
130 )
131 }))
132 }
133
134 pub fn as_<T: FromPrimitiveOrF16>(&self) -> Option<T> {
160 self.as_opt::<T>().unwrap_or_else(|| {
161 vortex_panic!(
162 "cast {} to {}: value out of range",
163 self.ptype,
164 type_name::<T>()
165 )
166 })
167 }
168
169 pub fn as_opt<T: FromPrimitiveOrF16>(&self) -> Option<Option<T>> {
197 if let Some(pv) = self.pvalue {
198 match pv {
199 PValue::U8(v) => T::from_u8(v),
200 PValue::U16(v) => T::from_u16(v),
201 PValue::U32(v) => T::from_u32(v),
202 PValue::U64(v) => T::from_u64(v),
203 PValue::I8(v) => T::from_i8(v),
204 PValue::I16(v) => T::from_i16(v),
205 PValue::I32(v) => T::from_i32(v),
206 PValue::I64(v) => T::from_i64(v),
207 PValue::F16(v) => T::from_f16(v),
208 PValue::F32(v) => T::from_f32(v),
209 PValue::F64(v) => T::from_f64(v),
210 }
211 .map(Some)
212 } else {
213 Some(None)
214 }
215 }
216}
217
218pub trait FromPrimitiveOrF16: FromPrimitive {
222 fn from_f16(v: f16) -> Option<Self>;
224}
225
226macro_rules! from_primitive_or_f16_for_non_floating_point {
227 ($T:ty) => {
228 impl FromPrimitiveOrF16 for $T {
229 fn from_f16(_: f16) -> Option<Self> {
230 None
231 }
232 }
233 };
234}
235
236from_primitive_or_f16_for_non_floating_point!(usize);
237from_primitive_or_f16_for_non_floating_point!(u8);
238from_primitive_or_f16_for_non_floating_point!(u16);
239from_primitive_or_f16_for_non_floating_point!(u32);
240from_primitive_or_f16_for_non_floating_point!(u64);
241from_primitive_or_f16_for_non_floating_point!(i8);
242from_primitive_or_f16_for_non_floating_point!(i16);
243from_primitive_or_f16_for_non_floating_point!(i32);
244from_primitive_or_f16_for_non_floating_point!(i64);
245
246impl FromPrimitiveOrF16 for f16 {
247 fn from_f16(v: f16) -> Option<Self> {
248 Some(v)
249 }
250}
251
252impl FromPrimitiveOrF16 for f32 {
253 fn from_f16(v: f16) -> Option<Self> {
254 Some(v.to_f32())
255 }
256}
257
258impl FromPrimitiveOrF16 for f64 {
259 fn from_f16(v: f16) -> Option<Self> {
260 Some(v.to_f64())
261 }
262}
263
264impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
265 type Error = VortexError;
266
267 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
268 Self::try_new(value.dtype(), value.value())
269 }
270}
271
272impl Sub for PrimitiveScalar<'_> {
273 type Output = Self;
274
275 fn sub(self, rhs: Self) -> Self::Output {
276 self.checked_sub(&rhs)
277 .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
278 }
279}
280
281impl CheckedSub for PrimitiveScalar<'_> {
282 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
283 self.checked_binary_numeric(rhs, NumericOperator::Sub)
284 }
285}
286
287impl Add for PrimitiveScalar<'_> {
288 type Output = Self;
289
290 fn add(self, rhs: Self) -> Self::Output {
291 self.checked_add(&rhs)
292 .vortex_expect("PrimitiveScalar add: overflow or underflow")
293 }
294}
295
296impl CheckedAdd for PrimitiveScalar<'_> {
297 fn checked_add(&self, rhs: &Self) -> Option<Self> {
298 self.checked_binary_numeric(rhs, NumericOperator::Add)
299 }
300}
301
302impl Scalar {
303 pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
305 Self::primitive_value(value.into(), T::PTYPE, nullability)
306 }
307
308 pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
313 Self::new(
314 DType::Primitive(ptype, nullability),
315 ScalarValue(InnerScalarValue::Primitive(value)),
316 )
317 }
318
319 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
325 let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
326 vortex_panic!(
327 e,
328 "Failed to reinterpret cast {} to {}",
329 self.dtype(),
330 ptype
331 )
332 });
333 if primitive.ptype() == ptype {
334 return self.clone();
335 }
336
337 assert_eq!(
338 primitive.ptype().byte_width(),
339 ptype.byte_width(),
340 "can't reinterpret cast between integers of two different widths"
341 );
342
343 Scalar::new(
344 DType::Primitive(ptype, self.dtype().nullability()),
345 primitive
346 .pvalue
347 .map(|p| p.reinterpret_cast(ptype))
348 .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
349 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
350 )
351 }
352}
353
354macro_rules! primitive_scalar {
355 ($T:ty) => {
356 impl TryFrom<&Scalar> for $T {
357 type Error = VortexError;
358
359 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
360 <Option<$T>>::try_from(value)?
361 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
362 }
363 }
364
365 impl TryFrom<Scalar> for $T {
366 type Error = VortexError;
367
368 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
369 <$T>::try_from(&value)
370 }
371 }
372
373 impl TryFrom<&Scalar> for Option<$T> {
374 type Error = VortexError;
375
376 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
377 Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
378 }
379 }
380
381 impl TryFrom<Scalar> for Option<$T> {
382 type Error = VortexError;
383
384 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
385 <Option<$T>>::try_from(&value)
386 }
387 }
388
389 impl From<$T> for Scalar {
390 fn from(value: $T) -> Self {
391 Scalar::new(
392 DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
393 ScalarValue(InnerScalarValue::Primitive(value.into())),
394 )
395 }
396 }
397
398 impl From<$T> for ScalarValue {
399 fn from(value: $T) -> Self {
400 ScalarValue(InnerScalarValue::Primitive(value.into()))
401 }
402 }
403 };
404}
405
406primitive_scalar!(u8);
407primitive_scalar!(u16);
408primitive_scalar!(u32);
409primitive_scalar!(u64);
410primitive_scalar!(i8);
411primitive_scalar!(i16);
412primitive_scalar!(i32);
413primitive_scalar!(i64);
414primitive_scalar!(f16);
415primitive_scalar!(f32);
416primitive_scalar!(f64);
417
418impl TryFrom<&Scalar> for usize {
420 type Error = VortexError;
421
422 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
423 let prim = PrimitiveScalar::try_from(value)?
424 .as_::<u64>()
425 .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
426 Ok(usize::try_from(prim)?)
427 }
428}
429
430impl TryFrom<&Scalar> for Option<usize> {
431 type Error = VortexError;
432
433 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
434 Ok(PrimitiveScalar::try_from(value)?
435 .as_::<u64>()
436 .map(usize::try_from)
437 .transpose()?)
438 }
439}
440
441impl From<usize> for Scalar {
443 fn from(value: usize) -> Self {
444 Scalar::primitive(value as u64, Nullability::NonNullable)
445 }
446}
447
448impl From<PValue> for ScalarValue {
449 fn from(value: PValue) -> Self {
450 ScalarValue(InnerScalarValue::Primitive(value))
451 }
452}
453
454impl From<usize> for ScalarValue {
456 fn from(value: usize) -> Self {
457 ScalarValue(InnerScalarValue::Primitive((value as u64).into()))
458 }
459}
460
461#[derive(Debug, Clone, Copy, PartialEq, Eq)]
462pub enum NumericOperator {
464 Add,
468 Sub,
470 RSub,
472 Mul,
474 Div,
476 RDiv,
478 }
483
484impl Display for NumericOperator {
485 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
486 Debug::fmt(self, f)
487 }
488}
489
490impl NumericOperator {
491 pub fn swap(self) -> Self {
493 match self {
494 NumericOperator::Add => NumericOperator::Add,
495 NumericOperator::Sub => NumericOperator::RSub,
496 NumericOperator::RSub => NumericOperator::Sub,
497 NumericOperator::Mul => NumericOperator::Mul,
498 NumericOperator::Div => NumericOperator::RDiv,
499 NumericOperator::RDiv => NumericOperator::Div,
500 }
501 }
502}
503
504impl<'a> PrimitiveScalar<'a> {
505 pub fn checked_binary_numeric(
513 &self,
514 other: &PrimitiveScalar<'a>,
515 op: NumericOperator,
516 ) -> Option<PrimitiveScalar<'a>> {
517 if !self.dtype().eq_ignore_nullability(other.dtype()) {
518 vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
519 }
520 let result_dtype = if self.dtype().is_nullable() {
521 self.dtype()
522 } else {
523 other.dtype()
524 };
525 let ptype = self.ptype();
526
527 match_each_native_ptype!(
528 self.ptype(),
529 integral: |P| {
530 self.checked_integral_numeric_operator::<P>(other, result_dtype, ptype, op)
531 },
532 floating: |P| {
533 let lhs = self.typed_value::<P>();
534 let rhs = other.typed_value::<P>();
535 let value_or_null = match (lhs, rhs) {
536 (_, None) | (None, _) => None,
537 (Some(lhs), Some(rhs)) => match op {
538 NumericOperator::Add => Some(lhs + rhs),
539 NumericOperator::Sub => Some(lhs - rhs),
540 NumericOperator::RSub => Some(rhs - lhs),
541 NumericOperator::Mul => Some(lhs * rhs),
542 NumericOperator::Div => Some(lhs / rhs),
543 NumericOperator::RDiv => Some(rhs / lhs),
544 }
545 };
546 Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
547 }
548 )
549 }
550
551 fn checked_integral_numeric_operator<
552 P: NativePType
553 + TryFrom<PValue, Error = VortexError>
554 + CheckedSub
555 + CheckedAdd
556 + CheckedMul
557 + CheckedDiv,
558 >(
559 &self,
560 other: &PrimitiveScalar<'a>,
561 result_dtype: &'a DType,
562 ptype: PType,
563 op: NumericOperator,
564 ) -> Option<PrimitiveScalar<'a>>
565 where
566 PValue: From<P>,
567 {
568 let lhs = self.typed_value::<P>();
569 let rhs = other.typed_value::<P>();
570 let value_or_null_or_overflow = match (lhs, rhs) {
571 (_, None) | (None, _) => Some(None),
572 (Some(lhs), Some(rhs)) => match op {
573 NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
574 NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
575 NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
576 NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
577 NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
578 NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
579 },
580 };
581
582 value_or_null_or_overflow.map(|value_or_null| Self {
583 dtype: result_dtype,
584 ptype,
585 pvalue: value_or_null.map(PValue::from),
586 })
587 }
588}
589
590#[cfg(test)]
591#[allow(clippy::cast_possible_truncation)]
592mod tests {
593 use num_traits::CheckedSub;
594 use rstest::rstest;
595 use vortex_dtype::{DType, Nullability, PType};
596
597 use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
598
599 #[test]
600 fn test_integer_subtract() {
601 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
602 let p_scalar1 = PrimitiveScalar::try_new(
603 &dtype,
604 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
605 )
606 .unwrap();
607 let p_scalar2 = PrimitiveScalar::try_new(
608 &dtype,
609 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
610 )
611 .unwrap();
612 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
613 let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
614 assert_eq!(value_or_null_or_type_error.unwrap(), 1);
615
616 assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap(), 1);
617 }
618
619 #[test]
620 #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
621 #[allow(clippy::assertions_on_constants)]
622 fn test_integer_subtract_overflow() {
623 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
624 let p_scalar1 = PrimitiveScalar::try_new(
625 &dtype,
626 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
627 )
628 .unwrap();
629 let p_scalar2 = PrimitiveScalar::try_new(
630 &dtype,
631 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
632 )
633 .unwrap();
634 let _ = p_scalar1 - p_scalar2;
635 }
636
637 #[test]
638 fn test_float_subtract() {
639 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
640 let p_scalar1 = PrimitiveScalar::try_new(
641 &dtype,
642 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
643 )
644 .unwrap();
645 let p_scalar2 = PrimitiveScalar::try_new(
646 &dtype,
647 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
648 )
649 .unwrap();
650 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
651 let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
652 assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32);
653
654 assert_eq!((p_scalar1 - p_scalar2).as_::<f32>().unwrap(), 0.99f32);
655 }
656
657 #[test]
658 fn test_primitive_scalar_equality() {
659 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
660 let scalar1 = PrimitiveScalar::try_new(
661 &dtype,
662 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
663 )
664 .unwrap();
665 let scalar2 = PrimitiveScalar::try_new(
666 &dtype,
667 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
668 )
669 .unwrap();
670 let scalar3 = PrimitiveScalar::try_new(
671 &dtype,
672 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))),
673 )
674 .unwrap();
675
676 assert_eq!(scalar1, scalar2);
677 assert_ne!(scalar1, scalar3);
678 }
679
680 #[test]
681 fn test_primitive_scalar_partial_ord() {
682 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
683 let scalar1 = PrimitiveScalar::try_new(
684 &dtype,
685 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
686 )
687 .unwrap();
688 let scalar2 = PrimitiveScalar::try_new(
689 &dtype,
690 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
691 )
692 .unwrap();
693
694 assert!(scalar1 < scalar2);
695 assert!(scalar2 > scalar1);
696 assert_eq!(
697 scalar1.partial_cmp(&scalar1),
698 Some(std::cmp::Ordering::Equal)
699 );
700 }
701
702 #[test]
703 fn test_primitive_scalar_null_handling() {
704 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
705 let null_scalar =
706 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
707
708 assert_eq!(null_scalar.pvalue(), None);
709 assert_eq!(null_scalar.typed_value::<i32>(), None);
710 }
711
712 #[test]
713 fn test_typed_value_correct_type() {
714 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
715 let scalar = PrimitiveScalar::try_new(
716 &dtype,
717 &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
718 )
719 .unwrap();
720
721 assert_eq!(scalar.typed_value::<f64>(), Some(3.5));
722 }
723
724 #[test]
725 #[should_panic(expected = "Attempting to read")]
726 fn test_typed_value_wrong_type() {
727 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
728 let scalar = PrimitiveScalar::try_new(
729 &dtype,
730 &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
731 )
732 .unwrap();
733
734 let _ = scalar.typed_value::<i32>();
735 }
736
737 #[rstest]
738 #[case(PType::I8, 127i32, PType::I16, true)]
739 #[case(PType::I8, 127i32, PType::I32, true)]
740 #[case(PType::I8, 127i32, PType::I64, true)]
741 #[case(PType::U8, 255i32, PType::U16, true)]
742 #[case(PType::U8, 255i32, PType::U32, true)]
743 #[case(PType::I32, 42i32, PType::F32, true)]
744 #[case(PType::I32, 42i32, PType::F64, true)]
745 #[case(PType::I32, 300i32, PType::U8, false)]
747 #[case(PType::I32, -1i32, PType::U32, false)]
748 #[case(PType::I32, 256i32, PType::I8, false)]
749 #[case(PType::U16, 65535i32, PType::I8, false)]
750 fn test_primitive_cast(
751 #[case] source_type: PType,
752 #[case] source_value: i32,
753 #[case] target_type: PType,
754 #[case] should_succeed: bool,
755 ) {
756 let source_pvalue = match source_type {
757 PType::I8 => PValue::I8(source_value as i8),
758 PType::U8 => PValue::U8(source_value as u8),
759 PType::U16 => PValue::U16(source_value as u16),
760 PType::I32 => PValue::I32(source_value),
761 _ => unreachable!("Test case uses unexpected source type"),
762 };
763
764 let dtype = DType::Primitive(source_type, Nullability::NonNullable);
765 let scalar = PrimitiveScalar::try_new(
766 &dtype,
767 &ScalarValue(InnerScalarValue::Primitive(source_pvalue)),
768 )
769 .unwrap();
770
771 let target_dtype = DType::Primitive(target_type, Nullability::NonNullable);
772 let result = scalar.cast(&target_dtype);
773
774 if should_succeed {
775 assert!(
776 result.is_ok(),
777 "Cast from {:?} to {:?} should succeed",
778 source_type,
779 target_type
780 );
781 } else {
782 assert!(
783 result.is_err(),
784 "Cast from {:?} to {:?} should fail due to overflow",
785 source_type,
786 target_type
787 );
788 }
789 }
790
791 #[test]
792 fn test_as_conversion_success() {
793 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
794 let scalar = PrimitiveScalar::try_new(
795 &dtype,
796 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
797 )
798 .unwrap();
799
800 assert_eq!(scalar.as_::<i64>(), Some(42i64));
801 assert_eq!(scalar.as_::<f64>(), Some(42.0));
802 }
803
804 #[test]
805 fn test_as_conversion_overflow() {
806 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
807 let scalar = PrimitiveScalar::try_new(
808 &dtype,
809 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))),
810 )
811 .unwrap();
812
813 let result = scalar.as_opt::<u32>();
815 assert!(result.is_none());
816 }
817
818 #[test]
819 fn test_as_conversion_null() {
820 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
821 let scalar =
822 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
823
824 assert_eq!(scalar.as_::<i32>(), None);
825 assert_eq!(scalar.as_::<f64>(), None);
826 }
827
828 #[test]
829 fn test_numeric_operator_swap() {
830 use crate::primitive::NumericOperator;
831
832 assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add);
833 assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub);
834 assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub);
835 assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul);
836 assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv);
837 assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div);
838 }
839
840 #[test]
841 fn test_checked_binary_numeric_add() {
842 use crate::primitive::NumericOperator;
843
844 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
845 let scalar1 = PrimitiveScalar::try_new(
846 &dtype,
847 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
848 )
849 .unwrap();
850 let scalar2 = PrimitiveScalar::try_new(
851 &dtype,
852 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
853 )
854 .unwrap();
855
856 let result = scalar1
857 .checked_binary_numeric(&scalar2, NumericOperator::Add)
858 .unwrap();
859 assert_eq!(result.typed_value::<i32>(), Some(30));
860 }
861
862 #[test]
863 fn test_checked_binary_numeric_overflow() {
864 use crate::primitive::NumericOperator;
865
866 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
867 let scalar1 = PrimitiveScalar::try_new(
868 &dtype,
869 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
870 )
871 .unwrap();
872 let scalar2 = PrimitiveScalar::try_new(
873 &dtype,
874 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))),
875 )
876 .unwrap();
877
878 let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add);
880 assert!(result.is_none());
881 }
882
883 #[test]
884 fn test_checked_binary_numeric_with_null() {
885 use crate::primitive::NumericOperator;
886
887 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
888 let scalar1 = PrimitiveScalar::try_new(
889 &dtype,
890 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
891 )
892 .unwrap();
893 let null_scalar =
894 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
895
896 let result = scalar1
898 .checked_binary_numeric(&null_scalar, NumericOperator::Add)
899 .unwrap();
900 assert_eq!(result.pvalue(), None);
901 }
902
903 #[test]
904 fn test_checked_binary_numeric_mul() {
905 use crate::primitive::NumericOperator;
906
907 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
908 let scalar1 = PrimitiveScalar::try_new(
909 &dtype,
910 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
911 )
912 .unwrap();
913 let scalar2 = PrimitiveScalar::try_new(
914 &dtype,
915 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))),
916 )
917 .unwrap();
918
919 let result = scalar1
920 .checked_binary_numeric(&scalar2, NumericOperator::Mul)
921 .unwrap();
922 assert_eq!(result.typed_value::<i32>(), Some(30));
923 }
924
925 #[test]
926 fn test_checked_binary_numeric_div() {
927 use crate::primitive::NumericOperator;
928
929 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
930 let scalar1 = PrimitiveScalar::try_new(
931 &dtype,
932 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
933 )
934 .unwrap();
935 let scalar2 = PrimitiveScalar::try_new(
936 &dtype,
937 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
938 )
939 .unwrap();
940
941 let result = scalar1
942 .checked_binary_numeric(&scalar2, NumericOperator::Div)
943 .unwrap();
944 assert_eq!(result.typed_value::<i32>(), Some(5));
945 }
946
947 #[test]
948 fn test_checked_binary_numeric_rdiv() {
949 use crate::primitive::NumericOperator;
950
951 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
952 let scalar1 = PrimitiveScalar::try_new(
953 &dtype,
954 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
955 )
956 .unwrap();
957 let scalar2 = PrimitiveScalar::try_new(
958 &dtype,
959 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
960 )
961 .unwrap();
962
963 let result = scalar1
965 .checked_binary_numeric(&scalar2, NumericOperator::RDiv)
966 .unwrap();
967 assert_eq!(result.typed_value::<i32>(), Some(5));
968 }
969
970 #[test]
971 fn test_checked_binary_numeric_div_by_zero() {
972 use crate::primitive::NumericOperator;
973
974 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
975 let scalar1 = PrimitiveScalar::try_new(
976 &dtype,
977 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
978 )
979 .unwrap();
980 let scalar2 = PrimitiveScalar::try_new(
981 &dtype,
982 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))),
983 )
984 .unwrap();
985
986 let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div);
988 assert!(result.is_none());
989 }
990
991 #[test]
992 fn test_checked_binary_numeric_float_ops() {
993 use crate::primitive::NumericOperator;
994
995 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
996 let scalar1 = PrimitiveScalar::try_new(
997 &dtype,
998 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
999 )
1000 .unwrap();
1001 let scalar2 = PrimitiveScalar::try_new(
1002 &dtype,
1003 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))),
1004 )
1005 .unwrap();
1006
1007 let add_result = scalar1
1009 .checked_binary_numeric(&scalar2, NumericOperator::Add)
1010 .unwrap();
1011 assert_eq!(add_result.typed_value::<f32>(), Some(12.5));
1012
1013 let sub_result = scalar1
1014 .checked_binary_numeric(&scalar2, NumericOperator::Sub)
1015 .unwrap();
1016 assert_eq!(sub_result.typed_value::<f32>(), Some(7.5));
1017
1018 let mul_result = scalar1
1019 .checked_binary_numeric(&scalar2, NumericOperator::Mul)
1020 .unwrap();
1021 assert_eq!(mul_result.typed_value::<f32>(), Some(25.0));
1022
1023 let div_result = scalar1
1024 .checked_binary_numeric(&scalar2, NumericOperator::Div)
1025 .unwrap();
1026 assert_eq!(div_result.typed_value::<f32>(), Some(4.0));
1027 }
1028
1029 #[test]
1030 fn test_from_primitive_or_f16() {
1031 use vortex_dtype::half::f16;
1032
1033 use crate::primitive::FromPrimitiveOrF16;
1034
1035 let f16_val = f16::from_f32(3.5);
1037 assert!(f32::from_f16(f16_val).is_some());
1038
1039 assert!(f64::from_f16(f16_val).is_some());
1041
1042 assert!(i32::from_f16(f16_val).is_none());
1044 assert!(u32::from_f16(f16_val).is_none());
1045 }
1046
1047 #[test]
1048 fn test_partial_ord_different_types() {
1049 let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable);
1050 let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable);
1051
1052 let scalar1 = PrimitiveScalar::try_new(
1053 &dtype1,
1054 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
1055 )
1056 .unwrap();
1057 let scalar2 = PrimitiveScalar::try_new(
1058 &dtype2,
1059 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1060 )
1061 .unwrap();
1062
1063 assert_eq!(scalar1.partial_cmp(&scalar2), None);
1065 }
1066
1067 #[test]
1068 fn test_scalar_value_from_usize() {
1069 let value: ScalarValue = 42usize.into();
1070 assert!(matches!(
1071 value.0,
1072 InnerScalarValue::Primitive(PValue::U64(42))
1073 ));
1074 }
1075
1076 #[test]
1077 fn test_getters() {
1078 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
1079 let scalar = PrimitiveScalar::try_new(
1080 &dtype,
1081 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
1082 )
1083 .unwrap();
1084
1085 assert_eq!(scalar.dtype(), &dtype);
1086 assert_eq!(scalar.ptype(), PType::I32);
1087 assert_eq!(scalar.pvalue(), Some(PValue::I32(42)));
1088 }
1089}