vortex_scalar/
extension.rsuse std::hash::Hash;
use std::sync::Arc;
use vortex_dtype::{DType, ExtDType};
use vortex_error::{vortex_bail, VortexError, VortexResult};
use crate::{Scalar, ScalarValue};
pub struct ExtScalar<'a> {
ext_dtype: &'a ExtDType,
value: &'a ScalarValue,
}
impl PartialEq for ExtScalar<'_> {
fn eq(&self, other: &Self) -> bool {
self.ext_dtype == other.ext_dtype && self.storage() == other.storage()
}
}
impl Eq for ExtScalar<'_> {}
impl PartialOrd for ExtScalar<'_> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
if self.ext_dtype != other.ext_dtype {
return None;
}
self.storage().partial_cmp(&other.storage())
}
}
impl Hash for ExtScalar<'_> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.ext_dtype.hash(state);
self.storage().hash(state);
}
}
impl<'a> ExtScalar<'a> {
pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult<Self> {
let DType::Extension(ext_dtype) = dtype else {
vortex_bail!("Expected extension scalar, found {}", dtype)
};
Ok(Self { ext_dtype, value })
}
pub fn storage(&self) -> Scalar {
Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone())
}
pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
if self.value.is_null() && !dtype.is_nullable() {
vortex_bail!(
"cannot cast extension dtype with id {} and storage type {} to {}",
self.ext_dtype.id(),
self.ext_dtype.storage_dtype(),
dtype
);
}
if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) {
return Ok(Scalar::new(dtype.clone(), self.value.clone()));
}
if let DType::Extension(ext_dtype) = dtype {
if self.ext_dtype.eq_ignore_nullability(ext_dtype) {
return Ok(Scalar::new(dtype.clone(), self.value.clone()));
}
}
vortex_bail!(
"cannot cast extension dtype with id {} and storage type {} to {}",
self.ext_dtype.id(),
self.ext_dtype.storage_dtype(),
dtype
);
}
}
impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
type Error = VortexError;
fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
ExtScalar::try_new(value.dtype(), &value.value)
}
}
impl Scalar {
pub fn extension(ext_dtype: Arc<ExtDType>, value: Scalar) -> Self {
Self {
dtype: DType::Extension(ext_dtype),
value: value.value().clone(),
}
}
}