1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod bigint;
13mod binary;
14mod bool;
15mod decimal;
16mod display;
17mod extension;
18mod list;
19mod null;
20mod primitive;
21mod proto;
22mod pvalue;
23mod scalar_type;
24mod scalar_value;
25mod struct_;
26mod utf8;
27
28pub use bigint::*;
29pub use binary::*;
30pub use bool::*;
31pub use decimal::*;
32pub use extension::*;
33pub use list::*;
34pub use primitive::*;
35pub use pvalue::*;
36pub use scalar_value::*;
37pub use struct_::*;
38pub use utf8::*;
39use vortex_error::{VortexExpect, VortexResult, vortex_bail};
40
41#[derive(Debug, Clone)]
50pub struct Scalar {
51 dtype: DType,
52 value: ScalarValue,
53}
54
55impl Scalar {
56 pub fn new(dtype: DType, value: ScalarValue) -> Self {
57 Self { dtype, value }
58 }
59
60 #[inline]
61 pub fn dtype(&self) -> &DType {
62 &self.dtype
63 }
64
65 #[inline]
66 pub fn value(&self) -> &ScalarValue {
67 &self.value
68 }
69
70 #[inline]
71 pub fn into_parts(self) -> (DType, ScalarValue) {
72 (self.dtype, self.value)
73 }
74
75 #[inline]
76 pub fn into_value(self) -> ScalarValue {
77 self.value
78 }
79
80 pub fn is_valid(&self) -> bool {
81 !self.value.is_null()
82 }
83
84 pub fn is_null(&self) -> bool {
85 self.value.is_null()
86 }
87
88 pub fn null(dtype: DType) -> Self {
89 assert!(
90 dtype.is_nullable(),
91 "Creating null scalar for non-nullable DType {dtype}"
92 );
93 Self {
94 dtype,
95 value: ScalarValue(InnerScalarValue::Null),
96 }
97 }
98
99 pub fn null_typed<T: ScalarType>() -> Self {
100 Self {
101 dtype: T::dtype().as_nullable(),
102 value: ScalarValue(InnerScalarValue::Null),
103 }
104 }
105
106 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
107 if let DType::Extension(ext_dtype) = target {
108 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
109 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
110 } else {
111 self.cast_to_non_extension(target)
112 }
113 }
114
115 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
116 assert!(!matches!(target, DType::Extension(..)));
117 if self.is_null() {
118 if target.is_nullable() {
119 return Ok(Scalar::new(target.clone(), self.value.clone()));
120 } else {
121 vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
122 }
123 }
124
125 if self.dtype().eq_ignore_nullability(target) {
126 return Ok(Scalar::new(target.clone(), self.value.clone()));
127 }
128
129 match &self.dtype {
130 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
132 DType::Primitive(..) => self.as_primitive().cast(target),
133 DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
134 DType::Utf8(_) => self.as_utf8().cast(target),
135 DType::Binary(_) => self.as_binary().cast(target),
136 DType::Struct(..) => self.as_struct().cast(target),
137 DType::List(..) => self.as_list().cast(target),
138 DType::Extension(..) => self.as_extension().cast(target),
139 }
140 }
141
142 pub fn into_nullable(self) -> Self {
143 Self {
144 dtype: self.dtype.as_nullable(),
145 value: self.value,
146 }
147 }
148
149 pub fn nbytes(&self) -> usize {
151 match self.dtype() {
152 DType::Null => 0,
153 DType::Bool(_) => 1,
154 DType::Primitive(ptype, _) => ptype.byte_width(),
155 DType::Decimal(dt, _) => {
156 if dt.precision() >= DECIMAL128_MAX_PRECISION {
157 size_of::<i128>()
158 } else {
159 size_of::<i256>()
160 }
161 }
162 DType::Binary(_) | DType::Utf8(_) => self
163 .value()
164 .as_buffer()
165 .ok()
166 .flatten()
167 .map_or(0, |s| s.len()),
168 DType::Struct(_dtype, _) => self
169 .as_struct()
170 .fields()
171 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
172 .unwrap_or_default(),
173 DType::List(_dtype, _) => self
174 .as_list()
175 .elements()
176 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
177 .unwrap_or_default(),
178 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
179 }
180 }
181}
182
183impl Scalar {
184 pub fn as_bool(&self) -> BoolScalar {
185 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
186 }
187
188 pub fn as_bool_opt(&self) -> Option<BoolScalar> {
189 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
190 }
191
192 pub fn as_primitive(&self) -> PrimitiveScalar {
193 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
194 }
195
196 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar> {
197 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
198 }
199
200 pub fn as_decimal(&self) -> DecimalScalar {
201 DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
202 }
203
204 pub fn as_decimal_opt(&self) -> Option<DecimalScalar> {
205 matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
206 }
207
208 pub fn as_utf8(&self) -> Utf8Scalar {
209 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
210 }
211
212 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar> {
213 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
214 }
215
216 pub fn as_binary(&self) -> BinaryScalar {
217 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
218 }
219
220 pub fn as_binary_opt(&self) -> Option<BinaryScalar> {
221 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
222 }
223
224 pub fn as_struct(&self) -> StructScalar {
225 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
226 }
227
228 pub fn as_struct_opt(&self) -> Option<StructScalar> {
229 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
230 }
231
232 pub fn as_list(&self) -> ListScalar {
233 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
234 }
235
236 pub fn as_list_opt(&self) -> Option<ListScalar> {
237 matches!(self.dtype, DType::List(..)).then(|| self.as_list())
238 }
239
240 pub fn as_extension(&self) -> ExtScalar {
241 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
242 }
243
244 pub fn as_extension_opt(&self) -> Option<ExtScalar> {
245 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
246 }
247}
248
249impl PartialEq for Scalar {
250 fn eq(&self, other: &Self) -> bool {
251 if !self.dtype.eq_ignore_nullability(&other.dtype) {
252 return false;
253 }
254
255 match self.dtype() {
256 DType::Null => true,
257 DType::Bool(_) => self.as_bool() == other.as_bool(),
258 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
259 DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
260 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
261 DType::Binary(_) => self.as_binary() == other.as_binary(),
262 DType::Struct(..) => self.as_struct() == other.as_struct(),
263 DType::List(..) => self.as_list() == other.as_list(),
264 DType::Extension(_) => self.as_extension() == other.as_extension(),
265 }
266 }
267}
268
269impl Eq for Scalar {}
270
271impl PartialOrd for Scalar {
272 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
273 if !self.dtype().eq_ignore_nullability(other.dtype()) {
274 return None;
275 }
276 match self.dtype() {
277 DType::Null => Some(Ordering::Equal),
278 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
279 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
280 DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
281 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
282 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
283 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
284 DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
285 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
286 }
287 }
288}
289
290impl Hash for Scalar {
291 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
292 match self.dtype() {
293 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
295 DType::Primitive(..) => self.as_primitive().hash(state),
296 DType::Decimal(..) => self.as_decimal().hash(state),
297 DType::Utf8(_) => self.as_utf8().hash(state),
298 DType::Binary(_) => self.as_binary().hash(state),
299 DType::Struct(..) => self.as_struct().hash(state),
300 DType::List(..) => self.as_list().hash(state),
301 DType::Extension(_) => self.as_extension().hash(state),
302 }
303 }
304}
305
306impl AsRef<Self> for Scalar {
307 fn as_ref(&self) -> &Self {
308 self
309 }
310}
311
312impl<T> From<Option<T>> for Scalar
313where
314 T: ScalarType,
315 Scalar: From<T>,
316{
317 fn from(value: Option<T>) -> Self {
318 value
319 .map(Scalar::from)
320 .map(|x| x.into_nullable())
321 .unwrap_or_else(|| Scalar {
322 dtype: T::dtype().as_nullable(),
323 value: ScalarValue(InnerScalarValue::Null),
324 })
325 }
326}
327
328impl From<PrimitiveScalar<'_>> for Scalar {
329 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
330 let dtype = pscalar.dtype().clone();
331 let value = pscalar
332 .pvalue()
333 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
334 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
335 Self::new(dtype, value)
336 }
337}
338
339impl From<DecimalScalar<'_>> for Scalar {
340 fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
341 let dtype = decimal_scalar.dtype().clone();
342 let value = decimal_scalar
343 .decimal_value()
344 .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
345 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
346 Self::new(dtype, value)
347 }
348}
349
350macro_rules! from_vec_for_scalar {
351 ($T:ty) => {
352 impl From<Vec<$T>> for Scalar {
353 fn from(value: Vec<$T>) -> Self {
354 Scalar {
355 dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
356 value: ScalarValue(InnerScalarValue::List(
357 value
358 .into_iter()
359 .map(Scalar::from)
360 .map(|s| s.into_value())
361 .collect::<Arc<[_]>>(),
362 )),
363 }
364 }
365 }
366 };
367}
368
369from_vec_for_scalar!(u16);
371from_vec_for_scalar!(u32);
372from_vec_for_scalar!(u64);
373from_vec_for_scalar!(usize); from_vec_for_scalar!(i8);
375from_vec_for_scalar!(i16);
376from_vec_for_scalar!(i32);
377from_vec_for_scalar!(i64);
378from_vec_for_scalar!(f16);
379from_vec_for_scalar!(f32);
380from_vec_for_scalar!(f64);
381from_vec_for_scalar!(String);
382from_vec_for_scalar!(BufferString);
383from_vec_for_scalar!(ByteBuffer);
384
385#[cfg(test)]
386mod test {
387 use std::sync::Arc;
388
389 use rstest::rstest;
390 use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
391
392 use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
393
394 #[rstest]
395 fn null_can_cast_to_anything_nullable(
396 #[values(
397 DType::Null,
398 DType::Bool(Nullability::Nullable),
399 DType::Primitive(PType::I32, Nullability::Nullable),
400 DType::Extension(Arc::from(ExtDType::new(
401 ExtID::from("a"),
402 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
403 None,
404 ))),
405 DType::Extension(Arc::from(ExtDType::new(
406 ExtID::from("b"),
407 Arc::from(DType::Utf8(Nullability::Nullable)),
408 None,
409 )))
410 )]
411 source_dtype: DType,
412 #[values(
413 DType::Null,
414 DType::Bool(Nullability::Nullable),
415 DType::Primitive(PType::I32, Nullability::Nullable),
416 DType::Extension(Arc::from(ExtDType::new(
417 ExtID::from("a"),
418 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
419 None,
420 ))),
421 DType::Extension(Arc::from(ExtDType::new(
422 ExtID::from("b"),
423 Arc::from(DType::Utf8(Nullability::Nullable)),
424 None,
425 )))
426 )]
427 target_dtype: DType,
428 ) {
429 assert_eq!(
430 Scalar::null(source_dtype)
431 .cast(&target_dtype)
432 .unwrap()
433 .dtype(),
434 &target_dtype
435 );
436 }
437
438 #[test]
439 fn list_casts() {
440 let list = Scalar::new(
441 DType::List(
442 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
443 Nullability::Nullable,
444 ),
445 ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
446 InnerScalarValue::Primitive(PValue::U16(6)),
447 )]))),
448 );
449
450 let target_u32 = DType::List(
451 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
452 Nullability::Nullable,
453 );
454 assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
455
456 let target_u32_nonnull = DType::List(
457 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
458 Nullability::Nullable,
459 );
460 assert_eq!(
461 list.cast(&target_u32_nonnull).unwrap().dtype(),
462 &target_u32_nonnull
463 );
464
465 let target_nonnull = DType::List(
466 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
467 Nullability::NonNullable,
468 );
469 assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
470
471 let target_u8 = DType::List(
472 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
473 Nullability::Nullable,
474 );
475 assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
476
477 let list_with_null = Scalar::new(
478 DType::List(
479 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
480 Nullability::Nullable,
481 ),
482 ScalarValue(InnerScalarValue::List(Arc::from([
483 ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
484 ScalarValue(InnerScalarValue::Null),
485 ]))),
486 );
487 let target_u8 = DType::List(
488 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
489 Nullability::Nullable,
490 );
491 assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
492
493 let target_u32_nonnull = DType::List(
494 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
495 Nullability::Nullable,
496 );
497 assert!(list_with_null.cast(&target_u32_nonnull).is_err());
498 }
499
500 #[test]
501 fn cast_to_from_extension_types() {
502 let apples = ExtDType::new(
503 ExtID::new(Arc::from("apples")),
504 Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
505 None,
506 );
507 let ext_dtype = DType::Extension(Arc::from(apples.clone()));
508 let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
509 let storage_scalar = Scalar::new(
510 DType::clone(apples.storage_dtype()),
511 ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
512 );
513
514 let expected_dtype = &ext_dtype;
516 let actual = ext_scalar.cast(expected_dtype).unwrap();
517 assert_eq!(actual.dtype(), expected_dtype);
518
519 let expected_dtype = &ext_dtype.as_nullable();
521 let actual = ext_scalar.cast(expected_dtype).unwrap();
522 assert_eq!(actual.dtype(), expected_dtype);
523
524 let expected_dtype = apples.storage_dtype();
526 let actual = ext_scalar.cast(expected_dtype).unwrap();
527 assert_eq!(actual.dtype(), expected_dtype);
528
529 let expected_dtype = &apples.storage_dtype().as_nullable();
531 let actual = ext_scalar.cast(expected_dtype).unwrap();
532 assert_eq!(actual.dtype(), expected_dtype);
533
534 let expected_dtype = &ext_dtype;
536 let actual = storage_scalar.cast(expected_dtype).unwrap();
537 assert_eq!(actual.dtype(), expected_dtype);
538
539 let expected_dtype = &ext_dtype.as_nullable();
541 let actual = storage_scalar.cast(expected_dtype).unwrap();
542 assert_eq!(actual.dtype(), expected_dtype);
543
544 let storage_scalar_u64 = Scalar::new(
546 DType::clone(apples.storage_dtype()),
547 ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
548 );
549 let expected_dtype = &ext_dtype;
550 let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
551 assert_eq!(actual.dtype(), expected_dtype);
552
553 let apples_u8 = ExtDType::new(
555 ExtID::new(Arc::from("apples")),
556 Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
557 None,
558 );
559 let expected_dtype = &DType::Extension(Arc::from(apples_u8));
560 let result = storage_scalar.cast(expected_dtype);
561 assert!(
562 result.as_ref().is_err_and(|err| {
563 err
564 .to_string()
565 .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
566 }),
567 "{result:?}"
568 );
569 }
570}