vortex_scalar/
lib.rs

1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod bigint;
13mod binary;
14mod bool;
15mod datafusion;
16mod decimal;
17mod display;
18mod extension;
19mod list;
20mod null;
21mod primitive;
22mod proto;
23mod pvalue;
24mod scalar_type;
25mod scalar_value;
26mod struct_;
27mod utf8;
28
29pub use bigint::*;
30pub use binary::*;
31pub use bool::*;
32pub use decimal::*;
33pub use extension::*;
34pub use list::*;
35pub use primitive::*;
36pub use pvalue::*;
37pub use scalar_value::*;
38pub use struct_::*;
39pub use utf8::*;
40use vortex_error::{VortexExpect, VortexResult, vortex_bail};
41
42/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
43///
44/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
45/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
46///
47/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
48/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
49/// natural ordering of the scalar value.
50#[derive(Debug, Clone)]
51pub struct Scalar {
52    dtype: DType,
53    value: ScalarValue,
54}
55
56impl Scalar {
57    pub fn new(dtype: DType, value: ScalarValue) -> Self {
58        Self { dtype, value }
59    }
60
61    #[inline]
62    pub fn dtype(&self) -> &DType {
63        &self.dtype
64    }
65
66    #[inline]
67    pub fn value(&self) -> &ScalarValue {
68        &self.value
69    }
70
71    #[inline]
72    pub fn into_parts(self) -> (DType, ScalarValue) {
73        (self.dtype, self.value)
74    }
75
76    #[inline]
77    pub fn into_value(self) -> ScalarValue {
78        self.value
79    }
80
81    pub fn is_valid(&self) -> bool {
82        !self.value.is_null()
83    }
84
85    pub fn is_null(&self) -> bool {
86        self.value.is_null()
87    }
88
89    pub fn null(dtype: DType) -> Self {
90        assert!(
91            dtype.is_nullable(),
92            "Creating null scalar for non-nullable DType {dtype}"
93        );
94        Self {
95            dtype,
96            value: ScalarValue(InnerScalarValue::Null),
97        }
98    }
99
100    pub fn null_typed<T: ScalarType>() -> Self {
101        Self {
102            dtype: T::dtype().as_nullable(),
103            value: ScalarValue(InnerScalarValue::Null),
104        }
105    }
106
107    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
108        if let DType::Extension(ext_dtype) = target {
109            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
110            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
111        } else {
112            self.cast_to_non_extension(target)
113        }
114    }
115
116    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
117        assert!(!matches!(target, DType::Extension(..)));
118        if self.is_null() {
119            if target.is_nullable() {
120                return Ok(Scalar::new(target.clone(), self.value.clone()));
121            } else {
122                vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
123            }
124        }
125
126        if self.dtype().eq_ignore_nullability(target) {
127            return Ok(Scalar::new(target.clone(), self.value.clone()));
128        }
129
130        match &self.dtype {
131            DType::Null => unreachable!(), // handled by if is_null case
132            DType::Bool(_) => self.as_bool().cast(target),
133            DType::Primitive(..) => self.as_primitive().cast(target),
134            DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
135            DType::Utf8(_) => self.as_utf8().cast(target),
136            DType::Binary(_) => self.as_binary().cast(target),
137            DType::Struct(..) => self.as_struct().cast(target),
138            DType::List(..) => self.as_list().cast(target),
139            DType::Extension(..) => self.as_extension().cast(target),
140        }
141    }
142
143    pub fn into_nullable(self) -> Self {
144        Self {
145            dtype: self.dtype.as_nullable(),
146            value: self.value,
147        }
148    }
149
150    /// Size of the scalar in bytes, uncompressed.
151    pub fn nbytes(&self) -> usize {
152        match self.dtype() {
153            DType::Null => 0,
154            DType::Bool(_) => 1,
155            DType::Primitive(ptype, _) => ptype.byte_width(),
156            DType::Decimal(dt, _) => {
157                if dt.precision() >= DECIMAL128_MAX_PRECISION {
158                    size_of::<i128>()
159                } else {
160                    size_of::<i256>()
161                }
162            }
163            DType::Binary(_) | DType::Utf8(_) => self
164                .value()
165                .as_buffer()
166                .ok()
167                .flatten()
168                .map_or(0, |s| s.len()),
169            DType::Struct(_dtype, _) => self
170                .as_struct()
171                .fields()
172                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
173                .unwrap_or_default(),
174            DType::List(_dtype, _) => self
175                .as_list()
176                .elements()
177                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
178                .unwrap_or_default(),
179            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
180        }
181    }
182}
183
184impl Scalar {
185    pub fn as_bool(&self) -> BoolScalar {
186        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
187    }
188
189    pub fn as_bool_opt(&self) -> Option<BoolScalar> {
190        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
191    }
192
193    pub fn as_primitive(&self) -> PrimitiveScalar {
194        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
195    }
196
197    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar> {
198        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
199    }
200
201    pub fn as_decimal(&self) -> DecimalScalar {
202        DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
203    }
204
205    pub fn as_decimal_opt(&self) -> Option<DecimalScalar> {
206        matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
207    }
208
209    pub fn as_utf8(&self) -> Utf8Scalar {
210        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
211    }
212
213    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar> {
214        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
215    }
216
217    pub fn as_binary(&self) -> BinaryScalar {
218        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
219    }
220
221    pub fn as_binary_opt(&self) -> Option<BinaryScalar> {
222        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
223    }
224
225    pub fn as_struct(&self) -> StructScalar {
226        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
227    }
228
229    pub fn as_struct_opt(&self) -> Option<StructScalar> {
230        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
231    }
232
233    pub fn as_list(&self) -> ListScalar {
234        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
235    }
236
237    pub fn as_list_opt(&self) -> Option<ListScalar> {
238        matches!(self.dtype, DType::List(..)).then(|| self.as_list())
239    }
240
241    pub fn as_extension(&self) -> ExtScalar {
242        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
243    }
244
245    pub fn as_extension_opt(&self) -> Option<ExtScalar> {
246        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
247    }
248}
249
250impl PartialEq for Scalar {
251    fn eq(&self, other: &Self) -> bool {
252        if !self.dtype.eq_ignore_nullability(&other.dtype) {
253            return false;
254        }
255
256        match self.dtype() {
257            DType::Null => true,
258            DType::Bool(_) => self.as_bool() == other.as_bool(),
259            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
260            DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
261            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
262            DType::Binary(_) => self.as_binary() == other.as_binary(),
263            DType::Struct(..) => self.as_struct() == other.as_struct(),
264            DType::List(..) => self.as_list() == other.as_list(),
265            DType::Extension(_) => self.as_extension() == other.as_extension(),
266        }
267    }
268}
269
270impl Eq for Scalar {}
271
272impl PartialOrd for Scalar {
273    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
274        if !self.dtype().eq_ignore_nullability(other.dtype()) {
275            return None;
276        }
277        match self.dtype() {
278            DType::Null => Some(Ordering::Equal),
279            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
280            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
281            DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
282            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
283            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
284            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
285            DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
286            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
287        }
288    }
289}
290
291impl Hash for Scalar {
292    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
293        match self.dtype() {
294            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
295            DType::Bool(_) => self.as_bool().hash(state),
296            DType::Primitive(..) => self.as_primitive().hash(state),
297            DType::Decimal(..) => self.as_decimal().hash(state),
298            DType::Utf8(_) => self.as_utf8().hash(state),
299            DType::Binary(_) => self.as_binary().hash(state),
300            DType::Struct(..) => self.as_struct().hash(state),
301            DType::List(..) => self.as_list().hash(state),
302            DType::Extension(_) => self.as_extension().hash(state),
303        }
304    }
305}
306
307impl AsRef<Self> for Scalar {
308    fn as_ref(&self) -> &Self {
309        self
310    }
311}
312
313impl<T> From<Option<T>> for Scalar
314where
315    T: ScalarType,
316    Scalar: From<T>,
317{
318    fn from(value: Option<T>) -> Self {
319        value
320            .map(Scalar::from)
321            .map(|x| x.into_nullable())
322            .unwrap_or_else(|| Scalar {
323                dtype: T::dtype().as_nullable(),
324                value: ScalarValue(InnerScalarValue::Null),
325            })
326    }
327}
328
329impl From<PrimitiveScalar<'_>> for Scalar {
330    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
331        let dtype = pscalar.dtype().clone();
332        let value = pscalar
333            .pvalue()
334            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
335            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
336        Self::new(dtype, value)
337    }
338}
339
340impl From<DecimalScalar<'_>> for Scalar {
341    fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
342        let dtype = decimal_scalar.dtype().clone();
343        let value = decimal_scalar
344            .decimal_value()
345            .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
346            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
347        Self::new(dtype, value)
348    }
349}
350
351macro_rules! from_vec_for_scalar {
352    ($T:ty) => {
353        impl From<Vec<$T>> for Scalar {
354            fn from(value: Vec<$T>) -> Self {
355                Scalar {
356                    dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
357                    value: ScalarValue(InnerScalarValue::List(
358                        value
359                            .into_iter()
360                            .map(Scalar::from)
361                            .map(|s| s.into_value())
362                            .collect::<Arc<[_]>>(),
363                    )),
364                }
365            }
366        }
367    };
368}
369
370// no From<Vec<u8>> because it could either be a List or a Buffer
371from_vec_for_scalar!(u16);
372from_vec_for_scalar!(u32);
373from_vec_for_scalar!(u64);
374from_vec_for_scalar!(usize); // For usize only, we implicitly cast for better ergonomics.
375from_vec_for_scalar!(i8);
376from_vec_for_scalar!(i16);
377from_vec_for_scalar!(i32);
378from_vec_for_scalar!(i64);
379from_vec_for_scalar!(f16);
380from_vec_for_scalar!(f32);
381from_vec_for_scalar!(f64);
382from_vec_for_scalar!(String);
383from_vec_for_scalar!(BufferString);
384from_vec_for_scalar!(ByteBuffer);
385
386#[cfg(test)]
387mod test {
388    use std::sync::Arc;
389
390    use rstest::rstest;
391    use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
392
393    use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
394
395    #[rstest]
396    fn null_can_cast_to_anything_nullable(
397        #[values(
398            DType::Null,
399            DType::Bool(Nullability::Nullable),
400            DType::Primitive(PType::I32, Nullability::Nullable),
401            DType::Extension(Arc::from(ExtDType::new(
402                ExtID::from("a"),
403                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
404                None,
405            ))),
406            DType::Extension(Arc::from(ExtDType::new(
407                ExtID::from("b"),
408                Arc::from(DType::Utf8(Nullability::Nullable)),
409                None,
410            )))
411        )]
412        source_dtype: DType,
413        #[values(
414            DType::Null,
415            DType::Bool(Nullability::Nullable),
416            DType::Primitive(PType::I32, Nullability::Nullable),
417            DType::Extension(Arc::from(ExtDType::new(
418                ExtID::from("a"),
419                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
420                None,
421            ))),
422            DType::Extension(Arc::from(ExtDType::new(
423                ExtID::from("b"),
424                Arc::from(DType::Utf8(Nullability::Nullable)),
425                None,
426            )))
427        )]
428        target_dtype: DType,
429    ) {
430        assert_eq!(
431            Scalar::null(source_dtype)
432                .cast(&target_dtype)
433                .unwrap()
434                .dtype(),
435            &target_dtype
436        );
437    }
438
439    #[test]
440    fn list_casts() {
441        let list = Scalar::new(
442            DType::List(
443                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
444                Nullability::Nullable,
445            ),
446            ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
447                InnerScalarValue::Primitive(PValue::U16(6)),
448            )]))),
449        );
450
451        let target_u32 = DType::List(
452            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
453            Nullability::Nullable,
454        );
455        assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
456
457        let target_u32_nonnull = DType::List(
458            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
459            Nullability::Nullable,
460        );
461        assert_eq!(
462            list.cast(&target_u32_nonnull).unwrap().dtype(),
463            &target_u32_nonnull
464        );
465
466        let target_nonnull = DType::List(
467            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
468            Nullability::NonNullable,
469        );
470        assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
471
472        let target_u8 = DType::List(
473            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
474            Nullability::Nullable,
475        );
476        assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
477
478        let list_with_null = Scalar::new(
479            DType::List(
480                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
481                Nullability::Nullable,
482            ),
483            ScalarValue(InnerScalarValue::List(Arc::from([
484                ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
485                ScalarValue(InnerScalarValue::Null),
486            ]))),
487        );
488        let target_u8 = DType::List(
489            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
490            Nullability::Nullable,
491        );
492        assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
493
494        let target_u32_nonnull = DType::List(
495            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
496            Nullability::Nullable,
497        );
498        assert!(list_with_null.cast(&target_u32_nonnull).is_err());
499    }
500
501    #[test]
502    fn cast_to_from_extension_types() {
503        let apples = ExtDType::new(
504            ExtID::new(Arc::from("apples")),
505            Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
506            None,
507        );
508        let ext_dtype = DType::Extension(Arc::from(apples.clone()));
509        let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
510        let storage_scalar = Scalar::new(
511            DType::clone(apples.storage_dtype()),
512            ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
513        );
514
515        // to self
516        let expected_dtype = &ext_dtype;
517        let actual = ext_scalar.cast(expected_dtype).unwrap();
518        assert_eq!(actual.dtype(), expected_dtype);
519
520        // to nullable self
521        let expected_dtype = &ext_dtype.as_nullable();
522        let actual = ext_scalar.cast(expected_dtype).unwrap();
523        assert_eq!(actual.dtype(), expected_dtype);
524
525        // cast to the storage type
526        let expected_dtype = apples.storage_dtype();
527        let actual = ext_scalar.cast(expected_dtype).unwrap();
528        assert_eq!(actual.dtype(), expected_dtype);
529
530        // cast to the storage type, nullable
531        let expected_dtype = &apples.storage_dtype().as_nullable();
532        let actual = ext_scalar.cast(expected_dtype).unwrap();
533        assert_eq!(actual.dtype(), expected_dtype);
534
535        // cast from storage type to extension
536        let expected_dtype = &ext_dtype;
537        let actual = storage_scalar.cast(expected_dtype).unwrap();
538        assert_eq!(actual.dtype(), expected_dtype);
539
540        // cast from storage type to extension, nullable
541        let expected_dtype = &ext_dtype.as_nullable();
542        let actual = storage_scalar.cast(expected_dtype).unwrap();
543        assert_eq!(actual.dtype(), expected_dtype);
544
545        // cast from *compatible* storage type to extension
546        let storage_scalar_u64 = Scalar::new(
547            DType::clone(apples.storage_dtype()),
548            ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
549        );
550        let expected_dtype = &ext_dtype;
551        let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
552        assert_eq!(actual.dtype(), expected_dtype);
553
554        // cast from *incompatible* storage type to extension
555        let apples_u8 = ExtDType::new(
556            ExtID::new(Arc::from("apples")),
557            Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
558            None,
559        );
560        let expected_dtype = &DType::Extension(Arc::from(apples_u8));
561        let result = storage_scalar.cast(expected_dtype);
562        assert!(
563            result.as_ref().is_err_and(|err| {
564                err
565                    .to_string()
566                    .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
567            }),
568            "{result:?}"
569        );
570    }
571}