vortex_scalar/
scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::Buffer;
9use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
11
12use super::*;
13
14/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
15///
16/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
17/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
18///
19/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
20/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
21/// natural ordering of the scalar value.
22#[derive(Debug, Clone)]
23pub struct Scalar {
24    /// The type of the scalar.
25    dtype: DType,
26
27    /// The value of the scalar.
28    ///
29    /// Invariant: If the `dtype` is non-nullable, then this value _cannot_ be equal to
30    /// [`ScalarValue::null()`](ScalarValue::null).
31    value: ScalarValue,
32}
33
34impl Scalar {
35    /// Creates a new scalar with the given data type and value.
36    pub fn new(dtype: DType, value: ScalarValue) -> Self {
37        if !dtype.is_nullable() {
38            assert!(
39                !value.is_null(),
40                "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}",
41            );
42        }
43
44        Self { dtype, value }
45    }
46
47    /// Returns a reference to the scalar's data type.
48    #[inline]
49    pub fn dtype(&self) -> &DType {
50        &self.dtype
51    }
52
53    /// Returns a reference to the scalar's underlying value.
54    #[inline]
55    pub fn value(&self) -> &ScalarValue {
56        &self.value
57    }
58
59    /// Consumes the scalar and returns its data type and value as a tuple.
60    #[inline]
61    pub fn into_parts(self) -> (DType, ScalarValue) {
62        (self.dtype, self.value)
63    }
64
65    /// Consumes the scalar and returns its underlying [`DType`].
66    #[inline]
67    pub fn into_dtype(self) -> DType {
68        self.dtype
69    }
70
71    /// Consumes the scalar and returns its underlying [`ScalarValue`].
72    #[inline]
73    pub fn into_value(self) -> ScalarValue {
74        self.value
75    }
76
77    /// Returns true if the scalar is not null.
78    pub fn is_valid(&self) -> bool {
79        !self.value.is_null()
80    }
81
82    /// Returns true if the scalar is null.
83    pub fn is_null(&self) -> bool {
84        self.value.is_null()
85    }
86
87    /// Creates a null scalar with the given nullable data type.
88    ///
89    /// # Panics
90    ///
91    /// Panics if the data type is not nullable.
92    pub fn null(dtype: DType) -> Self {
93        assert!(
94            dtype.is_nullable(),
95            "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}"
96        );
97
98        Self {
99            dtype,
100            value: ScalarValue(InnerScalarValue::Null),
101        }
102    }
103
104    /// Creates a null scalar for the given scalar type.
105    ///
106    /// The resulting scalar will have a nullable version of the type's data type.
107    pub fn null_typed<T: ScalarType>() -> Self {
108        Self {
109            dtype: T::dtype().as_nullable(),
110            value: ScalarValue(InnerScalarValue::Null),
111        }
112    }
113
114    /// Casts the scalar to the target data type.
115    ///
116    /// Returns an error if the cast is not supported or if the value cannot be represented
117    /// in the target type.
118    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
119        if let DType::Extension(ext_dtype) = target {
120            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
121            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
122        } else {
123            self.cast_to_non_extension(target)
124        }
125    }
126
127    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
128        assert!(!matches!(target, DType::Extension(..)));
129
130        if self.is_null() {
131            if target.is_nullable() {
132                return Ok(Scalar::new(target.clone(), self.value.clone()));
133            }
134
135            vortex_bail!("Cannot cast null to {target}: target type is non-nullable")
136        }
137
138        match &self.dtype {
139            DType::Null => unreachable!(), // Handled by `if self.is_null()` case.
140            DType::Bool(_) => self.as_bool().cast(target),
141            DType::Primitive(..) => self.as_primitive().cast(target),
142            DType::Decimal(..) => self.as_decimal().cast(target),
143            DType::Utf8(_) => self.as_utf8().cast(target),
144            DType::Binary(_) => self.as_binary().cast(target),
145            DType::Struct(..) => self.as_struct().cast(target),
146            DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target),
147            DType::Extension(..) => self.as_extension().cast(target),
148        }
149    }
150
151    /// Converts the scalar to have a nullable version of its data type.
152    pub fn into_nullable(self) -> Self {
153        Self {
154            dtype: self.dtype.as_nullable(),
155            value: self.value,
156        }
157    }
158
159    /// Returns the size of the scalar in bytes, uncompressed.
160    pub fn nbytes(&self) -> usize {
161        match self.dtype() {
162            DType::Null => 0,
163            DType::Bool(_) => 1,
164            DType::Primitive(ptype, _) => ptype.byte_width(),
165            DType::Decimal(dt, _) => {
166                if dt.precision() <= DECIMAL128_MAX_PRECISION {
167                    size_of::<i128>()
168                } else {
169                    size_of::<i256>()
170                }
171            }
172            DType::Binary(_) | DType::Utf8(_) => self
173                .value()
174                .as_buffer()
175                .ok()
176                .flatten()
177                .map_or(0, |s| s.len()),
178            DType::Struct(_dtype, _) => self
179                .as_struct()
180                .fields()
181                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
182                .unwrap_or_default(),
183            DType::List(..) | DType::FixedSizeList(..) => self
184                .as_list()
185                .elements()
186                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
187                .unwrap_or_default(),
188            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
189        }
190    }
191
192    /// Creates a "default" scalar value for the given data type.
193    ///
194    /// For nullable types, returns null. For non-nullable types, returns an appropriate zero/empty
195    /// value.
196    ///
197    /// # Default Values
198    ///
199    /// Here is the list of default values for each [`DType`] (when the [`DType`] is non-nullable):
200    ///
201    /// - `Null`: `null`
202    /// - `Bool`: `false`
203    /// - `Primitive`: `0`
204    /// - `Decimal`: `0`
205    /// - `Utf8`: `""`
206    /// - `Binary`: An empty buffer
207    /// - `List`: An empty list
208    /// - `FixedSizeList`: A list (with correct size) of default values, which is determined by the
209    ///   element [`DType`]
210    /// - `Struct`: A struct where each field has a default value, which is determined by the field
211    ///   [`DType`]
212    /// - `Extension`: The default value of the storage [`DType`]
213    pub fn default_value(dtype: DType) -> Self {
214        if dtype.is_nullable() {
215            return Self::null(dtype);
216        }
217
218        match dtype {
219            DType::Null => Self::null(dtype),
220            DType::Bool(nullability) => Self::bool(false, nullability),
221            DType::Primitive(pt, nullability) => {
222                Self::primitive_value(PValue::zero(pt), pt, nullability)
223            }
224            DType::Decimal(dt, nullability) => {
225                Self::decimal(DecimalValue::from(0), dt, nullability)
226            }
227            DType::Utf8(nullability) => Self::utf8("", nullability),
228            DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
229            DType::List(edt, nullability) => Self::list(edt, vec![], nullability),
230            DType::FixedSizeList(edt, size, nullability) => {
231                let elements = (0..size)
232                    .map(|_| Scalar::default_value(edt.as_ref().clone()))
233                    .collect();
234                Self::list(edt, elements, nullability)
235            }
236            DType::Struct(sf, nullability) => {
237                let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
238                Self::struct_(DType::Struct(sf, nullability), fields)
239            }
240            DType::Extension(dt) => {
241                let scalar = Self::default_value(dt.storage_dtype().clone());
242                Self::extension(dt, scalar)
243            }
244        }
245    }
246}
247
248/// This implementation block contains only `TryFrom` and `From` wrappers (`as_something`).
249impl Scalar {
250    /// Returns a view of the scalar as a boolean scalar.
251    ///
252    /// # Panics
253    ///
254    /// Panics if the scalar is not a boolean type.
255    pub fn as_bool(&self) -> BoolScalar<'_> {
256        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
257    }
258
259    /// Returns a view of the scalar as a boolean scalar if it has a boolean type.
260    pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
261        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
262    }
263
264    /// Returns a view of the scalar as a primitive scalar.
265    ///
266    /// # Panics
267    ///
268    /// Panics if the scalar is not a primitive type.
269    pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
270        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
271    }
272
273    /// Returns a view of the scalar as a primitive scalar if it has a primitive type.
274    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
275        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
276    }
277
278    /// Returns a view of the scalar as a decimal scalar.
279    ///
280    /// # Panics
281    ///
282    /// Panics if the scalar is not a decimal type.
283    pub fn as_decimal(&self) -> DecimalScalar<'_> {
284        DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
285    }
286
287    /// Returns a view of the scalar as a decimal scalar if it has a decimal type.
288    pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
289        matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
290    }
291
292    /// Returns a view of the scalar as a UTF-8 string scalar.
293    ///
294    /// # Panics
295    ///
296    /// Panics if the scalar is not a UTF-8 type.
297    pub fn as_utf8(&self) -> Utf8Scalar<'_> {
298        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
299    }
300
301    /// Returns a view of the scalar as a UTF-8 string scalar if it has a UTF-8 type.
302    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
303        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
304    }
305
306    /// Returns a view of the scalar as a binary scalar.
307    ///
308    /// # Panics
309    ///
310    /// Panics if the scalar is not a binary type.
311    pub fn as_binary(&self) -> BinaryScalar<'_> {
312        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
313    }
314
315    /// Returns a view of the scalar as a binary scalar if it has a binary type.
316    pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
317        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
318    }
319
320    /// Returns a view of the scalar as a struct scalar.
321    ///
322    /// # Panics
323    ///
324    /// Panics if the scalar is not a struct type.
325    pub fn as_struct(&self) -> StructScalar<'_> {
326        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
327    }
328
329    /// Returns a view of the scalar as a struct scalar if it has a struct type.
330    pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
331        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
332    }
333
334    /// Returns a view of the scalar as a list scalar.
335    ///
336    /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and
337    /// [`DType::FixedSizeList`].
338    ///
339    /// # Panics
340    ///
341    /// Panics if the scalar is not a list type.
342    pub fn as_list(&self) -> ListScalar<'_> {
343        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
344    }
345
346    /// Returns a view of the scalar as a list scalar if it has a list type.
347    ///
348    /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and
349    /// [`DType::FixedSizeList`].
350    pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
351        matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list())
352    }
353
354    /// Returns a view of the scalar as an extension scalar.
355    ///
356    /// # Panics
357    ///
358    /// Panics if the scalar is not an extension type.
359    pub fn as_extension(&self) -> ExtScalar<'_> {
360        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
361    }
362
363    /// Returns a view of the scalar as an extension scalar if it has an extension type.
364    pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
365        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
366    }
367}
368
369/// It is common to represent a nullable type `T` as an `Option<T>`, so we implement a blanket
370/// implementation for all `Option<T>` to simply be a nullable `T`.
371impl<T> From<Option<T>> for Scalar
372where
373    T: ScalarType,
374    Scalar: From<T>,
375{
376    /// A blanket implementation for all `Option<T>`.
377    fn from(value: Option<T>) -> Self {
378        value
379            .map(Scalar::from)
380            .map(|x| x.into_nullable())
381            .unwrap_or_else(|| Scalar {
382                dtype: T::dtype().as_nullable(),
383                value: ScalarValue(InnerScalarValue::Null),
384            })
385    }
386}
387
388impl<T> From<Vec<T>> for Scalar
389where
390    T: ScalarType,
391    Scalar: From<T>,
392{
393    /// Converts a vector into a `Scalar` (where the value is a `ListScalar`).
394    fn from(vec: Vec<T>) -> Self {
395        Scalar {
396            dtype: DType::List(Arc::from(T::dtype()), Nullability::NonNullable),
397            value: ScalarValue::from(vec),
398        }
399    }
400}
401
402impl<T> TryFrom<Scalar> for Vec<T>
403where
404    T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
405{
406    type Error = VortexError;
407
408    fn try_from(value: Scalar) -> Result<Self, Self::Error> {
409        Vec::try_from(&value)
410    }
411}
412
413impl<'a, T> TryFrom<&'a Scalar> for Vec<T>
414where
415    T: for<'b> TryFrom<&'b Scalar, Error = VortexError>,
416{
417    type Error = VortexError;
418
419    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
420        ListScalar::try_from(value)?
421            .elements()
422            .ok_or_else(|| vortex_err!("Expected non-null list"))?
423            .into_iter()
424            .map(|e| T::try_from(&e))
425            .collect::<VortexResult<Vec<T>>>()
426    }
427}
428
429impl PartialEq for Scalar {
430    fn eq(&self, other: &Self) -> bool {
431        if !self.dtype.eq_ignore_nullability(&other.dtype) {
432            return false;
433        }
434
435        match self.dtype() {
436            DType::Null => true,
437            DType::Bool(_) => self.as_bool() == other.as_bool(),
438            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
439            DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
440            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
441            DType::Binary(_) => self.as_binary() == other.as_binary(),
442            DType::Struct(..) => self.as_struct() == other.as_struct(),
443            DType::List(..) | DType::FixedSizeList(..) => self.as_list() == other.as_list(),
444            DType::Extension(_) => self.as_extension() == other.as_extension(),
445        }
446    }
447}
448
449impl Eq for Scalar {}
450
451impl PartialOrd for Scalar {
452    /// Compares two scalar values for ordering.
453    ///
454    /// # Returns
455    /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
456    /// - `None` if the scalars have different data types
457    ///
458    /// # Ordering Rules
459    /// When types match, the ordering follows these rules:
460    /// - Null values are considered less than all non-null values
461    /// - Non-null values are compared according to their natural ordering
462    ///
463    /// # Examples
464    /// ```ignore
465    /// // Same types compare successfully
466    /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
467    /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
468    /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
469    ///
470    /// // Different types return None
471    /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
472    /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
473    /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
474    ///
475    /// // Nulls are less than non-nulls
476    /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
477    /// let value = Scalar::primitive(0i32, Nullability::Nullable);
478    /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
479    /// ```
480    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
481        if !self.dtype().eq_ignore_nullability(other.dtype()) {
482            return None;
483        }
484        match self.dtype() {
485            DType::Null => Some(Ordering::Equal),
486            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
487            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
488            DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
489            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
490            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
491            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
492            DType::List(..) | DType::FixedSizeList(..) => {
493                self.as_list().partial_cmp(&other.as_list())
494            }
495            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
496        }
497    }
498}
499
500impl Hash for Scalar {
501    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
502        match self.dtype() {
503            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
504            DType::Bool(_) => self.as_bool().hash(state),
505            DType::Primitive(..) => self.as_primitive().hash(state),
506            DType::Decimal(..) => self.as_decimal().hash(state),
507            DType::Utf8(_) => self.as_utf8().hash(state),
508            DType::Binary(_) => self.as_binary().hash(state),
509            DType::Struct(..) => self.as_struct().hash(state),
510            DType::List(..) | DType::FixedSizeList(..) => self.as_list().hash(state),
511            DType::Extension(_) => self.as_extension().hash(state),
512        }
513    }
514}
515
516impl AsRef<Self> for Scalar {
517    fn as_ref(&self) -> &Self {
518        self
519    }
520}
521
522impl From<PrimitiveScalar<'_>> for Scalar {
523    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
524        let dtype = pscalar.dtype().clone();
525        let value = pscalar
526            .pvalue()
527            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
528            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
529        Self::new(dtype, value)
530    }
531}
532
533impl From<DecimalScalar<'_>> for Scalar {
534    fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
535        let dtype = decimal_scalar.dtype().clone();
536        let value = decimal_scalar
537            .decimal_value()
538            .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
539            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
540        Self::new(dtype, value)
541    }
542}