1#![deny(missing_docs)]
11
12use std::cmp::Ordering;
13use std::hash::Hash;
14use std::sync::Arc;
15
16pub use scalar_type::ScalarType;
17use vortex_buffer::{Buffer, BufferString, ByteBuffer};
18use vortex_dtype::half::f16;
19use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
20#[cfg(feature = "arbitrary")]
21pub mod arbitrary;
22mod arrow;
23mod bigint;
24mod binary;
25mod bool;
26mod decimal;
27mod display;
28mod extension;
29mod list;
30mod null;
31mod primitive;
32mod proto;
33mod pvalue;
34mod scalar_type;
35mod scalar_value;
36mod struct_;
37mod utf8;
38
39pub use bigint::*;
40pub use binary::*;
41pub use bool::*;
42pub use decimal::*;
43pub use extension::*;
44pub use list::*;
45pub use primitive::*;
46pub use pvalue::*;
47pub use scalar_value::*;
48pub use struct_::*;
49pub use utf8::*;
50use vortex_error::{VortexExpect, VortexResult, vortex_bail};
51
52#[derive(Debug, Clone)]
61pub struct Scalar {
62 dtype: DType,
63 value: ScalarValue,
64}
65
66impl Scalar {
67 pub fn new(dtype: DType, value: ScalarValue) -> Self {
72 Self { dtype, value }
73 }
74
75 #[inline]
77 pub fn dtype(&self) -> &DType {
78 &self.dtype
79 }
80
81 #[inline]
83 pub fn value(&self) -> &ScalarValue {
84 &self.value
85 }
86
87 #[inline]
89 pub fn into_parts(self) -> (DType, ScalarValue) {
90 (self.dtype, self.value)
91 }
92
93 #[inline]
95 pub fn into_value(self) -> ScalarValue {
96 self.value
97 }
98
99 pub fn is_valid(&self) -> bool {
101 !self.value.is_null()
102 }
103
104 pub fn is_null(&self) -> bool {
106 self.value.is_null()
107 }
108
109 pub fn null(dtype: DType) -> Self {
115 assert!(
116 dtype.is_nullable(),
117 "Creating null scalar for non-nullable DType {dtype}"
118 );
119 Self {
120 dtype,
121 value: ScalarValue(InnerScalarValue::Null),
122 }
123 }
124
125 pub fn null_typed<T: ScalarType>() -> Self {
129 Self {
130 dtype: T::dtype().as_nullable(),
131 value: ScalarValue(InnerScalarValue::Null),
132 }
133 }
134
135 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
140 if let DType::Extension(ext_dtype) = target {
141 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
142 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
143 } else {
144 self.cast_to_non_extension(target)
145 }
146 }
147
148 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
149 assert!(!matches!(target, DType::Extension(..)));
150 if self.is_null() {
151 if target.is_nullable() {
152 return Ok(Scalar::new(target.clone(), self.value.clone()));
153 } else {
154 vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
155 }
156 }
157
158 if self.dtype().eq_ignore_nullability(target) {
159 return Ok(Scalar::new(target.clone(), self.value.clone()));
160 }
161
162 match &self.dtype {
163 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
165 DType::Primitive(..) => self.as_primitive().cast(target),
166 DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
167 DType::Utf8(_) => self.as_utf8().cast(target),
168 DType::Binary(_) => self.as_binary().cast(target),
169 DType::Struct(..) => self.as_struct().cast(target),
170 DType::List(..) => self.as_list().cast(target),
171 DType::Extension(..) => self.as_extension().cast(target),
172 }
173 }
174
175 pub fn into_nullable(self) -> Self {
177 Self {
178 dtype: self.dtype.as_nullable(),
179 value: self.value,
180 }
181 }
182
183 pub fn nbytes(&self) -> usize {
185 match self.dtype() {
186 DType::Null => 0,
187 DType::Bool(_) => 1,
188 DType::Primitive(ptype, _) => ptype.byte_width(),
189 DType::Decimal(dt, _) => {
190 if dt.precision() >= DECIMAL128_MAX_PRECISION {
191 size_of::<i128>()
192 } else {
193 size_of::<i256>()
194 }
195 }
196 DType::Binary(_) | DType::Utf8(_) => self
197 .value()
198 .as_buffer()
199 .ok()
200 .flatten()
201 .map_or(0, |s| s.len()),
202 DType::Struct(_dtype, _) => self
203 .as_struct()
204 .fields()
205 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
206 .unwrap_or_default(),
207 DType::List(_dtype, _) => self
208 .as_list()
209 .elements()
210 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
211 .unwrap_or_default(),
212 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
213 }
214 }
215
216 pub fn default_value(dtype: DType) -> Self {
221 if dtype.is_nullable() {
222 return Self::null(dtype);
223 }
224
225 match dtype {
226 DType::Null => Self::null(dtype),
227 DType::Bool(nullability) => Self::bool(false, nullability),
228 DType::Primitive(pt, nullability) => {
229 Self::primitive_value(PValue::zero(pt), pt, nullability)
230 }
231 DType::Decimal(dt, nullability) => {
232 Self::decimal(DecimalValue::from(0), dt, nullability)
233 }
234 DType::Utf8(nullability) => Self::utf8("", nullability),
235 DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
236 DType::Struct(sf, nullability) => {
237 let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
238 Self::struct_(DType::Struct(sf, nullability), fields)
239 }
240 DType::List(dt, nullability) => Self::list(dt, vec![], nullability),
241 DType::Extension(dt) => {
242 let scalar = Self::default_value(dt.storage_dtype().clone());
243 Self::extension(dt, scalar)
244 }
245 }
246 }
247}
248
249impl Scalar {
250 pub fn as_bool(&self) -> BoolScalar<'_> {
256 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
257 }
258
259 pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
261 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
262 }
263
264 pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
270 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
271 }
272
273 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
275 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
276 }
277
278 pub fn as_decimal(&self) -> DecimalScalar<'_> {
284 DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
285 }
286
287 pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
289 matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
290 }
291
292 pub fn as_utf8(&self) -> Utf8Scalar<'_> {
298 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
299 }
300
301 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
303 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
304 }
305
306 pub fn as_binary(&self) -> BinaryScalar<'_> {
312 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
313 }
314
315 pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
317 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
318 }
319
320 pub fn as_struct(&self) -> StructScalar<'_> {
326 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
327 }
328
329 pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
331 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
332 }
333
334 pub fn as_list(&self) -> ListScalar<'_> {
340 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
341 }
342
343 pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
345 matches!(self.dtype, DType::List(..)).then(|| self.as_list())
346 }
347
348 pub fn as_extension(&self) -> ExtScalar<'_> {
354 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
355 }
356
357 pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
359 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
360 }
361}
362
363impl PartialEq for Scalar {
364 fn eq(&self, other: &Self) -> bool {
365 if !self.dtype.eq_ignore_nullability(&other.dtype) {
366 return false;
367 }
368
369 match self.dtype() {
370 DType::Null => true,
371 DType::Bool(_) => self.as_bool() == other.as_bool(),
372 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
373 DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
374 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
375 DType::Binary(_) => self.as_binary() == other.as_binary(),
376 DType::Struct(..) => self.as_struct() == other.as_struct(),
377 DType::List(..) => self.as_list() == other.as_list(),
378 DType::Extension(_) => self.as_extension() == other.as_extension(),
379 }
380 }
381}
382
383impl Eq for Scalar {}
384
385impl PartialOrd for Scalar {
386 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
387 if !self.dtype().eq_ignore_nullability(other.dtype()) {
388 return None;
389 }
390 match self.dtype() {
391 DType::Null => Some(Ordering::Equal),
392 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
393 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
394 DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
395 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
396 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
397 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
398 DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
399 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
400 }
401 }
402}
403
404impl Hash for Scalar {
405 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
406 match self.dtype() {
407 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
409 DType::Primitive(..) => self.as_primitive().hash(state),
410 DType::Decimal(..) => self.as_decimal().hash(state),
411 DType::Utf8(_) => self.as_utf8().hash(state),
412 DType::Binary(_) => self.as_binary().hash(state),
413 DType::Struct(..) => self.as_struct().hash(state),
414 DType::List(..) => self.as_list().hash(state),
415 DType::Extension(_) => self.as_extension().hash(state),
416 }
417 }
418}
419
420impl AsRef<Self> for Scalar {
421 fn as_ref(&self) -> &Self {
422 self
423 }
424}
425
426impl<T> From<Option<T>> for Scalar
427where
428 T: ScalarType,
429 Scalar: From<T>,
430{
431 fn from(value: Option<T>) -> Self {
432 value
433 .map(Scalar::from)
434 .map(|x| x.into_nullable())
435 .unwrap_or_else(|| Scalar {
436 dtype: T::dtype().as_nullable(),
437 value: ScalarValue(InnerScalarValue::Null),
438 })
439 }
440}
441
442impl From<PrimitiveScalar<'_>> for Scalar {
443 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
444 let dtype = pscalar.dtype().clone();
445 let value = pscalar
446 .pvalue()
447 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
448 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
449 Self::new(dtype, value)
450 }
451}
452
453impl From<DecimalScalar<'_>> for Scalar {
454 fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
455 let dtype = decimal_scalar.dtype().clone();
456 let value = decimal_scalar
457 .decimal_value()
458 .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
459 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
460 Self::new(dtype, value)
461 }
462}
463
464macro_rules! from_vec_for_scalar {
465 ($T:ty) => {
466 impl From<Vec<$T>> for Scalar {
467 fn from(value: Vec<$T>) -> Self {
468 Scalar {
469 dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
470 value: ScalarValue(InnerScalarValue::List(
471 value
472 .into_iter()
473 .map(Scalar::from)
474 .map(|s| s.into_value())
475 .collect::<Arc<[_]>>(),
476 )),
477 }
478 }
479 }
480 };
481}
482
483from_vec_for_scalar!(u16);
485from_vec_for_scalar!(u32);
486from_vec_for_scalar!(u64);
487from_vec_for_scalar!(usize); from_vec_for_scalar!(i8);
489from_vec_for_scalar!(i16);
490from_vec_for_scalar!(i32);
491from_vec_for_scalar!(i64);
492from_vec_for_scalar!(f16);
493from_vec_for_scalar!(f32);
494from_vec_for_scalar!(f64);
495from_vec_for_scalar!(String);
496from_vec_for_scalar!(BufferString);
497from_vec_for_scalar!(ByteBuffer);
498
499#[cfg(test)]
500mod test {
501 use std::sync::Arc;
502
503 use rstest::rstest;
504 use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
505
506 use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
507
508 #[rstest]
509 fn null_can_cast_to_anything_nullable(
510 #[values(
511 DType::Null,
512 DType::Bool(Nullability::Nullable),
513 DType::Primitive(PType::I32, Nullability::Nullable),
514 DType::Extension(Arc::from(ExtDType::new(
515 ExtID::from("a"),
516 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
517 None,
518 ))),
519 DType::Extension(Arc::from(ExtDType::new(
520 ExtID::from("b"),
521 Arc::from(DType::Utf8(Nullability::Nullable)),
522 None,
523 )))
524 )]
525 source_dtype: DType,
526 #[values(
527 DType::Null,
528 DType::Bool(Nullability::Nullable),
529 DType::Primitive(PType::I32, Nullability::Nullable),
530 DType::Extension(Arc::from(ExtDType::new(
531 ExtID::from("a"),
532 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
533 None,
534 ))),
535 DType::Extension(Arc::from(ExtDType::new(
536 ExtID::from("b"),
537 Arc::from(DType::Utf8(Nullability::Nullable)),
538 None,
539 )))
540 )]
541 target_dtype: DType,
542 ) {
543 assert_eq!(
544 Scalar::null(source_dtype)
545 .cast(&target_dtype)
546 .unwrap()
547 .dtype(),
548 &target_dtype
549 );
550 }
551
552 #[test]
553 fn list_casts() {
554 let list = Scalar::new(
555 DType::List(
556 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
557 Nullability::Nullable,
558 ),
559 ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
560 InnerScalarValue::Primitive(PValue::U16(6)),
561 )]))),
562 );
563
564 let target_u32 = DType::List(
565 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
566 Nullability::Nullable,
567 );
568 assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
569
570 let target_u32_nonnull = DType::List(
571 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
572 Nullability::Nullable,
573 );
574 assert_eq!(
575 list.cast(&target_u32_nonnull).unwrap().dtype(),
576 &target_u32_nonnull
577 );
578
579 let target_nonnull = DType::List(
580 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
581 Nullability::NonNullable,
582 );
583 assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
584
585 let target_u8 = DType::List(
586 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
587 Nullability::Nullable,
588 );
589 assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
590
591 let list_with_null = Scalar::new(
592 DType::List(
593 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
594 Nullability::Nullable,
595 ),
596 ScalarValue(InnerScalarValue::List(Arc::from([
597 ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
598 ScalarValue(InnerScalarValue::Null),
599 ]))),
600 );
601 let target_u8 = DType::List(
602 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
603 Nullability::Nullable,
604 );
605 assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
606
607 let target_u32_nonnull = DType::List(
608 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
609 Nullability::Nullable,
610 );
611 assert!(list_with_null.cast(&target_u32_nonnull).is_err());
612 }
613
614 #[test]
615 fn cast_to_from_extension_types() {
616 let apples = ExtDType::new(
617 ExtID::new(Arc::from("apples")),
618 Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
619 None,
620 );
621 let ext_dtype = DType::Extension(Arc::from(apples.clone()));
622 let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
623 let storage_scalar = Scalar::new(
624 DType::clone(apples.storage_dtype()),
625 ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
626 );
627
628 let expected_dtype = &ext_dtype;
630 let actual = ext_scalar.cast(expected_dtype).unwrap();
631 assert_eq!(actual.dtype(), expected_dtype);
632
633 let expected_dtype = &ext_dtype.as_nullable();
635 let actual = ext_scalar.cast(expected_dtype).unwrap();
636 assert_eq!(actual.dtype(), expected_dtype);
637
638 let expected_dtype = apples.storage_dtype();
640 let actual = ext_scalar.cast(expected_dtype).unwrap();
641 assert_eq!(actual.dtype(), expected_dtype);
642
643 let expected_dtype = &apples.storage_dtype().as_nullable();
645 let actual = ext_scalar.cast(expected_dtype).unwrap();
646 assert_eq!(actual.dtype(), expected_dtype);
647
648 let expected_dtype = &ext_dtype;
650 let actual = storage_scalar.cast(expected_dtype).unwrap();
651 assert_eq!(actual.dtype(), expected_dtype);
652
653 let expected_dtype = &ext_dtype.as_nullable();
655 let actual = storage_scalar.cast(expected_dtype).unwrap();
656 assert_eq!(actual.dtype(), expected_dtype);
657
658 let storage_scalar_u64 = Scalar::new(
660 DType::clone(apples.storage_dtype()),
661 ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
662 );
663 let expected_dtype = &ext_dtype;
664 let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
665 assert_eq!(actual.dtype(), expected_dtype);
666
667 let apples_u8 = ExtDType::new(
669 ExtID::new(Arc::from("apples")),
670 Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
671 None,
672 );
673 let expected_dtype = &DType::Extension(Arc::from(apples_u8));
674 let result = storage_scalar.cast(expected_dtype);
675 assert!(
676 result.as_ref().is_err_and(|err| {
677 err
678 .to_string()
679 .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
680 }),
681 "{result:?}"
682 );
683 }
684
685 #[test]
686 fn default_value_for_complex_dtype() {
687 let struct_dtype = DType::struct_(
688 [
689 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
690 (
691 "b",
692 DType::list(
693 DType::Primitive(PType::I8, Nullability::Nullable),
694 Nullability::NonNullable,
695 ),
696 ),
697 ("c", DType::Primitive(PType::I32, Nullability::Nullable)),
698 ],
699 Nullability::NonNullable,
700 );
701
702 let scalar = Scalar::default_value(struct_dtype.clone());
703 assert_eq!(scalar.dtype(), &struct_dtype);
704
705 let scalar = scalar.as_struct();
706
707 let a_field = scalar.field("a").unwrap();
708 assert_eq!(a_field.as_primitive().pvalue().unwrap(), PValue::I32(0));
709
710 let b_field = scalar.field("b").unwrap();
711 assert!(b_field.as_list().is_empty());
712
713 let c_field = scalar.field("c").unwrap();
714 assert!(c_field.is_null());
715 }
716}