vortex_scalar/
extension.rs

1use std::hash::Hash;
2use std::sync::Arc;
3
4use vortex_dtype::{DType, ExtDType};
5use vortex_error::{VortexError, VortexResult, vortex_bail};
6
7use crate::{Scalar, ScalarValue};
8
9pub struct ExtScalar<'a> {
10    ext_dtype: &'a ExtDType,
11    value: &'a ScalarValue,
12}
13
14impl PartialEq for ExtScalar<'_> {
15    fn eq(&self, other: &Self) -> bool {
16        self.ext_dtype.eq_ignore_nullability(other.ext_dtype) && self.storage() == other.storage()
17    }
18}
19
20impl Eq for ExtScalar<'_> {}
21
22// Ord is not implemented since it's undefined for different Extension DTypes
23impl PartialOrd for ExtScalar<'_> {
24    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
25        if !self.ext_dtype.eq_ignore_nullability(other.ext_dtype) {
26            return None;
27        }
28        self.storage().partial_cmp(&other.storage())
29    }
30}
31
32impl Hash for ExtScalar<'_> {
33    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
34        self.ext_dtype.hash(state);
35        self.storage().hash(state);
36    }
37}
38
39impl<'a> ExtScalar<'a> {
40    pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
41        let DType::Extension(ext_dtype) = dtype else {
42            vortex_bail!("Expected extension scalar, found {}", dtype)
43        };
44
45        Ok(Self { ext_dtype, value })
46    }
47
48    /// Returns the storage scalar of the extension scalar.
49    pub fn storage(&self) -> Scalar {
50        Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
51    }
52
53    pub fn ext_dtype(&self) -> &'a ExtDType {
54        self.ext_dtype
55    }
56
57    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
58        if self.value.is_null() && !dtype.is_nullable() {
59            vortex_bail!(
60                "cannot cast extension dtype with id {} and storage type {} to {}",
61                self.ext_dtype.id(),
62                self.ext_dtype.storage_dtype(),
63                dtype
64            );
65        }
66
67        if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
68            // Casting from an extension type to the underlying storage type is OK.
69            return Ok(Scalar::new(dtype.clone(), self.value.clone()));
70        }
71
72        if let DType::Extension(ext_dtype) = dtype {
73            if self.ext_dtype.eq_ignore_nullability(ext_dtype) {
74                return Ok(Scalar::new(dtype.clone(), self.value.clone()));
75            }
76        }
77
78        vortex_bail!(
79            "cannot cast extension dtype with id {} and storage type {} to {}",
80            self.ext_dtype.id(),
81            self.ext_dtype.storage_dtype(),
82            dtype
83        );
84    }
85}
86
87impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
88    type Error = VortexError;
89
90    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
91        ExtScalar::try_new(value.dtype(), &value.value)
92    }
93}
94
95impl Scalar {
96    pub fn extension(ext_dtype: Arc<ExtDType>, value: Scalar) -> Self {
97        Self {
98            dtype: DType::Extension(ext_dtype),
99            value: value.value().clone(),
100        }
101    }
102}