vortex_scalar/
lib.rs

1use std::cmp::Ordering;
2use std::hash::Hash;
3use std::sync::Arc;
4
5pub use scalar_type::ScalarType;
6use vortex_buffer::{BufferString, ByteBuffer};
7use vortex_dtype::half::f16;
8use vortex_dtype::{DType, Nullability};
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod arrow;
12mod binary;
13mod bool;
14mod datafusion;
15mod display;
16mod extension;
17mod list;
18mod null;
19mod primitive;
20mod pvalue;
21mod scalar_type;
22mod scalarvalue;
23#[cfg(feature = "serde")]
24mod serde;
25mod struct_;
26mod utf8;
27
28pub use binary::*;
29pub use bool::*;
30pub use extension::*;
31pub use list::*;
32pub use primitive::*;
33pub use pvalue::*;
34pub use scalarvalue::*;
35pub use struct_::*;
36pub use utf8::*;
37use vortex_error::{VortexExpect, VortexResult, vortex_bail};
38
39/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`].
40///
41/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers
42/// for example [`BoolScalar`], [`PrimitiveScalar`], etc.
43///
44/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype,
45/// including nullability. When the DType does match, ordering is nulls first (lowest), then the
46/// natural ordering of the scalar value.
47#[derive(Debug, Clone)]
48#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
49pub struct Scalar {
50    dtype: DType,
51    value: ScalarValue,
52}
53
54impl Scalar {
55    pub fn new(dtype: DType, value: ScalarValue) -> Self {
56        Self { dtype, value }
57    }
58
59    #[inline]
60    pub fn dtype(&self) -> &DType {
61        &self.dtype
62    }
63
64    #[inline]
65    pub fn value(&self) -> &ScalarValue {
66        &self.value
67    }
68
69    #[inline]
70    pub fn into_parts(self) -> (DType, ScalarValue) {
71        (self.dtype, self.value)
72    }
73
74    #[inline]
75    pub fn into_value(self) -> ScalarValue {
76        self.value
77    }
78
79    pub fn is_valid(&self) -> bool {
80        !self.value.is_null()
81    }
82
83    pub fn is_null(&self) -> bool {
84        self.value.is_null()
85    }
86
87    pub fn null(dtype: DType) -> Self {
88        assert!(
89            dtype.is_nullable(),
90            "Creating null scalar for non-nullable DType {}",
91            dtype
92        );
93        Self {
94            dtype,
95            value: ScalarValue(InnerScalarValue::Null),
96        }
97    }
98
99    pub fn null_typed<T: ScalarType>() -> Self {
100        Self {
101            dtype: T::dtype().as_nullable(),
102            value: ScalarValue(InnerScalarValue::Null),
103        }
104    }
105
106    pub fn cast(&self, target: &DType) -> VortexResult<Self> {
107        if let DType::Extension(ext_dtype) = target {
108            let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
109            Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
110        } else {
111            self.cast_to_non_extension(target)
112        }
113    }
114
115    fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
116        assert!(!matches!(target, DType::Extension(..)));
117        if self.is_null() {
118            if target.is_nullable() {
119                return Ok(Scalar::new(target.clone(), self.value.clone()));
120            } else {
121                vortex_bail!("Can't cast null scalar to non-nullable type {}", target)
122            }
123        }
124
125        if self.dtype().eq_ignore_nullability(target) {
126            return Ok(Scalar::new(target.clone(), self.value.clone()));
127        }
128
129        match &self.dtype {
130            DType::Null => unreachable!(), // handled by if is_null case
131            DType::Bool(_) => self.as_bool().cast(target),
132            DType::Primitive(..) => self.as_primitive().cast(target),
133            DType::Utf8(_) => self.as_utf8().cast(target),
134            DType::Binary(_) => self.as_binary().cast(target),
135            DType::Struct(..) => self.as_struct().cast(target),
136            DType::List(..) => self.as_list().cast(target),
137            DType::Extension(..) => self.as_extension().cast(target),
138        }
139    }
140
141    pub fn into_nullable(self) -> Self {
142        Self {
143            dtype: self.dtype.as_nullable(),
144            value: self.value,
145        }
146    }
147
148    /// Size of the scalar in bytes, uncompressed.
149    pub fn nbytes(&self) -> usize {
150        match self.dtype() {
151            DType::Null => 0,
152            DType::Bool(_) => 1,
153            DType::Primitive(ptype, _) => ptype.byte_width(),
154            DType::Binary(_) | DType::Utf8(_) => self
155                .value()
156                .as_buffer()
157                .ok()
158                .flatten()
159                .map_or(0, |s| s.len()),
160            DType::Struct(_dtype, _) => self
161                .as_struct()
162                .fields()
163                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
164                .unwrap_or_default(),
165            DType::List(_dtype, _) => self
166                .as_list()
167                .elements()
168                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
169                .unwrap_or_default(),
170            DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
171        }
172    }
173}
174
175impl Scalar {
176    pub fn as_bool(&self) -> BoolScalar {
177        BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
178    }
179
180    pub fn as_bool_opt(&self) -> Option<BoolScalar> {
181        matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
182    }
183
184    pub fn as_primitive(&self) -> PrimitiveScalar {
185        PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
186    }
187
188    pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar> {
189        matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
190    }
191
192    pub fn as_utf8(&self) -> Utf8Scalar {
193        Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
194    }
195
196    pub fn as_utf8_opt(&self) -> Option<Utf8Scalar> {
197        matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
198    }
199
200    pub fn as_binary(&self) -> BinaryScalar {
201        BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
202    }
203
204    pub fn as_binary_opt(&self) -> Option<BinaryScalar> {
205        matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
206    }
207
208    pub fn as_struct(&self) -> StructScalar {
209        StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
210    }
211
212    pub fn as_struct_opt(&self) -> Option<StructScalar> {
213        matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
214    }
215
216    pub fn as_list(&self) -> ListScalar {
217        ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
218    }
219
220    pub fn as_list_opt(&self) -> Option<ListScalar> {
221        matches!(self.dtype, DType::List(..)).then(|| self.as_list())
222    }
223
224    pub fn as_extension(&self) -> ExtScalar {
225        ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
226    }
227
228    pub fn as_extension_opt(&self) -> Option<ExtScalar> {
229        matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
230    }
231}
232
233impl PartialEq for Scalar {
234    fn eq(&self, other: &Self) -> bool {
235        if !self.dtype.eq_ignore_nullability(&other.dtype) {
236            return false;
237        }
238
239        match self.dtype() {
240            DType::Null => true,
241            DType::Bool(_) => self.as_bool() == other.as_bool(),
242            DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
243            DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
244            DType::Binary(_) => self.as_binary() == other.as_binary(),
245            DType::Struct(..) => self.as_struct() == other.as_struct(),
246            DType::List(..) => self.as_list() == other.as_list(),
247            DType::Extension(_) => self.as_extension() == other.as_extension(),
248        }
249    }
250}
251
252impl Eq for Scalar {}
253
254impl PartialOrd for Scalar {
255    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
256        if !self.dtype().eq_ignore_nullability(other.dtype()) {
257            return None;
258        }
259        match self.dtype() {
260            DType::Null => Some(Ordering::Equal),
261            DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
262            DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
263            DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
264            DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
265            DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
266            DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
267            DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
268        }
269    }
270}
271
272impl Hash for Scalar {
273    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
274        match self.dtype() {
275            DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value
276            DType::Bool(_) => self.as_bool().hash(state),
277            DType::Primitive(..) => self.as_primitive().hash(state),
278            DType::Utf8(_) => self.as_utf8().hash(state),
279            DType::Binary(_) => self.as_binary().hash(state),
280            DType::Struct(..) => self.as_struct().hash(state),
281            DType::List(..) => self.as_list().hash(state),
282            DType::Extension(_) => self.as_extension().hash(state),
283        }
284    }
285}
286
287impl AsRef<Self> for Scalar {
288    fn as_ref(&self) -> &Self {
289        self
290    }
291}
292
293impl<T> From<Option<T>> for Scalar
294where
295    T: ScalarType,
296    Scalar: From<T>,
297{
298    fn from(value: Option<T>) -> Self {
299        value
300            .map(Scalar::from)
301            .map(|x| x.into_nullable())
302            .unwrap_or_else(|| Scalar {
303                dtype: T::dtype().as_nullable(),
304                value: ScalarValue(InnerScalarValue::Null),
305            })
306    }
307}
308
309impl From<PrimitiveScalar<'_>> for Scalar {
310    fn from(pscalar: PrimitiveScalar<'_>) -> Self {
311        let dtype = pscalar.dtype().clone();
312        let value = pscalar
313            .pvalue()
314            .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
315            .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
316        Self::new(dtype, value)
317    }
318}
319
320macro_rules! from_vec_for_scalar {
321    ($T:ty) => {
322        impl From<Vec<$T>> for Scalar {
323            fn from(value: Vec<$T>) -> Self {
324                Scalar {
325                    dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
326                    value: ScalarValue(InnerScalarValue::List(
327                        value
328                            .into_iter()
329                            .map(Scalar::from)
330                            .map(|s| s.into_value())
331                            .collect::<Arc<[_]>>(),
332                    )),
333                }
334            }
335        }
336    };
337}
338
339// no From<Vec<u8>> because it could either be a List or a Buffer
340from_vec_for_scalar!(u16);
341from_vec_for_scalar!(u32);
342from_vec_for_scalar!(u64);
343from_vec_for_scalar!(usize); // For usize only, we implicitly cast for better ergonomics.
344from_vec_for_scalar!(i8);
345from_vec_for_scalar!(i16);
346from_vec_for_scalar!(i32);
347from_vec_for_scalar!(i64);
348from_vec_for_scalar!(f16);
349from_vec_for_scalar!(f32);
350from_vec_for_scalar!(f64);
351from_vec_for_scalar!(String);
352from_vec_for_scalar!(BufferString);
353from_vec_for_scalar!(ByteBuffer);
354
355#[cfg(test)]
356mod test {
357    use std::sync::Arc;
358
359    use rstest::rstest;
360    use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
361
362    use crate::{InnerScalarValue, PValue, Scalar, ScalarValue};
363
364    #[rstest]
365    fn null_can_cast_to_anything_nullable(
366        #[values(
367            DType::Null,
368            DType::Bool(Nullability::Nullable),
369            DType::Primitive(PType::I32, Nullability::Nullable),
370            DType::Extension(Arc::from(ExtDType::new(
371                ExtID::from("a"),
372                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
373                None,
374            ))),
375            DType::Extension(Arc::from(ExtDType::new(
376                ExtID::from("b"),
377                Arc::from(DType::Utf8(Nullability::Nullable)),
378                None,
379            )))
380        )]
381        source_dtype: DType,
382        #[values(
383            DType::Null,
384            DType::Bool(Nullability::Nullable),
385            DType::Primitive(PType::I32, Nullability::Nullable),
386            DType::Extension(Arc::from(ExtDType::new(
387                ExtID::from("a"),
388                Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
389                None,
390            ))),
391            DType::Extension(Arc::from(ExtDType::new(
392                ExtID::from("b"),
393                Arc::from(DType::Utf8(Nullability::Nullable)),
394                None,
395            )))
396        )]
397        target_dtype: DType,
398    ) {
399        assert_eq!(
400            Scalar::null(source_dtype)
401                .cast(&target_dtype)
402                .unwrap()
403                .dtype(),
404            &target_dtype
405        );
406    }
407
408    #[test]
409    fn list_casts() {
410        let list = Scalar::new(
411            DType::List(
412                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
413                Nullability::Nullable,
414            ),
415            ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(
416                InnerScalarValue::Primitive(PValue::U16(6)),
417            )]))),
418        );
419
420        let target_u32 = DType::List(
421            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
422            Nullability::Nullable,
423        );
424        assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32);
425
426        let target_u32_nonnull = DType::List(
427            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
428            Nullability::Nullable,
429        );
430        assert_eq!(
431            list.cast(&target_u32_nonnull).unwrap().dtype(),
432            &target_u32_nonnull
433        );
434
435        let target_nonnull = DType::List(
436            Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)),
437            Nullability::NonNullable,
438        );
439        assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull);
440
441        let target_u8 = DType::List(
442            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
443            Nullability::Nullable,
444        );
445        assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8);
446
447        let list_with_null = Scalar::new(
448            DType::List(
449                Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)),
450                Nullability::Nullable,
451            ),
452            ScalarValue(InnerScalarValue::List(Arc::from([
453                ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))),
454                ScalarValue(InnerScalarValue::Null),
455            ]))),
456        );
457        let target_u8 = DType::List(
458            Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)),
459            Nullability::Nullable,
460        );
461        assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8);
462
463        let target_u32_nonnull = DType::List(
464            Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
465            Nullability::Nullable,
466        );
467        assert!(list_with_null.cast(&target_u32_nonnull).is_err());
468    }
469
470    #[test]
471    fn cast_to_from_extension_types() {
472        let apples = ExtDType::new(
473            ExtID::new(Arc::from("apples")),
474            Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)),
475            None,
476        );
477        let ext_dtype = DType::Extension(Arc::from(apples.clone()));
478        let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true)));
479        let storage_scalar = Scalar::new(
480            DType::clone(apples.storage_dtype()),
481            ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))),
482        );
483
484        // to self
485        let expected_dtype = &ext_dtype;
486        let actual = ext_scalar.cast(expected_dtype).unwrap();
487        assert_eq!(actual.dtype(), expected_dtype);
488
489        // to nullable self
490        let expected_dtype = &ext_dtype.as_nullable();
491        let actual = ext_scalar.cast(expected_dtype).unwrap();
492        assert_eq!(actual.dtype(), expected_dtype);
493
494        // cast to the storage type
495        let expected_dtype = apples.storage_dtype();
496        let actual = ext_scalar.cast(expected_dtype).unwrap();
497        assert_eq!(actual.dtype(), expected_dtype);
498
499        // cast to the storage type, nullable
500        let expected_dtype = &apples.storage_dtype().as_nullable();
501        let actual = ext_scalar.cast(expected_dtype).unwrap();
502        assert_eq!(actual.dtype(), expected_dtype);
503
504        // cast from storage type to extension
505        let expected_dtype = &ext_dtype;
506        let actual = storage_scalar.cast(expected_dtype).unwrap();
507        assert_eq!(actual.dtype(), expected_dtype);
508
509        // cast from storage type to extension, nullable
510        let expected_dtype = &ext_dtype.as_nullable();
511        let actual = storage_scalar.cast(expected_dtype).unwrap();
512        assert_eq!(actual.dtype(), expected_dtype);
513
514        // cast from *compatible* storage type to extension
515        let storage_scalar_u64 = Scalar::new(
516            DType::clone(apples.storage_dtype()),
517            ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))),
518        );
519        let expected_dtype = &ext_dtype;
520        let actual = storage_scalar_u64.cast(expected_dtype).unwrap();
521        assert_eq!(actual.dtype(), expected_dtype);
522
523        // cast from *incompatible* storage type to extension
524        let apples_u8 = ExtDType::new(
525            ExtID::new(Arc::from("apples")),
526            Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)),
527            None,
528        );
529        let expected_dtype = &DType::Extension(Arc::from(apples_u8));
530        let result = storage_scalar.cast(expected_dtype);
531        assert!(
532            result.as_ref().is_err_and(|err| {
533                err
534                    .to_string()
535                    .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8")
536            }),
537            "{:?}",
538            result
539        );
540    }
541}