Skip to main content

vortex_array/scalar/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Scalar casting between [`DType`]s.
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10
11use crate::dtype::DType;
12use crate::scalar::Scalar;
13
14impl Scalar {
15    /// Cast this scalar to another data type.
16    ///
17    /// # Errors
18    ///
19    /// Returns an error if the cast is not supported or if a null value is cast to a non-nullable
20    /// type.
21    pub fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
22        // If the types are the same, return a clone.
23        if self.dtype() == target_dtype {
24            return Ok(self.clone());
25        }
26
27        // Check for solely nullability casting.
28        if self.dtype().eq_ignore_nullability(target_dtype) {
29            // Cast from non-nullable to nullable or vice versa.
30            // The `try_new` will handle nullability checks.
31            return Scalar::try_new(target_dtype.clone(), self.value().cloned());
32        }
33
34        // Null can be cast into any nullable type as null.
35        // Note that the `matches` clause is technically unnecessary here, just protective.
36        if self.value().is_none() || matches!(self.dtype(), DType::Null) {
37            vortex_ensure!(
38                target_dtype.is_nullable(),
39                "Cannot cast null to {target_dtype}: target type is non-nullable"
40            );
41
42            return Scalar::try_new(target_dtype.clone(), self.value().cloned());
43        }
44
45        // TODO(connor): This isn't really correct for extension types.
46        // If the target is an extension type, then we want to cast to its storage type.
47        if let Some(ext_dtype) = target_dtype.as_extension_opt() {
48            let cast_storage_scalar_value = self.cast(ext_dtype.storage_dtype())?.into_value();
49            return Scalar::try_new(target_dtype.clone(), cast_storage_scalar_value);
50        }
51
52        match &self.dtype() {
53            DType::Null => unreachable!("Handled by the if case above"),
54            DType::Bool(_) => self.as_bool().cast(target_dtype),
55            DType::Primitive(..) => self.as_primitive().cast(target_dtype),
56            DType::Decimal(..) => self.as_decimal().cast(target_dtype),
57            DType::Utf8(_) => self.as_utf8().cast(target_dtype),
58            DType::Binary(_) => self.as_binary().cast(target_dtype),
59            DType::Struct(..) => self.as_struct().cast(target_dtype),
60            DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target_dtype),
61            DType::Extension(..) => self.as_extension().cast(target_dtype),
62            DType::Variant(_) => vortex_bail!("Variant scalars can't be cast to {target_dtype}"),
63        }
64    }
65
66    /// Cast the scalar into a nullable version of its current type.
67    pub fn into_nullable(self) -> Scalar {
68        let (dtype, value) = self.into_parts();
69        Self::try_new(dtype.as_nullable(), value)
70            .vortex_expect("Casting to nullable should always succeed")
71    }
72}