1use std::any::type_name;
2use std::cmp::Ordering;
3use std::fmt::{Debug, Display};
4use std::ops::{Add, Sub};
5
6use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{
10 VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
11};
12
13use crate::pvalue::PValue;
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16#[derive(Debug, Clone, Copy, Hash)]
17pub struct PrimitiveScalar<'a> {
18 dtype: &'a DType,
19 ptype: PType,
20 pvalue: Option<PValue>,
21}
22
23impl PartialEq for PrimitiveScalar<'_> {
24 fn eq(&self, other: &Self) -> bool {
25 self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
26 }
27}
28
29impl Eq for PrimitiveScalar<'_> {}
30
31impl PartialOrd for PrimitiveScalar<'_> {
33 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
34 if !self.dtype.eq_ignore_nullability(other.dtype) {
35 return None;
36 }
37 self.pvalue.partial_cmp(&other.pvalue)
38 }
39}
40
41impl<'a> PrimitiveScalar<'a> {
42 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
43 let ptype = PType::try_from(dtype)?;
44
45 let pvalue = match_each_native_ptype!(ptype, |$T| {
48 if let Some(pvalue) = value.as_pvalue()? {
49 Some(PValue::from(<$T>::try_from(pvalue)?))
50 } else {
51 None
52 }
53 });
54
55 Ok(Self {
56 dtype,
57 ptype,
58 pvalue,
59 })
60 }
61
62 #[inline]
63 pub fn dtype(&self) -> &'a DType {
64 self.dtype
65 }
66
67 #[inline]
68 pub fn ptype(&self) -> PType {
69 self.ptype
70 }
71
72 #[inline]
73 pub fn pvalue(&self) -> Option<PValue> {
74 self.pvalue
75 }
76
77 pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
78 assert_eq!(
79 self.ptype,
80 T::PTYPE,
81 "Attempting to read {} scalar as {}",
82 self.ptype,
83 T::PTYPE
84 );
85
86 self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
87 }
88
89 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
90 let ptype = PType::try_from(dtype)?;
91 let pvalue = self
92 .pvalue
93 .vortex_expect("nullness handled in Scalar::cast");
94 Ok(match_each_native_ptype!(ptype, |$Q| {
95 Scalar::primitive(
96 pvalue
97 .as_primitive::<$Q>()
98 .map_err(|err| vortex_err!("Can't cast {} scalar {} to {} (cause: {})", self.ptype, pvalue, dtype, err))?,
99 dtype.nullability()
100 )
101 }))
102 }
103
104 pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
107 match self.pvalue {
108 None => Ok(None),
109 Some(pv) => Ok(Some(match pv {
110 PValue::U8(v) => T::from_u8(v)
111 .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
112 PValue::U16(v) => T::from_u16(v)
113 .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
114 PValue::U32(v) => T::from_u32(v)
115 .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
116 PValue::U64(v) => T::from_u64(v)
117 .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
118 PValue::I8(v) => T::from_i8(v)
119 .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
120 PValue::I16(v) => T::from_i16(v)
121 .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
122 PValue::I32(v) => T::from_i32(v)
123 .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
124 PValue::I64(v) => T::from_i64(v)
125 .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
126 PValue::F16(v) => T::from_f16(v)
127 .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
128 PValue::F32(v) => T::from_f32(v)
129 .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
130 PValue::F64(v) => T::from_f64(v)
131 .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
132 }?)),
133 }
134 }
135}
136
137pub trait FromPrimitiveOrF16: FromPrimitive {
138 fn from_f16(v: f16) -> Option<Self>;
139}
140
141macro_rules! from_primitive_or_f16_for_non_floating_point {
142 ($T:ty) => {
143 impl FromPrimitiveOrF16 for $T {
144 fn from_f16(_: f16) -> Option<Self> {
145 None
146 }
147 }
148 };
149}
150
151from_primitive_or_f16_for_non_floating_point!(usize);
152from_primitive_or_f16_for_non_floating_point!(u8);
153from_primitive_or_f16_for_non_floating_point!(u16);
154from_primitive_or_f16_for_non_floating_point!(u32);
155from_primitive_or_f16_for_non_floating_point!(u64);
156from_primitive_or_f16_for_non_floating_point!(i8);
157from_primitive_or_f16_for_non_floating_point!(i16);
158from_primitive_or_f16_for_non_floating_point!(i32);
159from_primitive_or_f16_for_non_floating_point!(i64);
160
161impl FromPrimitiveOrF16 for f16 {
162 fn from_f16(v: f16) -> Option<Self> {
163 Some(v)
164 }
165}
166
167impl FromPrimitiveOrF16 for f32 {
168 fn from_f16(v: f16) -> Option<Self> {
169 Some(v.to_f32())
170 }
171}
172
173impl FromPrimitiveOrF16 for f64 {
174 fn from_f16(v: f16) -> Option<Self> {
175 Some(v.to_f64())
176 }
177}
178
179impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
180 type Error = VortexError;
181
182 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
183 Self::try_new(value.dtype(), value.value())
184 }
185}
186
187impl Sub for PrimitiveScalar<'_> {
188 type Output = Self;
189
190 fn sub(self, rhs: Self) -> Self::Output {
191 self.checked_sub(&rhs)
192 .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
193 }
194}
195
196impl CheckedSub for PrimitiveScalar<'_> {
197 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
198 self.checked_binary_numeric(rhs, BinaryNumericOperator::Sub)
199 }
200}
201
202impl Add for PrimitiveScalar<'_> {
203 type Output = Self;
204
205 fn add(self, rhs: Self) -> Self::Output {
206 self.checked_add(&rhs)
207 .vortex_expect("PrimitiveScalar add: overflow or underflow")
208 }
209}
210
211impl CheckedAdd for PrimitiveScalar<'_> {
212 fn checked_add(&self, rhs: &Self) -> Option<Self> {
213 self.checked_binary_numeric(rhs, BinaryNumericOperator::Add)
214 }
215}
216
217impl Scalar {
218 pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
219 Self::primitive_value(value.into(), T::PTYPE, nullability)
220 }
221
222 pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
227 Self {
228 dtype: DType::Primitive(ptype, nullability),
229 value: ScalarValue(InnerScalarValue::Primitive(value)),
230 }
231 }
232
233 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
234 let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
235 vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
236 });
237 if primitive.ptype() == ptype {
238 return self.clone();
239 }
240
241 assert_eq!(
242 primitive.ptype().byte_width(),
243 ptype.byte_width(),
244 "can't reinterpret cast between integers of two different widths"
245 );
246
247 Scalar::new(
248 DType::Primitive(ptype, self.dtype.nullability()),
249 primitive
250 .pvalue
251 .map(|p| p.reinterpret_cast(ptype))
252 .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
253 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
254 )
255 }
256}
257
258macro_rules! primitive_scalar {
259 ($T:ty) => {
260 impl TryFrom<&Scalar> for $T {
261 type Error = VortexError;
262
263 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
264 <Option<$T>>::try_from(value)?
265 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
266 }
267 }
268
269 impl TryFrom<Scalar> for $T {
270 type Error = VortexError;
271
272 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
273 <$T>::try_from(&value)
274 }
275 }
276
277 impl TryFrom<&Scalar> for Option<$T> {
278 type Error = VortexError;
279
280 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
281 Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
282 }
283 }
284
285 impl TryFrom<Scalar> for Option<$T> {
286 type Error = VortexError;
287
288 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
289 <Option<$T>>::try_from(&value)
290 }
291 }
292
293 impl From<$T> for Scalar {
294 fn from(value: $T) -> Self {
295 Scalar {
296 dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
297 value: ScalarValue(InnerScalarValue::Primitive(value.into())),
298 }
299 }
300 }
301 };
302}
303
304primitive_scalar!(u8);
305primitive_scalar!(u16);
306primitive_scalar!(u32);
307primitive_scalar!(u64);
308primitive_scalar!(i8);
309primitive_scalar!(i16);
310primitive_scalar!(i32);
311primitive_scalar!(i64);
312primitive_scalar!(f16);
313primitive_scalar!(f32);
314primitive_scalar!(f64);
315
316impl TryFrom<&Scalar> for usize {
318 type Error = VortexError;
319
320 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
321 let prim = PrimitiveScalar::try_from(value)?
322 .as_::<u64>()?
323 .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
324 Ok(usize::try_from(prim)?)
325 }
326}
327
328impl From<usize> for Scalar {
330 fn from(value: usize) -> Self {
331 Scalar::primitive(value as u64, Nullability::NonNullable)
332 }
333}
334
335#[derive(Debug, Clone, Copy, PartialEq, Eq)]
336pub enum BinaryNumericOperator {
338 Add,
340 Sub,
342 RSub,
344 Mul,
346 Div,
348 RDiv,
350 }
355
356impl Display for BinaryNumericOperator {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 Debug::fmt(self, f)
359 }
360}
361
362impl BinaryNumericOperator {
363 pub fn swap(self) -> Self {
364 match self {
365 BinaryNumericOperator::Add => BinaryNumericOperator::Add,
366 BinaryNumericOperator::Sub => BinaryNumericOperator::RSub,
367 BinaryNumericOperator::RSub => BinaryNumericOperator::Sub,
368 BinaryNumericOperator::Mul => BinaryNumericOperator::Mul,
369 BinaryNumericOperator::Div => BinaryNumericOperator::RDiv,
370 BinaryNumericOperator::RDiv => BinaryNumericOperator::Div,
371 }
372 }
373}
374
375impl<'a> PrimitiveScalar<'a> {
376 pub fn checked_binary_numeric(
384 &self,
385 other: &PrimitiveScalar<'a>,
386 op: BinaryNumericOperator,
387 ) -> Option<PrimitiveScalar<'a>> {
388 if !self.dtype().eq_ignore_nullability(other.dtype()) {
389 vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
390 }
391 let result_dtype = if self.dtype().is_nullable() {
392 self.dtype()
393 } else {
394 other.dtype()
395 };
396 let ptype = self.ptype();
397
398 match_each_native_ptype!(
399 self.ptype(),
400 integral: |$P| {
401 self.checked_integeral_numeric_operator::<$P>(other, result_dtype, ptype, op)
402 }
403 floating_point: |$P| {
404 let lhs = self.typed_value::<$P>();
405 let rhs = other.typed_value::<$P>();
406 let value_or_null = match (lhs, rhs) {
407 (_, None) | (None, _) => None,
408 (Some(lhs), Some(rhs)) => match op {
409 BinaryNumericOperator::Add => Some(lhs + rhs),
410 BinaryNumericOperator::Sub => Some(lhs - rhs),
411 BinaryNumericOperator::RSub => Some(rhs - lhs),
412 BinaryNumericOperator::Mul => Some(lhs * rhs),
413 BinaryNumericOperator::Div => Some(lhs / rhs),
414 BinaryNumericOperator::RDiv => Some(rhs / lhs),
415 }
416 };
417 Some(Self { dtype: result_dtype, ptype: ptype, pvalue: value_or_null.map(PValue::from) })
418 }
419 )
420 }
421
422 fn checked_integeral_numeric_operator<
423 P: NativePType
424 + TryFrom<PValue, Error = VortexError>
425 + CheckedSub
426 + CheckedAdd
427 + CheckedMul
428 + CheckedDiv,
429 >(
430 &self,
431 other: &PrimitiveScalar<'a>,
432 result_dtype: &'a DType,
433 ptype: PType,
434 op: BinaryNumericOperator,
435 ) -> Option<PrimitiveScalar<'a>>
436 where
437 PValue: From<P>,
438 {
439 let lhs = self.typed_value::<P>();
440 let rhs = other.typed_value::<P>();
441 let value_or_null_or_overflow = match (lhs, rhs) {
442 (_, None) | (None, _) => Some(None),
443 (Some(lhs), Some(rhs)) => match op {
444 BinaryNumericOperator::Add => lhs.checked_add(&rhs).map(Some),
445 BinaryNumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
446 BinaryNumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
447 BinaryNumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
448 BinaryNumericOperator::Div => lhs.checked_div(&rhs).map(Some),
449 BinaryNumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
450 },
451 };
452
453 value_or_null_or_overflow.map(|value_or_null| Self {
454 dtype: result_dtype,
455 ptype,
456 pvalue: value_or_null.map(PValue::from),
457 })
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use num_traits::CheckedSub;
464 use vortex_dtype::{DType, Nullability, PType};
465
466 use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
467
468 #[test]
469 fn test_integer_subtract() {
470 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
471 let p_scalar1 = PrimitiveScalar::try_new(
472 &dtype,
473 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
474 )
475 .unwrap();
476 let p_scalar2 = PrimitiveScalar::try_new(
477 &dtype,
478 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
479 )
480 .unwrap();
481 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
482 let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
483 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
484
485 assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
486 }
487
488 #[test]
489 #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
490 #[allow(clippy::assertions_on_constants)]
491 fn test_integer_subtract_overflow() {
492 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
493 let p_scalar1 = PrimitiveScalar::try_new(
494 &dtype,
495 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
496 )
497 .unwrap();
498 let p_scalar2 = PrimitiveScalar::try_new(
499 &dtype,
500 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
501 )
502 .unwrap();
503 let _ = p_scalar1 - p_scalar2;
504 }
505
506 #[test]
507 fn test_float_subtract() {
508 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
509 let p_scalar1 = PrimitiveScalar::try_new(
510 &dtype,
511 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
512 )
513 .unwrap();
514 let p_scalar2 = PrimitiveScalar::try_new(
515 &dtype,
516 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
517 )
518 .unwrap();
519 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
520 let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
521 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
522
523 assert_eq!(
524 (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
525 0.99f32
526 );
527 }
528}