vortex_scalar/
extension.rs

1use std::fmt::{Display, Formatter};
2use std::hash::Hash;
3use std::sync::Arc;
4
5use vortex_dtype::datetime::{TemporalMetadata, is_temporal_ext_type};
6use vortex_dtype::{DType, ExtDType};
7use vortex_error::{VortexError, VortexResult, vortex_bail};
8
9use crate::{Scalar, ScalarValue};
10
11pub struct ExtScalar<'a> {
12    ext_dtype: &'a ExtDType,
13    value: &'a ScalarValue,
14}
15
16impl Display for ExtScalar<'_> {
17    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18        // Specialized handling for date/time/timestamp builtin extension types.
19        if is_temporal_ext_type(self.ext_dtype().id()) {
20            let metadata =
21                TemporalMetadata::try_from(self.ext_dtype()).map_err(|_| std::fmt::Error)?;
22
23            let maybe_timestamp = self
24                .storage()
25                .as_primitive()
26                .as_::<i64>()
27                .and_then(|maybe_timestamp| {
28                    maybe_timestamp.map(|v| metadata.to_jiff(v)).transpose()
29                })
30                .map_err(|_| std::fmt::Error)?;
31
32            match maybe_timestamp {
33                None => write!(f, "null"),
34                Some(v) => write!(f, "{v}"),
35            }
36        } else {
37            write!(f, "{}({})", self.ext_dtype().id(), self.storage())
38        }
39    }
40}
41
42impl PartialEq for ExtScalar<'_> {
43    fn eq(&self, other: &Self) -> bool {
44        self.ext_dtype.eq_ignore_nullability(other.ext_dtype) && self.storage() == other.storage()
45    }
46}
47
48impl Eq for ExtScalar<'_> {}
49
50// Ord is not implemented since it's undefined for different Extension DTypes
51impl PartialOrd for ExtScalar<'_> {
52    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
53        if !self.ext_dtype.eq_ignore_nullability(other.ext_dtype) {
54            return None;
55        }
56        self.storage().partial_cmp(&other.storage())
57    }
58}
59
60impl Hash for ExtScalar<'_> {
61    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
62        self.ext_dtype.hash(state);
63        self.storage().hash(state);
64    }
65}
66
67impl<'a> ExtScalar<'a> {
68    pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
69        let DType::Extension(ext_dtype) = dtype else {
70            vortex_bail!("Expected extension scalar, found {}", dtype)
71        };
72
73        Ok(Self { ext_dtype, value })
74    }
75
76    /// Returns the storage scalar of the extension scalar.
77    pub fn storage(&self) -> Scalar {
78        Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
79    }
80
81    pub fn ext_dtype(&self) -> &'a ExtDType {
82        self.ext_dtype
83    }
84
85    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
86        if self.value.is_null() && !dtype.is_nullable() {
87            vortex_bail!(
88                "cannot cast extension dtype with id {} and storage type {} to {}",
89                self.ext_dtype.id(),
90                self.ext_dtype.storage_dtype(),
91                dtype
92            );
93        }
94
95        if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
96            // Casting from an extension type to the underlying storage type is OK.
97            return Ok(Scalar::new(dtype.clone(), self.value.clone()));
98        }
99
100        if let DType::Extension(ext_dtype) = dtype {
101            if self.ext_dtype.eq_ignore_nullability(ext_dtype) {
102                return Ok(Scalar::new(dtype.clone(), self.value.clone()));
103            }
104        }
105
106        vortex_bail!(
107            "cannot cast extension dtype with id {} and storage type {} to {}",
108            self.ext_dtype.id(),
109            self.ext_dtype.storage_dtype(),
110            dtype
111        );
112    }
113}
114
115impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
116    type Error = VortexError;
117
118    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
119        ExtScalar::try_new(value.dtype(), &value.value)
120    }
121}
122
123impl Scalar {
124    pub fn extension(ext_dtype: Arc<ExtDType>, value: Scalar) -> Self {
125        Self {
126            dtype: DType::Extension(ext_dtype),
127            value: value.value().clone(),
128        }
129    }
130}