1use std::any::type_name;
2use std::cmp::Ordering;
3use std::fmt::{Debug, Display, Formatter};
4use std::ops::{Add, Sub};
5
6use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{
10 VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
11};
12
13use crate::pvalue::PValue;
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16#[derive(Debug, Clone, Copy, Hash)]
17pub struct PrimitiveScalar<'a> {
18 dtype: &'a DType,
19 ptype: PType,
20 pvalue: Option<PValue>,
21}
22
23impl Display for PrimitiveScalar<'_> {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 match self.pvalue {
26 None => write!(f, "null"),
27 Some(pv) => write!(f, "{pv}"),
28 }
29 }
30}
31
32impl PartialEq for PrimitiveScalar<'_> {
33 fn eq(&self, other: &Self) -> bool {
34 self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
35 }
36}
37
38impl Eq for PrimitiveScalar<'_> {}
39
40impl PartialOrd for PrimitiveScalar<'_> {
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43 if !self.dtype.eq_ignore_nullability(other.dtype) {
44 return None;
45 }
46 self.pvalue.partial_cmp(&other.pvalue)
47 }
48}
49
50impl<'a> PrimitiveScalar<'a> {
51 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
52 let ptype = PType::try_from(dtype)?;
53
54 let pvalue = match_each_native_ptype!(ptype, |T| {
57 if let Some(pvalue) = value.as_pvalue()? {
58 Some(PValue::from(<T>::try_from(pvalue)?))
59 } else {
60 None
61 }
62 });
63
64 Ok(Self {
65 dtype,
66 ptype,
67 pvalue,
68 })
69 }
70
71 #[inline]
72 pub fn dtype(&self) -> &'a DType {
73 self.dtype
74 }
75
76 #[inline]
77 pub fn ptype(&self) -> PType {
78 self.ptype
79 }
80
81 #[inline]
82 pub fn pvalue(&self) -> Option<PValue> {
83 self.pvalue
84 }
85
86 pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
87 assert_eq!(
88 self.ptype,
89 T::PTYPE,
90 "Attempting to read {} scalar as {}",
91 self.ptype,
92 T::PTYPE
93 );
94
95 self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
96 }
97
98 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
99 let ptype = PType::try_from(dtype)?;
100 let pvalue = self
101 .pvalue
102 .vortex_expect("nullness handled in Scalar::cast");
103 Ok(match_each_native_ptype!(ptype, |Q| {
104 Scalar::primitive(
105 pvalue.as_primitive::<Q>().map_err(|err| {
106 vortex_err!(
107 "Can't cast {} scalar {} to {} (cause: {})",
108 self.ptype,
109 pvalue,
110 dtype,
111 err
112 )
113 })?,
114 dtype.nullability(),
115 )
116 }))
117 }
118
119 pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
122 match self.pvalue {
123 None => Ok(None),
124 Some(pv) => Ok(Some(match pv {
125 PValue::U8(v) => T::from_u8(v)
126 .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
127 PValue::U16(v) => T::from_u16(v)
128 .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
129 PValue::U32(v) => T::from_u32(v)
130 .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
131 PValue::U64(v) => T::from_u64(v)
132 .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
133 PValue::I8(v) => T::from_i8(v)
134 .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
135 PValue::I16(v) => T::from_i16(v)
136 .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
137 PValue::I32(v) => T::from_i32(v)
138 .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
139 PValue::I64(v) => T::from_i64(v)
140 .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
141 PValue::F16(v) => T::from_f16(v)
142 .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
143 PValue::F32(v) => T::from_f32(v)
144 .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
145 PValue::F64(v) => T::from_f64(v)
146 .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
147 }?)),
148 }
149 }
150}
151
152pub trait FromPrimitiveOrF16: FromPrimitive {
153 fn from_f16(v: f16) -> Option<Self>;
154}
155
156macro_rules! from_primitive_or_f16_for_non_floating_point {
157 ($T:ty) => {
158 impl FromPrimitiveOrF16 for $T {
159 fn from_f16(_: f16) -> Option<Self> {
160 None
161 }
162 }
163 };
164}
165
166from_primitive_or_f16_for_non_floating_point!(usize);
167from_primitive_or_f16_for_non_floating_point!(u8);
168from_primitive_or_f16_for_non_floating_point!(u16);
169from_primitive_or_f16_for_non_floating_point!(u32);
170from_primitive_or_f16_for_non_floating_point!(u64);
171from_primitive_or_f16_for_non_floating_point!(i8);
172from_primitive_or_f16_for_non_floating_point!(i16);
173from_primitive_or_f16_for_non_floating_point!(i32);
174from_primitive_or_f16_for_non_floating_point!(i64);
175
176impl FromPrimitiveOrF16 for f16 {
177 fn from_f16(v: f16) -> Option<Self> {
178 Some(v)
179 }
180}
181
182impl FromPrimitiveOrF16 for f32 {
183 fn from_f16(v: f16) -> Option<Self> {
184 Some(v.to_f32())
185 }
186}
187
188impl FromPrimitiveOrF16 for f64 {
189 fn from_f16(v: f16) -> Option<Self> {
190 Some(v.to_f64())
191 }
192}
193
194impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
195 type Error = VortexError;
196
197 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
198 Self::try_new(value.dtype(), value.value())
199 }
200}
201
202impl Sub for PrimitiveScalar<'_> {
203 type Output = Self;
204
205 fn sub(self, rhs: Self) -> Self::Output {
206 self.checked_sub(&rhs)
207 .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
208 }
209}
210
211impl CheckedSub for PrimitiveScalar<'_> {
212 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
213 self.checked_binary_numeric(rhs, NumericOperator::Sub)
214 }
215}
216
217impl Add for PrimitiveScalar<'_> {
218 type Output = Self;
219
220 fn add(self, rhs: Self) -> Self::Output {
221 self.checked_add(&rhs)
222 .vortex_expect("PrimitiveScalar add: overflow or underflow")
223 }
224}
225
226impl CheckedAdd for PrimitiveScalar<'_> {
227 fn checked_add(&self, rhs: &Self) -> Option<Self> {
228 self.checked_binary_numeric(rhs, NumericOperator::Add)
229 }
230}
231
232impl Scalar {
233 pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
234 Self::primitive_value(value.into(), T::PTYPE, nullability)
235 }
236
237 pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
242 Self {
243 dtype: DType::Primitive(ptype, nullability),
244 value: ScalarValue(InnerScalarValue::Primitive(value)),
245 }
246 }
247
248 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
249 let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
250 vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
251 });
252 if primitive.ptype() == ptype {
253 return self.clone();
254 }
255
256 assert_eq!(
257 primitive.ptype().byte_width(),
258 ptype.byte_width(),
259 "can't reinterpret cast between integers of two different widths"
260 );
261
262 Scalar::new(
263 DType::Primitive(ptype, self.dtype.nullability()),
264 primitive
265 .pvalue
266 .map(|p| p.reinterpret_cast(ptype))
267 .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
268 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
269 )
270 }
271}
272
273macro_rules! primitive_scalar {
274 ($T:ty) => {
275 impl TryFrom<&Scalar> for $T {
276 type Error = VortexError;
277
278 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
279 <Option<$T>>::try_from(value)?
280 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
281 }
282 }
283
284 impl TryFrom<Scalar> for $T {
285 type Error = VortexError;
286
287 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
288 <$T>::try_from(&value)
289 }
290 }
291
292 impl TryFrom<&Scalar> for Option<$T> {
293 type Error = VortexError;
294
295 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
296 Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
297 }
298 }
299
300 impl TryFrom<Scalar> for Option<$T> {
301 type Error = VortexError;
302
303 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
304 <Option<$T>>::try_from(&value)
305 }
306 }
307
308 impl From<$T> for Scalar {
309 fn from(value: $T) -> Self {
310 Scalar {
311 dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
312 value: ScalarValue(InnerScalarValue::Primitive(value.into())),
313 }
314 }
315 }
316 };
317}
318
319primitive_scalar!(u8);
320primitive_scalar!(u16);
321primitive_scalar!(u32);
322primitive_scalar!(u64);
323primitive_scalar!(i8);
324primitive_scalar!(i16);
325primitive_scalar!(i32);
326primitive_scalar!(i64);
327primitive_scalar!(f16);
328primitive_scalar!(f32);
329primitive_scalar!(f64);
330
331impl TryFrom<&Scalar> for usize {
333 type Error = VortexError;
334
335 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
336 let prim = PrimitiveScalar::try_from(value)?
337 .as_::<u64>()?
338 .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
339 Ok(usize::try_from(prim)?)
340 }
341}
342
343impl TryFrom<&Scalar> for Option<usize> {
344 type Error = VortexError;
345
346 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
347 Ok(PrimitiveScalar::try_from(value)?
348 .as_::<u64>()?
349 .map(usize::try_from)
350 .transpose()?)
351 }
352}
353
354impl From<usize> for Scalar {
356 fn from(value: usize) -> Self {
357 Scalar::primitive(value as u64, Nullability::NonNullable)
358 }
359}
360
361#[derive(Debug, Clone, Copy, PartialEq, Eq)]
362pub enum NumericOperator {
364 Add,
366 Sub,
368 RSub,
370 Mul,
372 Div,
374 RDiv,
376 }
381
382impl Display for NumericOperator {
383 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
384 Debug::fmt(self, f)
385 }
386}
387
388impl NumericOperator {
389 pub fn swap(self) -> Self {
390 match self {
391 NumericOperator::Add => NumericOperator::Add,
392 NumericOperator::Sub => NumericOperator::RSub,
393 NumericOperator::RSub => NumericOperator::Sub,
394 NumericOperator::Mul => NumericOperator::Mul,
395 NumericOperator::Div => NumericOperator::RDiv,
396 NumericOperator::RDiv => NumericOperator::Div,
397 }
398 }
399}
400
401impl<'a> PrimitiveScalar<'a> {
402 pub fn checked_binary_numeric(
410 &self,
411 other: &PrimitiveScalar<'a>,
412 op: NumericOperator,
413 ) -> Option<PrimitiveScalar<'a>> {
414 if !self.dtype().eq_ignore_nullability(other.dtype()) {
415 vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
416 }
417 let result_dtype = if self.dtype().is_nullable() {
418 self.dtype()
419 } else {
420 other.dtype()
421 };
422 let ptype = self.ptype();
423
424 match_each_native_ptype!(
425 self.ptype(),
426 integral: |P| {
427 self.checked_integeral_numeric_operator::<P>(other, result_dtype, ptype, op)
428 },
429 floating_point: |P| {
430 let lhs = self.typed_value::<P>();
431 let rhs = other.typed_value::<P>();
432 let value_or_null = match (lhs, rhs) {
433 (_, None) | (None, _) => None,
434 (Some(lhs), Some(rhs)) => match op {
435 NumericOperator::Add => Some(lhs + rhs),
436 NumericOperator::Sub => Some(lhs - rhs),
437 NumericOperator::RSub => Some(rhs - lhs),
438 NumericOperator::Mul => Some(lhs * rhs),
439 NumericOperator::Div => Some(lhs / rhs),
440 NumericOperator::RDiv => Some(rhs / lhs),
441 }
442 };
443 Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
444 }
445 )
446 }
447
448 fn checked_integeral_numeric_operator<
449 P: NativePType
450 + TryFrom<PValue, Error = VortexError>
451 + CheckedSub
452 + CheckedAdd
453 + CheckedMul
454 + CheckedDiv,
455 >(
456 &self,
457 other: &PrimitiveScalar<'a>,
458 result_dtype: &'a DType,
459 ptype: PType,
460 op: NumericOperator,
461 ) -> Option<PrimitiveScalar<'a>>
462 where
463 PValue: From<P>,
464 {
465 let lhs = self.typed_value::<P>();
466 let rhs = other.typed_value::<P>();
467 let value_or_null_or_overflow = match (lhs, rhs) {
468 (_, None) | (None, _) => Some(None),
469 (Some(lhs), Some(rhs)) => match op {
470 NumericOperator::Add => lhs.checked_add(&rhs).map(Some),
471 NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
472 NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
473 NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
474 NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
475 NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
476 },
477 };
478
479 value_or_null_or_overflow.map(|value_or_null| Self {
480 dtype: result_dtype,
481 ptype,
482 pvalue: value_or_null.map(PValue::from),
483 })
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use num_traits::CheckedSub;
490 use vortex_dtype::{DType, Nullability, PType};
491
492 use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
493
494 #[test]
495 fn test_integer_subtract() {
496 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
497 let p_scalar1 = PrimitiveScalar::try_new(
498 &dtype,
499 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
500 )
501 .unwrap();
502 let p_scalar2 = PrimitiveScalar::try_new(
503 &dtype,
504 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
505 )
506 .unwrap();
507 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
508 let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
509 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
510
511 assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
512 }
513
514 #[test]
515 #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
516 #[allow(clippy::assertions_on_constants)]
517 fn test_integer_subtract_overflow() {
518 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
519 let p_scalar1 = PrimitiveScalar::try_new(
520 &dtype,
521 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
522 )
523 .unwrap();
524 let p_scalar2 = PrimitiveScalar::try_new(
525 &dtype,
526 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
527 )
528 .unwrap();
529 let _ = p_scalar1 - p_scalar2;
530 }
531
532 #[test]
533 fn test_float_subtract() {
534 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
535 let p_scalar1 = PrimitiveScalar::try_new(
536 &dtype,
537 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
538 )
539 .unwrap();
540 let p_scalar2 = PrimitiveScalar::try_new(
541 &dtype,
542 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
543 )
544 .unwrap();
545 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
546 let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
547 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
548
549 assert_eq!(
550 (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
551 0.99f32
552 );
553 }
554}