vortex_compute/cast/
dvector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::DType;
6use vortex_dtype::DecimalType;
7use vortex_dtype::NativeDecimalType;
8use vortex_dtype::PrecisionScale;
9use vortex_dtype::match_each_decimal_value_type;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_vector::Scalar;
14use vortex_vector::ScalarOps;
15use vortex_vector::Vector;
16use vortex_vector::VectorOps;
17use vortex_vector::decimal::DScalar;
18use vortex_vector::decimal::DVector;
19
20use crate::cast::Cast;
21use crate::cast::try_cast_scalar_common;
22use crate::cast::try_cast_vector_common;
23
24impl<D: NativeDecimalType> Cast for DVector<D> {
25    type Output = Vector;
26
27    /// Casts to Decimal with potentially different precision and native type.
28    fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
29        if let Some(result) = try_cast_vector_common(self, target_dtype)? {
30            return Ok(result);
31        }
32
33        let DType::Decimal(ddt, n) = target_dtype else {
34            vortex_bail!("Cannot cast DVector to {}", target_dtype);
35        };
36
37        // Check nullability compatibility
38        if !n.is_nullable() && !self.validity().all_true() {
39            vortex_bail!(
40                "Cannot cast nullable DVector to non-nullable {}",
41                target_dtype
42            );
43        }
44
45        // Scale changes require multiplication/division by powers of 10
46        if ddt.scale() != self.scale() {
47            vortex_bail!(
48                "Casting DVector with scale {} to scale {} not yet implemented",
49                self.scale(),
50                ddt.scale()
51            );
52        }
53
54        // If the precision is the same, it's an identity cast
55        if ddt.precision() == self.precision() {
56            return Ok(self.clone().into());
57        }
58
59        // If the precision is wider, we may need to upcast the underlying type
60        if ddt.precision() > self.precision() {
61            // Need to upcast to a wider type
62            let target_type = DecimalType::smallest_decimal_value_type(ddt);
63            match_each_decimal_value_type!(target_type, |T| {
64                return upcast_dvector::<D, T>(self, ddt.precision());
65            })
66        }
67
68        // TODO(ngates): we need to rebuild the vector as that will validate all values
69        //  fit into the precision / scale.
70        vortex_bail!(
71            "Downcasting DVector from precision {} to {} not yet implemented",
72            self.precision(),
73            ddt.precision()
74        );
75    }
76}
77
78/// Upcast a DVector<D> to DVector<T> where T is wider than D.
79fn upcast_dvector<D: NativeDecimalType, T: NativeDecimalType>(
80    source: &DVector<D>,
81    target_precision: u8,
82) -> VortexResult<Vector> {
83    let target_ps = PrecisionScale::<T>::try_new(target_precision, source.scale())?;
84
85    // Upcast each element using BigCast. This should never fail since T is wider than D.
86    let elements: Buffer<T> = source
87        .elements()
88        .iter()
89        .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
90        .collect();
91
92    let validity = source.validity().clone();
93
94    // SAFETY: We've upcasted from a narrower type, so all values fit.
95    Ok(unsafe { DVector::new_unchecked(target_ps, elements, validity) }.into())
96}
97
98impl<D: NativeDecimalType> Cast for DScalar<D> {
99    type Output = Scalar;
100
101    /// Casts to Decimal (identity with same precision/scale and compatible nullability).
102    fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
103        if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
104            return Ok(result);
105        }
106
107        match target_dtype {
108            // Identity cast: same precision, scale, and compatible nullability.
109            DType::Decimal(ddt, n)
110                if ddt.precision() == self.precision()
111                    && ddt.scale() == self.scale()
112                    && (n.is_nullable() || self.is_valid()) =>
113            {
114                Ok(self.clone().into())
115            }
116            // TODO(connor): cast to different precision/scale
117            DType::Decimal(..) => {
118                vortex_bail!(
119                    "Casting DScalar to {} with different precision/scale not yet implemented",
120                    target_dtype
121                );
122            }
123            _ => {
124                vortex_bail!("Cannot cast DScalar to {}", target_dtype);
125            }
126        }
127    }
128}