1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod binary;
13mod bool;
14mod datafusion;
15mod display;
16mod extension;
17mod list;
18mod null;
19mod primitive;
20mod pvalue;
21mod scalar_type;
22mod scalarvalue;
23#[cfg(feature = "serde")]
24mod serde;
25mod struct_;
26mod utf8;
27
28pub use binary::*;
29pub use bool::*;
30pub use extension::*;
31pub use list::*;
32pub use primitive::*;
33pub use pvalue::*;
34pub use scalarvalue::*;
35pub use struct_::*;
36pub use utf8::*;
37use vortex_error::{VortexExpect, VortexResult, vortex_bail};
38
39#[derive(Debug, Clone)]
48#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
49pub struct Scalar {
50 dtype: DType,
51 value: ScalarValue,
52}
53
54impl Scalar {
55 pub fn new(dtype: DType, value: ScalarValue) -> Self {
56 Self { dtype, value }
57 }
58
59 #[inline]
60 pub fn dtype(&self) -> &DType {
61 &self.dtype
62 }
63
64 #[inline]
65 pub fn value(&self) -> &ScalarValue {
66 &self.value
67 }
68
69 #[inline]
70 pub fn into_parts(self) -> (DType, ScalarValue) {
71 (self.dtype, self.value)
72 }
73
74 #[inline]
75 pub fn into_value(self) -> ScalarValue {
76 self.value
77 }
78
79 pub fn is_valid(&self) -> bool {
80 !self.value.is_null()
81 }
82
83 pub fn is_null(&self) -> bool {
84 self.value.is_null()
85 }
86
87 pub fn null(dtype: DType) -> Self {
88 assert!(
89 dtype.is_nullable(),
90 "Creating null scalar for non-nullable DType {}",
91 dtype
92 );
93 Self {
94 dtype,
95 value: ScalarValue(InnerScalarValue::Null),
96 }
97 }
98
99 pub fn null_typed<T: ScalarType>() -> Self {
100 Self {
101 dtype: T::dtype().as_nullable(),
102 value: ScalarValue(InnerScalarValue::Null),
103 }
104 }
105
106 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
107 if let DType::Extension(ext_dtype) = target {
108 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
109 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
110 } else {
111 self.cast_to_non_extension(target)
112 }
113 }
114
115 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
116 assert!(!matches!(target, DType::Extension(..)));
117 if self.is_null() {
118 if target.is_nullable() {
119 return Ok(Scalar::new(target.clone(), self.value.clone()));
120 } else {
121 vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
122 }
123 }
124
125 if self.dtype().eq_ignore_nullability(target) {
126 return Ok(Scalar::new(target.clone(), self.value.clone()));
127 }
128
129 match &self.dtype {
130 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
132 DType::Primitive(..) => self.as_primitive().cast(target),
133 DType::Utf8(_) => self.as_utf8().cast(target),
134 DType::Binary(_) => self.as_binary().cast(target),
135 DType::Struct(..) => self.as_struct().cast(target),
136 DType::List(..) => self.as_list().cast(target),
137 DType::Extension(..) => self.as_extension().cast(target),
138 }
139 }
140
141 pub fn into_nullable(self) -> Self {
142 Self {
143 dtype: self.dtype.as_nullable(),
144 value: self.value,
145 }
146 }
147
148 pub fn nbytes(&self) -> usize {
150 match self.dtype() {
151 DType::Null => 0,
152 DType::Bool(_) => 1,
153 DType::Primitive(ptype, _) => ptype.byte_width(),
154 DType::Binary(_) | DType::Utf8(_) => self
155 .value()
156 .as_buffer()
157 .ok()
158 .flatten()
159 .map_or(0, |s| s.len()),
160 DType::Struct(_dtype, _) => self
161 .as_struct()
162 .fields()
163 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
164 .unwrap_or_default(),
165 DType::List(_dtype, _) => self
166 .as_list()
167 .elements()
168 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
169 .unwrap_or_default(),
170 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
171 }
172 }
173}
174
175impl Scalar {
176 pub fn as_bool(&self) -> BoolScalar {
177 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
178 }
179
180 pub fn as_bool_opt(&self) -> Option<BoolScalar> {
181 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
182 }
183
184 pub fn as_primitive(&self) -> PrimitiveScalar {
185 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
186 }
187
188 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar> {
189 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
190 }
191
192 pub fn as_utf8(&self) -> Utf8Scalar {
193 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
194 }
195
196 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar> {
197 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
198 }
199
200 pub fn as_binary(&self) -> BinaryScalar {
201 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
202 }
203
204 pub fn as_binary_opt(&self) -> Option<BinaryScalar> {
205 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
206 }
207
208 pub fn as_struct(&self) -> StructScalar {
209 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
210 }
211
212 pub fn as_struct_opt(&self) -> Option<StructScalar> {
213 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
214 }
215
216 pub fn as_list(&self) -> ListScalar {
217 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
218 }
219
220 pub fn as_list_opt(&self) -> Option<ListScalar> {
221 matches!(self.dtype, DType::List(..)).then(|| self.as_list())
222 }
223
224 pub fn as_extension(&self) -> ExtScalar {
225 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
226 }
227
228 pub fn as_extension_opt(&self) -> Option<ExtScalar> {
229 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
230 }
231}
232
233impl PartialEq for Scalar {
234 fn eq(&self, other: &Self) -> bool {
235 if !self.dtype.eq_ignore_nullability(&other.dtype) {
236 return false;
237 }
238
239 match self.dtype() {
240 DType::Null => true,
241 DType::Bool(_) => self.as_bool() == other.as_bool(),
242 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
243 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
244 DType::Binary(_) => self.as_binary() == other.as_binary(),
245 DType::Struct(..) => self.as_struct() == other.as_struct(),
246 DType::List(..) => self.as_list() == other.as_list(),
247 DType::Extension(_) => self.as_extension() == other.as_extension(),
248 }
249 }
250}
251
252impl Eq for Scalar {}
253
254impl PartialOrd for Scalar {
255 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
256 if !self.dtype().eq_ignore_nullability(other.dtype()) {
257 return None;
258 }
259 match self.dtype() {
260 DType::Null => Some(Ordering::Equal),
261 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
262 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
263 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
264 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
265 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
266 DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
267 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
268 }
269 }
270}
271
272impl Hash for Scalar {
273 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
274 match self.dtype() {
275 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
277 DType::Primitive(..) => self.as_primitive().hash(state),
278 DType::Utf8(_) => self.as_utf8().hash(state),
279 DType::Binary(_) => self.as_binary().hash(state),
280 DType::Struct(..) => self.as_struct().hash(state),
281 DType::List(..) => self.as_list().hash(state),
282 DType::Extension(_) => self.as_extension().hash(state),
283 }
284 }
285}
286
287impl AsRef<Self> for Scalar {
288 fn as_ref(&self) -> &Self {
289 self
290 }
291}
292
293impl<T> From<Option<T>> for Scalar
294where
295 T: ScalarType,
296 Scalar: From<T>,
297{
298 fn from(value: Option<T>) -> Self {
299 value
300 .map(Scalar::from)
301 .map(|x| x.into_nullable())
302 .unwrap_or_else(|| Scalar {
303 dtype: T::dtype().as_nullable(),
304 value: ScalarValue(InnerScalarValue::Null),
305 })
306 }
307}
308
309impl From<PrimitiveScalar<'_>> for Scalar {
310 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
311 let dtype = pscalar.dtype().clone();
312 let value = pscalar
313 .pvalue()
314 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
315 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
316 Self::new(dtype, value)
317 }
318}
319
320macro_rules! from_vec_for_scalar {
321 ($T:ty) => {
322 impl From<Vec<$T>> for Scalar {
323 fn from(value: Vec<$T>) -> Self {
324 Scalar {
325 dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
326 value: ScalarValue(InnerScalarValue::List(
327 value
328 .into_iter()
329 .map(Scalar::from)
330 .map(|s| s.into_value())
331 .collect::<Arc<[_]>>(),
332 )),
333 }
334 }
335 }
336 };
337}
338
339from_vec_for_scalar!(u16);
341from_vec_for_scalar!(u32);
342from_vec_for_scalar!(u64);
343from_vec_for_scalar!(usize); from_vec_for_scalar!(i8);
345from_vec_for_scalar!(i16);
346from_vec_for_scalar!(i32);
347from_vec_for_scalar!(i64);
348from_vec_for_scalar!(f16);
349from_vec_for_scalar!(f32);
350from_vec_for_scalar!(f64);
351from_vec_for_scalar!(String);
352from_vec_for_scalar!(BufferString);
353from_vec_for_scalar!(ByteBuffer);
354
355#[cfg(test)]
356mod test {
357 use std::sync::Arc;
358
359 use rstest::rstest;
360 use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
361
362 use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
363
364 #[rstest]
365 fn null_can_cast_to_anything_nullable(
366 #[values(
367 DType::Null,
368 DType::Bool(Nullability::Nullable),
369 DType::Primitive(PType::I32, Nullability::Nullable),
370 DType::Extension(Arc::from(ExtDType::new(
371 ExtID::from("a"),
372 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
373 None,
374 ))),
375 DType::Extension(Arc::from(ExtDType::new(
376 ExtID::from("b"),
377 Arc::from(DType::Utf8(Nullability::Nullable)),
378 None,
379 )))
380 )]
381 source_dtype: DType,
382 #[values(
383 DType::Null,
384 DType::Bool(Nullability::Nullable),
385 DType::Primitive(PType::I32, Nullability::Nullable),
386 DType::Extension(Arc::from(ExtDType::new(
387 ExtID::from("a"),
388 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
389 None,
390 ))),
391 DType::Extension(Arc::from(ExtDType::new(
392 ExtID::from("b"),
393 Arc::from(DType::Utf8(Nullability::Nullable)),
394 None,
395 )))
396 )]
397 target_dtype: DType,
398 ) {
399 assert_eq!(
400 Scalar::null(source_dtype)
401 .cast(&target_dtype)
402 .unwrap()
403 .dtype(),
404 &target_dtype
405 );
406 }
407
408 #[test]
409 fn list_casts() {
410 let list = Scalar::new(
411 DType::List(
412 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
413 Nullability::Nullable,
414 ),
415 ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
416 InnerScalarValue::Primitive(PValue::U16(6)),
417 )]))),
418 );
419
420 let target_u32 = DType::List(
421 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
422 Nullability::Nullable,
423 );
424 assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
425
426 let target_u32_nonnull = DType::List(
427 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
428 Nullability::Nullable,
429 );
430 assert_eq!(
431 list.cast(&target_u32_nonnull).unwrap().dtype(),
432 &target_u32_nonnull
433 );
434
435 let target_nonnull = DType::List(
436 Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
437 Nullability::NonNullable,
438 );
439 assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
440
441 let target_u8 = DType::List(
442 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
443 Nullability::Nullable,
444 );
445 assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
446
447 let list_with_null = Scalar::new(
448 DType::List(
449 Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
450 Nullability::Nullable,
451 ),
452 ScalarValue(InnerScalarValue::List(Arc::from([
453 ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
454 ScalarValue(InnerScalarValue::Null),
455 ]))),
456 );
457 let target_u8 = DType::List(
458 Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
459 Nullability::Nullable,
460 );
461 assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
462
463 let target_u32_nonnull = DType::List(
464 Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
465 Nullability::Nullable,
466 );
467 assert!(list_with_null.cast(&target_u32_nonnull).is_err());
468 }
469
470 #[test]
471 fn cast_to_from_extension_types() {
472 let apples = ExtDType::new(
473 ExtID::new(Arc::from("apples")),
474 Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
475 None,
476 );
477 let ext_dtype = DType::Extension(Arc::from(apples.clone()));
478 let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
479 let storage_scalar = Scalar::new(
480 DType::clone(apples.storage_dtype()),
481 ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
482 );
483
484 let expected_dtype = &ext_dtype;
486 let actual = ext_scalar.cast(expected_dtype).unwrap();
487 assert_eq!(actual.dtype(), expected_dtype);
488
489 let expected_dtype = &ext_dtype.as_nullable();
491 let actual = ext_scalar.cast(expected_dtype).unwrap();
492 assert_eq!(actual.dtype(), expected_dtype);
493
494 let expected_dtype = apples.storage_dtype();
496 let actual = ext_scalar.cast(expected_dtype).unwrap();
497 assert_eq!(actual.dtype(), expected_dtype);
498
499 let expected_dtype = &apples.storage_dtype().as_nullable();
501 let actual = ext_scalar.cast(expected_dtype).unwrap();
502 assert_eq!(actual.dtype(), expected_dtype);
503
504 let expected_dtype = &ext_dtype;
506 let actual = storage_scalar.cast(expected_dtype).unwrap();
507 assert_eq!(actual.dtype(), expected_dtype);
508
509 let expected_dtype = &ext_dtype.as_nullable();
511 let actual = storage_scalar.cast(expected_dtype).unwrap();
512 assert_eq!(actual.dtype(), expected_dtype);
513
514 let storage_scalar_u64 = Scalar::new(
516 DType::clone(apples.storage_dtype()),
517 ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
518 );
519 let expected_dtype = &ext_dtype;
520 let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
521 assert_eq!(actual.dtype(), expected_dtype);
522
523 let apples_u8 = ExtDType::new(
525 ExtID::new(Arc::from("apples")),
526 Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
527 None,
528 );
529 let expected_dtype = &DType::Extension(Arc::from(apples_u8));
530 let result = storage_scalar.cast(expected_dtype);
531 assert!(
532 result.as_ref().is_err_and(|err| {
533 err
534 .to_string()
535 .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
536 }),
537 "{:?}",
538 result
539 );
540 }
541}