1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod bigint;
13mod binary;
14mod bool;
15mod datafusion;
16mod decimal;
17mod display;
18mod extension;
19mod list;
20mod null;
21mod primitive;
22mod proto;
23mod pvalue;
24mod scalar_type;
25mod scalar_value;
26mod struct_;
27mod utf8;
28
29pub use bigint::*;
30pub use binary::*;
31pub use bool::*;
32pub use decimal::*;
33pub use extension::*;
34pub use list::*;
35pub use primitive::*;
36pub use pvalue::*;
37pub use scalar_value::*;
38pub use struct_::*;
39pub use utf8::*;
40use vortex_error::{VortexExpect, VortexResult, vortex_bail};
41
42#[derive(Debug, Clone)]
51pub struct Scalar {
52 dtype: DType,
53 value: ScalarValue,
54}
55
56impl Scalar {
57 pub fn new(dtype: DType, value: ScalarValue) -> Self {
58 Self { dtype, value }
59 }
60
61 #[inline]
62 pub fn dtype(&self) -> &DType {
63 &self.dtype
64 }
65
66 #[inline]
67 pub fn value(&self) -> &ScalarValue {
68 &self.value
69 }
70
71 #[inline]
72 pub fn into_parts(self) -> (DType, ScalarValue) {
73 (self.dtype, self.value)
74 }
75
76 #[inline]
77 pub fn into_value(self) -> ScalarValue {
78 self.value
79 }
80
81 pub fn is_valid(&self) -> bool {
82 !self.value.is_null()
83 }
84
85 pub fn is_null(&self) -> bool {
86 self.value.is_null()
87 }
88
89 pub fn null(dtype: DType) -> Self {
90 assert!(
91 dtype.is_nullable(),
92 "Creating null scalar for non-nullable DType {dtype}"
93 );
94 Self {
95 dtype,
96 value: ScalarValue(InnerScalarValue::Null),
97 }
98 }
99
100 pub fn null_typed<T: ScalarType>() -> Self {
101 Self {
102 dtype: T::dtype().as_nullable(),
103 value: ScalarValue(InnerScalarValue::Null),
104 }
105 }
106
107 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
108 if let DType::Extension(ext_dtype) = target {
109 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
110 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
111 } else {
112 self.cast_to_non_extension(target)
113 }
114 }
115
116 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
117 assert!(!matches!(target, DType::Extension(..)));
118 if self.is_null() {
119 if target.is_nullable() {
120 return Ok(Scalar::new(target.clone(), self.value.clone()));
121 } else {
122 vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
123 }
124 }
125
126 if self.dtype().eq_ignore_nullability(target) {
127 return Ok(Scalar::new(target.clone(), self.value.clone()));
128 }
129
130 match &self.dtype {
131 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
133 DType::Primitive(..) => self.as_primitive().cast(target),
134 DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
135 DType::Utf8(_) => self.as_utf8().cast(target),
136 DType::Binary(_) => self.as_binary().cast(target),
137 DType::Struct(..) => self.as_struct().cast(target),
138 DType::List(..) => self.as_list().cast(target),
139 DType::Extension(..) => self.as_extension().cast(target),
140 }
141 }
142
143 pub fn into_nullable(self) -> Self {
144 Self {
145 dtype: self.dtype.as_nullable(),
146 value: self.value,
147 }
148 }
149
150 pub fn nbytes(&self) -> usize {
152 match self.dtype() {
153 DType::Null => 0,
154 DType::Bool(_) => 1,
155 DType::Primitive(ptype, _) => ptype.byte_width(),
156 DType::Decimal(dt, _) => {
157 if dt.precision() >= DECIMAL128_MAX_PRECISION {
158 size_of::<i128>()
159 } else {
160 size_of::<i256>()
161 }
162 }
163 DType::Binary(_) | DType::Utf8(_) => self
164 .value()
165 .as_buffer()
166 .ok()
167 .flatten()
168 .map_or(0, |s| s.len()),
169 DType::Struct(_dtype, _) => self
170 .as_struct()
171 .fields()
172 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
173 .unwrap_or_default(),
174 DType::List(_dtype, _) => self
175 .as_list()
176 .elements()
177 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
178 .unwrap_or_default(),
179 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
180 }
181 }
182}
183
184impl Scalar {
185 pub fn as_bool(&self) -> BoolScalar {
186 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
187 }
188
189 pub fn as_bool_opt(&self) -> Option<BoolScalar> {
190 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
191 }
192
193 pub fn as_primitive(&self) -> PrimitiveScalar {
194 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
195 }
196
197 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar> {
198 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
199 }
200
201 pub fn as_decimal(&self) -> DecimalScalar {
202 DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
203 }
204
205 pub fn as_decimal_opt(&self) -> Option<DecimalScalar> {
206 matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
207 }
208
209 pub fn as_utf8(&self) -> Utf8Scalar {
210 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
211 }
212
213 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar> {
214 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
215 }
216
217 pub fn as_binary(&self) -> BinaryScalar {
218 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
219 }
220
221 pub fn as_binary_opt(&self) -> Option<BinaryScalar> {
222 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
223 }
224
225 pub fn as_struct(&self) -> StructScalar {
226 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
227 }
228
229 pub fn as_struct_opt(&self) -> Option<StructScalar> {
230 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
231 }
232
233 pub fn as_list(&self) -> ListScalar {
234 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
235 }
236
237 pub fn as_list_opt(&self) -> Option<ListScalar> {
238 matches!(self.dtype, DType::List(..)).then(|| self.as_list())
239 }
240
241 pub fn as_extension(&self) -> ExtScalar {
242 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
243 }
244
245 pub fn as_extension_opt(&self) -> Option<ExtScalar> {
246 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
247 }
248}
249
250impl PartialEq for Scalar {
251 fn eq(&self, other: &Self) -> bool {
252 if !self.dtype.eq_ignore_nullability(&other.dtype) {
253 return false;
254 }
255
256 match self.dtype() {
257 DType::Null => true,
258 DType::Bool(_) => self.as_bool() == other.as_bool(),
259 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
260 DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
261 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
262 DType::Binary(_) => self.as_binary() == other.as_binary(),
263 DType::Struct(..) => self.as_struct() == other.as_struct(),
264 DType::List(..) => self.as_list() == other.as_list(),
265 DType::Extension(_) => self.as_extension() == other.as_extension(),
266 }
267 }
268}
269
270impl Eq for Scalar {}
271
272impl PartialOrd for Scalar {
273 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
274 if !self.dtype().eq_ignore_nullability(other.dtype()) {
275 return None;
276 }
277 match self.dtype() {
278 DType::Null => Some(Ordering::Equal),
279 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
280 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
281 DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
282 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
283 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
284 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
285 DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
286 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
287 }
288 }
289}
290
291impl Hash for Scalar {
292 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
293 match self.dtype() {
294 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
296 DType::Primitive(..) => self.as_primitive().hash(state),
297 DType::Decimal(..) => self.as_decimal().hash(state),
298 DType::Utf8(_) => self.as_utf8().hash(state),
299 DType::Binary(_) => self.as_binary().hash(state),
300 DType::Struct(..) => self.as_struct().hash(state),
301 DType::List(..) => self.as_list().hash(state),
302 DType::Extension(_) => self.as_extension().hash(state),
303 }
304 }
305}
306
307impl AsRef<Self> for Scalar {
308 fn as_ref(&self) -> &Self {
309 self
310 }
311}
312
313impl<T> From<Option<T>> for Scalar
314where
315 T: ScalarType,
316 Scalar: From<T>,
317{
318 fn from(value: Option<T>) -> Self {
319 value
320 .map(Scalar::from)
321 .map(|x| x.into_nullable())
322 .unwrap_or_else(|| Scalar {
323 dtype: T::dtype().as_nullable(),
324 value: ScalarValue(InnerScalarValue::Null),
325 })
326 }
327}
328
329impl From<PrimitiveScalar<'_>> for Scalar {
330 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
331 let dtype = pscalar.dtype().clone();
332 let value = pscalar
333 .pvalue()
334 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
335 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
336 Self::new(dtype, value)
337 }
338}
339
340impl From<DecimalScalar<'_>> for Scalar {
341 fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
342 let dtype = decimal_scalar.dtype().clone();
343 let value = decimal_scalar
344 .decimal_value()
345 .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
346 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
347 Self::new(dtype, value)
348 }
349}
350
351macro_rules! from_vec_for_scalar {
352 ($T:ty) => {
353 impl From<Vec<$T>> for Scalar {
354 fn from(value: Vec<$T>) -> Self {
355 Scalar {
356 dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
357 value: ScalarValue(InnerScalarValue::List(
358 value
359 .into_iter()
360 .map(Scalar::from)
361 .map(|s| s.into_value())
362 .collect::<Arc<[_]>>(),
363 )),
364 }
365 }
366 }
367 };
368}
369
370from_vec_for_scalar!(u16);
372from_vec_for_scalar!(u32);
373from_vec_for_scalar!(u64);
374from_vec_for_scalar!(usize); from_vec_for_scalar!(i8);
376from_vec_for_scalar!(i16);
377from_vec_for_scalar!(i32);
378from_vec_for_scalar!(i64);
379from_vec_for_scalar!(f16);
380from_vec_for_scalar!(f32);
381from_vec_for_scalar!(f64);
382from_vec_for_scalar!(String);
383from_vec_for_scalar!(BufferString);
384from_vec_for_scalar!(ByteBuffer);
385
386#[cfg(test)]
387mod test {
388 use std::sync::Arc;
389
390 use rstest::rstest;
391 use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
392
393 use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
394
395 #[rstest]
396 fn null_can_cast_to_anything_nullable(
397 #[values(
398 DType::Null,
399 DType::Bool(Nullability::Nullable),
400 DType::Primitive(PType::I32, Nullability::Nullable),
401 DType::Extension(Arc::from(ExtDType::new(
402 ExtID::from("a"),
403 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
404 None,
405 ))),
406 DType::Extension(Arc::from(ExtDType::new(
407 ExtID::from("b"),
408 Arc::from(DType::Utf8(Nullability::Nullable)),
409 None,
410 )))
411 )]
412 source_dtype: DType,
413 #[values(
414 DType::Null,
415 DType::Bool(Nullability::Nullable),
416 DType::Primitive(PType::I32, Nullability::Nullable),
417 DType::Extension(Arc::from(ExtDType::new(
418 ExtID::from("a"),
419 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
420 None,
421 ))),
422 DType::Extension(Arc::from(ExtDType::new(
423 ExtID::from("b"),
424 Arc::from(DType::Utf8(Nullability::Nullable)),
425 None,
426 )))
427 )]
428 target_dtype: DType,
429 ) {
430 assert_eq!(
431 Scalar::null(source_dtype)
432 .cast(&target_dtype)
433 .unwrap()
434 .dtype(),
435 &target_dtype
436 );
437 }
438
439 #[test]
440 fn list_casts() {
441 let list = Scalar::new(
442 DType::List(
443 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
444 Nullability::Nullable,
445 ),
446 ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
447 InnerScalarValue::Primitive(PValue::U16(6)),
448 )]))),
449 );
450
451 let target_u32 = DType::List(
452 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
453 Nullability::Nullable,
454 );
455 assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
456
457 let target_u32_nonnull = DType::List(
458 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
459 Nullability::Nullable,
460 );
461 assert_eq!(
462 list.cast(&target_u32_nonnull).unwrap().dtype(),
463 &target_u32_nonnull
464 );
465
466 let target_nonnull = DType::List(
467 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
468 Nullability::NonNullable,
469 );
470 assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
471
472 let target_u8 = DType::List(
473 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
474 Nullability::Nullable,
475 );
476 assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
477
478 let list_with_null = Scalar::new(
479 DType::List(
480 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
481 Nullability::Nullable,
482 ),
483 ScalarValue(InnerScalarValue::List(Arc::from([
484 ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
485 ScalarValue(InnerScalarValue::Null),
486 ]))),
487 );
488 let target_u8 = DType::List(
489 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
490 Nullability::Nullable,
491 );
492 assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
493
494 let target_u32_nonnull = DType::List(
495 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
496 Nullability::Nullable,
497 );
498 assert!(list_with_null.cast(&target_u32_nonnull).is_err());
499 }
500
501 #[test]
502 fn cast_to_from_extension_types() {
503 let apples = ExtDType::new(
504 ExtID::new(Arc::from("apples")),
505 Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
506 None,
507 );
508 let ext_dtype = DType::Extension(Arc::from(apples.clone()));
509 let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
510 let storage_scalar = Scalar::new(
511 DType::clone(apples.storage_dtype()),
512 ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
513 );
514
515 let expected_dtype = &ext_dtype;
517 let actual = ext_scalar.cast(expected_dtype).unwrap();
518 assert_eq!(actual.dtype(), expected_dtype);
519
520 let expected_dtype = &ext_dtype.as_nullable();
522 let actual = ext_scalar.cast(expected_dtype).unwrap();
523 assert_eq!(actual.dtype(), expected_dtype);
524
525 let expected_dtype = apples.storage_dtype();
527 let actual = ext_scalar.cast(expected_dtype).unwrap();
528 assert_eq!(actual.dtype(), expected_dtype);
529
530 let expected_dtype = &apples.storage_dtype().as_nullable();
532 let actual = ext_scalar.cast(expected_dtype).unwrap();
533 assert_eq!(actual.dtype(), expected_dtype);
534
535 let expected_dtype = &ext_dtype;
537 let actual = storage_scalar.cast(expected_dtype).unwrap();
538 assert_eq!(actual.dtype(), expected_dtype);
539
540 let expected_dtype = &ext_dtype.as_nullable();
542 let actual = storage_scalar.cast(expected_dtype).unwrap();
543 assert_eq!(actual.dtype(), expected_dtype);
544
545 let storage_scalar_u64 = Scalar::new(
547 DType::clone(apples.storage_dtype()),
548 ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
549 );
550 let expected_dtype = &ext_dtype;
551 let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
552 assert_eq!(actual.dtype(), expected_dtype);
553
554 let apples_u8 = ExtDType::new(
556 ExtID::new(Arc::from("apples")),
557 Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
558 None,
559 );
560 let expected_dtype = &DType::Extension(Arc::from(apples_u8));
561 let result = storage_scalar.cast(expected_dtype);
562 assert!(
563 result.as_ref().is_err_and(|err| {
564 err
565 .to_string()
566 .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
567 }),
568 "{result:?}"
569 );
570 }
571}