1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{Buffer, BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod bigint;
13mod binary;
14mod bool;
15mod decimal;
16mod display;
17mod extension;
18mod list;
19mod null;
20mod primitive;
21mod proto;
22mod pvalue;
23mod scalar_type;
24mod scalar_value;
25mod struct_;
26mod utf8;
27
28pub use bigint::*;
29pub use binary::*;
30pub use bool::*;
31pub use decimal::*;
32pub use extension::*;
33pub use list::*;
34pub use primitive::*;
35pub use pvalue::*;
36pub use scalar_value::*;
37pub use struct_::*;
38pub use utf8::*;
39use vortex_error::{VortexExpect, VortexResult, vortex_bail};
40
41#[derive(Debug, Clone)]
50pub struct Scalar {
51 dtype: DType,
52 value: ScalarValue,
53}
54
55impl Scalar {
56 pub fn new(dtype: DType, value: ScalarValue) -> Self {
57 Self { dtype, value }
58 }
59
60 #[inline]
61 pub fn dtype(&self) -> &DType {
62 &self.dtype
63 }
64
65 #[inline]
66 pub fn value(&self) -> &ScalarValue {
67 &self.value
68 }
69
70 #[inline]
71 pub fn into_parts(self) -> (DType, ScalarValue) {
72 (self.dtype, self.value)
73 }
74
75 #[inline]
76 pub fn into_value(self) -> ScalarValue {
77 self.value
78 }
79
80 pub fn is_valid(&self) -> bool {
81 !self.value.is_null()
82 }
83
84 pub fn is_null(&self) -> bool {
85 self.value.is_null()
86 }
87
88 pub fn null(dtype: DType) -> Self {
89 assert!(
90 dtype.is_nullable(),
91 "Creating null scalar for non-nullable DType {dtype}"
92 );
93 Self {
94 dtype,
95 value: ScalarValue(InnerScalarValue::Null),
96 }
97 }
98
99 pub fn null_typed<T: ScalarType>() -> Self {
100 Self {
101 dtype: T::dtype().as_nullable(),
102 value: ScalarValue(InnerScalarValue::Null),
103 }
104 }
105
106 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
107 if let DType::Extension(ext_dtype) = target {
108 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
109 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
110 } else {
111 self.cast_to_non_extension(target)
112 }
113 }
114
115 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
116 assert!(!matches!(target, DType::Extension(..)));
117 if self.is_null() {
118 if target.is_nullable() {
119 return Ok(Scalar::new(target.clone(), self.value.clone()));
120 } else {
121 vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
122 }
123 }
124
125 if self.dtype().eq_ignore_nullability(target) {
126 return Ok(Scalar::new(target.clone(), self.value.clone()));
127 }
128
129 match &self.dtype {
130 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
132 DType::Primitive(..) => self.as_primitive().cast(target),
133 DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
134 DType::Utf8(_) => self.as_utf8().cast(target),
135 DType::Binary(_) => self.as_binary().cast(target),
136 DType::Struct(..) => self.as_struct().cast(target),
137 DType::List(..) => self.as_list().cast(target),
138 DType::Extension(..) => self.as_extension().cast(target),
139 }
140 }
141
142 pub fn into_nullable(self) -> Self {
143 Self {
144 dtype: self.dtype.as_nullable(),
145 value: self.value,
146 }
147 }
148
149 pub fn nbytes(&self) -> usize {
151 match self.dtype() {
152 DType::Null => 0,
153 DType::Bool(_) => 1,
154 DType::Primitive(ptype, _) => ptype.byte_width(),
155 DType::Decimal(dt, _) => {
156 if dt.precision() >= DECIMAL128_MAX_PRECISION {
157 size_of::<i128>()
158 } else {
159 size_of::<i256>()
160 }
161 }
162 DType::Binary(_) | DType::Utf8(_) => self
163 .value()
164 .as_buffer()
165 .ok()
166 .flatten()
167 .map_or(0, |s| s.len()),
168 DType::Struct(_dtype, _) => self
169 .as_struct()
170 .fields()
171 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
172 .unwrap_or_default(),
173 DType::List(_dtype, _) => self
174 .as_list()
175 .elements()
176 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
177 .unwrap_or_default(),
178 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
179 }
180 }
181
182 pub fn default_value(dtype: DType) -> Self {
184 if dtype.is_nullable() {
185 return Self::null(dtype);
186 }
187
188 match dtype {
189 DType::Null => Self::null(dtype),
190 DType::Bool(nullability) => Self::bool(false, nullability),
191 DType::Primitive(pt, nullability) => {
192 Self::primitive_value(PValue::zero(pt), pt, nullability)
193 }
194 DType::Decimal(dt, nullability) => {
195 Self::decimal(DecimalValue::from(0), dt, nullability)
196 }
197 DType::Utf8(nullability) => Self::utf8("", nullability),
198 DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
199 DType::Struct(sf, nullability) => {
200 let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
201 Self::struct_(DType::Struct(sf, nullability), fields)
202 }
203 DType::List(dt, nullability) => Self::list(dt, vec![], nullability),
204 DType::Extension(dt) => {
205 let scalar = Self::default_value(dt.storage_dtype().clone());
206 Self::extension(dt, scalar)
207 }
208 }
209 }
210}
211
212impl Scalar {
213 pub fn as_bool(&self) -> BoolScalar<'_> {
214 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
215 }
216
217 pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
218 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
219 }
220
221 pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
222 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
223 }
224
225 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
226 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
227 }
228
229 pub fn as_decimal(&self) -> DecimalScalar<'_> {
230 DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
231 }
232
233 pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
234 matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
235 }
236
237 pub fn as_utf8(&self) -> Utf8Scalar<'_> {
238 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
239 }
240
241 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
242 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
243 }
244
245 pub fn as_binary(&self) -> BinaryScalar<'_> {
246 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
247 }
248
249 pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
250 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
251 }
252
253 pub fn as_struct(&self) -> StructScalar<'_> {
254 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
255 }
256
257 pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
258 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
259 }
260
261 pub fn as_list(&self) -> ListScalar<'_> {
262 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
263 }
264
265 pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
266 matches!(self.dtype, DType::List(..)).then(|| self.as_list())
267 }
268
269 pub fn as_extension(&self) -> ExtScalar<'_> {
270 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
271 }
272
273 pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
274 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
275 }
276}
277
278impl PartialEq for Scalar {
279 fn eq(&self, other: &Self) -> bool {
280 if !self.dtype.eq_ignore_nullability(&other.dtype) {
281 return false;
282 }
283
284 match self.dtype() {
285 DType::Null => true,
286 DType::Bool(_) => self.as_bool() == other.as_bool(),
287 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
288 DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
289 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
290 DType::Binary(_) => self.as_binary() == other.as_binary(),
291 DType::Struct(..) => self.as_struct() == other.as_struct(),
292 DType::List(..) => self.as_list() == other.as_list(),
293 DType::Extension(_) => self.as_extension() == other.as_extension(),
294 }
295 }
296}
297
298impl Eq for Scalar {}
299
300impl PartialOrd for Scalar {
301 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
302 if !self.dtype().eq_ignore_nullability(other.dtype()) {
303 return None;
304 }
305 match self.dtype() {
306 DType::Null => Some(Ordering::Equal),
307 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
308 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
309 DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
310 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
311 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
312 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
313 DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
314 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
315 }
316 }
317}
318
319impl Hash for Scalar {
320 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
321 match self.dtype() {
322 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
324 DType::Primitive(..) => self.as_primitive().hash(state),
325 DType::Decimal(..) => self.as_decimal().hash(state),
326 DType::Utf8(_) => self.as_utf8().hash(state),
327 DType::Binary(_) => self.as_binary().hash(state),
328 DType::Struct(..) => self.as_struct().hash(state),
329 DType::List(..) => self.as_list().hash(state),
330 DType::Extension(_) => self.as_extension().hash(state),
331 }
332 }
333}
334
335impl AsRef<Self> for Scalar {
336 fn as_ref(&self) -> &Self {
337 self
338 }
339}
340
341impl<T> From<Option<T>> for Scalar
342where
343 T: ScalarType,
344 Scalar: From<T>,
345{
346 fn from(value: Option<T>) -> Self {
347 value
348 .map(Scalar::from)
349 .map(|x| x.into_nullable())
350 .unwrap_or_else(|| Scalar {
351 dtype: T::dtype().as_nullable(),
352 value: ScalarValue(InnerScalarValue::Null),
353 })
354 }
355}
356
357impl From<PrimitiveScalar<'_>> for Scalar {
358 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
359 let dtype = pscalar.dtype().clone();
360 let value = pscalar
361 .pvalue()
362 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
363 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
364 Self::new(dtype, value)
365 }
366}
367
368impl From<DecimalScalar<'_>> for Scalar {
369 fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
370 let dtype = decimal_scalar.dtype().clone();
371 let value = decimal_scalar
372 .decimal_value()
373 .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
374 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
375 Self::new(dtype, value)
376 }
377}
378
379macro_rules! from_vec_for_scalar {
380 ($T:ty) => {
381 impl From<Vec<$T>> for Scalar {
382 fn from(value: Vec<$T>) -> Self {
383 Scalar {
384 dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
385 value: ScalarValue(InnerScalarValue::List(
386 value
387 .into_iter()
388 .map(Scalar::from)
389 .map(|s| s.into_value())
390 .collect::<Arc<[_]>>(),
391 )),
392 }
393 }
394 }
395 };
396}
397
398from_vec_for_scalar!(u16);
400from_vec_for_scalar!(u32);
401from_vec_for_scalar!(u64);
402from_vec_for_scalar!(usize); from_vec_for_scalar!(i8);
404from_vec_for_scalar!(i16);
405from_vec_for_scalar!(i32);
406from_vec_for_scalar!(i64);
407from_vec_for_scalar!(f16);
408from_vec_for_scalar!(f32);
409from_vec_for_scalar!(f64);
410from_vec_for_scalar!(String);
411from_vec_for_scalar!(BufferString);
412from_vec_for_scalar!(ByteBuffer);
413
414#[cfg(test)]
415mod test {
416 use std::sync::Arc;
417
418 use rstest::rstest;
419 use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
420
421 use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
422
423 #[rstest]
424 fn null_can_cast_to_anything_nullable(
425 #[values(
426 DType::Null,
427 DType::Bool(Nullability::Nullable),
428 DType::Primitive(PType::I32, Nullability::Nullable),
429 DType::Extension(Arc::from(ExtDType::new(
430 ExtID::from("a"),
431 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
432 None,
433 ))),
434 DType::Extension(Arc::from(ExtDType::new(
435 ExtID::from("b"),
436 Arc::from(DType::Utf8(Nullability::Nullable)),
437 None,
438 )))
439 )]
440 source_dtype: DType,
441 #[values(
442 DType::Null,
443 DType::Bool(Nullability::Nullable),
444 DType::Primitive(PType::I32, Nullability::Nullable),
445 DType::Extension(Arc::from(ExtDType::new(
446 ExtID::from("a"),
447 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
448 None,
449 ))),
450 DType::Extension(Arc::from(ExtDType::new(
451 ExtID::from("b"),
452 Arc::from(DType::Utf8(Nullability::Nullable)),
453 None,
454 )))
455 )]
456 target_dtype: DType,
457 ) {
458 assert_eq!(
459 Scalar::null(source_dtype)
460 .cast(&target_dtype)
461 .unwrap()
462 .dtype(),
463 &target_dtype
464 );
465 }
466
467 #[test]
468 fn list_casts() {
469 let list = Scalar::new(
470 DType::List(
471 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
472 Nullability::Nullable,
473 ),
474 ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
475 InnerScalarValue::Primitive(PValue::U16(6)),
476 )]))),
477 );
478
479 let target_u32 = DType::List(
480 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
481 Nullability::Nullable,
482 );
483 assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
484
485 let target_u32_nonnull = DType::List(
486 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
487 Nullability::Nullable,
488 );
489 assert_eq!(
490 list.cast(&target_u32_nonnull).unwrap().dtype(),
491 &target_u32_nonnull
492 );
493
494 let target_nonnull = DType::List(
495 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
496 Nullability::NonNullable,
497 );
498 assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
499
500 let target_u8 = DType::List(
501 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
502 Nullability::Nullable,
503 );
504 assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
505
506 let list_with_null = Scalar::new(
507 DType::List(
508 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
509 Nullability::Nullable,
510 ),
511 ScalarValue(InnerScalarValue::List(Arc::from([
512 ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
513 ScalarValue(InnerScalarValue::Null),
514 ]))),
515 );
516 let target_u8 = DType::List(
517 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
518 Nullability::Nullable,
519 );
520 assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
521
522 let target_u32_nonnull = DType::List(
523 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
524 Nullability::Nullable,
525 );
526 assert!(list_with_null.cast(&target_u32_nonnull).is_err());
527 }
528
529 #[test]
530 fn cast_to_from_extension_types() {
531 let apples = ExtDType::new(
532 ExtID::new(Arc::from("apples")),
533 Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
534 None,
535 );
536 let ext_dtype = DType::Extension(Arc::from(apples.clone()));
537 let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
538 let storage_scalar = Scalar::new(
539 DType::clone(apples.storage_dtype()),
540 ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
541 );
542
543 let expected_dtype = &ext_dtype;
545 let actual = ext_scalar.cast(expected_dtype).unwrap();
546 assert_eq!(actual.dtype(), expected_dtype);
547
548 let expected_dtype = &ext_dtype.as_nullable();
550 let actual = ext_scalar.cast(expected_dtype).unwrap();
551 assert_eq!(actual.dtype(), expected_dtype);
552
553 let expected_dtype = apples.storage_dtype();
555 let actual = ext_scalar.cast(expected_dtype).unwrap();
556 assert_eq!(actual.dtype(), expected_dtype);
557
558 let expected_dtype = &apples.storage_dtype().as_nullable();
560 let actual = ext_scalar.cast(expected_dtype).unwrap();
561 assert_eq!(actual.dtype(), expected_dtype);
562
563 let expected_dtype = &ext_dtype;
565 let actual = storage_scalar.cast(expected_dtype).unwrap();
566 assert_eq!(actual.dtype(), expected_dtype);
567
568 let expected_dtype = &ext_dtype.as_nullable();
570 let actual = storage_scalar.cast(expected_dtype).unwrap();
571 assert_eq!(actual.dtype(), expected_dtype);
572
573 let storage_scalar_u64 = Scalar::new(
575 DType::clone(apples.storage_dtype()),
576 ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
577 );
578 let expected_dtype = &ext_dtype;
579 let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
580 assert_eq!(actual.dtype(), expected_dtype);
581
582 let apples_u8 = ExtDType::new(
584 ExtID::new(Arc::from("apples")),
585 Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
586 None,
587 );
588 let expected_dtype = &DType::Extension(Arc::from(apples_u8));
589 let result = storage_scalar.cast(expected_dtype);
590 assert!(
591 result.as_ref().is_err_and(|err| {
592 err
593 .to_string()
594 .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
595 }),
596 "{result:?}"
597 );
598 }
599
600 #[test]
601 fn default_value_for_complex_dtype() {
602 let struct_dtype = DType::struct_(
603 [
604 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
605 (
606 "b",
607 DType::list(
608 DType::Primitive(PType::I8, Nullability::Nullable),
609 Nullability::NonNullable,
610 ),
611 ),
612 ("c", DType::Primitive(PType::I32, Nullability::Nullable)),
613 ],
614 Nullability::NonNullable,
615 );
616
617 let scalar = Scalar::default_value(struct_dtype.clone());
618 assert_eq!(scalar.dtype(), &struct_dtype);
619
620 let scalar = scalar.as_struct();
621
622 let a_field = scalar.field("a").unwrap();
623 assert_eq!(a_field.as_primitive().pvalue().unwrap(), PValue::I32(0));
624
625 let b_field = scalar.field("b").unwrap();
626 assert!(b_field.as_list().is_empty());
627
628 let c_field = scalar.field("c").unwrap();
629 assert!(c_field.is_null());
630 }
631}