1use std::any::type_name;
2use std::cmp::Ordering;
3use std::fmt::{Debug, Display, Formatter};
4use std::ops::{Add, Sub};
5
6use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{
10 VortexError, VortexExpect as _, VortexResult, VortexUnwrap, vortex_err, vortex_panic,
11};
12
13use crate::pvalue::PValue;
14use crate::{InnerScalarValue, Scalar, ScalarValue};
15
16#[derive(Debug, Clone, Copy, Hash)]
17pub struct PrimitiveScalar<'a> {
18 dtype: &'a DType,
19 ptype: PType,
20 pvalue: Option<PValue>,
21}
22
23impl Display for PrimitiveScalar<'_> {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 match self.pvalue {
26 None => write!(f, "null"),
27 Some(pv) => write!(f, "{}", pv),
28 }
29 }
30}
31
32impl PartialEq for PrimitiveScalar<'_> {
33 fn eq(&self, other: &Self) -> bool {
34 self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue
35 }
36}
37
38impl Eq for PrimitiveScalar<'_> {}
39
40impl PartialOrd for PrimitiveScalar<'_> {
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43 if !self.dtype.eq_ignore_nullability(other.dtype) {
44 return None;
45 }
46 self.pvalue.partial_cmp(&other.pvalue)
47 }
48}
49
50impl<'a> PrimitiveScalar<'a> {
51 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
52 let ptype = PType::try_from(dtype)?;
53
54 let pvalue = match_each_native_ptype!(ptype, |$T| {
57 if let Some(pvalue) = value.as_pvalue()? {
58 Some(PValue::from(<$T>::try_from(pvalue)?))
59 } else {
60 None
61 }
62 });
63
64 Ok(Self {
65 dtype,
66 ptype,
67 pvalue,
68 })
69 }
70
71 #[inline]
72 pub fn dtype(&self) -> &'a DType {
73 self.dtype
74 }
75
76 #[inline]
77 pub fn ptype(&self) -> PType {
78 self.ptype
79 }
80
81 #[inline]
82 pub fn pvalue(&self) -> Option<PValue> {
83 self.pvalue
84 }
85
86 pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
87 assert_eq!(
88 self.ptype,
89 T::PTYPE,
90 "Attempting to read {} scalar as {}",
91 self.ptype,
92 T::PTYPE
93 );
94
95 self.pvalue.map(|pv| pv.as_primitive::<T>().vortex_unwrap())
96 }
97
98 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
99 let ptype = PType::try_from(dtype)?;
100 let pvalue = self
101 .pvalue
102 .vortex_expect("nullness handled in Scalar::cast");
103 Ok(match_each_native_ptype!(ptype, |$Q| {
104 Scalar::primitive(
105 pvalue
106 .as_primitive::<$Q>()
107 .map_err(|err| vortex_err!("Can't cast {} scalar {} to {} (cause: {})", self.ptype, pvalue, dtype, err))?,
108 dtype.nullability()
109 )
110 }))
111 }
112
113 pub fn as_<T: FromPrimitiveOrF16>(&self) -> VortexResult<Option<T>> {
116 match self.pvalue {
117 None => Ok(None),
118 Some(pv) => Ok(Some(match pv {
119 PValue::U8(v) => T::from_u8(v)
120 .ok_or_else(|| vortex_err!("Failed to cast u8 to {}", type_name::<T>())),
121 PValue::U16(v) => T::from_u16(v)
122 .ok_or_else(|| vortex_err!("Failed to cast u16 to {}", type_name::<T>())),
123 PValue::U32(v) => T::from_u32(v)
124 .ok_or_else(|| vortex_err!("Failed to cast u32 to {}", type_name::<T>())),
125 PValue::U64(v) => T::from_u64(v)
126 .ok_or_else(|| vortex_err!("Failed to cast u64 to {}", type_name::<T>())),
127 PValue::I8(v) => T::from_i8(v)
128 .ok_or_else(|| vortex_err!("Failed to cast i8 to {}", type_name::<T>())),
129 PValue::I16(v) => T::from_i16(v)
130 .ok_or_else(|| vortex_err!("Failed to cast i16 to {}", type_name::<T>())),
131 PValue::I32(v) => T::from_i32(v)
132 .ok_or_else(|| vortex_err!("Failed to cast i32 to {}", type_name::<T>())),
133 PValue::I64(v) => T::from_i64(v)
134 .ok_or_else(|| vortex_err!("Failed to cast i64 to {}", type_name::<T>())),
135 PValue::F16(v) => T::from_f16(v)
136 .ok_or_else(|| vortex_err!("Failed to cast f16 to {}", type_name::<T>())),
137 PValue::F32(v) => T::from_f32(v)
138 .ok_or_else(|| vortex_err!("Failed to cast f32 to {}", type_name::<T>())),
139 PValue::F64(v) => T::from_f64(v)
140 .ok_or_else(|| vortex_err!("Failed to cast f64 to {}", type_name::<T>())),
141 }?)),
142 }
143 }
144}
145
146pub trait FromPrimitiveOrF16: FromPrimitive {
147 fn from_f16(v: f16) -> Option<Self>;
148}
149
150macro_rules! from_primitive_or_f16_for_non_floating_point {
151 ($T:ty) => {
152 impl FromPrimitiveOrF16 for $T {
153 fn from_f16(_: f16) -> Option<Self> {
154 None
155 }
156 }
157 };
158}
159
160from_primitive_or_f16_for_non_floating_point!(usize);
161from_primitive_or_f16_for_non_floating_point!(u8);
162from_primitive_or_f16_for_non_floating_point!(u16);
163from_primitive_or_f16_for_non_floating_point!(u32);
164from_primitive_or_f16_for_non_floating_point!(u64);
165from_primitive_or_f16_for_non_floating_point!(i8);
166from_primitive_or_f16_for_non_floating_point!(i16);
167from_primitive_or_f16_for_non_floating_point!(i32);
168from_primitive_or_f16_for_non_floating_point!(i64);
169
170impl FromPrimitiveOrF16 for f16 {
171 fn from_f16(v: f16) -> Option<Self> {
172 Some(v)
173 }
174}
175
176impl FromPrimitiveOrF16 for f32 {
177 fn from_f16(v: f16) -> Option<Self> {
178 Some(v.to_f32())
179 }
180}
181
182impl FromPrimitiveOrF16 for f64 {
183 fn from_f16(v: f16) -> Option<Self> {
184 Some(v.to_f64())
185 }
186}
187
188impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
189 type Error = VortexError;
190
191 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
192 Self::try_new(value.dtype(), value.value())
193 }
194}
195
196impl Sub for PrimitiveScalar<'_> {
197 type Output = Self;
198
199 fn sub(self, rhs: Self) -> Self::Output {
200 self.checked_sub(&rhs)
201 .vortex_expect("PrimitiveScalar subtract: overflow or underflow")
202 }
203}
204
205impl CheckedSub for PrimitiveScalar<'_> {
206 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
207 self.checked_binary_numeric(rhs, BinaryNumericOperator::Sub)
208 }
209}
210
211impl Add for PrimitiveScalar<'_> {
212 type Output = Self;
213
214 fn add(self, rhs: Self) -> Self::Output {
215 self.checked_add(&rhs)
216 .vortex_expect("PrimitiveScalar add: overflow or underflow")
217 }
218}
219
220impl CheckedAdd for PrimitiveScalar<'_> {
221 fn checked_add(&self, rhs: &Self) -> Option<Self> {
222 self.checked_binary_numeric(rhs, BinaryNumericOperator::Add)
223 }
224}
225
226impl Scalar {
227 pub fn primitive<T: NativePType + Into<PValue>>(value: T, nullability: Nullability) -> Self {
228 Self::primitive_value(value.into(), T::PTYPE, nullability)
229 }
230
231 pub fn primitive_value(value: PValue, ptype: PType, nullability: Nullability) -> Self {
236 Self {
237 dtype: DType::Primitive(ptype, nullability),
238 value: ScalarValue(InnerScalarValue::Primitive(value)),
239 }
240 }
241
242 pub fn reinterpret_cast(&self, ptype: PType) -> Self {
243 let primitive = PrimitiveScalar::try_from(self).unwrap_or_else(|e| {
244 vortex_panic!(e, "Failed to reinterpret cast {} to {}", self.dtype, ptype)
245 });
246 if primitive.ptype() == ptype {
247 return self.clone();
248 }
249
250 assert_eq!(
251 primitive.ptype().byte_width(),
252 ptype.byte_width(),
253 "can't reinterpret cast between integers of two different widths"
254 );
255
256 Scalar::new(
257 DType::Primitive(ptype, self.dtype.nullability()),
258 primitive
259 .pvalue
260 .map(|p| p.reinterpret_cast(ptype))
261 .map(|x| ScalarValue(InnerScalarValue::Primitive(x)))
262 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
263 )
264 }
265}
266
267macro_rules! primitive_scalar {
268 ($T:ty) => {
269 impl TryFrom<&Scalar> for $T {
270 type Error = VortexError;
271
272 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
273 <Option<$T>>::try_from(value)?
274 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
275 }
276 }
277
278 impl TryFrom<Scalar> for $T {
279 type Error = VortexError;
280
281 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
282 <$T>::try_from(&value)
283 }
284 }
285
286 impl TryFrom<&Scalar> for Option<$T> {
287 type Error = VortexError;
288
289 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
290 Ok(PrimitiveScalar::try_from(value)?.typed_value::<$T>())
291 }
292 }
293
294 impl TryFrom<Scalar> for Option<$T> {
295 type Error = VortexError;
296
297 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
298 <Option<$T>>::try_from(&value)
299 }
300 }
301
302 impl From<$T> for Scalar {
303 fn from(value: $T) -> Self {
304 Scalar {
305 dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable),
306 value: ScalarValue(InnerScalarValue::Primitive(value.into())),
307 }
308 }
309 }
310 };
311}
312
313primitive_scalar!(u8);
314primitive_scalar!(u16);
315primitive_scalar!(u32);
316primitive_scalar!(u64);
317primitive_scalar!(i8);
318primitive_scalar!(i16);
319primitive_scalar!(i32);
320primitive_scalar!(i64);
321primitive_scalar!(f16);
322primitive_scalar!(f32);
323primitive_scalar!(f64);
324
325impl TryFrom<&Scalar> for usize {
327 type Error = VortexError;
328
329 fn try_from(value: &Scalar) -> Result<Self, Self::Error> {
330 let prim = PrimitiveScalar::try_from(value)?
331 .as_::<u64>()?
332 .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?;
333 Ok(usize::try_from(prim)?)
334 }
335}
336
337impl From<usize> for Scalar {
339 fn from(value: usize) -> Self {
340 Scalar::primitive(value as u64, Nullability::NonNullable)
341 }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq)]
345pub enum BinaryNumericOperator {
347 Add,
349 Sub,
351 RSub,
353 Mul,
355 Div,
357 RDiv,
359 }
364
365impl Display for BinaryNumericOperator {
366 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
367 Debug::fmt(self, f)
368 }
369}
370
371impl BinaryNumericOperator {
372 pub fn swap(self) -> Self {
373 match self {
374 BinaryNumericOperator::Add => BinaryNumericOperator::Add,
375 BinaryNumericOperator::Sub => BinaryNumericOperator::RSub,
376 BinaryNumericOperator::RSub => BinaryNumericOperator::Sub,
377 BinaryNumericOperator::Mul => BinaryNumericOperator::Mul,
378 BinaryNumericOperator::Div => BinaryNumericOperator::RDiv,
379 BinaryNumericOperator::RDiv => BinaryNumericOperator::Div,
380 }
381 }
382}
383
384impl<'a> PrimitiveScalar<'a> {
385 pub fn checked_binary_numeric(
393 &self,
394 other: &PrimitiveScalar<'a>,
395 op: BinaryNumericOperator,
396 ) -> Option<PrimitiveScalar<'a>> {
397 if !self.dtype().eq_ignore_nullability(other.dtype()) {
398 vortex_panic!("types must match: {} {}", self.dtype(), other.dtype());
399 }
400 let result_dtype = if self.dtype().is_nullable() {
401 self.dtype()
402 } else {
403 other.dtype()
404 };
405 let ptype = self.ptype();
406
407 match_each_native_ptype!(
408 self.ptype(),
409 integral: |$P| {
410 self.checked_integeral_numeric_operator::<$P>(other, result_dtype, ptype, op)
411 }
412 floating_point: |$P| {
413 let lhs = self.typed_value::<$P>();
414 let rhs = other.typed_value::<$P>();
415 let value_or_null = match (lhs, rhs) {
416 (_, None) | (None, _) => None,
417 (Some(lhs), Some(rhs)) => match op {
418 BinaryNumericOperator::Add => Some(lhs + rhs),
419 BinaryNumericOperator::Sub => Some(lhs - rhs),
420 BinaryNumericOperator::RSub => Some(rhs - lhs),
421 BinaryNumericOperator::Mul => Some(lhs * rhs),
422 BinaryNumericOperator::Div => Some(lhs / rhs),
423 BinaryNumericOperator::RDiv => Some(rhs / lhs),
424 }
425 };
426 Some(Self { dtype: result_dtype, ptype: ptype, pvalue: value_or_null.map(PValue::from) })
427 }
428 )
429 }
430
431 fn checked_integeral_numeric_operator<
432 P: NativePType
433 + TryFrom<PValue, Error = VortexError>
434 + CheckedSub
435 + CheckedAdd
436 + CheckedMul
437 + CheckedDiv,
438 >(
439 &self,
440 other: &PrimitiveScalar<'a>,
441 result_dtype: &'a DType,
442 ptype: PType,
443 op: BinaryNumericOperator,
444 ) -> Option<PrimitiveScalar<'a>>
445 where
446 PValue: From<P>,
447 {
448 let lhs = self.typed_value::<P>();
449 let rhs = other.typed_value::<P>();
450 let value_or_null_or_overflow = match (lhs, rhs) {
451 (_, None) | (None, _) => Some(None),
452 (Some(lhs), Some(rhs)) => match op {
453 BinaryNumericOperator::Add => lhs.checked_add(&rhs).map(Some),
454 BinaryNumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
455 BinaryNumericOperator::RSub => rhs.checked_sub(&lhs).map(Some),
456 BinaryNumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
457 BinaryNumericOperator::Div => lhs.checked_div(&rhs).map(Some),
458 BinaryNumericOperator::RDiv => rhs.checked_div(&lhs).map(Some),
459 },
460 };
461
462 value_or_null_or_overflow.map(|value_or_null| Self {
463 dtype: result_dtype,
464 ptype,
465 pvalue: value_or_null.map(PValue::from),
466 })
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use num_traits::CheckedSub;
473 use vortex_dtype::{DType, Nullability, PType};
474
475 use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue};
476
477 #[test]
478 fn test_integer_subtract() {
479 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
480 let p_scalar1 = PrimitiveScalar::try_new(
481 &dtype,
482 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))),
483 )
484 .unwrap();
485 let p_scalar2 = PrimitiveScalar::try_new(
486 &dtype,
487 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
488 )
489 .unwrap();
490 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2);
491 let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
492 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);
493
494 assert_eq!((p_scalar1 - p_scalar2).as_::<i32>().unwrap().unwrap(), 1);
495 }
496
497 #[test]
498 #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")]
499 #[allow(clippy::assertions_on_constants)]
500 fn test_integer_subtract_overflow() {
501 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
502 let p_scalar1 = PrimitiveScalar::try_new(
503 &dtype,
504 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))),
505 )
506 .unwrap();
507 let p_scalar2 = PrimitiveScalar::try_new(
508 &dtype,
509 &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
510 )
511 .unwrap();
512 let _ = p_scalar1 - p_scalar2;
513 }
514
515 #[test]
516 fn test_float_subtract() {
517 let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
518 let p_scalar1 = PrimitiveScalar::try_new(
519 &dtype,
520 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))),
521 )
522 .unwrap();
523 let p_scalar2 = PrimitiveScalar::try_new(
524 &dtype,
525 &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
526 )
527 .unwrap();
528 let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap();
529 let value_or_null_or_type_error = pscalar_or_overflow.as_::<f32>();
530 assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);
531
532 assert_eq!(
533 (p_scalar1 - p_scalar2).as_::<f32>().unwrap().unwrap(),
534 0.99f32
535 );
536 }
537}