1use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::{Debug, Display, Formatter};
7use std::ops::{Add, Sub};
8
9use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
10use vortex_dtype::half::f16;
11use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
12use vortex_error::{
13 VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
14};
15
16use crate::pvalue::{CoercePValue, PValue};
17use crate::{InnerScalarValue, Scalar, ScalarValue};
18
19#[derive(Debug, Clone, Copy, Hash)]
24pub struct PrimitiveScalar<'a> {
25 dtype: &'a DType,
26 ptype: PType,
27 pvalue: Option<PValue>,
28}
29
30impl Display for PrimitiveScalar<'_> {
31 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32 match self.pvalue {
33 None => write!(f, "null"),
34 Some(pv) => write!(f, "{pv}"),
35 }
36 }
37}
38
39impl PartialEq for PrimitiveScalar<'_> {
40 fn eq(&self, other: &Self) -> bool {
41 self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
42 }
43}
44
45impl Eq for PrimitiveScalar<'_> {}
46
47impl PartialOrd for PrimitiveScalar<'_> {
49 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50 if !self.dtype.eq_ignore_nullability(other.dtype) {
51 return None;
52 }
53 self.pvalue.partial_cmp(&other.pvalue)
54 }
55}
56
57impl<'a> PrimitiveScalar<'a> {
58 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
65 let ptype = PType::try_from(dtype)?;
66
67 let pvalue = match_each_native_ptype!(ptype, |T| {
70 value
71 .as_pvalue()?
72 .map(|pv| VortexResult::Ok(PValue::from(<T>::coerce(pv)?)))
73 .transpose()?
74 });
75
76 Ok(Self {
77 dtype,
78 ptype,
79 pvalue,
80 })
81 }
82
83 #[inline]
85 pub fn dtype(&self) -> &'a DType {
86 self.dtype
87 }
88
89 #[inline]
91 pub fn ptype(&self) -> PType {
92 self.ptype
93 }
94
95 #[inline]
97 pub fn pvalue(&self) -> Option<PValue> {
98 self.pvalue
99 }
100
101 pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
111 assert_eq!(
112 self.ptype,
113 T::PTYPE,
114 "Attempting to read {} scalar as {}",
115 self.ptype,
116 T::PTYPE
117 );
118
119 self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
120 }
121
122 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
123 let ptype = PType::try_from(dtype)?;
124 let pvalue = self
125 .pvalue
126 .vortex_expect("nullness handled in Scalar::cast");
127 Ok(match_each_native_ptype!(ptype, |Q| {
128 Scalar::primitive(
129 pvalue.as_primitive::<Q>().map_err(|err| {
130 vortex_err!("Cannot cast {} to {}: {}", self.ptype, dtype, err)
131 })?,
132 dtype.nullability(),
133 )
134 }))
135 }
136
137 pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
143 match self.pvalue {
144 None => Ok(None),
145 Some(pv) => Ok(Some(match pv {
146 PValue::U8(v) => T::from_u8(v).ok_or_else(|| {
147 vortex_err!("Cannot cast u8 to {}: value out of range", type_name::<T>())
148 }),
149 PValue::U16(v) => T::from_u16(v).ok_or_else(|| {
150 vortex_err!(
151 "Cannot cast u16 to {}: value out of range",
152 type_name::<T>()
153 )
154 }),
155 PValue::U32(v) => T::from_u32(v).ok_or_else(|| {
156 vortex_err!(
157 "Cannot cast u32 to {}: value out of range",
158 type_name::<T>()
159 )
160 }),
161 PValue::U64(v) => T::from_u64(v).ok_or_else(|| {
162 vortex_err!(
163 "Cannot cast u64 to {}: value out of range",
164 type_name::<T>()
165 )
166 }),
167 PValue::I8(v) => T::from_i8(v).ok_or_else(|| {
168 vortex_err!("Cannot cast i8 to {}: value out of range", type_name::<T>())
169 }),
170 PValue::I16(v) => T::from_i16(v).ok_or_else(|| {
171 vortex_err!(
172 "Cannot cast i16 to {}: value out of range",
173 type_name::<T>()
174 )
175 }),
176 PValue::I32(v) => T::from_i32(v).ok_or_else(|| {
177 vortex_err!(
178 "Cannot cast i32 to {}: value out of range",
179 type_name::<T>()
180 )
181 }),
182 PValue::I64(v) => T::from_i64(v).ok_or_else(|| {
183 vortex_err!(
184 "Cannot cast i64 to {}: value out of range",
185 type_name::<T>()
186 )
187 }),
188 PValue::F16(v) => T::from_f16(v).ok_or_else(|| {
189 vortex_err!(
190 "Cannot cast f16 to {}: value out of range",
191 type_name::<T>()
192 )
193 }),
194 PValue::F32(v) => T::from_f32(v).ok_or_else(|| {
195 vortex_err!(
196 "Cannot cast f32 to {}: value out of range",
197 type_name::<T>()
198 )
199 }),
200 PValue::F64(v) => T::from_f64(v).ok_or_else(|| {
201 vortex_err!(
202 "Cannot cast f64 to {}: value out of range",
203 type_name::<T>()
204 )
205 }),
206 }?)),
207 }
208 }
209}
210
211pub trait FromPrimitiveOrF16: FromPrimitive {
215 fn from_f16(v: f16) -> Option<Self>;
217}
218
219macro_rules! from_primitive_or_f16_for_non_floating_point {
220 ($T:ty) => {
221 impl FromPrimitiveOrF16 for $T {
222 fn from_f16(_: f16) -> Option<Self> {
223 None
224 }
225 }
226 };
227}
228
229from_primitive_or_f16_for_non_floating_point!(usize);
230from_primitive_or_f16_for_non_floating_point!(u8);
231from_primitive_or_f16_for_non_floating_point!(u16);
232from_primitive_or_f16_for_non_floating_point!(u32);
233from_primitive_or_f16_for_non_floating_point!(u64);
234from_primitive_or_f16_for_non_floating_point!(i8);
235from_primitive_or_f16_for_non_floating_point!(i16);
236from_primitive_or_f16_for_non_floating_point!(i32);
237from_primitive_or_f16_for_non_floating_point!(i64);
238
239impl FromPrimitiveOrF16 for f16 {
240 fn from_f16(v: f16) -> Option<Self> {
241 Some(v)
242 }
243}
244
245impl FromPrimitiveOrF16 for f32 {
246 fn from_f16(v: f16) -> Option<Self> {
247 Some(v.to_f32())
248 }
249}
250
251impl FromPrimitiveOrF16 for f64 {
252 fn from_f16(v: f16) -> Option<Self> {
253 Some(v.to_f64())
254 }
255}
256
257impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
258 type Error = VortexError;
259
260 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
261 Self::try_new(value.dtype(), value.value())
262 }
263}
264
265impl Sub for PrimitiveScalar<'_> {
266 type Output = Self;
267
268 fn sub(self, rhs: Self) -> Self::Output {
269 self.checked_sub(&rhs)
270 .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
271 }
272}
273
274impl CheckedSub for PrimitiveScalar<'_> {
275 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
276 self.checked_binary_numeric(rhs, NumericOperator::Sub)
277 }
278}
279
280impl Add for PrimitiveScalar<'_> {
281 type Output = Self;
282
283 fn add(self, rhs: Self) -> Self::Output {
284 self.checked_add(&rhs)
285 .vortex_expect("PrimitiveScalar add: overflow or underflow")
286 }
287}
288
289impl CheckedAdd for PrimitiveScalar<'_> {
290 fn checked_add(&self, rhs: &Self) -> Option<Self> {
291 self.checked_binary_numeric(rhs, NumericOperator::Add)
292 }
293}
294
295impl Scalar {
296 pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
298 Self::primitive_value(value.into(), T::PTYPE, nullability)
299 }
300
301 pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
306 Self::new(
307 DType::Primitive(ptype, nullability),
308 ScalarValue(InnerScalarValue::Primitive(value)),
309 )
310 }
311
312 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
318 let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
319 vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
320 });
321 if primitive.ptype() == ptype {
322 return self.clone();
323 }
324
325 assert_eq!(
326 primitive.ptype().byte_width(),
327 ptype.byte_width(),
328 "can't reinterpret cast between integers of two different widths"
329 );
330
331 Scalar::new(
332 DType::Primitive(ptype, self.dtype.nullability()),
333 primitive
334 .pvalue
335 .map(|p| p.reinterpret_cast(ptype))
336 .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
337 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
338 )
339 }
340}
341
342macro_rules! primitive_scalar {
343 ($T:ty) => {
344 impl TryFrom<&Scalar> for $T {
345 type Error = VortexError;
346
347 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
348 <Option<$T>>::try_from(value)?
349 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
350 }
351 }
352
353 impl TryFrom<Scalar> for $T {
354 type Error = VortexError;
355
356 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
357 <$T>::try_from(&value)
358 }
359 }
360
361 impl TryFrom<&Scalar> for Option<$T> {
362 type Error = VortexError;
363
364 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
365 Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
366 }
367 }
368
369 impl TryFrom<Scalar> for Option<$T> {
370 type Error = VortexError;
371
372 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
373 <Option<$T>>::try_from(&value)
374 }
375 }
376
377 impl From<$T> for Scalar {
378 fn from(value: $T) -> Self {
379 Scalar::new(
380 DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
381 ScalarValue(InnerScalarValue::Primitive(value.into())),
382 )
383 }
384 }
385 };
386}
387
388primitive_scalar!(u8);
389primitive_scalar!(u16);
390primitive_scalar!(u32);
391primitive_scalar!(u64);
392primitive_scalar!(i8);
393primitive_scalar!(i16);
394primitive_scalar!(i32);
395primitive_scalar!(i64);
396primitive_scalar!(f16);
397primitive_scalar!(f32);
398primitive_scalar!(f64);
399
400impl TryFrom<&Scalar> for usize {
402 type Error = VortexError;
403
404 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
405 let prim = PrimitiveScalar::try_from(value)?
406 .as_::<u64>()?
407 .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
408 Ok(usize::try_from(prim)?)
409 }
410}
411
412impl TryFrom<&Scalar> for Option<usize> {
413 type Error = VortexError;
414
415 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
416 Ok(PrimitiveScalar::try_from(value)?
417 .as_::<u64>()?
418 .map(usize::try_from)
419 .transpose()?)
420 }
421}
422
423impl From<usize> for Scalar {
425 fn from(value: usize) -> Self {
426 Scalar::primitive(value as u64, Nullability::NonNullable)
427 }
428}
429
430macro_rules! primitive_scalar {
431 ($T:ty) => {
432 impl From<$T> for ScalarValue {
433 fn from(value: $T) -> Self {
434 ScalarValue(InnerScalarValue::Primitive(value.into()))
435 }
436 }
437 };
438}
439
440primitive_scalar!(u8);
441primitive_scalar!(u16);
442primitive_scalar!(u32);
443primitive_scalar!(u64);
444primitive_scalar!(i8);
445primitive_scalar!(i16);
446primitive_scalar!(i32);
447primitive_scalar!(i64);
448primitive_scalar!(f16);
449primitive_scalar!(f32);
450primitive_scalar!(f64);
451
452impl From<PValue> for ScalarValue {
453 fn from(value: PValue) -> Self {
454 ScalarValue(InnerScalarValue::Primitive(value))
455 }
456}
457
458impl From<usize> for ScalarValue {
460 fn from(value: usize) -> Self {
461 ScalarValue(InnerScalarValue::Primitive((value as u64).into()))
462 }
463}
464
465#[derive(Debug, Clone, Copy, PartialEq, Eq)]
466pub enum NumericOperator {
468 Add,
472 Sub,
474 RSub,
476 Mul,
478 Div,
480 RDiv,
482 }
487
488impl Display for NumericOperator {
489 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
490 Debug::fmt(self, f)
491 }
492}
493
494impl NumericOperator {
495 pub fn swap(self) -> Self {
497 match self {
498 NumericOperator::Add => NumericOperator::Add,
499 NumericOperator::Sub => NumericOperator::RSub,
500 NumericOperator::RSub => NumericOperator::Sub,
501 NumericOperator::Mul => NumericOperator::Mul,
502 NumericOperator::Div => NumericOperator::RDiv,
503 NumericOperator::RDiv => NumericOperator::Div,
504 }
505 }
506}
507
508impl<'a> PrimitiveScalar<'a> {
509 pub fn checked_binary_numeric(
517 &self,
518 other: &PrimitiveScalar<'a>,
519 op: NumericOperator,
520 ) -> Option<PrimitiveScalar<'a>> {
521 if !self.dtype().eq_ignore_nullability(other.dtype()) {
522 vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
523 }
524 let result_dtype = if self.dtype().is_nullable() {
525 self.dtype()
526 } else {
527 other.dtype()
528 };
529 let ptype = self.ptype();
530
531 match_each_native_ptype!(
532 self.ptype(),
533 integral: |P| {
534 self.checked_integeral_numeric_operator::<P>(other, result_dtype, ptype, op)
535 },
536 floating: |P| {
537 let lhs = self.typed_value::<P>();
538 let rhs = other.typed_value::<P>();
539 let value_or_null = match (lhs, rhs) {
540 (_, None) | (None, _) => None,
541 (Some(lhs), Some(rhs)) => match op {
542 NumericOperator::Add => Some(lhs + rhs),
543 NumericOperator::Sub => Some(lhs - rhs),
544 NumericOperator::RSub => Some(rhs - lhs),
545 NumericOperator::Mul => Some(lhs * rhs),
546 NumericOperator::Div => Some(lhs / rhs),
547 NumericOperator::RDiv => Some(rhs / lhs),
548 }
549 };
550 Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
551 }
552 )
553 }
554
555 fn checked_integeral_numeric_operator<
556 P: NativePType
557 + TryFrom<PValue, Error = VortexError>
558 + CheckedSub
559 + CheckedAdd
560 + CheckedMul
561 + CheckedDiv,
562 >(
563 &self,
564 other: &PrimitiveScalar<'a>,
565 result_dtype: &'a DType,
566 ptype: PType,
567 op: NumericOperator,
568 ) -> Option<PrimitiveScalar<'a>>
569 where
570 PValue: From<P>,
571 {
572 let lhs = self.typed_value::<P>();
573 let rhs = other.typed_value::<P>();
574 let value_or_null_or_overflow = match (lhs, rhs) {
575 (_, None) | (None, _) => Some(None),
576 (Some(lhs), Some(rhs)) => match op {
577 NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
578 NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
579 NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
580 NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
581 NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
582 NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
583 },
584 };
585
586 value_or_null_or_overflow.map(|value_or_null| Self {
587 dtype: result_dtype,
588 ptype,
589 pvalue: value_or_null.map(PValue::from),
590 })
591 }
592}
593
594#[cfg(test)]
595#[allow(clippy::cast_possible_truncation)]
596mod tests {
597 use num_traits::CheckedSub;
598 use rstest::rstest;
599 use vortex_dtype::{DType, Nullability, PType};
600
601 use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
602
603 #[test]
604 fn test_integer_subtract() {
605 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
606 let p_scalar1 = PrimitiveScalar::try_new(
607 &dtype,
608 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
609 )
610 .unwrap();
611 let p_scalar2 = PrimitiveScalar::try_new(
612 &dtype,
613 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
614 )
615 .unwrap();
616 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
617 let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
618 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
619
620 assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
621 }
622
623 #[test]
624 #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
625 #[allow(clippy::assertions_on_constants)]
626 fn test_integer_subtract_overflow() {
627 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
628 let p_scalar1 = PrimitiveScalar::try_new(
629 &dtype,
630 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
631 )
632 .unwrap();
633 let p_scalar2 = PrimitiveScalar::try_new(
634 &dtype,
635 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
636 )
637 .unwrap();
638 let _ = p_scalar1 - p_scalar2;
639 }
640
641 #[test]
642 fn test_float_subtract() {
643 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
644 let p_scalar1 = PrimitiveScalar::try_new(
645 &dtype,
646 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
647 )
648 .unwrap();
649 let p_scalar2 = PrimitiveScalar::try_new(
650 &dtype,
651 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
652 )
653 .unwrap();
654 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
655 let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
656 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
657
658 assert_eq!(
659 (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
660 0.99f32
661 );
662 }
663
664 #[test]
665 fn test_primitive_scalar_equality() {
666 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
667 let scalar1 = PrimitiveScalar::try_new(
668 &dtype,
669 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
670 )
671 .unwrap();
672 let scalar2 = PrimitiveScalar::try_new(
673 &dtype,
674 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
675 )
676 .unwrap();
677 let scalar3 = PrimitiveScalar::try_new(
678 &dtype,
679 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))),
680 )
681 .unwrap();
682
683 assert_eq!(scalar1, scalar2);
684 assert_ne!(scalar1, scalar3);
685 }
686
687 #[test]
688 fn test_primitive_scalar_partial_ord() {
689 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
690 let scalar1 = PrimitiveScalar::try_new(
691 &dtype,
692 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
693 )
694 .unwrap();
695 let scalar2 = PrimitiveScalar::try_new(
696 &dtype,
697 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
698 )
699 .unwrap();
700
701 assert!(scalar1 < scalar2);
702 assert!(scalar2 > scalar1);
703 assert_eq!(
704 scalar1.partial_cmp(&scalar1),
705 Some(std::cmp::Ordering::Equal)
706 );
707 }
708
709 #[test]
710 fn test_primitive_scalar_null_handling() {
711 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
712 let null_scalar =
713 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
714
715 assert_eq!(null_scalar.pvalue(), None);
716 assert_eq!(null_scalar.typed_value::<i32>(), None);
717 }
718
719 #[test]
720 fn test_typed_value_correct_type() {
721 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
722 let scalar = PrimitiveScalar::try_new(
723 &dtype,
724 &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
725 )
726 .unwrap();
727
728 assert_eq!(scalar.typed_value::<f64>(), Some(3.5));
729 }
730
731 #[test]
732 #[should_panic(expected = "Attempting to read")]
733 fn test_typed_value_wrong_type() {
734 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
735 let scalar = PrimitiveScalar::try_new(
736 &dtype,
737 &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))),
738 )
739 .unwrap();
740
741 let _ = scalar.typed_value::<i32>();
742 }
743
744 #[rstest]
745 #[case(PType::I8, 127i32, PType::I16, true)]
746 #[case(PType::I8, 127i32, PType::I32, true)]
747 #[case(PType::I8, 127i32, PType::I64, true)]
748 #[case(PType::U8, 255i32, PType::U16, true)]
749 #[case(PType::U8, 255i32, PType::U32, true)]
750 #[case(PType::I32, 42i32, PType::F32, true)]
751 #[case(PType::I32, 42i32, PType::F64, true)]
752 #[case(PType::I32, 300i32, PType::U8, false)]
754 #[case(PType::I32, -1i32, PType::U32, false)]
755 #[case(PType::I32, 256i32, PType::I8, false)]
756 #[case(PType::U16, 65535i32, PType::I8, false)]
757 fn test_primitive_cast(
758 #[case] source_type: PType,
759 #[case] source_value: i32,
760 #[case] target_type: PType,
761 #[case] should_succeed: bool,
762 ) {
763 let source_pvalue = match source_type {
764 PType::I8 => PValue::I8(source_value as i8),
765 PType::U8 => PValue::U8(source_value as u8),
766 PType::U16 => PValue::U16(source_value as u16),
767 PType::I32 => PValue::I32(source_value),
768 _ => unreachable!("Test case uses unexpected source type"),
769 };
770
771 let dtype = DType::Primitive(source_type, Nullability::NonNullable);
772 let scalar = PrimitiveScalar::try_new(
773 &dtype,
774 &ScalarValue(InnerScalarValue::Primitive(source_pvalue)),
775 )
776 .unwrap();
777
778 let target_dtype = DType::Primitive(target_type, Nullability::NonNullable);
779 let result = scalar.cast(&target_dtype);
780
781 if should_succeed {
782 assert!(
783 result.is_ok(),
784 "Cast from {:?} to {:?} should succeed",
785 source_type,
786 target_type
787 );
788 } else {
789 assert!(
790 result.is_err(),
791 "Cast from {:?} to {:?} should fail due to overflow",
792 source_type,
793 target_type
794 );
795 }
796 }
797
798 #[test]
799 fn test_as_conversion_success() {
800 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
801 let scalar = PrimitiveScalar::try_new(
802 &dtype,
803 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
804 )
805 .unwrap();
806
807 assert_eq!(scalar.as_::<i64>().unwrap(), Some(42i64));
808 assert_eq!(scalar.as_::<f64>().unwrap(), Some(42.0));
809 }
810
811 #[test]
812 fn test_as_conversion_overflow() {
813 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
814 let scalar = PrimitiveScalar::try_new(
815 &dtype,
816 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))),
817 )
818 .unwrap();
819
820 let result = scalar.as_::<u32>();
822 assert!(result.is_err());
823 }
824
825 #[test]
826 fn test_as_conversion_null() {
827 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
828 let scalar =
829 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
830
831 assert_eq!(scalar.as_::<i32>().unwrap(), None);
832 assert_eq!(scalar.as_::<f64>().unwrap(), None);
833 }
834
835 #[test]
836 fn test_numeric_operator_swap() {
837 use crate::primitive::NumericOperator;
838
839 assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add);
840 assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub);
841 assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub);
842 assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul);
843 assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv);
844 assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div);
845 }
846
847 #[test]
848 fn test_checked_binary_numeric_add() {
849 use crate::primitive::NumericOperator;
850
851 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
852 let scalar1 = PrimitiveScalar::try_new(
853 &dtype,
854 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
855 )
856 .unwrap();
857 let scalar2 = PrimitiveScalar::try_new(
858 &dtype,
859 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
860 )
861 .unwrap();
862
863 let result = scalar1
864 .checked_binary_numeric(&scalar2, NumericOperator::Add)
865 .unwrap();
866 assert_eq!(result.typed_value::<i32>(), Some(30));
867 }
868
869 #[test]
870 fn test_checked_binary_numeric_overflow() {
871 use crate::primitive::NumericOperator;
872
873 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
874 let scalar1 = PrimitiveScalar::try_new(
875 &dtype,
876 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
877 )
878 .unwrap();
879 let scalar2 = PrimitiveScalar::try_new(
880 &dtype,
881 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))),
882 )
883 .unwrap();
884
885 let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add);
887 assert!(result.is_none());
888 }
889
890 #[test]
891 fn test_checked_binary_numeric_with_null() {
892 use crate::primitive::NumericOperator;
893
894 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
895 let scalar1 = PrimitiveScalar::try_new(
896 &dtype,
897 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
898 )
899 .unwrap();
900 let null_scalar =
901 PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap();
902
903 let result = scalar1
905 .checked_binary_numeric(&null_scalar, NumericOperator::Add)
906 .unwrap();
907 assert_eq!(result.pvalue(), None);
908 }
909
910 #[test]
911 fn test_checked_binary_numeric_mul() {
912 use crate::primitive::NumericOperator;
913
914 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
915 let scalar1 = PrimitiveScalar::try_new(
916 &dtype,
917 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
918 )
919 .unwrap();
920 let scalar2 = PrimitiveScalar::try_new(
921 &dtype,
922 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))),
923 )
924 .unwrap();
925
926 let result = scalar1
927 .checked_binary_numeric(&scalar2, NumericOperator::Mul)
928 .unwrap();
929 assert_eq!(result.typed_value::<i32>(), Some(30));
930 }
931
932 #[test]
933 fn test_checked_binary_numeric_div() {
934 use crate::primitive::NumericOperator;
935
936 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
937 let scalar1 = PrimitiveScalar::try_new(
938 &dtype,
939 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
940 )
941 .unwrap();
942 let scalar2 = PrimitiveScalar::try_new(
943 &dtype,
944 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
945 )
946 .unwrap();
947
948 let result = scalar1
949 .checked_binary_numeric(&scalar2, NumericOperator::Div)
950 .unwrap();
951 assert_eq!(result.typed_value::<i32>(), Some(5));
952 }
953
954 #[test]
955 fn test_checked_binary_numeric_rdiv() {
956 use crate::primitive::NumericOperator;
957
958 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
959 let scalar1 = PrimitiveScalar::try_new(
960 &dtype,
961 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
962 )
963 .unwrap();
964 let scalar2 = PrimitiveScalar::try_new(
965 &dtype,
966 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))),
967 )
968 .unwrap();
969
970 let result = scalar1
972 .checked_binary_numeric(&scalar2, NumericOperator::RDiv)
973 .unwrap();
974 assert_eq!(result.typed_value::<i32>(), Some(5));
975 }
976
977 #[test]
978 fn test_checked_binary_numeric_div_by_zero() {
979 use crate::primitive::NumericOperator;
980
981 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
982 let scalar1 = PrimitiveScalar::try_new(
983 &dtype,
984 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
985 )
986 .unwrap();
987 let scalar2 = PrimitiveScalar::try_new(
988 &dtype,
989 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))),
990 )
991 .unwrap();
992
993 let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div);
995 assert!(result.is_none());
996 }
997
998 #[test]
999 fn test_checked_binary_numeric_float_ops() {
1000 use crate::primitive::NumericOperator;
1001
1002 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
1003 let scalar1 = PrimitiveScalar::try_new(
1004 &dtype,
1005 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1006 )
1007 .unwrap();
1008 let scalar2 = PrimitiveScalar::try_new(
1009 &dtype,
1010 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))),
1011 )
1012 .unwrap();
1013
1014 let add_result = scalar1
1016 .checked_binary_numeric(&scalar2, NumericOperator::Add)
1017 .unwrap();
1018 assert_eq!(add_result.typed_value::<f32>(), Some(12.5));
1019
1020 let sub_result = scalar1
1021 .checked_binary_numeric(&scalar2, NumericOperator::Sub)
1022 .unwrap();
1023 assert_eq!(sub_result.typed_value::<f32>(), Some(7.5));
1024
1025 let mul_result = scalar1
1026 .checked_binary_numeric(&scalar2, NumericOperator::Mul)
1027 .unwrap();
1028 assert_eq!(mul_result.typed_value::<f32>(), Some(25.0));
1029
1030 let div_result = scalar1
1031 .checked_binary_numeric(&scalar2, NumericOperator::Div)
1032 .unwrap();
1033 assert_eq!(div_result.typed_value::<f32>(), Some(4.0));
1034 }
1035
1036 #[test]
1037 fn test_from_primitive_or_f16() {
1038 use vortex_dtype::half::f16;
1039
1040 use crate::primitive::FromPrimitiveOrF16;
1041
1042 let f16_val = f16::from_f32(3.5);
1044 assert!(f32::from_f16(f16_val).is_some());
1045
1046 assert!(f64::from_f16(f16_val).is_some());
1048
1049 assert!(i32::from_f16(f16_val).is_none());
1051 assert!(u32::from_f16(f16_val).is_none());
1052 }
1053
1054 #[test]
1055 fn test_partial_ord_different_types() {
1056 let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable);
1057 let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable);
1058
1059 let scalar1 = PrimitiveScalar::try_new(
1060 &dtype1,
1061 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))),
1062 )
1063 .unwrap();
1064 let scalar2 = PrimitiveScalar::try_new(
1065 &dtype2,
1066 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))),
1067 )
1068 .unwrap();
1069
1070 assert_eq!(scalar1.partial_cmp(&scalar2), None);
1072 }
1073
1074 #[test]
1075 fn test_scalar_value_from_usize() {
1076 let value: ScalarValue = 42usize.into();
1077 assert!(matches!(
1078 value.0,
1079 InnerScalarValue::Primitive(PValue::U64(42))
1080 ));
1081 }
1082
1083 #[test]
1084 fn test_getters() {
1085 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
1086 let scalar = PrimitiveScalar::try_new(
1087 &dtype,
1088 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))),
1089 )
1090 .unwrap();
1091
1092 assert_eq!(scalar.dtype(), &dtype);
1093 assert_eq!(scalar.ptype(), PType::I32);
1094 assert_eq!(scalar.pvalue(), Some(PValue::I32(42)));
1095 }
1096}