vortex_scalar/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Scalar values and types for the Vortex system.
5//!
6//! This crate provides scalar types and values that can be used to represent individual
7//! data elements in the Vortex array system. Scalars are composed of a logical data type
8//! ([`DType`]) and a value ([`ScalarValue`]).
9
10#![deny(missing_docs)]
11
12use std::cmp::Ordering;
13use std::hash::Hash;
14use std::sync::Arc;
15
16pub use scalar_type::ScalarType;
17use vortex_buffer::{Buffer, BufferString, ByteBuffer};
18use vortex_dtype::half::f16;
19use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
20#[cfg(feature = "arbitrary")]
21pub mod arbitrary;
22mod arrow;
23mod bigint;
24mod binary;
25mod bool;
26mod decimal;
27mod display;
28mod extension;
29mod list;
30mod null;
31mod primitive;
32mod proto;
33mod pvalue;
34mod scalar_type;
35mod scalar_value;
36mod struct_;
37#[cfg(test)]
38mod tests;
39mod utf8;
40
41pub use bigint::*;
42pub use binary::*;
43pub use bool::*;
44pub use decimal::*;
45pub use extension::*;
46pub use list::*;
47pub use primitive::*;
48pub use pvalue::*;
49pub use scalar_value::*;
50pub use struct_::*;
51pub use utf8::*;
52use vortex_error::{VortexExpect, VortexResult, vortex_bail};
53
54/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
55///
56/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
57/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
58///
59/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
60/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
61/// natural ordering of the scalar value.
62#[derive(Debug, Clone)]
63pub struct Scalar {
64    dtype: DType,
65    value: ScalarValue,
66}
67
68impl Scalar {
69    /// Creates a new scalar with the given data type and value.
70    pub fn new(dtype: DType, value: ScalarValue) -> Self {
71        Self { dtype, value }
72    }
73
74    /// Returns a reference to the scalar's data type.
75    #[inline]
76    pub fn dtype(&self) -> &DType {
77        &self.dtype
78    }
79
80    /// Returns a reference to the scalar's underlying value.
81    #[inline]
82    pub fn value(&self) -> &ScalarValue {
83        &self.value
84    }
85
86    /// Consumes the scalar and returns its data type and value as a tuple.
87    #[inline]
88    pub fn into_parts(self) -> (DType, ScalarValue) {
89        (self.dtype, self.value)
90    }
91
92    /// Consumes the scalar and returns its underlying value.
93    #[inline]
94    pub fn into_value(self) -> ScalarValue {
95        self.value
96    }
97
98    /// Returns true if the scalar is not null.
99    pub fn is_valid(&self) -> bool {
100        !self.value.is_null()
101    }
102
103    /// Returns true if the scalar is null.
104    pub fn is_null(&self) -> bool {
105        self.value.is_null()
106    }
107
108    /// Creates a null scalar with the given nullable data type.
109    ///
110    /// # Panics
111    ///
112    /// Panics if the data type is not nullable.
113    pub fn null(dtype: DType) -> Self {
114        assert!(
115            dtype.is_nullable(),
116            "Creating null scalar for non-nullable DType {dtype}"
117        );
118        Self {
119            dtype,
120            value: ScalarValue(InnerScalarValue::Null),
121        }
122    }
123
124    /// Creates a null scalar for the given scalar type.
125    ///
126    /// The resulting scalar will have a nullable version of the type's data type.
127    pub fn null_typed<T: ScalarType>() -> Self {
128        Self {
129            dtype: T::dtype().as_nullable(),
130            value: ScalarValue(InnerScalarValue::Null),
131        }
132    }
133
134    /// Casts the scalar to the target data type.
135    ///
136    /// Returns an error if the cast is not supported or if the value cannot be represented
137    /// in the target type.
138    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
139        if let DType::Extension(ext_dtype) = target {
140            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
141            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
142        } else {
143            self.cast_to_non_extension(target)
144        }
145    }
146
147    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
148        assert!(!matches!(target, DType::Extension(..)));
149        if self.is_null() {
150            if target.is_nullable() {
151                return Ok(Scalar::new(target.clone(), self.value.clone()));
152            } else {
153                vortex_bail!(
154                    "Cannot cast null to {}: target type is non-nullable",
155                    target
156                )
157            }
158        }
159
160        if self.dtype().eq_ignore_nullability(target) {
161            return Ok(Scalar::new(target.clone(), self.value.clone()));
162        }
163
164        match &self.dtype {
165            DType::Null => unreachable!(), // handled by if is_null case
166            DType::Bool(_) => self.as_bool().cast(target),
167            DType::Primitive(..) => self.as_primitive().cast(target),
168            DType::Decimal(..) => self.as_decimal().cast(target),
169            DType::Utf8(_) => self.as_utf8().cast(target),
170            DType::Binary(_) => self.as_binary().cast(target),
171            DType::Struct(..) => self.as_struct().cast(target),
172            DType::List(..) => self.as_list().cast(target),
173            DType::Extension(..) => self.as_extension().cast(target),
174        }
175    }
176
177    /// Converts the scalar to have a nullable version of its data type.
178    pub fn into_nullable(self) -> Self {
179        Self {
180            dtype: self.dtype.as_nullable(),
181            value: self.value,
182        }
183    }
184
185    /// Returns the size of the scalar in bytes, uncompressed.
186    pub fn nbytes(&self) -> usize {
187        match self.dtype() {
188            DType::Null => 0,
189            DType::Bool(_) => 1,
190            DType::Primitive(ptype, _) => ptype.byte_width(),
191            DType::Decimal(dt, _) => {
192                if dt.precision() <= DECIMAL128_MAX_PRECISION {
193                    size_of::<i128>()
194                } else {
195                    size_of::<i256>()
196                }
197            }
198            DType::Binary(_) | DType::Utf8(_) => self
199                .value()
200                .as_buffer()
201                .ok()
202                .flatten()
203                .map_or(0, |s| s.len()),
204            DType::Struct(_dtype, _) => self
205                .as_struct()
206                .fields()
207                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
208                .unwrap_or_default(),
209            DType::List(_dtype, _) => self
210                .as_list()
211                .elements()
212                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
213                .unwrap_or_default(),
214            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
215        }
216    }
217
218    /// Creates a "default" scalar value for the given data type.
219    ///
220    /// For nullable types, returns null. For non-nullable types, returns
221    /// an appropriate zero/empty value.
222    pub fn default_value(dtype: DType) -> Self {
223        if dtype.is_nullable() {
224            return Self::null(dtype);
225        }
226
227        match dtype {
228            DType::Null => Self::null(dtype),
229            DType::Bool(nullability) => Self::bool(false, nullability),
230            DType::Primitive(pt, nullability) => {
231                Self::primitive_value(PValue::zero(pt), pt, nullability)
232            }
233            DType::Decimal(dt, nullability) => {
234                Self::decimal(DecimalValue::from(0), dt, nullability)
235            }
236            DType::Utf8(nullability) => Self::utf8("", nullability),
237            DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
238            DType::Struct(sf, nullability) => {
239                let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
240                Self::struct_(DType::Struct(sf, nullability), fields)
241            }
242            DType::List(dt, nullability) => Self::list(dt, vec![], nullability),
243            DType::Extension(dt) => {
244                let scalar = Self::default_value(dt.storage_dtype().clone());
245                Self::extension(dt, scalar)
246            }
247        }
248    }
249}
250
251impl Scalar {
252    /// Returns a view of the scalar as a boolean scalar.
253    ///
254    /// # Panics
255    ///
256    /// Panics if the scalar is not a boolean type.
257    pub fn as_bool(&self) -> BoolScalar<'_> {
258        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
259    }
260
261    /// Returns a view of the scalar as a boolean scalar if it has a boolean type.
262    pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
263        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
264    }
265
266    /// Returns a view of the scalar as a primitive scalar.
267    ///
268    /// # Panics
269    ///
270    /// Panics if the scalar is not a primitive type.
271    pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
272        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
273    }
274
275    /// Returns a view of the scalar as a primitive scalar if it has a primitive type.
276    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
277        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
278    }
279
280    /// Returns a view of the scalar as a decimal scalar.
281    ///
282    /// # Panics
283    ///
284    /// Panics if the scalar is not a decimal type.
285    pub fn as_decimal(&self) -> DecimalScalar<'_> {
286        DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
287    }
288
289    /// Returns a view of the scalar as a decimal scalar if it has a decimal type.
290    pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
291        matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
292    }
293
294    /// Returns a view of the scalar as a UTF-8 string scalar.
295    ///
296    /// # Panics
297    ///
298    /// Panics if the scalar is not a UTF-8 type.
299    pub fn as_utf8(&self) -> Utf8Scalar<'_> {
300        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
301    }
302
303    /// Returns a view of the scalar as a UTF-8 string scalar if it has a UTF-8 type.
304    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
305        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
306    }
307
308    /// Returns a view of the scalar as a binary scalar.
309    ///
310    /// # Panics
311    ///
312    /// Panics if the scalar is not a binary type.
313    pub fn as_binary(&self) -> BinaryScalar<'_> {
314        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
315    }
316
317    /// Returns a view of the scalar as a binary scalar if it has a binary type.
318    pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
319        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
320    }
321
322    /// Returns a view of the scalar as a struct scalar.
323    ///
324    /// # Panics
325    ///
326    /// Panics if the scalar is not a struct type.
327    pub fn as_struct(&self) -> StructScalar<'_> {
328        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
329    }
330
331    /// Returns a view of the scalar as a struct scalar if it has a struct type.
332    pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
333        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
334    }
335
336    /// Returns a view of the scalar as a list scalar.
337    ///
338    /// # Panics
339    ///
340    /// Panics if the scalar is not a list type.
341    pub fn as_list(&self) -> ListScalar<'_> {
342        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
343    }
344
345    /// Returns a view of the scalar as a list scalar if it has a list type.
346    pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
347        matches!(self.dtype, DType::List(..)).then(|| self.as_list())
348    }
349
350    /// Returns a view of the scalar as an extension scalar.
351    ///
352    /// # Panics
353    ///
354    /// Panics if the scalar is not an extension type.
355    pub fn as_extension(&self) -> ExtScalar<'_> {
356        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
357    }
358
359    /// Returns a view of the scalar as an extension scalar if it has an extension type.
360    pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
361        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
362    }
363}
364
365impl PartialEq for Scalar {
366    fn eq(&self, other: &Self) -> bool {
367        if !self.dtype.eq_ignore_nullability(&other.dtype) {
368            return false;
369        }
370
371        match self.dtype() {
372            DType::Null => true,
373            DType::Bool(_) => self.as_bool() == other.as_bool(),
374            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
375            DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
376            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
377            DType::Binary(_) => self.as_binary() == other.as_binary(),
378            DType::Struct(..) => self.as_struct() == other.as_struct(),
379            DType::List(..) => self.as_list() == other.as_list(),
380            DType::Extension(_) => self.as_extension() == other.as_extension(),
381        }
382    }
383}
384
385impl Eq for Scalar {}
386
387impl PartialOrd for Scalar {
388    /// Compares two scalar values for ordering.
389    ///
390    /// # Returns
391    /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
392    /// - `None` if the scalars have different data types
393    ///
394    /// # Ordering Rules
395    /// When types match, the ordering follows these rules:
396    /// - Null values are considered less than all non-null values
397    /// - Non-null values are compared according to their natural ordering
398    ///
399    /// # Examples
400    /// ```ignore
401    /// // Same types compare successfully
402    /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
403    /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
404    /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
405    ///
406    /// // Different types return None
407    /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
408    /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
409    /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
410    ///
411    /// // Nulls are less than non-nulls
412    /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
413    /// let value = Scalar::primitive(0i32, Nullability::Nullable);
414    /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
415    /// ```
416    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
417        if !self.dtype().eq_ignore_nullability(other.dtype()) {
418            return None;
419        }
420        match self.dtype() {
421            DType::Null => Some(Ordering::Equal),
422            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
423            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
424            DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
425            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
426            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
427            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
428            DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
429            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
430        }
431    }
432}
433
434impl Hash for Scalar {
435    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
436        match self.dtype() {
437            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
438            DType::Bool(_) => self.as_bool().hash(state),
439            DType::Primitive(..) => self.as_primitive().hash(state),
440            DType::Decimal(..) => self.as_decimal().hash(state),
441            DType::Utf8(_) => self.as_utf8().hash(state),
442            DType::Binary(_) => self.as_binary().hash(state),
443            DType::Struct(..) => self.as_struct().hash(state),
444            DType::List(..) => self.as_list().hash(state),
445            DType::Extension(_) => self.as_extension().hash(state),
446        }
447    }
448}
449
450impl AsRef<Self> for Scalar {
451    fn as_ref(&self) -> &Self {
452        self
453    }
454}
455
456impl<T> From<Option<T>> for Scalar
457where
458    T: ScalarType,
459    Scalar: From<T>,
460{
461    fn from(value: Option<T>) -> Self {
462        value
463            .map(Scalar::from)
464            .map(|x| x.into_nullable())
465            .unwrap_or_else(|| Scalar {
466                dtype: T::dtype().as_nullable(),
467                value: ScalarValue(InnerScalarValue::Null),
468            })
469    }
470}
471
472impl From<PrimitiveScalar<'_>> for Scalar {
473    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
474        let dtype = pscalar.dtype().clone();
475        let value = pscalar
476            .pvalue()
477            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
478            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
479        Self::new(dtype, value)
480    }
481}
482
483impl From<DecimalScalar<'_>> for Scalar {
484    fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
485        let dtype = decimal_scalar.dtype().clone();
486        let value = decimal_scalar
487            .decimal_value()
488            .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
489            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
490        Self::new(dtype, value)
491    }
492}
493
494macro_rules! from_vec_for_scalar {
495    ($T:ty) => {
496        impl From<Vec<$T>> for Scalar {
497            fn from(value: Vec<$T>) -> Self {
498                Scalar {
499                    dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
500                    value: ScalarValue(InnerScalarValue::List(
501                        value
502                            .into_iter()
503                            .map(Scalar::from)
504                            .map(|s| s.into_value())
505                            .collect::<Arc<[_]>>(),
506                    )),
507                }
508            }
509        }
510    };
511}
512
513// no From<Vec<u8>> because it could either be a List or a Buffer
514from_vec_for_scalar!(u16);
515from_vec_for_scalar!(u32);
516from_vec_for_scalar!(u64);
517from_vec_for_scalar!(usize); // For usize only, we implicitly cast for better ergonomics.
518from_vec_for_scalar!(i8);
519from_vec_for_scalar!(i16);
520from_vec_for_scalar!(i32);
521from_vec_for_scalar!(i64);
522from_vec_for_scalar!(f16);
523from_vec_for_scalar!(f32);
524from_vec_for_scalar!(f64);
525from_vec_for_scalar!(String);
526from_vec_for_scalar!(BufferString);
527from_vec_for_scalar!(ByteBuffer);