vortex_array/arrays/decimal/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::{Alignment, Buffer, ByteBuffer};
5use vortex_dtype::DType;
6#[cfg(test)]
7use vortex_dtype::DecimalDType;
8use vortex_error::{VortexResult, vortex_bail, vortex_ensure};
9use vortex_scalar::{DecimalValueType, NativeDecimalType, match_each_decimal_value_type};
10
11use super::{DecimalArray, DecimalEncoding};
12use crate::ProstMetadata;
13use crate::arrays::DecimalVTable;
14use crate::serde::ArrayChildren;
15use crate::validity::Validity;
16use crate::vtable::SerdeVTable;
17
18// The type of the values can be determined by looking at the type info...right?
19#[derive(prost::Message)]
20pub struct DecimalMetadata {
21    #[prost(enumeration = "DecimalValueType", tag = "1")]
22    pub(super) values_type: i32,
23}
24
25impl SerdeVTable<DecimalVTable> for DecimalVTable {
26    type Metadata = ProstMetadata<DecimalMetadata>;
27
28    fn metadata(array: &DecimalArray) -> VortexResult<Option<Self::Metadata>> {
29        Ok(Some(ProstMetadata(DecimalMetadata {
30            values_type: array.values_type() as i32,
31        })))
32    }
33
34    fn build(
35        _encoding: &DecimalEncoding,
36        dtype: &DType,
37        len: usize,
38        metadata: &DecimalMetadata,
39        buffers: &[ByteBuffer],
40        children: &dyn ArrayChildren,
41    ) -> VortexResult<DecimalArray> {
42        if buffers.len() != 1 {
43            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
44        }
45        let buffer = buffers[0].clone();
46
47        let validity = if children.is_empty() {
48            Validity::from(dtype.nullability())
49        } else if children.len() == 1 {
50            let validity = children.get(0, &Validity::DTYPE, len)?;
51            Validity::Array(validity)
52        } else {
53            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
54        };
55
56        let Some(decimal_dtype) = dtype.as_decimal_opt() else {
57            vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
58        };
59
60        match_each_decimal_value_type!(metadata.values_type(), |D| {
61            // Check and reinterpret-cast the buffer
62            vortex_ensure!(
63                buffer.is_aligned(Alignment::of::<D>()),
64                "DecimalArray buffer not aligned for values type {:?}",
65                D::VALUES_TYPE
66            );
67            let buffer = Buffer::<D>::from_byte_buffer(buffer);
68            DecimalArray::try_new::<D>(buffer, *decimal_dtype, validity)
69        })
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use vortex_buffer::{ByteBufferMut, buffer};
76
77    use super::*;
78    use crate::serde::{ArrayParts, SerializeOptions};
79    use crate::{ArrayContext, EncodingRef, IntoArray};
80
81    #[test]
82    fn test_array_serde() {
83        let array = DecimalArray::new(
84            buffer![100i128, 200i128, 300i128, 400i128, 500i128],
85            DecimalDType::new(10, 2),
86            Validity::NonNullable,
87        );
88        let dtype = array.dtype().clone();
89        let ctx = ArrayContext::empty().with(EncodingRef::new_ref(DecimalEncoding.as_ref()));
90        let out = array
91            .into_array()
92            .serialize(&ctx, &SerializeOptions::default())
93            .unwrap();
94        // Concat into a single buffer
95        let mut concat = ByteBufferMut::empty();
96        for buf in out {
97            concat.extend_from_slice(buf.as_ref());
98        }
99
100        let concat = concat.freeze();
101
102        let parts = ArrayParts::try_from(concat).unwrap();
103
104        let decoded = parts.decode(&ctx, &dtype, 5).unwrap();
105        assert_eq!(decoded.encoding_id(), DecimalEncoding.id());
106    }
107}