vortex_array/arrays/decimal/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::{Alignment, Buffer, ByteBuffer};
5use vortex_dtype::{DType, DecimalDType};
6use vortex_error::{VortexResult, vortex_bail};
7use vortex_scalar::{DecimalValueType, NativeDecimalType, match_each_decimal_value_type};
8
9use super::{DecimalArray, DecimalEncoding};
10use crate::ProstMetadata;
11use crate::arrays::DecimalVTable;
12use crate::serde::ArrayChildren;
13use crate::validity::Validity;
14use crate::vtable::SerdeVTable;
15
16// The type of the values can be determined by looking at the type info...right?
17#[derive(prost::Message)]
18pub struct DecimalMetadata {
19    #[prost(enumeration = "DecimalValueType", tag = "1")]
20    pub(super) values_type: i32,
21}
22
23impl SerdeVTable<DecimalVTable> for DecimalVTable {
24    type Metadata = ProstMetadata<DecimalMetadata>;
25
26    fn metadata(array: &DecimalArray) -> VortexResult<Option<Self::Metadata>> {
27        Ok(Some(ProstMetadata(DecimalMetadata {
28            values_type: array.values_type() as i32,
29        })))
30    }
31
32    fn build(
33        _encoding: &DecimalEncoding,
34        dtype: &DType,
35        len: usize,
36        metadata: &DecimalMetadata,
37        buffers: &[ByteBuffer],
38        children: &dyn ArrayChildren,
39    ) -> VortexResult<DecimalArray> {
40        if buffers.len() != 1 {
41            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
42        }
43        let buffer = buffers[0].clone();
44
45        let validity = if children.is_empty() {
46            Validity::from(dtype.nullability())
47        } else if children.len() == 1 {
48            let validity = children.get(0, &Validity::DTYPE, len)?;
49            Validity::Array(validity)
50        } else {
51            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
52        };
53
54        let Some(decimal_dtype) = dtype.as_decimal() else {
55            vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
56        };
57
58        match_each_decimal_value_type!(metadata.values_type(), |D| {
59            check_and_build_decimal::<D>(len, buffer, *decimal_dtype, validity)
60        })
61    }
62}
63
64fn check_and_build_decimal<T: NativeDecimalType>(
65    array_len: usize,
66    buffer: ByteBuffer,
67    decimal_dtype: DecimalDType,
68    validity: Validity,
69) -> VortexResult<DecimalArray> {
70    // Assuming 16-byte alignment for decimal values
71    if !buffer.is_aligned(Alignment::of::<T>()) {
72        vortex_bail!("Buffer is not aligned to 16-byte boundary");
73    }
74
75    let buffer = Buffer::<T>::from_byte_buffer(buffer);
76    if buffer.len() != array_len {
77        vortex_bail!(
78            "Buffer length {} does not match expected length {} for decimal values",
79            buffer.len(),
80            array_len,
81        );
82    }
83
84    Ok(DecimalArray::new(buffer, decimal_dtype, validity))
85}
86
87#[cfg(test)]
88mod tests {
89    use vortex_buffer::{ByteBufferMut, buffer};
90
91    use super::*;
92    use crate::serde::{ArrayParts, SerializeOptions};
93    use crate::{ArrayContext, EncodingRef, IntoArray};
94
95    #[test]
96    fn test_array_serde() {
97        let array = DecimalArray::new(
98            buffer![100i128, 200i128, 300i128, 400i128, 500i128],
99            DecimalDType::new(10, 2),
100            Validity::NonNullable,
101        );
102        let dtype = array.dtype().clone();
103        let ctx = ArrayContext::empty().with(EncodingRef::new_ref(DecimalEncoding.as_ref()));
104        let out = array
105            .into_array()
106            .serialize(&ctx, &SerializeOptions::default())
107            .unwrap();
108        // Concat into a single buffer
109        let mut concat = ByteBufferMut::empty();
110        for buf in out {
111            concat.extend_from_slice(buf.as_ref());
112        }
113
114        let concat = concat.freeze();
115
116        let parts = ArrayParts::try_from(concat).unwrap();
117
118        let decoded = parts.decode(&ctx, &dtype, 5).unwrap();
119        assert_eq!(decoded.encoding_id(), DecimalEncoding.id());
120    }
121}