1use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt::{Debug, Display, Formatter};
7use std::ops::{Add, Sub};
8
9use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
10use vortex_dtype::half::f16;
11use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
12use vortex_error::{
13 VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
14};
15
16use crate::pvalue::PValue;
17use crate::{InnerScalarValue, Scalar, ScalarValue};
18
19#[derive(Debug, Clone, Copy, Hash)]
24pub struct PrimitiveScalar<'a> {
25 dtype: &'a DType,
26 ptype: PType,
27 pvalue: Option<PValue>,
28}
29
30impl Display for PrimitiveScalar<'_> {
31 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32 match self.pvalue {
33 None => write!(f, "null"),
34 Some(pv) => write!(f, "{pv}"),
35 }
36 }
37}
38
39impl PartialEq for PrimitiveScalar<'_> {
40 fn eq(&self, other: &Self) -> bool {
41 self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
42 }
43}
44
45impl Eq for PrimitiveScalar<'_> {}
46
47impl PartialOrd for PrimitiveScalar<'_> {
49 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50 if !self.dtype.eq_ignore_nullability(other.dtype) {
51 return None;
52 }
53 self.pvalue.partial_cmp(&other.pvalue)
54 }
55}
56
57impl<'a> PrimitiveScalar<'a> {
58 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
65 let ptype = PType::try_from(dtype)?;
66
67 let pvalue = match_each_native_ptype!(ptype, |T| {
70 if let Some(pvalue) = value.as_pvalue()? {
71 Some(PValue::from(<T>::try_from(pvalue)?))
72 } else {
73 None
74 }
75 });
76
77 Ok(Self {
78 dtype,
79 ptype,
80 pvalue,
81 })
82 }
83
84 #[inline]
86 pub fn dtype(&self) -> &'a DType {
87 self.dtype
88 }
89
90 #[inline]
92 pub fn ptype(&self) -> PType {
93 self.ptype
94 }
95
96 #[inline]
98 pub fn pvalue(&self) -> Option<PValue> {
99 self.pvalue
100 }
101
102 pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
108 assert_eq!(
109 self.ptype,
110 T::PTYPE,
111 "Attempting to read {} scalar as {}",
112 self.ptype,
113 T::PTYPE
114 );
115
116 self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
117 }
118
119 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
120 let ptype = PType::try_from(dtype)?;
121 let pvalue = self
122 .pvalue
123 .vortex_expect("nullness handled in Scalar::cast");
124 Ok(match_each_native_ptype!(ptype, |Q| {
125 Scalar::primitive(
126 pvalue.as_primitive::<Q>().map_err(|err| {
127 vortex_err!(
128 "Can't cast {} scalar {} to {} (cause: {})",
129 self.ptype,
130 pvalue,
131 dtype,
132 err
133 )
134 })?,
135 dtype.nullability(),
136 )
137 }))
138 }
139
140 pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
146 match self.pvalue {
147 None => Ok(None),
148 Some(pv) => Ok(Some(match pv {
149 PValue::U8(v) => T::from_u8(v)
150 .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
151 PValue::U16(v) => T::from_u16(v)
152 .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
153 PValue::U32(v) => T::from_u32(v)
154 .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
155 PValue::U64(v) => T::from_u64(v)
156 .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
157 PValue::I8(v) => T::from_i8(v)
158 .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
159 PValue::I16(v) => T::from_i16(v)
160 .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
161 PValue::I32(v) => T::from_i32(v)
162 .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
163 PValue::I64(v) => T::from_i64(v)
164 .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
165 PValue::F16(v) => T::from_f16(v)
166 .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
167 PValue::F32(v) => T::from_f32(v)
168 .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
169 PValue::F64(v) => T::from_f64(v)
170 .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
171 }?)),
172 }
173 }
174}
175
176pub trait FromPrimitiveOrF16: FromPrimitive {
180 fn from_f16(v: f16) -> Option<Self>;
182}
183
184macro_rules! from_primitive_or_f16_for_non_floating_point {
185 ($T:ty) => {
186 impl FromPrimitiveOrF16 for $T {
187 fn from_f16(_: f16) -> Option<Self> {
188 None
189 }
190 }
191 };
192}
193
194from_primitive_or_f16_for_non_floating_point!(usize);
195from_primitive_or_f16_for_non_floating_point!(u8);
196from_primitive_or_f16_for_non_floating_point!(u16);
197from_primitive_or_f16_for_non_floating_point!(u32);
198from_primitive_or_f16_for_non_floating_point!(u64);
199from_primitive_or_f16_for_non_floating_point!(i8);
200from_primitive_or_f16_for_non_floating_point!(i16);
201from_primitive_or_f16_for_non_floating_point!(i32);
202from_primitive_or_f16_for_non_floating_point!(i64);
203
204impl FromPrimitiveOrF16 for f16 {
205 fn from_f16(v: f16) -> Option<Self> {
206 Some(v)
207 }
208}
209
210impl FromPrimitiveOrF16 for f32 {
211 fn from_f16(v: f16) -> Option<Self> {
212 Some(v.to_f32())
213 }
214}
215
216impl FromPrimitiveOrF16 for f64 {
217 fn from_f16(v: f16) -> Option<Self> {
218 Some(v.to_f64())
219 }
220}
221
222impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
223 type Error = VortexError;
224
225 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
226 Self::try_new(value.dtype(), value.value())
227 }
228}
229
230impl Sub for PrimitiveScalar<'_> {
231 type Output = Self;
232
233 fn sub(self, rhs: Self) -> Self::Output {
234 self.checked_sub(&rhs)
235 .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
236 }
237}
238
239impl CheckedSub for PrimitiveScalar<'_> {
240 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
241 self.checked_binary_numeric(rhs, NumericOperator::Sub)
242 }
243}
244
245impl Add for PrimitiveScalar<'_> {
246 type Output = Self;
247
248 fn add(self, rhs: Self) -> Self::Output {
249 self.checked_add(&rhs)
250 .vortex_expect("PrimitiveScalar add: overflow or underflow")
251 }
252}
253
254impl CheckedAdd for PrimitiveScalar<'_> {
255 fn checked_add(&self, rhs: &Self) -> Option<Self> {
256 self.checked_binary_numeric(rhs, NumericOperator::Add)
257 }
258}
259
260impl Scalar {
261 pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
263 Self::primitive_value(value.into(), T::PTYPE, nullability)
264 }
265
266 pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
271 Self {
272 dtype: DType::Primitive(ptype, nullability),
273 value: ScalarValue(InnerScalarValue::Primitive(value)),
274 }
275 }
276
277 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
283 let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
284 vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
285 });
286 if primitive.ptype() == ptype {
287 return self.clone();
288 }
289
290 assert_eq!(
291 primitive.ptype().byte_width(),
292 ptype.byte_width(),
293 "can't reinterpret cast between integers of two different widths"
294 );
295
296 Scalar::new(
297 DType::Primitive(ptype, self.dtype.nullability()),
298 primitive
299 .pvalue
300 .map(|p| p.reinterpret_cast(ptype))
301 .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
302 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
303 )
304 }
305}
306
307macro_rules! primitive_scalar {
308 ($T:ty) => {
309 impl TryFrom<&Scalar> for $T {
310 type Error = VortexError;
311
312 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
313 <Option<$T>>::try_from(value)?
314 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
315 }
316 }
317
318 impl TryFrom<Scalar> for $T {
319 type Error = VortexError;
320
321 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
322 <$T>::try_from(&value)
323 }
324 }
325
326 impl TryFrom<&Scalar> for Option<$T> {
327 type Error = VortexError;
328
329 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
330 Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
331 }
332 }
333
334 impl TryFrom<Scalar> for Option<$T> {
335 type Error = VortexError;
336
337 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
338 <Option<$T>>::try_from(&value)
339 }
340 }
341
342 impl From<$T> for Scalar {
343 fn from(value: $T) -> Self {
344 Scalar {
345 dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
346 value: ScalarValue(InnerScalarValue::Primitive(value.into())),
347 }
348 }
349 }
350 };
351}
352
353primitive_scalar!(u8);
354primitive_scalar!(u16);
355primitive_scalar!(u32);
356primitive_scalar!(u64);
357primitive_scalar!(i8);
358primitive_scalar!(i16);
359primitive_scalar!(i32);
360primitive_scalar!(i64);
361primitive_scalar!(f16);
362primitive_scalar!(f32);
363primitive_scalar!(f64);
364
365impl TryFrom<&Scalar> for usize {
367 type Error = VortexError;
368
369 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
370 let prim = PrimitiveScalar::try_from(value)?
371 .as_::<u64>()?
372 .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
373 Ok(usize::try_from(prim)?)
374 }
375}
376
377impl TryFrom<&Scalar> for Option<usize> {
378 type Error = VortexError;
379
380 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
381 Ok(PrimitiveScalar::try_from(value)?
382 .as_::<u64>()?
383 .map(usize::try_from)
384 .transpose()?)
385 }
386}
387
388impl From<usize> for Scalar {
390 fn from(value: usize) -> Self {
391 Scalar::primitive(value as u64, Nullability::NonNullable)
392 }
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Eq)]
396pub enum NumericOperator {
398 Add,
402 Sub,
404 RSub,
406 Mul,
408 Div,
410 RDiv,
412 }
417
418impl Display for NumericOperator {
419 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
420 Debug::fmt(self, f)
421 }
422}
423
424impl NumericOperator {
425 pub fn swap(self) -> Self {
427 match self {
428 NumericOperator::Add => NumericOperator::Add,
429 NumericOperator::Sub => NumericOperator::RSub,
430 NumericOperator::RSub => NumericOperator::Sub,
431 NumericOperator::Mul => NumericOperator::Mul,
432 NumericOperator::Div => NumericOperator::RDiv,
433 NumericOperator::RDiv => NumericOperator::Div,
434 }
435 }
436}
437
438impl<'a> PrimitiveScalar<'a> {
439 pub fn checked_binary_numeric(
447 &self,
448 other: &PrimitiveScalar<'a>,
449 op: NumericOperator,
450 ) -> Option<PrimitiveScalar<'a>> {
451 if !self.dtype().eq_ignore_nullability(other.dtype()) {
452 vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
453 }
454 let result_dtype = if self.dtype().is_nullable() {
455 self.dtype()
456 } else {
457 other.dtype()
458 };
459 let ptype = self.ptype();
460
461 match_each_native_ptype!(
462 self.ptype(),
463 integral: |P| {
464 self.checked_integeral_numeric_operator::<P>(other, result_dtype, ptype, op)
465 },
466 floating: |P| {
467 let lhs = self.typed_value::<P>();
468 let rhs = other.typed_value::<P>();
469 let value_or_null = match (lhs, rhs) {
470 (_, None) | (None, _) => None,
471 (Some(lhs), Some(rhs)) => match op {
472 NumericOperator::Add => Some(lhs + rhs),
473 NumericOperator::Sub => Some(lhs - rhs),
474 NumericOperator::RSub => Some(rhs - lhs),
475 NumericOperator::Mul => Some(lhs * rhs),
476 NumericOperator::Div => Some(lhs / rhs),
477 NumericOperator::RDiv => Some(rhs / lhs),
478 }
479 };
480 Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
481 }
482 )
483 }
484
485 fn checked_integeral_numeric_operator<
486 P: NativePType
487 + TryFrom<PValue, Error = VortexError>
488 + CheckedSub
489 + CheckedAdd
490 + CheckedMul
491 + CheckedDiv,
492 >(
493 &self,
494 other: &PrimitiveScalar<'a>,
495 result_dtype: &'a DType,
496 ptype: PType,
497 op: NumericOperator,
498 ) -> Option<PrimitiveScalar<'a>>
499 where
500 PValue: From<P>,
501 {
502 let lhs = self.typed_value::<P>();
503 let rhs = other.typed_value::<P>();
504 let value_or_null_or_overflow = match (lhs, rhs) {
505 (_, None) | (None, _) => Some(None),
506 (Some(lhs), Some(rhs)) => match op {
507 NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
508 NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
509 NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
510 NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
511 NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
512 NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
513 },
514 };
515
516 value_or_null_or_overflow.map(|value_or_null| Self {
517 dtype: result_dtype,
518 ptype,
519 pvalue: value_or_null.map(PValue::from),
520 })
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use num_traits::CheckedSub;
527 use vortex_dtype::{DType, Nullability, PType};
528
529 use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
530
531 #[test]
532 fn test_integer_subtract() {
533 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
534 let p_scalar1 = PrimitiveScalar::try_new(
535 &dtype,
536 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
537 )
538 .unwrap();
539 let p_scalar2 = PrimitiveScalar::try_new(
540 &dtype,
541 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
542 )
543 .unwrap();
544 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
545 let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
546 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
547
548 assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
549 }
550
551 #[test]
552 #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
553 #[allow(clippy::assertions_on_constants)]
554 fn test_integer_subtract_overflow() {
555 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
556 let p_scalar1 = PrimitiveScalar::try_new(
557 &dtype,
558 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
559 )
560 .unwrap();
561 let p_scalar2 = PrimitiveScalar::try_new(
562 &dtype,
563 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
564 )
565 .unwrap();
566 let _ = p_scalar1 - p_scalar2;
567 }
568
569 #[test]
570 fn test_float_subtract() {
571 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
572 let p_scalar1 = PrimitiveScalar::try_new(
573 &dtype,
574 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
575 )
576 .unwrap();
577 let p_scalar2 = PrimitiveScalar::try_new(
578 &dtype,
579 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
580 )
581 .unwrap();
582 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
583 let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
584 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
585
586 assert_eq!(
587 (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
588 0.99f32
589 );
590 }
591}