1use std::cmp::Ordering;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::Buffer;
9use vortex_dtype::{DType, NativeDType, NativeDecimalType, Nullability, i256};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
11
12use super::*;
13
14#[derive(Debug, Clone)]
23pub struct Scalar {
24 dtype: DType,
26
27 value: ScalarValue,
32}
33
34impl Scalar {
35 pub fn new(dtype: DType, value: ScalarValue) -> Self {
37 if !dtype.is_nullable() {
38 assert!(
39 !value.is_null(),
40 "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}",
41 );
42 }
43
44 Self { dtype, value }
45 }
46
47 #[inline]
49 pub fn dtype(&self) -> &DType {
50 &self.dtype
51 }
52
53 #[inline]
55 pub fn value(&self) -> &ScalarValue {
56 &self.value
57 }
58
59 #[inline]
61 pub fn into_parts(self) -> (DType, ScalarValue) {
62 (self.dtype, self.value)
63 }
64
65 #[inline]
67 pub fn into_dtype(self) -> DType {
68 self.dtype
69 }
70
71 #[inline]
73 pub fn into_value(self) -> ScalarValue {
74 self.value
75 }
76
77 pub fn is_valid(&self) -> bool {
79 !self.value.is_null()
80 }
81
82 pub fn is_null(&self) -> bool {
84 self.value.is_null()
85 }
86
87 pub fn null(dtype: DType) -> Self {
93 assert!(
94 dtype.is_nullable(),
95 "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}"
96 );
97
98 Self {
99 dtype,
100 value: ScalarValue(InnerScalarValue::Null),
101 }
102 }
103
104 pub fn null_typed<T: NativeDType>() -> Self {
108 Self {
109 dtype: T::dtype().as_nullable(),
110 value: ScalarValue(InnerScalarValue::Null),
111 }
112 }
113
114 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
119 if let DType::Extension(ext_dtype) = target {
120 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
121 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
122 } else {
123 self.cast_to_non_extension(target)
124 }
125 }
126
127 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
128 assert!(!matches!(target, DType::Extension(..)));
129
130 if self.is_null() {
131 if target.is_nullable() {
132 return Ok(Scalar::new(target.clone(), self.value.clone()));
133 }
134
135 vortex_bail!("Cannot cast null to {target}: target type is non-nullable")
136 }
137
138 match &self.dtype {
139 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
141 DType::Primitive(..) => self.as_primitive().cast(target),
142 DType::Decimal(..) => self.as_decimal().cast(target),
143 DType::Utf8(_) => self.as_utf8().cast(target),
144 DType::Binary(_) => self.as_binary().cast(target),
145 DType::Struct(..) => self.as_struct().cast(target),
146 DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target),
147 DType::Extension(..) => self.as_extension().cast(target),
148 }
149 }
150
151 pub fn into_nullable(self) -> Self {
153 Self {
154 dtype: self.dtype.as_nullable(),
155 value: self.value,
156 }
157 }
158
159 pub fn nbytes(&self) -> usize {
161 match self.dtype() {
162 DType::Null => 0,
163 DType::Bool(_) => 1,
164 DType::Primitive(ptype, _) => ptype.byte_width(),
165 DType::Decimal(dt, _) => {
166 if dt.precision() <= i128::MAX_PRECISION {
167 size_of::<i128>()
168 } else {
169 size_of::<i256>()
170 }
171 }
172 DType::Binary(_) | DType::Utf8(_) => self
173 .value()
174 .as_buffer()
175 .ok()
176 .flatten()
177 .map_or(0, |s| s.len()),
178 DType::Struct(_dtype, _) => self
179 .as_struct()
180 .fields()
181 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
182 .unwrap_or_default(),
183 DType::List(..) | DType::FixedSizeList(..) => self
184 .as_list()
185 .elements()
186 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
187 .unwrap_or_default(),
188 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
189 }
190 }
191
192 pub fn zero_value(dtype: DType) -> Self {
213 match dtype {
214 DType::Null => Self::null(dtype),
215 DType::Bool(nullability) => Self::bool(false, nullability),
216 DType::Primitive(pt, nullability) => {
217 Self::primitive_value(PValue::zero(pt), pt, nullability)
218 }
219 DType::Decimal(dt, nullability) => {
220 Self::decimal(DecimalValue::from(0i8), dt, nullability)
221 }
222 DType::Utf8(nullability) => Self::utf8("", nullability),
223 DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
224 DType::List(edt, nullability) => Self::list(edt, vec![], nullability),
225 DType::FixedSizeList(edt, size, nullability) => {
226 let elements = (0..size)
227 .map(|_| Scalar::zero_value(edt.as_ref().clone()))
228 .collect();
229 Self::fixed_size_list(edt, elements, nullability)
230 }
231 DType::Struct(sf, nullability) => {
232 let fields: Vec<_> = sf.fields().map(Scalar::zero_value).collect();
233 Self::struct_(DType::Struct(sf, nullability), fields)
234 }
235 DType::Extension(dt) => {
236 let scalar = Self::zero_value(dt.storage_dtype().clone());
237 Self::extension(dt, scalar)
238 }
239 }
240 }
241
242 pub fn default_value(dtype: DType) -> Self {
264 if dtype.is_nullable() {
265 return Self::null(dtype);
266 }
267
268 match dtype {
269 DType::Null => Self::null(dtype),
270 DType::Bool(nullability) => Self::bool(false, nullability),
271 DType::Primitive(pt, nullability) => {
272 Self::primitive_value(PValue::zero(pt), pt, nullability)
273 }
274 DType::Decimal(dt, nullability) => {
275 Self::decimal(DecimalValue::from(0i8), dt, nullability)
276 }
277 DType::Utf8(nullability) => Self::utf8("", nullability),
278 DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
279 DType::List(edt, nullability) => Self::list(edt, vec![], nullability),
280 DType::FixedSizeList(edt, size, nullability) => {
281 let elements = (0..size)
282 .map(|_| Scalar::default_value(edt.as_ref().clone()))
283 .collect();
284 Self::fixed_size_list(edt, elements, nullability)
285 }
286 DType::Struct(sf, nullability) => {
287 let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
288 Self::struct_(DType::Struct(sf, nullability), fields)
289 }
290 DType::Extension(dt) => {
291 let scalar = Self::default_value(dt.storage_dtype().clone());
292 Self::extension(dt, scalar)
293 }
294 }
295 }
296}
297
298impl Scalar {
300 pub fn as_bool(&self) -> BoolScalar<'_> {
306 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
307 }
308
309 pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
311 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
312 }
313
314 pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
320 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
321 }
322
323 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
325 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
326 }
327
328 pub fn as_decimal(&self) -> DecimalScalar<'_> {
334 DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
335 }
336
337 pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
339 matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
340 }
341
342 pub fn as_utf8(&self) -> Utf8Scalar<'_> {
348 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
349 }
350
351 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
353 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
354 }
355
356 pub fn as_binary(&self) -> BinaryScalar<'_> {
362 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
363 }
364
365 pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
367 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
368 }
369
370 pub fn as_struct(&self) -> StructScalar<'_> {
376 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
377 }
378
379 pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
381 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
382 }
383
384 pub fn as_list(&self) -> ListScalar<'_> {
393 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
394 }
395
396 pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
401 matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list())
402 }
403
404 pub fn as_extension(&self) -> ExtScalar<'_> {
410 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
411 }
412
413 pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
415 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
416 }
417}
418
419impl<T> From<Option<T>> for Scalar
422where
423 T: NativeDType,
424 Scalar: From<T>,
425{
426 fn from(value: Option<T>) -> Self {
428 value
429 .map(Scalar::from)
430 .map(|x| x.into_nullable())
431 .unwrap_or_else(|| Scalar {
432 dtype: T::dtype().as_nullable(),
433 value: ScalarValue(InnerScalarValue::Null),
434 })
435 }
436}
437
438impl<T> From<Vec<T>> for Scalar
439where
440 T: NativeDType,
441 Scalar: From<T>,
442{
443 fn from(vec: Vec<T>) -> Self {
445 Scalar {
446 dtype: DType::List(Arc::from(T::dtype()), Nullability::NonNullable),
447 value: ScalarValue::from(vec),
448 }
449 }
450}
451
452impl<T> TryFrom<Scalar> for Vec<T>
453where
454 T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
455{
456 type Error = VortexError;
457
458 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
459 Vec::try_from(&value)
460 }
461}
462
463impl<'a, T> TryFrom<&'a Scalar> for Vec<T>
464where
465 T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
466{
467 type Error = VortexError;
468
469 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
470 ListScalar::try_from(value)?
471 .elements()
472 .ok_or_else(|| vortex_err!("Expected non-null list"))?
473 .into_iter()
474 .map(|e| T::try_from(&e))
475 .collect::<VortexResult<Vec<T>>>()
476 }
477}
478
479impl PartialEq for Scalar {
480 fn eq(&self, other: &Self) -> bool {
481 if !self.dtype.eq_ignore_nullability(&other.dtype) {
482 return false;
483 }
484
485 match self.dtype() {
486 DType::Null => true,
487 DType::Bool(_) => self.as_bool() == other.as_bool(),
488 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
489 DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
490 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
491 DType::Binary(_) => self.as_binary() == other.as_binary(),
492 DType::Struct(..) => self.as_struct() == other.as_struct(),
493 DType::List(..) | DType::FixedSizeList(..) => self.as_list() == other.as_list(),
494 DType::Extension(_) => self.as_extension() == other.as_extension(),
495 }
496 }
497}
498
499impl Eq for Scalar {}
500
501impl PartialOrd for Scalar {
502 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
531 if !self.dtype().eq_ignore_nullability(other.dtype()) {
532 return None;
533 }
534 match self.dtype() {
535 DType::Null => Some(Ordering::Equal),
536 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
537 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
538 DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
539 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
540 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
541 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
542 DType::List(..) | DType::FixedSizeList(..) => {
543 self.as_list().partial_cmp(&other.as_list())
544 }
545 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
546 }
547 }
548}
549
550impl Hash for Scalar {
551 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
552 match self.dtype() {
553 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
555 DType::Primitive(..) => self.as_primitive().hash(state),
556 DType::Decimal(..) => self.as_decimal().hash(state),
557 DType::Utf8(_) => self.as_utf8().hash(state),
558 DType::Binary(_) => self.as_binary().hash(state),
559 DType::Struct(..) => self.as_struct().hash(state),
560 DType::List(..) | DType::FixedSizeList(..) => self.as_list().hash(state),
561 DType::Extension(_) => self.as_extension().hash(state),
562 }
563 }
564}
565
566impl AsRef<Self> for Scalar {
567 fn as_ref(&self) -> &Self {
568 self
569 }
570}
571
572impl From<PrimitiveScalar<'_>> for Scalar {
573 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
574 let dtype = pscalar.dtype().clone();
575 let value = pscalar
576 .pvalue()
577 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
578 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
579 Self::new(dtype, value)
580 }
581}
582
583impl From<DecimalScalar<'_>> for Scalar {
584 fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
585 let dtype = decimal_scalar.dtype().clone();
586 let value = decimal_scalar
587 .decimal_value()
588 .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
589 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
590 Self::new(dtype, value)
591 }
592}