vortex_scalar/
scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::Buffer;
9use vortex_dtype::{DType, NativeDType, NativeDecimalType, Nullability, i256};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
11
12use super::*;
13
14/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
15///
16/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
17/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
18///
19/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
20/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
21/// natural ordering of the scalar value.
22#[derive(Debug, Clone)]
23pub struct Scalar {
24    /// The type of the scalar.
25    dtype: DType,
26
27    /// The value of the scalar.
28    ///
29    /// Invariant: If the `dtype` is non-nullable, then this value _cannot_ be equal to
30    /// [`ScalarValue::null()`](ScalarValue::null).
31    value: ScalarValue,
32}
33
34impl Scalar {
35    /// Creates a new scalar with the given data type and value.
36    pub fn new(dtype: DType, value: ScalarValue) -> Self {
37        if !dtype.is_nullable() {
38            assert!(
39                !value.is_null(),
40                "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}",
41            );
42        }
43
44        Self { dtype, value }
45    }
46
47    /// Returns a reference to the scalar's data type.
48    #[inline]
49    pub fn dtype(&self) -> &DType {
50        &self.dtype
51    }
52
53    /// Returns a reference to the scalar's underlying value.
54    #[inline]
55    pub fn value(&self) -> &ScalarValue {
56        &self.value
57    }
58
59    /// Consumes the scalar and returns its data type and value as a tuple.
60    #[inline]
61    pub fn into_parts(self) -> (DType, ScalarValue) {
62        (self.dtype, self.value)
63    }
64
65    /// Consumes the scalar and returns its underlying [`DType`].
66    #[inline]
67    pub fn into_dtype(self) -> DType {
68        self.dtype
69    }
70
71    /// Consumes the scalar and returns its underlying [`ScalarValue`].
72    #[inline]
73    pub fn into_value(self) -> ScalarValue {
74        self.value
75    }
76
77    /// Returns true if the scalar is not null.
78    pub fn is_valid(&self) -> bool {
79        !self.value.is_null()
80    }
81
82    /// Returns true if the scalar is null.
83    pub fn is_null(&self) -> bool {
84        self.value.is_null()
85    }
86
87    /// Creates a null scalar with the given nullable data type.
88    ///
89    /// # Panics
90    ///
91    /// Panics if the data type is not nullable.
92    pub fn null(dtype: DType) -> Self {
93        assert!(
94            dtype.is_nullable(),
95            "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}"
96        );
97
98        Self {
99            dtype,
100            value: ScalarValue(InnerScalarValue::Null),
101        }
102    }
103
104    /// Creates a null scalar for the given scalar type.
105    ///
106    /// The resulting scalar will have a nullable version of the type's data type.
107    pub fn null_typed<T: NativeDType>() -> Self {
108        Self {
109            dtype: T::dtype().as_nullable(),
110            value: ScalarValue(InnerScalarValue::Null),
111        }
112    }
113
114    /// Casts the scalar to the target data type.
115    ///
116    /// Returns an error if the cast is not supported or if the value cannot be represented
117    /// in the target type.
118    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
119        if let DType::Extension(ext_dtype) = target {
120            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
121            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
122        } else {
123            self.cast_to_non_extension(target)
124        }
125    }
126
127    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
128        assert!(!matches!(target, DType::Extension(..)));
129
130        if self.is_null() {
131            if target.is_nullable() {
132                return Ok(Scalar::new(target.clone(), self.value.clone()));
133            }
134
135            vortex_bail!("Cannot cast null to {target}: target type is non-nullable")
136        }
137
138        match &self.dtype {
139            DType::Null => unreachable!(), // Handled by `if self.is_null()` case.
140            DType::Bool(_) => self.as_bool().cast(target),
141            DType::Primitive(..) => self.as_primitive().cast(target),
142            DType::Decimal(..) => self.as_decimal().cast(target),
143            DType::Utf8(_) => self.as_utf8().cast(target),
144            DType::Binary(_) => self.as_binary().cast(target),
145            DType::Struct(..) => self.as_struct().cast(target),
146            DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target),
147            DType::Extension(..) => self.as_extension().cast(target),
148        }
149    }
150
151    /// Converts the scalar to have a nullable version of its data type.
152    pub fn into_nullable(self) -> Self {
153        Self {
154            dtype: self.dtype.as_nullable(),
155            value: self.value,
156        }
157    }
158
159    /// Returns the size of the scalar in bytes, uncompressed.
160    pub fn nbytes(&self) -> usize {
161        match self.dtype() {
162            DType::Null => 0,
163            DType::Bool(_) => 1,
164            DType::Primitive(ptype, _) => ptype.byte_width(),
165            DType::Decimal(dt, _) => {
166                if dt.precision() <= i128::MAX_PRECISION {
167                    size_of::<i128>()
168                } else {
169                    size_of::<i256>()
170                }
171            }
172            DType::Binary(_) | DType::Utf8(_) => self
173                .value()
174                .as_buffer()
175                .ok()
176                .flatten()
177                .map_or(0, |s| s.len()),
178            DType::Struct(_dtype, _) => self
179                .as_struct()
180                .fields()
181                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
182                .unwrap_or_default(),
183            DType::List(..) | DType::FixedSizeList(..) => self
184                .as_list()
185                .elements()
186                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
187                .unwrap_or_default(),
188            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
189        }
190    }
191
192    /// Creates a "zero"-value scalar value for the given data type.
193    ///
194    /// For nullable types the zero value is the underlying `DType`'s zero value.
195    ///
196    /// # Zero Values
197    ///
198    /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable):
199    /// - `Bool`: `false`
200    /// - `Primitive`: `0`
201    /// - `Decimal`: `0`
202    /// - `Utf8`: `""`
203    /// - `Binary`: An empty buffer
204    /// - `List`: An empty list
205    /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the
206    ///   element [`DType`]
207    /// - `Struct`: A struct where each field has a zero value, which is determined by the field
208    ///   [`DType`]
209    /// - `Extension`: The zero value of the storage [`DType`]
210    ///
211    /// This is similar to `default_value` except in its handling of nullability.
212    pub fn zero_value(dtype: DType) -> Self {
213        match dtype {
214            DType::Null => Self::null(dtype),
215            DType::Bool(nullability) => Self::bool(false, nullability),
216            DType::Primitive(pt, nullability) => {
217                Self::primitive_value(PValue::zero(pt), pt, nullability)
218            }
219            DType::Decimal(dt, nullability) => {
220                Self::decimal(DecimalValue::from(0i8), dt, nullability)
221            }
222            DType::Utf8(nullability) => Self::utf8("", nullability),
223            DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
224            DType::List(edt, nullability) => Self::list(edt, vec![], nullability),
225            DType::FixedSizeList(edt, size, nullability) => {
226                let elements = (0..size)
227                    .map(|_| Scalar::zero_value(edt.as_ref().clone()))
228                    .collect();
229                Self::fixed_size_list(edt, elements, nullability)
230            }
231            DType::Struct(sf, nullability) => {
232                let fields: Vec<_> = sf.fields().map(Scalar::zero_value).collect();
233                Self::struct_(DType::Struct(sf, nullability), fields)
234            }
235            DType::Extension(dt) => {
236                let scalar = Self::zero_value(dt.storage_dtype().clone());
237                Self::extension(dt, scalar)
238            }
239        }
240    }
241
242    /// Creates a "default" scalar value for the given data type.
243    ///
244    /// For nullable types, returns null. For non-nullable types, returns an appropriate zero/empty
245    /// value.
246    ///
247    /// # Default Values
248    ///
249    /// Here is the list of default values for each [`DType`] (when the [`DType`] is non-nullable):
250    ///
251    /// - `Null`: `null`
252    /// - `Bool`: `false`
253    /// - `Primitive`: `0`
254    /// - `Decimal`: `0`
255    /// - `Utf8`: `""`
256    /// - `Binary`: An empty buffer
257    /// - `List`: An empty list
258    /// - `FixedSizeList`: A list (with correct size) of default values, which is determined by the
259    ///   element [`DType`]
260    /// - `Struct`: A struct where each field has a default value, which is determined by the field
261    ///   [`DType`]
262    /// - `Extension`: The default value of the storage [`DType`]
263    pub fn default_value(dtype: DType) -> Self {
264        if dtype.is_nullable() {
265            return Self::null(dtype);
266        }
267
268        match dtype {
269            DType::Null => Self::null(dtype),
270            DType::Bool(nullability) => Self::bool(false, nullability),
271            DType::Primitive(pt, nullability) => {
272                Self::primitive_value(PValue::zero(pt), pt, nullability)
273            }
274            DType::Decimal(dt, nullability) => {
275                Self::decimal(DecimalValue::from(0i8), dt, nullability)
276            }
277            DType::Utf8(nullability) => Self::utf8("", nullability),
278            DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
279            DType::List(edt, nullability) => Self::list(edt, vec![], nullability),
280            DType::FixedSizeList(edt, size, nullability) => {
281                let elements = (0..size)
282                    .map(|_| Scalar::default_value(edt.as_ref().clone()))
283                    .collect();
284                Self::fixed_size_list(edt, elements, nullability)
285            }
286            DType::Struct(sf, nullability) => {
287                let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
288                Self::struct_(DType::Struct(sf, nullability), fields)
289            }
290            DType::Extension(dt) => {
291                let scalar = Self::default_value(dt.storage_dtype().clone());
292                Self::extension(dt, scalar)
293            }
294        }
295    }
296}
297
298/// This implementation block contains only `TryFrom` and `From` wrappers (`as_something`).
299impl Scalar {
300    /// Returns a view of the scalar as a boolean scalar.
301    ///
302    /// # Panics
303    ///
304    /// Panics if the scalar is not a boolean type.
305    pub fn as_bool(&self) -> BoolScalar<'_> {
306        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
307    }
308
309    /// Returns a view of the scalar as a boolean scalar if it has a boolean type.
310    pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
311        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
312    }
313
314    /// Returns a view of the scalar as a primitive scalar.
315    ///
316    /// # Panics
317    ///
318    /// Panics if the scalar is not a primitive type.
319    pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
320        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
321    }
322
323    /// Returns a view of the scalar as a primitive scalar if it has a primitive type.
324    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
325        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
326    }
327
328    /// Returns a view of the scalar as a decimal scalar.
329    ///
330    /// # Panics
331    ///
332    /// Panics if the scalar is not a decimal type.
333    pub fn as_decimal(&self) -> DecimalScalar<'_> {
334        DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
335    }
336
337    /// Returns a view of the scalar as a decimal scalar if it has a decimal type.
338    pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
339        matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
340    }
341
342    /// Returns a view of the scalar as a UTF-8 string scalar.
343    ///
344    /// # Panics
345    ///
346    /// Panics if the scalar is not a UTF-8 type.
347    pub fn as_utf8(&self) -> Utf8Scalar<'_> {
348        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
349    }
350
351    /// Returns a view of the scalar as a UTF-8 string scalar if it has a UTF-8 type.
352    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
353        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
354    }
355
356    /// Returns a view of the scalar as a binary scalar.
357    ///
358    /// # Panics
359    ///
360    /// Panics if the scalar is not a binary type.
361    pub fn as_binary(&self) -> BinaryScalar<'_> {
362        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
363    }
364
365    /// Returns a view of the scalar as a binary scalar if it has a binary type.
366    pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
367        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
368    }
369
370    /// Returns a view of the scalar as a struct scalar.
371    ///
372    /// # Panics
373    ///
374    /// Panics if the scalar is not a struct type.
375    pub fn as_struct(&self) -> StructScalar<'_> {
376        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
377    }
378
379    /// Returns a view of the scalar as a struct scalar if it has a struct type.
380    pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
381        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
382    }
383
384    /// Returns a view of the scalar as a list scalar.
385    ///
386    /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and
387    /// [`DType::FixedSizeList`].
388    ///
389    /// # Panics
390    ///
391    /// Panics if the scalar is not a list type.
392    pub fn as_list(&self) -> ListScalar<'_> {
393        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
394    }
395
396    /// Returns a view of the scalar as a list scalar if it has a list type.
397    ///
398    /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and
399    /// [`DType::FixedSizeList`].
400    pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
401        matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list())
402    }
403
404    /// Returns a view of the scalar as an extension scalar.
405    ///
406    /// # Panics
407    ///
408    /// Panics if the scalar is not an extension type.
409    pub fn as_extension(&self) -> ExtScalar<'_> {
410        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
411    }
412
413    /// Returns a view of the scalar as an extension scalar if it has an extension type.
414    pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
415        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
416    }
417}
418
419/// It is common to represent a nullable type `T` as an `Option<T>`, so we implement a blanket
420/// implementation for all `Option<T>` to simply be a nullable `T`.
421impl<T> From<Option<T>> for Scalar
422where
423    T: NativeDType,
424    Scalar: From<T>,
425{
426    /// A blanket implementation for all `Option<T>`.
427    fn from(value: Option<T>) -> Self {
428        value
429            .map(Scalar::from)
430            .map(|x| x.into_nullable())
431            .unwrap_or_else(|| Scalar {
432                dtype: T::dtype().as_nullable(),
433                value: ScalarValue(InnerScalarValue::Null),
434            })
435    }
436}
437
438impl<T> From<Vec<T>> for Scalar
439where
440    T: NativeDType,
441    Scalar: From<T>,
442{
443    /// Converts a vector into a `Scalar` (where the value is a `ListScalar`).
444    fn from(vec: Vec<T>) -> Self {
445        Scalar {
446            dtype: DType::List(Arc::from(T::dtype()), Nullability::NonNullable),
447            value: ScalarValue::from(vec),
448        }
449    }
450}
451
452impl<T> TryFrom<Scalar> for Vec<T>
453where
454    T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
455{
456    type Error = VortexError;
457
458    fn try_from(value: Scalar) -> Result<Self, Self::Error> {
459        Vec::try_from(&value)
460    }
461}
462
463impl<'a, T> TryFrom<&'a Scalar> for Vec<T>
464where
465    T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
466{
467    type Error = VortexError;
468
469    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
470        ListScalar::try_from(value)?
471            .elements()
472            .ok_or_else(|| vortex_err!("Expected non-null list"))?
473            .into_iter()
474            .map(|e| T::try_from(&e))
475            .collect::<VortexResult<Vec<T>>>()
476    }
477}
478
479impl PartialEq for Scalar {
480    fn eq(&self, other: &Self) -> bool {
481        if !self.dtype.eq_ignore_nullability(&other.dtype) {
482            return false;
483        }
484
485        match self.dtype() {
486            DType::Null => true,
487            DType::Bool(_) => self.as_bool() == other.as_bool(),
488            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
489            DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
490            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
491            DType::Binary(_) => self.as_binary() == other.as_binary(),
492            DType::Struct(..) => self.as_struct() == other.as_struct(),
493            DType::List(..) | DType::FixedSizeList(..) => self.as_list() == other.as_list(),
494            DType::Extension(_) => self.as_extension() == other.as_extension(),
495        }
496    }
497}
498
499impl Eq for Scalar {}
500
501impl PartialOrd for Scalar {
502    /// Compares two scalar values for ordering.
503    ///
504    /// # Returns
505    /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
506    /// - `None` if the scalars have different data types
507    ///
508    /// # Ordering Rules
509    /// When types match, the ordering follows these rules:
510    /// - Null values are considered less than all non-null values
511    /// - Non-null values are compared according to their natural ordering
512    ///
513    /// # Examples
514    /// ```ignore
515    /// // Same types compare successfully
516    /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
517    /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
518    /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
519    ///
520    /// // Different types return None
521    /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
522    /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
523    /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
524    ///
525    /// // Nulls are less than non-nulls
526    /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
527    /// let value = Scalar::primitive(0i32, Nullability::Nullable);
528    /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
529    /// ```
530    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
531        if !self.dtype().eq_ignore_nullability(other.dtype()) {
532            return None;
533        }
534        match self.dtype() {
535            DType::Null => Some(Ordering::Equal),
536            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
537            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
538            DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
539            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
540            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
541            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
542            DType::List(..) | DType::FixedSizeList(..) => {
543                self.as_list().partial_cmp(&other.as_list())
544            }
545            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
546        }
547    }
548}
549
550impl Hash for Scalar {
551    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
552        match self.dtype() {
553            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
554            DType::Bool(_) => self.as_bool().hash(state),
555            DType::Primitive(..) => self.as_primitive().hash(state),
556            DType::Decimal(..) => self.as_decimal().hash(state),
557            DType::Utf8(_) => self.as_utf8().hash(state),
558            DType::Binary(_) => self.as_binary().hash(state),
559            DType::Struct(..) => self.as_struct().hash(state),
560            DType::List(..) | DType::FixedSizeList(..) => self.as_list().hash(state),
561            DType::Extension(_) => self.as_extension().hash(state),
562        }
563    }
564}
565
566impl AsRef<Self> for Scalar {
567    fn as_ref(&self) -> &Self {
568        self
569    }
570}
571
572impl From<PrimitiveScalar<'_>> for Scalar {
573    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
574        let dtype = pscalar.dtype().clone();
575        let value = pscalar
576            .pvalue()
577            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
578            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
579        Self::new(dtype, value)
580    }
581}
582
583impl From<DecimalScalar<'_>> for Scalar {
584    fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
585        let dtype = decimal_scalar.dtype().clone();
586        let value = decimal_scalar
587            .decimal_value()
588            .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
589            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
590        Self::new(dtype, value)
591    }
592}