1use std::cmp::Ordering;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::Buffer;
9use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
11
12use super::*;
13
14#[derive(Debug, Clone)]
23pub struct Scalar {
24 dtype: DType,
26
27 value: ScalarValue,
32}
33
34impl Scalar {
35 pub fn new(dtype: DType, value: ScalarValue) -> Self {
37 if !dtype.is_nullable() {
38 assert!(
39 !value.is_null(),
40 "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}",
41 );
42 }
43
44 Self { dtype, value }
45 }
46
47 #[inline]
49 pub fn dtype(&self) -> &DType {
50 &self.dtype
51 }
52
53 #[inline]
55 pub fn value(&self) -> &ScalarValue {
56 &self.value
57 }
58
59 #[inline]
61 pub fn into_parts(self) -> (DType, ScalarValue) {
62 (self.dtype, self.value)
63 }
64
65 #[inline]
67 pub fn into_dtype(self) -> DType {
68 self.dtype
69 }
70
71 #[inline]
73 pub fn into_value(self) -> ScalarValue {
74 self.value
75 }
76
77 pub fn is_valid(&self) -> bool {
79 !self.value.is_null()
80 }
81
82 pub fn is_null(&self) -> bool {
84 self.value.is_null()
85 }
86
87 pub fn null(dtype: DType) -> Self {
93 assert!(
94 dtype.is_nullable(),
95 "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}"
96 );
97
98 Self {
99 dtype,
100 value: ScalarValue(InnerScalarValue::Null),
101 }
102 }
103
104 pub fn null_typed<T: ScalarType>() -> Self {
108 Self {
109 dtype: T::dtype().as_nullable(),
110 value: ScalarValue(InnerScalarValue::Null),
111 }
112 }
113
114 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
119 if let DType::Extension(ext_dtype) = target {
120 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
121 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
122 } else {
123 self.cast_to_non_extension(target)
124 }
125 }
126
127 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
128 assert!(!matches!(target, DType::Extension(..)));
129
130 if self.is_null() {
131 if target.is_nullable() {
132 return Ok(Scalar::new(target.clone(), self.value.clone()));
133 }
134
135 vortex_bail!("Cannot cast null to {target}: target type is non-nullable")
136 }
137
138 match &self.dtype {
139 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
141 DType::Primitive(..) => self.as_primitive().cast(target),
142 DType::Decimal(..) => self.as_decimal().cast(target),
143 DType::Utf8(_) => self.as_utf8().cast(target),
144 DType::Binary(_) => self.as_binary().cast(target),
145 DType::Struct(..) => self.as_struct().cast(target),
146 DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target),
147 DType::Extension(..) => self.as_extension().cast(target),
148 }
149 }
150
151 pub fn into_nullable(self) -> Self {
153 Self {
154 dtype: self.dtype.as_nullable(),
155 value: self.value,
156 }
157 }
158
159 pub fn nbytes(&self) -> usize {
161 match self.dtype() {
162 DType::Null => 0,
163 DType::Bool(_) => 1,
164 DType::Primitive(ptype, _) => ptype.byte_width(),
165 DType::Decimal(dt, _) => {
166 if dt.precision() <= DECIMAL128_MAX_PRECISION {
167 size_of::<i128>()
168 } else {
169 size_of::<i256>()
170 }
171 }
172 DType::Binary(_) | DType::Utf8(_) => self
173 .value()
174 .as_buffer()
175 .ok()
176 .flatten()
177 .map_or(0, |s| s.len()),
178 DType::Struct(_dtype, _) => self
179 .as_struct()
180 .fields()
181 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
182 .unwrap_or_default(),
183 DType::List(..) | DType::FixedSizeList(..) => self
184 .as_list()
185 .elements()
186 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
187 .unwrap_or_default(),
188 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
189 }
190 }
191
192 pub fn default_value(dtype: DType) -> Self {
214 if dtype.is_nullable() {
215 return Self::null(dtype);
216 }
217
218 match dtype {
219 DType::Null => Self::null(dtype),
220 DType::Bool(nullability) => Self::bool(false, nullability),
221 DType::Primitive(pt, nullability) => {
222 Self::primitive_value(PValue::zero(pt), pt, nullability)
223 }
224 DType::Decimal(dt, nullability) => {
225 Self::decimal(DecimalValue::from(0), dt, nullability)
226 }
227 DType::Utf8(nullability) => Self::utf8("", nullability),
228 DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
229 DType::List(edt, nullability) => Self::list(edt, vec![], nullability),
230 DType::FixedSizeList(edt, size, nullability) => {
231 let elements = (0..size)
232 .map(|_| Scalar::default_value(edt.as_ref().clone()))
233 .collect();
234 Self::list(edt, elements, nullability)
235 }
236 DType::Struct(sf, nullability) => {
237 let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
238 Self::struct_(DType::Struct(sf, nullability), fields)
239 }
240 DType::Extension(dt) => {
241 let scalar = Self::default_value(dt.storage_dtype().clone());
242 Self::extension(dt, scalar)
243 }
244 }
245 }
246}
247
248impl Scalar {
250 pub fn as_bool(&self) -> BoolScalar<'_> {
256 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
257 }
258
259 pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
261 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
262 }
263
264 pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
270 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
271 }
272
273 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
275 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
276 }
277
278 pub fn as_decimal(&self) -> DecimalScalar<'_> {
284 DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
285 }
286
287 pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
289 matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
290 }
291
292 pub fn as_utf8(&self) -> Utf8Scalar<'_> {
298 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
299 }
300
301 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
303 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
304 }
305
306 pub fn as_binary(&self) -> BinaryScalar<'_> {
312 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
313 }
314
315 pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
317 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
318 }
319
320 pub fn as_struct(&self) -> StructScalar<'_> {
326 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
327 }
328
329 pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
331 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
332 }
333
334 pub fn as_list(&self) -> ListScalar<'_> {
343 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
344 }
345
346 pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
351 matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list())
352 }
353
354 pub fn as_extension(&self) -> ExtScalar<'_> {
360 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
361 }
362
363 pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
365 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
366 }
367}
368
369impl<T> From<Option<T>> for Scalar
372where
373 T: ScalarType,
374 Scalar: From<T>,
375{
376 fn from(value: Option<T>) -> Self {
378 value
379 .map(Scalar::from)
380 .map(|x| x.into_nullable())
381 .unwrap_or_else(|| Scalar {
382 dtype: T::dtype().as_nullable(),
383 value: ScalarValue(InnerScalarValue::Null),
384 })
385 }
386}
387
388impl<T> From<Vec<T>> for Scalar
389where
390 T: ScalarType,
391 Scalar: From<T>,
392{
393 fn from(vec: Vec<T>) -> Self {
395 Scalar {
396 dtype: DType::List(Arc::from(T::dtype()), Nullability::NonNullable),
397 value: ScalarValue::from(vec),
398 }
399 }
400}
401
402impl<T> TryFrom<Scalar> for Vec<T>
403where
404 T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
405{
406 type Error = VortexError;
407
408 fn try_from(value: Scalar) -> Result<Self, Self::Error> {
409 Vec::try_from(&value)
410 }
411}
412
413impl<'a, T> TryFrom<&'a Scalar> for Vec<T>
414where
415 T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
416{
417 type Error = VortexError;
418
419 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
420 ListScalar::try_from(value)?
421 .elements()
422 .ok_or_else(|| vortex_err!("Expected non-null list"))?
423 .into_iter()
424 .map(|e| T::try_from(&e))
425 .collect::<VortexResult<Vec<T>>>()
426 }
427}
428
429impl PartialEq for Scalar {
430 fn eq(&self, other: &Self) -> bool {
431 if !self.dtype.eq_ignore_nullability(&other.dtype) {
432 return false;
433 }
434
435 match self.dtype() {
436 DType::Null => true,
437 DType::Bool(_) => self.as_bool() == other.as_bool(),
438 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
439 DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
440 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
441 DType::Binary(_) => self.as_binary() == other.as_binary(),
442 DType::Struct(..) => self.as_struct() == other.as_struct(),
443 DType::List(..) | DType::FixedSizeList(..) => self.as_list() == other.as_list(),
444 DType::Extension(_) => self.as_extension() == other.as_extension(),
445 }
446 }
447}
448
449impl Eq for Scalar {}
450
451impl PartialOrd for Scalar {
452 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
481 if !self.dtype().eq_ignore_nullability(other.dtype()) {
482 return None;
483 }
484 match self.dtype() {
485 DType::Null => Some(Ordering::Equal),
486 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
487 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
488 DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
489 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
490 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
491 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
492 DType::List(..) | DType::FixedSizeList(..) => {
493 self.as_list().partial_cmp(&other.as_list())
494 }
495 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
496 }
497 }
498}
499
500impl Hash for Scalar {
501 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
502 match self.dtype() {
503 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
505 DType::Primitive(..) => self.as_primitive().hash(state),
506 DType::Decimal(..) => self.as_decimal().hash(state),
507 DType::Utf8(_) => self.as_utf8().hash(state),
508 DType::Binary(_) => self.as_binary().hash(state),
509 DType::Struct(..) => self.as_struct().hash(state),
510 DType::List(..) | DType::FixedSizeList(..) => self.as_list().hash(state),
511 DType::Extension(_) => self.as_extension().hash(state),
512 }
513 }
514}
515
516impl AsRef<Self> for Scalar {
517 fn as_ref(&self) -> &Self {
518 self
519 }
520}
521
522impl From<PrimitiveScalar<'_>> for Scalar {
523 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
524 let dtype = pscalar.dtype().clone();
525 let value = pscalar
526 .pvalue()
527 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
528 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
529 Self::new(dtype, value)
530 }
531}
532
533impl From<DecimalScalar<'_>> for Scalar {
534 fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
535 let dtype = decimal_scalar.dtype().clone();
536 let value = decimal_scalar
537 .decimal_value()
538 .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
539 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
540 Self::new(dtype, value)
541 }
542}