vortex_scalar/
lib.rs

1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{Buffer, BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod bigint;
13mod binary;
14mod bool;
15mod decimal;
16mod display;
17mod extension;
18mod list;
19mod null;
20mod primitive;
21mod proto;
22mod pvalue;
23mod scalar_type;
24mod scalar_value;
25mod struct_;
26mod utf8;
27
28pub use bigint::*;
29pub use binary::*;
30pub use bool::*;
31pub use decimal::*;
32pub use extension::*;
33pub use list::*;
34pub use primitive::*;
35pub use pvalue::*;
36pub use scalar_value::*;
37pub use struct_::*;
38pub use utf8::*;
39use vortex_error::{VortexExpect, VortexResult, vortex_bail};
40
41/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
42///
43/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
44/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
45///
46/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
47/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
48/// natural ordering of the scalar value.
49#[derive(Debug, Clone)]
50pub struct Scalar {
51    dtype: DType,
52    value: ScalarValue,
53}
54
55impl Scalar {
56    pub fn new(dtype: DType, value: ScalarValue) -> Self {
57        Self { dtype, value }
58    }
59
60    #[inline]
61    pub fn dtype(&self) -> &DType {
62        &self.dtype
63    }
64
65    #[inline]
66    pub fn value(&self) -> &ScalarValue {
67        &self.value
68    }
69
70    #[inline]
71    pub fn into_parts(self) -> (DType, ScalarValue) {
72        (self.dtype, self.value)
73    }
74
75    #[inline]
76    pub fn into_value(self) -> ScalarValue {
77        self.value
78    }
79
80    pub fn is_valid(&self) -> bool {
81        !self.value.is_null()
82    }
83
84    pub fn is_null(&self) -> bool {
85        self.value.is_null()
86    }
87
88    pub fn null(dtype: DType) -> Self {
89        assert!(
90            dtype.is_nullable(),
91            "Creating null scalar for non-nullable DType {dtype}"
92        );
93        Self {
94            dtype,
95            value: ScalarValue(InnerScalarValue::Null),
96        }
97    }
98
99    pub fn null_typed<T: ScalarType>() -> Self {
100        Self {
101            dtype: T::dtype().as_nullable(),
102            value: ScalarValue(InnerScalarValue::Null),
103        }
104    }
105
106    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
107        if let DType::Extension(ext_dtype) = target {
108            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
109            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
110        } else {
111            self.cast_to_non_extension(target)
112        }
113    }
114
115    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
116        assert!(!matches!(target, DType::Extension(..)));
117        if self.is_null() {
118            if target.is_nullable() {
119                return Ok(Scalar::new(target.clone(), self.value.clone()));
120            } else {
121                vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
122            }
123        }
124
125        if self.dtype().eq_ignore_nullability(target) {
126            return Ok(Scalar::new(target.clone(), self.value.clone()));
127        }
128
129        match &self.dtype {
130            DType::Null => unreachable!(), // handled by if is_null case
131            DType::Bool(_) => self.as_bool().cast(target),
132            DType::Primitive(..) => self.as_primitive().cast(target),
133            DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
134            DType::Utf8(_) => self.as_utf8().cast(target),
135            DType::Binary(_) => self.as_binary().cast(target),
136            DType::Struct(..) => self.as_struct().cast(target),
137            DType::List(..) => self.as_list().cast(target),
138            DType::Extension(..) => self.as_extension().cast(target),
139        }
140    }
141
142    pub fn into_nullable(self) -> Self {
143        Self {
144            dtype: self.dtype.as_nullable(),
145            value: self.value,
146        }
147    }
148
149    /// Size of the scalar in bytes, uncompressed.
150    pub fn nbytes(&self) -> usize {
151        match self.dtype() {
152            DType::Null => 0,
153            DType::Bool(_) => 1,
154            DType::Primitive(ptype, _) => ptype.byte_width(),
155            DType::Decimal(dt, _) => {
156                if dt.precision() >= DECIMAL128_MAX_PRECISION {
157                    size_of::<i128>()
158                } else {
159                    size_of::<i256>()
160                }
161            }
162            DType::Binary(_) | DType::Utf8(_) => self
163                .value()
164                .as_buffer()
165                .ok()
166                .flatten()
167                .map_or(0, |s| s.len()),
168            DType::Struct(_dtype, _) => self
169                .as_struct()
170                .fields()
171                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
172                .unwrap_or_default(),
173            DType::List(_dtype, _) => self
174                .as_list()
175                .elements()
176                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
177                .unwrap_or_default(),
178            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
179        }
180    }
181
182    /// Create a "default" scalar value for the given data type.
183    pub fn default_value(dtype: DType) -> Self {
184        if dtype.is_nullable() {
185            return Self::null(dtype);
186        }
187
188        match dtype {
189            DType::Null => Self::null(dtype),
190            DType::Bool(nullability) => Self::bool(false, nullability),
191            DType::Primitive(pt, nullability) => {
192                Self::primitive_value(PValue::zero(pt), pt, nullability)
193            }
194            DType::Decimal(dt, nullability) => {
195                Self::decimal(DecimalValue::from(0), dt, nullability)
196            }
197            DType::Utf8(nullability) => Self::utf8("", nullability),
198            DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
199            DType::Struct(sf, nullability) => {
200                let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
201                Self::struct_(DType::Struct(sf, nullability), fields)
202            }
203            DType::List(dt, nullability) => Self::list(dt, vec![], nullability),
204            DType::Extension(dt) => {
205                let scalar = Self::default_value(dt.storage_dtype().clone());
206                Self::extension(dt, scalar)
207            }
208        }
209    }
210}
211
212impl Scalar {
213    pub fn as_bool(&self) -> BoolScalar<'_> {
214        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
215    }
216
217    pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
218        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
219    }
220
221    pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
222        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
223    }
224
225    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
226        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
227    }
228
229    pub fn as_decimal(&self) -> DecimalScalar<'_> {
230        DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
231    }
232
233    pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
234        matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
235    }
236
237    pub fn as_utf8(&self) -> Utf8Scalar<'_> {
238        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
239    }
240
241    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
242        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
243    }
244
245    pub fn as_binary(&self) -> BinaryScalar<'_> {
246        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
247    }
248
249    pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
250        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
251    }
252
253    pub fn as_struct(&self) -> StructScalar<'_> {
254        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
255    }
256
257    pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
258        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
259    }
260
261    pub fn as_list(&self) -> ListScalar<'_> {
262        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
263    }
264
265    pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
266        matches!(self.dtype, DType::List(..)).then(|| self.as_list())
267    }
268
269    pub fn as_extension(&self) -> ExtScalar<'_> {
270        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
271    }
272
273    pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
274        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
275    }
276}
277
278impl PartialEq for Scalar {
279    fn eq(&self, other: &Self) -> bool {
280        if !self.dtype.eq_ignore_nullability(&other.dtype) {
281            return false;
282        }
283
284        match self.dtype() {
285            DType::Null => true,
286            DType::Bool(_) => self.as_bool() == other.as_bool(),
287            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
288            DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
289            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
290            DType::Binary(_) => self.as_binary() == other.as_binary(),
291            DType::Struct(..) => self.as_struct() == other.as_struct(),
292            DType::List(..) => self.as_list() == other.as_list(),
293            DType::Extension(_) => self.as_extension() == other.as_extension(),
294        }
295    }
296}
297
298impl Eq for Scalar {}
299
300impl PartialOrd for Scalar {
301    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
302        if !self.dtype().eq_ignore_nullability(other.dtype()) {
303            return None;
304        }
305        match self.dtype() {
306            DType::Null => Some(Ordering::Equal),
307            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
308            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
309            DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
310            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
311            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
312            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
313            DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
314            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
315        }
316    }
317}
318
319impl Hash for Scalar {
320    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
321        match self.dtype() {
322            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
323            DType::Bool(_) => self.as_bool().hash(state),
324            DType::Primitive(..) => self.as_primitive().hash(state),
325            DType::Decimal(..) => self.as_decimal().hash(state),
326            DType::Utf8(_) => self.as_utf8().hash(state),
327            DType::Binary(_) => self.as_binary().hash(state),
328            DType::Struct(..) => self.as_struct().hash(state),
329            DType::List(..) => self.as_list().hash(state),
330            DType::Extension(_) => self.as_extension().hash(state),
331        }
332    }
333}
334
335impl AsRef<Self> for Scalar {
336    fn as_ref(&self) -> &Self {
337        self
338    }
339}
340
341impl<T> From<Option<T>> for Scalar
342where
343    T: ScalarType,
344    Scalar: From<T>,
345{
346    fn from(value: Option<T>) -> Self {
347        value
348            .map(Scalar::from)
349            .map(|x| x.into_nullable())
350            .unwrap_or_else(|| Scalar {
351                dtype: T::dtype().as_nullable(),
352                value: ScalarValue(InnerScalarValue::Null),
353            })
354    }
355}
356
357impl From<PrimitiveScalar<'_>> for Scalar {
358    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
359        let dtype = pscalar.dtype().clone();
360        let value = pscalar
361            .pvalue()
362            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
363            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
364        Self::new(dtype, value)
365    }
366}
367
368impl From<DecimalScalar<'_>> for Scalar {
369    fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
370        let dtype = decimal_scalar.dtype().clone();
371        let value = decimal_scalar
372            .decimal_value()
373            .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
374            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
375        Self::new(dtype, value)
376    }
377}
378
379macro_rules! from_vec_for_scalar {
380    ($T:ty) => {
381        impl From<Vec<$T>> for Scalar {
382            fn from(value: Vec<$T>) -> Self {
383                Scalar {
384                    dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
385                    value: ScalarValue(InnerScalarValue::List(
386                        value
387                            .into_iter()
388                            .map(Scalar::from)
389                            .map(|s| s.into_value())
390                            .collect::<Arc<[_]>>(),
391                    )),
392                }
393            }
394        }
395    };
396}
397
398// no From<Vec<u8>> because it could either be a List or a Buffer
399from_vec_for_scalar!(u16);
400from_vec_for_scalar!(u32);
401from_vec_for_scalar!(u64);
402from_vec_for_scalar!(usize); // For usize only, we implicitly cast for better ergonomics.
403from_vec_for_scalar!(i8);
404from_vec_for_scalar!(i16);
405from_vec_for_scalar!(i32);
406from_vec_for_scalar!(i64);
407from_vec_for_scalar!(f16);
408from_vec_for_scalar!(f32);
409from_vec_for_scalar!(f64);
410from_vec_for_scalar!(String);
411from_vec_for_scalar!(BufferString);
412from_vec_for_scalar!(ByteBuffer);
413
414#[cfg(test)]
415mod test {
416    use std::sync::Arc;
417
418    use rstest::rstest;
419    use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
420
421    use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
422
423    #[rstest]
424    fn null_can_cast_to_anything_nullable(
425        #[values(
426            DType::Null,
427            DType::Bool(Nullability::Nullable),
428            DType::Primitive(PType::I32, Nullability::Nullable),
429            DType::Extension(Arc::from(ExtDType::new(
430                ExtID::from("a"),
431                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
432                None,
433            ))),
434            DType::Extension(Arc::from(ExtDType::new(
435                ExtID::from("b"),
436                Arc::from(DType::Utf8(Nullability::Nullable)),
437                None,
438            )))
439        )]
440        source_dtype: DType,
441        #[values(
442            DType::Null,
443            DType::Bool(Nullability::Nullable),
444            DType::Primitive(PType::I32, Nullability::Nullable),
445            DType::Extension(Arc::from(ExtDType::new(
446                ExtID::from("a"),
447                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
448                None,
449            ))),
450            DType::Extension(Arc::from(ExtDType::new(
451                ExtID::from("b"),
452                Arc::from(DType::Utf8(Nullability::Nullable)),
453                None,
454            )))
455        )]
456        target_dtype: DType,
457    ) {
458        assert_eq!(
459            Scalar::null(source_dtype)
460                .cast(&target_dtype)
461                .unwrap()
462                .dtype(),
463            &target_dtype
464        );
465    }
466
467    #[test]
468    fn list_casts() {
469        let list = Scalar::new(
470            DType::List(
471                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
472                Nullability::Nullable,
473            ),
474            ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
475                InnerScalarValue::Primitive(PValue::U16(6)),
476            )]))),
477        );
478
479        let target_u32 = DType::List(
480            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
481            Nullability::Nullable,
482        );
483        assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
484
485        let target_u32_nonnull = DType::List(
486            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
487            Nullability::Nullable,
488        );
489        assert_eq!(
490            list.cast(&target_u32_nonnull).unwrap().dtype(),
491            &target_u32_nonnull
492        );
493
494        let target_nonnull = DType::List(
495            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
496            Nullability::NonNullable,
497        );
498        assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
499
500        let target_u8 = DType::List(
501            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
502            Nullability::Nullable,
503        );
504        assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
505
506        let list_with_null = Scalar::new(
507            DType::List(
508                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
509                Nullability::Nullable,
510            ),
511            ScalarValue(InnerScalarValue::List(Arc::from([
512                ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
513                ScalarValue(InnerScalarValue::Null),
514            ]))),
515        );
516        let target_u8 = DType::List(
517            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
518            Nullability::Nullable,
519        );
520        assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
521
522        let target_u32_nonnull = DType::List(
523            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
524            Nullability::Nullable,
525        );
526        assert!(list_with_null.cast(&target_u32_nonnull).is_err());
527    }
528
529    #[test]
530    fn cast_to_from_extension_types() {
531        let apples = ExtDType::new(
532            ExtID::new(Arc::from("apples")),
533            Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
534            None,
535        );
536        let ext_dtype = DType::Extension(Arc::from(apples.clone()));
537        let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
538        let storage_scalar = Scalar::new(
539            DType::clone(apples.storage_dtype()),
540            ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
541        );
542
543        // to self
544        let expected_dtype = &ext_dtype;
545        let actual = ext_scalar.cast(expected_dtype).unwrap();
546        assert_eq!(actual.dtype(), expected_dtype);
547
548        // to nullable self
549        let expected_dtype = &ext_dtype.as_nullable();
550        let actual = ext_scalar.cast(expected_dtype).unwrap();
551        assert_eq!(actual.dtype(), expected_dtype);
552
553        // cast to the storage type
554        let expected_dtype = apples.storage_dtype();
555        let actual = ext_scalar.cast(expected_dtype).unwrap();
556        assert_eq!(actual.dtype(), expected_dtype);
557
558        // cast to the storage type, nullable
559        let expected_dtype = &apples.storage_dtype().as_nullable();
560        let actual = ext_scalar.cast(expected_dtype).unwrap();
561        assert_eq!(actual.dtype(), expected_dtype);
562
563        // cast from storage type to extension
564        let expected_dtype = &ext_dtype;
565        let actual = storage_scalar.cast(expected_dtype).unwrap();
566        assert_eq!(actual.dtype(), expected_dtype);
567
568        // cast from storage type to extension, nullable
569        let expected_dtype = &ext_dtype.as_nullable();
570        let actual = storage_scalar.cast(expected_dtype).unwrap();
571        assert_eq!(actual.dtype(), expected_dtype);
572
573        // cast from *compatible* storage type to extension
574        let storage_scalar_u64 = Scalar::new(
575            DType::clone(apples.storage_dtype()),
576            ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
577        );
578        let expected_dtype = &ext_dtype;
579        let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
580        assert_eq!(actual.dtype(), expected_dtype);
581
582        // cast from *incompatible* storage type to extension
583        let apples_u8 = ExtDType::new(
584            ExtID::new(Arc::from("apples")),
585            Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
586            None,
587        );
588        let expected_dtype = &DType::Extension(Arc::from(apples_u8));
589        let result = storage_scalar.cast(expected_dtype);
590        assert!(
591            result.as_ref().is_err_and(|err| {
592                err
593                    .to_string()
594                    .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
595            }),
596            "{result:?}"
597        );
598    }
599
600    #[test]
601    fn default_value_for_complex_dtype() {
602        let struct_dtype = DType::struct_(
603            [
604                ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
605                (
606                    "b",
607                    DType::list(
608                        DType::Primitive(PType::I8, Nullability::Nullable),
609                        Nullability::NonNullable,
610                    ),
611                ),
612                ("c", DType::Primitive(PType::I32, Nullability::Nullable)),
613            ],
614            Nullability::NonNullable,
615        );
616
617        let scalar = Scalar::default_value(struct_dtype.clone());
618        assert_eq!(scalar.dtype(), &struct_dtype);
619
620        let scalar = scalar.as_struct();
621
622        let a_field = scalar.field("a").unwrap();
623        assert_eq!(a_field.as_primitive().pvalue().unwrap(), PValue::I32(0));
624
625        let b_field = scalar.field("b").unwrap();
626        assert!(b_field.as_list().is_empty());
627
628        let c_field = scalar.field("c").unwrap();
629        assert!(c_field.is_null());
630    }
631}