vortex_scalar/
lib.rs

1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod bigint;
13mod binary;
14mod bool;
15mod decimal;
16mod display;
17mod extension;
18mod list;
19mod null;
20mod primitive;
21mod proto;
22mod pvalue;
23mod scalar_type;
24mod scalar_value;
25mod struct_;
26mod utf8;
27
28pub use bigint::*;
29pub use binary::*;
30pub use bool::*;
31pub use decimal::*;
32pub use extension::*;
33pub use list::*;
34pub use primitive::*;
35pub use pvalue::*;
36pub use scalar_value::*;
37pub use struct_::*;
38pub use utf8::*;
39use vortex_error::{VortexExpect, VortexResult, vortex_bail};
40
41/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
42///
43/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
44/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
45///
46/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
47/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
48/// natural ordering of the scalar value.
49#[derive(Debug, Clone)]
50pub struct Scalar {
51    dtype: DType,
52    value: ScalarValue,
53}
54
55impl Scalar {
56    pub fn new(dtype: DType, value: ScalarValue) -> Self {
57        Self { dtype, value }
58    }
59
60    #[inline]
61    pub fn dtype(&self) -> &DType {
62        &self.dtype
63    }
64
65    #[inline]
66    pub fn value(&self) -> &ScalarValue {
67        &self.value
68    }
69
70    #[inline]
71    pub fn into_parts(self) -> (DType, ScalarValue) {
72        (self.dtype, self.value)
73    }
74
75    #[inline]
76    pub fn into_value(self) -> ScalarValue {
77        self.value
78    }
79
80    pub fn is_valid(&self) -> bool {
81        !self.value.is_null()
82    }
83
84    pub fn is_null(&self) -> bool {
85        self.value.is_null()
86    }
87
88    pub fn null(dtype: DType) -> Self {
89        assert!(
90            dtype.is_nullable(),
91            "Creating null scalar for non-nullable DType {dtype}"
92        );
93        Self {
94            dtype,
95            value: ScalarValue(InnerScalarValue::Null),
96        }
97    }
98
99    pub fn null_typed<T: ScalarType>() -> Self {
100        Self {
101            dtype: T::dtype().as_nullable(),
102            value: ScalarValue(InnerScalarValue::Null),
103        }
104    }
105
106    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
107        if let DType::Extension(ext_dtype) = target {
108            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
109            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
110        } else {
111            self.cast_to_non_extension(target)
112        }
113    }
114
115    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
116        assert!(!matches!(target, DType::Extension(..)));
117        if self.is_null() {
118            if target.is_nullable() {
119                return Ok(Scalar::new(target.clone(), self.value.clone()));
120            } else {
121                vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
122            }
123        }
124
125        if self.dtype().eq_ignore_nullability(target) {
126            return Ok(Scalar::new(target.clone(), self.value.clone()));
127        }
128
129        match &self.dtype {
130            DType::Null => unreachable!(), // handled by if is_null case
131            DType::Bool(_) => self.as_bool().cast(target),
132            DType::Primitive(..) => self.as_primitive().cast(target),
133            DType::Decimal(..) => todo!("(aduffy): implement DecimalScalar casting"),
134            DType::Utf8(_) => self.as_utf8().cast(target),
135            DType::Binary(_) => self.as_binary().cast(target),
136            DType::Struct(..) => self.as_struct().cast(target),
137            DType::List(..) => self.as_list().cast(target),
138            DType::Extension(..) => self.as_extension().cast(target),
139        }
140    }
141
142    pub fn into_nullable(self) -> Self {
143        Self {
144            dtype: self.dtype.as_nullable(),
145            value: self.value,
146        }
147    }
148
149    /// Size of the scalar in bytes, uncompressed.
150    pub fn nbytes(&self) -> usize {
151        match self.dtype() {
152            DType::Null => 0,
153            DType::Bool(_) => 1,
154            DType::Primitive(ptype, _) => ptype.byte_width(),
155            DType::Decimal(dt, _) => {
156                if dt.precision() >= DECIMAL128_MAX_PRECISION {
157                    size_of::<i128>()
158                } else {
159                    size_of::<i256>()
160                }
161            }
162            DType::Binary(_) | DType::Utf8(_) => self
163                .value()
164                .as_buffer()
165                .ok()
166                .flatten()
167                .map_or(0, |s| s.len()),
168            DType::Struct(_dtype, _) => self
169                .as_struct()
170                .fields()
171                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
172                .unwrap_or_default(),
173            DType::List(_dtype, _) => self
174                .as_list()
175                .elements()
176                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
177                .unwrap_or_default(),
178            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
179        }
180    }
181}
182
183impl Scalar {
184    pub fn as_bool(&self) -> BoolScalar {
185        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
186    }
187
188    pub fn as_bool_opt(&self) -> Option<BoolScalar> {
189        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
190    }
191
192    pub fn as_primitive(&self) -> PrimitiveScalar {
193        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
194    }
195
196    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar> {
197        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
198    }
199
200    pub fn as_decimal(&self) -> DecimalScalar {
201        DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
202    }
203
204    pub fn as_decimal_opt(&self) -> Option<DecimalScalar> {
205        matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
206    }
207
208    pub fn as_utf8(&self) -> Utf8Scalar {
209        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
210    }
211
212    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar> {
213        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
214    }
215
216    pub fn as_binary(&self) -> BinaryScalar {
217        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
218    }
219
220    pub fn as_binary_opt(&self) -> Option<BinaryScalar> {
221        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
222    }
223
224    pub fn as_struct(&self) -> StructScalar {
225        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
226    }
227
228    pub fn as_struct_opt(&self) -> Option<StructScalar> {
229        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
230    }
231
232    pub fn as_list(&self) -> ListScalar {
233        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
234    }
235
236    pub fn as_list_opt(&self) -> Option<ListScalar> {
237        matches!(self.dtype, DType::List(..)).then(|| self.as_list())
238    }
239
240    pub fn as_extension(&self) -> ExtScalar {
241        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
242    }
243
244    pub fn as_extension_opt(&self) -> Option<ExtScalar> {
245        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
246    }
247}
248
249impl PartialEq for Scalar {
250    fn eq(&self, other: &Self) -> bool {
251        if !self.dtype.eq_ignore_nullability(&other.dtype) {
252            return false;
253        }
254
255        match self.dtype() {
256            DType::Null => true,
257            DType::Bool(_) => self.as_bool() == other.as_bool(),
258            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
259            DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
260            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
261            DType::Binary(_) => self.as_binary() == other.as_binary(),
262            DType::Struct(..) => self.as_struct() == other.as_struct(),
263            DType::List(..) => self.as_list() == other.as_list(),
264            DType::Extension(_) => self.as_extension() == other.as_extension(),
265        }
266    }
267}
268
269impl Eq for Scalar {}
270
271impl PartialOrd for Scalar {
272    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
273        if !self.dtype().eq_ignore_nullability(other.dtype()) {
274            return None;
275        }
276        match self.dtype() {
277            DType::Null => Some(Ordering::Equal),
278            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
279            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
280            DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
281            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
282            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
283            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
284            DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
285            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
286        }
287    }
288}
289
290impl Hash for Scalar {
291    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
292        match self.dtype() {
293            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
294            DType::Bool(_) => self.as_bool().hash(state),
295            DType::Primitive(..) => self.as_primitive().hash(state),
296            DType::Decimal(..) => self.as_decimal().hash(state),
297            DType::Utf8(_) => self.as_utf8().hash(state),
298            DType::Binary(_) => self.as_binary().hash(state),
299            DType::Struct(..) => self.as_struct().hash(state),
300            DType::List(..) => self.as_list().hash(state),
301            DType::Extension(_) => self.as_extension().hash(state),
302        }
303    }
304}
305
306impl AsRef<Self> for Scalar {
307    fn as_ref(&self) -> &Self {
308        self
309    }
310}
311
312impl<T> From<Option<T>> for Scalar
313where
314    T: ScalarType,
315    Scalar: From<T>,
316{
317    fn from(value: Option<T>) -> Self {
318        value
319            .map(Scalar::from)
320            .map(|x| x.into_nullable())
321            .unwrap_or_else(|| Scalar {
322                dtype: T::dtype().as_nullable(),
323                value: ScalarValue(InnerScalarValue::Null),
324            })
325    }
326}
327
328impl From<PrimitiveScalar<'_>> for Scalar {
329    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
330        let dtype = pscalar.dtype().clone();
331        let value = pscalar
332            .pvalue()
333            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
334            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
335        Self::new(dtype, value)
336    }
337}
338
339impl From<DecimalScalar<'_>> for Scalar {
340    fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
341        let dtype = decimal_scalar.dtype().clone();
342        let value = decimal_scalar
343            .decimal_value()
344            .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
345            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
346        Self::new(dtype, value)
347    }
348}
349
350macro_rules! from_vec_for_scalar {
351    ($T:ty) => {
352        impl From<Vec<$T>> for Scalar {
353            fn from(value: Vec<$T>) -> Self {
354                Scalar {
355                    dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
356                    value: ScalarValue(InnerScalarValue::List(
357                        value
358                            .into_iter()
359                            .map(Scalar::from)
360                            .map(|s| s.into_value())
361                            .collect::<Arc<[_]>>(),
362                    )),
363                }
364            }
365        }
366    };
367}
368
369// no From<Vec<u8>> because it could either be a List or a Buffer
370from_vec_for_scalar!(u16);
371from_vec_for_scalar!(u32);
372from_vec_for_scalar!(u64);
373from_vec_for_scalar!(usize); // For usize only, we implicitly cast for better ergonomics.
374from_vec_for_scalar!(i8);
375from_vec_for_scalar!(i16);
376from_vec_for_scalar!(i32);
377from_vec_for_scalar!(i64);
378from_vec_for_scalar!(f16);
379from_vec_for_scalar!(f32);
380from_vec_for_scalar!(f64);
381from_vec_for_scalar!(String);
382from_vec_for_scalar!(BufferString);
383from_vec_for_scalar!(ByteBuffer);
384
385#[cfg(test)]
386mod test {
387    use std::sync::Arc;
388
389    use rstest::rstest;
390    use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
391
392    use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
393
394    #[rstest]
395    fn null_can_cast_to_anything_nullable(
396        #[values(
397            DType::Null,
398            DType::Bool(Nullability::Nullable),
399            DType::Primitive(PType::I32, Nullability::Nullable),
400            DType::Extension(Arc::from(ExtDType::new(
401                ExtID::from("a"),
402                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
403                None,
404            ))),
405            DType::Extension(Arc::from(ExtDType::new(
406                ExtID::from("b"),
407                Arc::from(DType::Utf8(Nullability::Nullable)),
408                None,
409            )))
410        )]
411        source_dtype: DType,
412        #[values(
413            DType::Null,
414            DType::Bool(Nullability::Nullable),
415            DType::Primitive(PType::I32, Nullability::Nullable),
416            DType::Extension(Arc::from(ExtDType::new(
417                ExtID::from("a"),
418                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
419                None,
420            ))),
421            DType::Extension(Arc::from(ExtDType::new(
422                ExtID::from("b"),
423                Arc::from(DType::Utf8(Nullability::Nullable)),
424                None,
425            )))
426        )]
427        target_dtype: DType,
428    ) {
429        assert_eq!(
430            Scalar::null(source_dtype)
431                .cast(&target_dtype)
432                .unwrap()
433                .dtype(),
434            &target_dtype
435        );
436    }
437
438    #[test]
439    fn list_casts() {
440        let list = Scalar::new(
441            DType::List(
442                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
443                Nullability::Nullable,
444            ),
445            ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
446                InnerScalarValue::Primitive(PValue::U16(6)),
447            )]))),
448        );
449
450        let target_u32 = DType::List(
451            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
452            Nullability::Nullable,
453        );
454        assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
455
456        let target_u32_nonnull = DType::List(
457            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
458            Nullability::Nullable,
459        );
460        assert_eq!(
461            list.cast(&target_u32_nonnull).unwrap().dtype(),
462            &target_u32_nonnull
463        );
464
465        let target_nonnull = DType::List(
466            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
467            Nullability::NonNullable,
468        );
469        assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
470
471        let target_u8 = DType::List(
472            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
473            Nullability::Nullable,
474        );
475        assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
476
477        let list_with_null = Scalar::new(
478            DType::List(
479                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
480                Nullability::Nullable,
481            ),
482            ScalarValue(InnerScalarValue::List(Arc::from([
483                ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
484                ScalarValue(InnerScalarValue::Null),
485            ]))),
486        );
487        let target_u8 = DType::List(
488            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
489            Nullability::Nullable,
490        );
491        assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
492
493        let target_u32_nonnull = DType::List(
494            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
495            Nullability::Nullable,
496        );
497        assert!(list_with_null.cast(&target_u32_nonnull).is_err());
498    }
499
500    #[test]
501    fn cast_to_from_extension_types() {
502        let apples = ExtDType::new(
503            ExtID::new(Arc::from("apples")),
504            Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
505            None,
506        );
507        let ext_dtype = DType::Extension(Arc::from(apples.clone()));
508        let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
509        let storage_scalar = Scalar::new(
510            DType::clone(apples.storage_dtype()),
511            ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
512        );
513
514        // to self
515        let expected_dtype = &ext_dtype;
516        let actual = ext_scalar.cast(expected_dtype).unwrap();
517        assert_eq!(actual.dtype(), expected_dtype);
518
519        // to nullable self
520        let expected_dtype = &ext_dtype.as_nullable();
521        let actual = ext_scalar.cast(expected_dtype).unwrap();
522        assert_eq!(actual.dtype(), expected_dtype);
523
524        // cast to the storage type
525        let expected_dtype = apples.storage_dtype();
526        let actual = ext_scalar.cast(expected_dtype).unwrap();
527        assert_eq!(actual.dtype(), expected_dtype);
528
529        // cast to the storage type, nullable
530        let expected_dtype = &apples.storage_dtype().as_nullable();
531        let actual = ext_scalar.cast(expected_dtype).unwrap();
532        assert_eq!(actual.dtype(), expected_dtype);
533
534        // cast from storage type to extension
535        let expected_dtype = &ext_dtype;
536        let actual = storage_scalar.cast(expected_dtype).unwrap();
537        assert_eq!(actual.dtype(), expected_dtype);
538
539        // cast from storage type to extension, nullable
540        let expected_dtype = &ext_dtype.as_nullable();
541        let actual = storage_scalar.cast(expected_dtype).unwrap();
542        assert_eq!(actual.dtype(), expected_dtype);
543
544        // cast from *compatible* storage type to extension
545        let storage_scalar_u64 = Scalar::new(
546            DType::clone(apples.storage_dtype()),
547            ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
548        );
549        let expected_dtype = &ext_dtype;
550        let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
551        assert_eq!(actual.dtype(), expected_dtype);
552
553        // cast from *incompatible* storage type to extension
554        let apples_u8 = ExtDType::new(
555            ExtID::new(Arc::from("apples")),
556            Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
557            None,
558        );
559        let expected_dtype = &DType::Extension(Arc::from(apples_u8));
560        let result = storage_scalar.cast(expected_dtype);
561        assert!(
562            result.as_ref().is_err_and(|err| {
563                err
564                    .to_string()
565                    .contains("Can't cast u16 scalar 1000u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
566            }),
567            "{result:?}"
568        );
569    }
570}