vortex_vector/decimal/
scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DecimalDType;
5use vortex_dtype::DecimalType;
6use vortex_dtype::DecimalTypeDowncast;
7use vortex_dtype::DecimalTypeUpcast;
8use vortex_dtype::NativeDecimalType;
9use vortex_dtype::PrecisionScale;
10use vortex_dtype::i256;
11use vortex_dtype::match_each_decimal_value_type;
12use vortex_error::VortexExpect;
13use vortex_error::vortex_panic;
14
15use crate::Scalar;
16use crate::ScalarOps;
17use crate::VectorMut;
18use crate::VectorMutOps;
19use crate::decimal::DVectorMut;
20
21/// Represents a decimal scalar value.
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum DecimalScalar {
24    /// 8-bit decimal scalar.
25    D8(DScalar<i8>),
26    /// 16-bit decimal scalar.
27    D16(DScalar<i16>),
28    /// 32-bit decimal scalar.
29    D32(DScalar<i32>),
30    /// 64-bit decimal scalar.
31    D64(DScalar<i64>),
32    /// 128-bit decimal scalar.
33    D128(DScalar<i128>),
34    /// 256-bit decimal scalar.
35    D256(DScalar<i256>),
36}
37
38impl DecimalScalar {
39    /// Returns the precision of the decimal scalar.
40    pub fn precision(&self) -> u8 {
41        match self {
42            DecimalScalar::D8(v) => v.ps.precision(),
43            DecimalScalar::D16(v) => v.ps.precision(),
44            DecimalScalar::D32(v) => v.ps.precision(),
45            DecimalScalar::D64(v) => v.ps.precision(),
46            DecimalScalar::D128(v) => v.ps.precision(),
47            DecimalScalar::D256(v) => v.ps.precision(),
48        }
49    }
50
51    /// Returns the scale of the decimal scalar.
52    pub fn scale(&self) -> i8 {
53        match self {
54            DecimalScalar::D8(v) => v.ps.scale(),
55            DecimalScalar::D16(v) => v.ps.scale(),
56            DecimalScalar::D32(v) => v.ps.scale(),
57            DecimalScalar::D64(v) => v.ps.scale(),
58            DecimalScalar::D128(v) => v.ps.scale(),
59            DecimalScalar::D256(v) => v.ps.scale(),
60        }
61    }
62
63    /// Creates a zero decimal scalar of the given [`DecimalDType`].
64    pub fn zero(decimal_dtype: &DecimalDType) -> Self {
65        let decimal_type = DecimalType::smallest_decimal_value_type(decimal_dtype);
66        match_each_decimal_value_type!(decimal_type, |D| {
67            DScalar::<D>::zero(decimal_dtype).into()
68        })
69    }
70
71    /// Creates a null decimal scalar of the given [`DecimalDType`].
72    pub fn null(decimal_dtype: &DecimalDType) -> Self {
73        let decimal_type = DecimalType::smallest_decimal_value_type(decimal_dtype);
74        match_each_decimal_value_type!(decimal_type, |D| {
75            DScalar::<D>::null(decimal_dtype).into()
76        })
77    }
78}
79
80impl ScalarOps for DecimalScalar {
81    fn is_valid(&self) -> bool {
82        match self {
83            DecimalScalar::D8(v) => v.is_valid(),
84            DecimalScalar::D16(v) => v.is_valid(),
85            DecimalScalar::D32(v) => v.is_valid(),
86            DecimalScalar::D64(v) => v.is_valid(),
87            DecimalScalar::D128(v) => v.is_valid(),
88            DecimalScalar::D256(v) => v.is_valid(),
89        }
90    }
91
92    fn mask_validity(&mut self, mask: bool) {
93        match self {
94            DecimalScalar::D8(v) => v.mask_validity(mask),
95            DecimalScalar::D16(v) => v.mask_validity(mask),
96            DecimalScalar::D32(v) => v.mask_validity(mask),
97            DecimalScalar::D64(v) => v.mask_validity(mask),
98            DecimalScalar::D128(v) => v.mask_validity(mask),
99            DecimalScalar::D256(v) => v.mask_validity(mask),
100        }
101    }
102
103    fn repeat(&self, n: usize) -> VectorMut {
104        match self {
105            DecimalScalar::D8(v) => v.repeat(n),
106            DecimalScalar::D16(v) => v.repeat(n),
107            DecimalScalar::D32(v) => v.repeat(n),
108            DecimalScalar::D64(v) => v.repeat(n),
109            DecimalScalar::D128(v) => v.repeat(n),
110            DecimalScalar::D256(v) => v.repeat(n),
111        }
112    }
113}
114
115impl From<DecimalScalar> for Scalar {
116    fn from(val: DecimalScalar) -> Self {
117        Scalar::Decimal(val)
118    }
119}
120
121/// Represents a decimal scalar value with a specific native decimal type.
122#[derive(Clone, Debug, PartialEq, Eq, Hash)]
123pub struct DScalar<D> {
124    ps: PrecisionScale<D>,
125    value: Option<D>,
126}
127
128impl<D: NativeDecimalType> DScalar<D> {
129    /// Creates a new decimal scalar with the given precision/scale and value.
130    ///
131    /// Returns `None` if the value is not valid for the given precision/scale.
132    pub fn maybe_new(ps: PrecisionScale<D>, value: Option<D>) -> Option<Self> {
133        Some(match value {
134            None => Self { ps, value: None },
135            Some(v) => {
136                if !ps.is_valid(v) {
137                    return None;
138                }
139                Self { ps, value: Some(v) }
140            }
141        })
142    }
143
144    /// Creates a new decimal scalar with the given precision/scale and value without validation.
145    ///
146    /// # Safety
147    ///
148    /// The caller must ensure that the value is valid for the given precision/scale.
149    pub unsafe fn new_unchecked(ps: PrecisionScale<D>, value: Option<D>) -> Self {
150        Self { ps, value }
151    }
152
153    /// Returns the value of the decimal scalar, or `None` if the scalar is null.
154    pub fn value(&self) -> Option<D> {
155        self.value
156    }
157
158    /// Get the precision/scale of the decimal scalar.
159    pub fn precision_scale(&self) -> PrecisionScale<D> {
160        self.ps
161    }
162
163    /// Returns the precision of the decimal scalar.
164    pub fn precision(&self) -> u8 {
165        self.ps.precision()
166    }
167
168    /// Returns the scale of the decimal scalar.
169    pub fn scale(&self) -> i8 {
170        self.ps.scale()
171    }
172
173    /// Creates a zero decimal scalar of the given [`DecimalDType`].
174    pub fn zero(decimal_dtype: &DecimalDType) -> Self {
175        let ps = PrecisionScale::<D>::new(decimal_dtype.precision(), decimal_dtype.scale());
176        // SAFETY: Zero (default) is always valid for any precision/scale.
177        unsafe { DScalar::<D>::new_unchecked(ps, Some(D::default())) }
178    }
179
180    /// Creates a null decimal scalar of the given [`DecimalDType`].
181    pub fn null(decimal_dtype: &DecimalDType) -> Self {
182        let ps = PrecisionScale::<D>::new(decimal_dtype.precision(), decimal_dtype.scale());
183        // SAFETY: None is always valid for any precision/scale.
184        unsafe { DScalar::<D>::new_unchecked(ps, None) }
185    }
186}
187
188impl<D: NativeDecimalType> ScalarOps for DScalar<D> {
189    fn is_valid(&self) -> bool {
190        self.value.is_some()
191    }
192
193    fn mask_validity(&mut self, mask: bool) {
194        if !mask {
195            self.value = None;
196        }
197    }
198
199    fn repeat(&self, n: usize) -> VectorMut {
200        let mut vec = DVectorMut::with_capacity(self.ps, n);
201        match &self.value {
202            None => vec.append_nulls(n),
203            Some(v) => vec.try_append_n(*v, n).vortex_expect("known to fit"),
204        }
205        vec.into()
206    }
207}
208
209impl<D: NativeDecimalType> From<DScalar<D>> for Scalar {
210    fn from(value: DScalar<D>) -> Self {
211        Scalar::Decimal(D::upcast(value))
212    }
213}
214
215impl<D: NativeDecimalType> From<DScalar<D>> for DecimalScalar {
216    fn from(value: DScalar<D>) -> Self {
217        D::upcast(value)
218    }
219}
220
221impl DecimalTypeUpcast for DecimalScalar {
222    type Input<T: NativeDecimalType> = DScalar<T>;
223
224    fn from_i8(input: Self::Input<i8>) -> Self {
225        DecimalScalar::D8(input)
226    }
227
228    fn from_i16(input: Self::Input<i16>) -> Self {
229        DecimalScalar::D16(input)
230    }
231
232    fn from_i32(input: Self::Input<i32>) -> Self {
233        DecimalScalar::D32(input)
234    }
235
236    fn from_i64(input: Self::Input<i64>) -> Self {
237        DecimalScalar::D64(input)
238    }
239
240    fn from_i128(input: Self::Input<i128>) -> Self {
241        DecimalScalar::D128(input)
242    }
243
244    fn from_i256(input: Self::Input<i256>) -> Self {
245        DecimalScalar::D256(input)
246    }
247}
248
249impl DecimalTypeDowncast for DecimalScalar {
250    type Output<T: NativeDecimalType> = DScalar<T>;
251
252    fn into_i8(self) -> Self::Output<i8> {
253        if let Self::D8(v) = self {
254            return v;
255        }
256        vortex_panic!("Expected DecimalScalar::D8, got {self:?}");
257    }
258
259    fn into_i16(self) -> Self::Output<i16> {
260        if let Self::D16(v) = self {
261            return v;
262        }
263        vortex_panic!("Expected DecimalScalar::D16, got {self:?}");
264    }
265
266    fn into_i32(self) -> Self::Output<i32> {
267        if let Self::D32(v) = self {
268            return v;
269        }
270        vortex_panic!("Expected DecimalScalar::D32, got {self:?}");
271    }
272
273    fn into_i64(self) -> Self::Output<i64> {
274        if let Self::D64(v) = self {
275            return v;
276        }
277        vortex_panic!("Expected DecimalScalar::D64, got {self:?}");
278    }
279
280    fn into_i128(self) -> Self::Output<i128> {
281        if let Self::D128(v) = self {
282            return v;
283        }
284        vortex_panic!("Expected DecimalScalar::D128, got {self:?}");
285    }
286
287    fn into_i256(self) -> Self::Output<i256> {
288        if let Self::D256(v) = self {
289            return v;
290        }
291        vortex_panic!("Expected DecimalScalar::D256, got {self:?}");
292    }
293}