vortex_array/arrays/decimal/
serde.rs

1use vortex_buffer::{Alignment, Buffer, ByteBuffer};
2use vortex_dtype::{DType, DecimalDType};
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::i256;
5
6use super::{DecimalArray, DecimalEncoding};
7use crate::arrays::NativeDecimalType;
8use crate::serde::ArrayParts;
9use crate::validity::Validity;
10use crate::vtable::EncodingVTable;
11use crate::{
12    Array, ArrayContext, ArrayRef, Canonical, DeserializeMetadata, EncodingId, ProstMetadata,
13};
14
15/// Type of the decimal values.
16#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq)]
17#[repr(u8)]
18#[non_exhaustive]
19pub enum DecimalValueType {
20    I8 = 0,
21    I16 = 1,
22    I32 = 2,
23    I64 = 3,
24    I128 = 4,
25    I256 = 5,
26}
27
28// The type of the values can be determined by looking at the type info...right?
29#[derive(prost::Message)]
30pub struct DecimalMetadata {
31    #[prost(enumeration = "DecimalValueType", tag = "1")]
32    pub(super) values_type: i32,
33}
34
35impl EncodingVTable for DecimalEncoding {
36    fn id(&self) -> EncodingId {
37        EncodingId::new_ref("vortex.decimal")
38    }
39
40    fn decode(
41        &self,
42        parts: &ArrayParts,
43        ctx: &ArrayContext,
44        dtype: DType,
45        len: usize,
46    ) -> VortexResult<ArrayRef> {
47        if parts.nbuffers() != 1 {
48            vortex_bail!("Expected 1 buffer, got {}", parts.nbuffers());
49        }
50        let buffer = parts.buffer(0)?;
51
52        let validity = if parts.nchildren() == 0 {
53            Validity::from(dtype.nullability())
54        } else if parts.nchildren() == 1 {
55            let validity = parts.child(0).decode(ctx, Validity::DTYPE, len)?;
56            Validity::Array(validity)
57        } else {
58            vortex_bail!("Expected 0 or 1 child, got {}", parts.nchildren());
59        };
60
61        let decimal_dtype = match &dtype {
62            DType::Decimal(decimal_dtype, _) => *decimal_dtype,
63            _ => vortex_bail!("Expected Decimal dtype, got {:?}", dtype),
64        };
65
66        let metadata = ProstMetadata::<DecimalMetadata>::deserialize(parts.metadata())?;
67        match metadata.values_type() {
68            DecimalValueType::I8 => {
69                check_and_build_decimal::<i8>(len, buffer, decimal_dtype, validity)
70            }
71            DecimalValueType::I16 => {
72                check_and_build_decimal::<i16>(len, buffer, decimal_dtype, validity)
73            }
74            DecimalValueType::I32 => {
75                check_and_build_decimal::<i32>(len, buffer, decimal_dtype, validity)
76            }
77            DecimalValueType::I64 => {
78                check_and_build_decimal::<i64>(len, buffer, decimal_dtype, validity)
79            }
80            DecimalValueType::I128 => {
81                check_and_build_decimal::<i128>(len, buffer, decimal_dtype, validity)
82            }
83            DecimalValueType::I256 => {
84                check_and_build_decimal::<i256>(len, buffer, decimal_dtype, validity)
85            }
86        }
87    }
88
89    fn encode(
90        &self,
91        input: &Canonical,
92        _like: Option<&dyn Array>,
93    ) -> VortexResult<Option<ArrayRef>> {
94        Ok(Some(input.clone().into_decimal()?.into_array()))
95    }
96}
97
98fn check_and_build_decimal<T: NativeDecimalType>(
99    array_len: usize,
100    buffer: ByteBuffer,
101    decimal_dtype: DecimalDType,
102    validity: Validity,
103) -> VortexResult<ArrayRef> {
104    // Assuming 16-byte alignment for decimal values
105    if !buffer.is_aligned(Alignment::of::<T>()) {
106        vortex_bail!("Buffer is not aligned to 16-byte boundary");
107    }
108
109    let buffer = Buffer::<T>::from_byte_buffer(buffer);
110    if buffer.len() != array_len {
111        vortex_bail!(
112            "Buffer length {} does not match expected length {} for decimal values",
113            buffer.len(),
114            array_len,
115        );
116    }
117
118    Ok(DecimalArray::new(buffer, decimal_dtype, validity).into_array())
119}
120
121#[macro_export]
122macro_rules! match_each_decimal_value {
123    ($self:expr, | $_:tt $value:ident | $($body:tt)*) => ({
124        macro_rules! __with__ {( $_ $value:ident ) => ( $($body)* )}
125        macro_rules! __with__ {( $_ $value:ident ) => ( $($body)* )}
126        match $self {
127            DecimalValue::I8(v) => __with__! { v },
128            DecimalValue::I16(v) => __with__! { v },
129            DecimalValue::I32(v) => __with__! { v },
130            DecimalValue::I64(v) => __with__! { v },
131            DecimalValue::I128(v) => __with__! { v },
132            DecimalValue::I256(v) => __with__! { v },
133        }
134    });
135}
136
137/// Macro to match over each decimal value type, binding the corresponding native type (from `DecimalValueType`)
138#[macro_export]
139macro_rules! match_each_decimal_value_type {
140    ($self:expr, | $_:tt $enc:ident | $($body:tt)*) => ({
141        macro_rules! __with__ {( $_ $enc:ident ) => ( $($body)* )}
142        use $crate::arrays::DecimalValueType;
143        use vortex_scalar::i256;
144        match $self {
145            DecimalValueType::I8 => __with__! { i8 },
146            DecimalValueType::I16 => __with__! { i16 },
147            DecimalValueType::I32 => __with__! { i32 },
148            DecimalValueType::I64 => __with__! { i64 },
149            DecimalValueType::I128 => __with__! { i128 },
150            DecimalValueType::I256 => __with__! { i256 },
151        }
152    });
153    ($self:expr, | ($_0:tt $enc:ident, $_1:tt $dv_path:ident) | $($body:tt)*) => ({
154        macro_rules! __with2__ { ( $_0 $enc:ident, $_1 $dv_path:ident ) => ( $($body)* ) }
155        use $crate::arrays::DecimalValueType;
156        use vortex_scalar::i256;
157        use vortex_scalar::DecimalValue::*;
158
159        match $self {
160            DecimalValueType::I8 => __with2__! { i8, I8 },
161            DecimalValueType::I16 => __with2__! { i16, I16 },
162            DecimalValueType::I32 => __with2__! { i32, I32 },
163            DecimalValueType::I64 => __with2__! { i64, I64 },
164            DecimalValueType::I128 => __with2__! { i128, I128 },
165            DecimalValueType::I256 => __with2__! { i256, I256 },
166        }
167    });
168}
169
170#[cfg(test)]
171mod tests {
172    use vortex_buffer::{ByteBufferMut, buffer};
173
174    use super::*;
175    use crate::Encoding;
176    use crate::serde::SerializeOptions;
177
178    #[test]
179    fn test_array_serde() {
180        let array = DecimalArray::new(
181            buffer![100i128, 200i128, 300i128, 400i128, 500i128],
182            DecimalDType::new(10, 2),
183            Validity::NonNullable,
184        );
185        let dtype = array.dtype().clone();
186        let ctx = ArrayContext::empty().with(DecimalEncoding.vtable());
187        let out = array
188            .into_array()
189            .serialize(&ctx, &SerializeOptions::default());
190        // Concat into a single buffer
191        let mut concat = ByteBufferMut::empty();
192        for buf in out {
193            concat.extend(buf.as_ref());
194        }
195
196        let concat = concat.freeze();
197
198        let parts = ArrayParts::try_from(concat).unwrap();
199
200        let decoded = parts.decode(&ctx, dtype, 5).unwrap();
201        assert_eq!(decoded.encoding(), DecimalEncoding.id());
202    }
203}