vortex_array/arrays/decimal/
serde.rs

1use vortex_buffer::{Alignment, Buffer, ByteBuffer};
2use vortex_dtype::{DType, DecimalDType};
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{DecimalValueType, NativeDecimalType, match_each_decimal_value_type};
5
6use super::{DecimalArray, DecimalEncoding};
7use crate::ProstMetadata;
8use crate::arrays::DecimalVTable;
9use crate::serde::ArrayChildren;
10use crate::validity::Validity;
11use crate::vtable::SerdeVTable;
12
13// The type of the values can be determined by looking at the type info...right?
14#[derive(prost::Message)]
15pub struct DecimalMetadata {
16    #[prost(enumeration = "DecimalValueType", tag = "1")]
17    pub(super) values_type: i32,
18}
19
20impl SerdeVTable<DecimalVTable> for DecimalVTable {
21    type Metadata = ProstMetadata<DecimalMetadata>;
22
23    fn metadata(array: &DecimalArray) -> VortexResult<Option<Self::Metadata>> {
24        Ok(Some(ProstMetadata(DecimalMetadata {
25            values_type: array.values_type() as i32,
26        })))
27    }
28
29    fn build(
30        _encoding: &DecimalEncoding,
31        dtype: &DType,
32        len: usize,
33        metadata: &DecimalMetadata,
34        buffers: &[ByteBuffer],
35        children: &dyn ArrayChildren,
36    ) -> VortexResult<DecimalArray> {
37        if buffers.len() != 1 {
38            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
39        }
40        let buffer = buffers[0].clone();
41
42        let validity = if children.is_empty() {
43            Validity::from(dtype.nullability())
44        } else if children.len() == 1 {
45            let validity = children.get(0, &Validity::DTYPE, len)?;
46            Validity::Array(validity)
47        } else {
48            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
49        };
50
51        let Some(decimal_dtype) = dtype.as_decimal() else {
52            vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
53        };
54
55        match_each_decimal_value_type!(metadata.values_type(), |D| {
56            check_and_build_decimal::<D>(len, buffer, *decimal_dtype, validity)
57        })
58    }
59}
60
61fn check_and_build_decimal<T: NativeDecimalType>(
62    array_len: usize,
63    buffer: ByteBuffer,
64    decimal_dtype: DecimalDType,
65    validity: Validity,
66) -> VortexResult<DecimalArray> {
67    // Assuming 16-byte alignment for decimal values
68    if !buffer.is_aligned(Alignment::of::<T>()) {
69        vortex_bail!("Buffer is not aligned to 16-byte boundary");
70    }
71
72    let buffer = Buffer::<T>::from_byte_buffer(buffer);
73    if buffer.len() != array_len {
74        vortex_bail!(
75            "Buffer length {} does not match expected length {} for decimal values",
76            buffer.len(),
77            array_len,
78        );
79    }
80
81    Ok(DecimalArray::new(buffer, decimal_dtype, validity))
82}
83
84#[cfg(test)]
85mod tests {
86    use vortex_buffer::{ByteBufferMut, buffer};
87
88    use super::*;
89    use crate::serde::{ArrayParts, SerializeOptions};
90    use crate::{ArrayContext, EncodingRef, IntoArray};
91
92    #[test]
93    fn test_array_serde() {
94        let array = DecimalArray::new(
95            buffer![100i128, 200i128, 300i128, 400i128, 500i128],
96            DecimalDType::new(10, 2),
97            Validity::NonNullable,
98        );
99        let dtype = array.dtype().clone();
100        let ctx = ArrayContext::empty().with(EncodingRef::new_ref(DecimalEncoding.as_ref()));
101        let out = array
102            .into_array()
103            .serialize(&ctx, &SerializeOptions::default())
104            .unwrap();
105        // Concat into a single buffer
106        let mut concat = ByteBufferMut::empty();
107        for buf in out {
108            concat.extend(buf.as_ref());
109        }
110
111        let concat = concat.freeze();
112
113        let parts = ArrayParts::try_from(concat).unwrap();
114
115        let decoded = parts.decode(&ctx, &dtype, 5).unwrap();
116        assert_eq!(decoded.encoding_id(), DecimalEncoding.id());
117    }
118}