vortex_array/arrays/decimal/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Alignment;
5use vortex_buffer::Buffer;
6use vortex_buffer::BufferHandle;
7use vortex_dtype::DType;
8use vortex_dtype::NativeDecimalType;
9use vortex_dtype::PrecisionScale;
10use vortex_dtype::match_each_decimal_value_type;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_scalar::DecimalType;
15use vortex_vector::Vector;
16use vortex_vector::decimal::DVector;
17
18use crate::DeserializeMetadata;
19use crate::ProstMetadata;
20use crate::SerializeMetadata;
21use crate::arrays::DecimalArray;
22use crate::execution::ExecutionCtx;
23use crate::serde::ArrayChildren;
24use crate::validity::Validity;
25use crate::vtable;
26use crate::vtable::ArrayVTableExt;
27use crate::vtable::NotSupported;
28use crate::vtable::VTable;
29use crate::vtable::ValidityVTableFromValidityHelper;
30
31mod array;
32mod canonical;
33mod operations;
34pub mod operator;
35mod validity;
36mod visitor;
37
38pub use operator::DecimalMaskedValidityRule;
39
40use crate::vtable::ArrayId;
41use crate::vtable::ArrayVTable;
42
43vtable!(Decimal);
44
45// The type of the values can be determined by looking at the type info...right?
46#[derive(prost::Message)]
47pub struct DecimalMetadata {
48    #[prost(enumeration = "DecimalType", tag = "1")]
49    pub(super) values_type: i32,
50}
51
52impl VTable for DecimalVTable {
53    type Array = DecimalArray;
54
55    type Metadata = ProstMetadata<DecimalMetadata>;
56
57    type ArrayVTable = Self;
58    type CanonicalVTable = Self;
59    type OperationsVTable = Self;
60    type ValidityVTable = ValidityVTableFromValidityHelper;
61    type VisitorVTable = Self;
62    type ComputeVTable = NotSupported;
63    type EncodeVTable = NotSupported;
64
65    fn id(&self) -> ArrayId {
66        ArrayId::new_ref("vortex.decimal")
67    }
68
69    fn encoding(_array: &Self::Array) -> ArrayVTable {
70        DecimalVTable.as_vtable()
71    }
72
73    fn metadata(array: &DecimalArray) -> VortexResult<Self::Metadata> {
74        Ok(ProstMetadata(DecimalMetadata {
75            values_type: array.values_type() as i32,
76        }))
77    }
78
79    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
80        Ok(Some(metadata.serialize()))
81    }
82
83    fn deserialize(bytes: &[u8]) -> VortexResult<Self::Metadata> {
84        let metadata = ProstMetadata::<DecimalMetadata>::deserialize(bytes)?;
85        Ok(ProstMetadata(metadata))
86    }
87
88    fn build(
89        &self,
90        dtype: &DType,
91        len: usize,
92        metadata: &Self::Metadata,
93        buffers: &[BufferHandle],
94        children: &dyn ArrayChildren,
95    ) -> VortexResult<DecimalArray> {
96        if buffers.len() != 1 {
97            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
98        }
99        let buffer = buffers[0].clone().try_to_bytes()?;
100
101        let validity = if children.is_empty() {
102            Validity::from(dtype.nullability())
103        } else if children.len() == 1 {
104            let validity = children.get(0, &Validity::DTYPE, len)?;
105            Validity::Array(validity)
106        } else {
107            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
108        };
109
110        let Some(decimal_dtype) = dtype.as_decimal_opt() else {
111            vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
112        };
113
114        match_each_decimal_value_type!(metadata.values_type(), |D| {
115            // Check and reinterpret-cast the buffer
116            vortex_ensure!(
117                buffer.is_aligned(Alignment::of::<D>()),
118                "DecimalArray buffer not aligned for values type {:?}",
119                D::DECIMAL_TYPE
120            );
121            let buffer = Buffer::<D>::from_byte_buffer(buffer);
122            DecimalArray::try_new::<D>(buffer, *decimal_dtype, validity)
123        })
124    }
125
126    fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
127        match_each_decimal_value_type!(array.values_type(), |D| {
128            Ok(unsafe {
129                DVector::<D>::new_unchecked(
130                    PrecisionScale::new_unchecked(array.precision(), array.scale()),
131                    array.buffer::<D>(),
132                    array.validity_mask(),
133                )
134            }
135            .into())
136        })
137    }
138}
139
140#[derive(Debug)]
141pub struct DecimalVTable;
142
143#[cfg(test)]
144mod tests {
145    use vortex_buffer::ByteBufferMut;
146    use vortex_buffer::buffer;
147    use vortex_dtype::DecimalDType;
148
149    use crate::ArrayContext;
150    use crate::IntoArray;
151    use crate::arrays::DecimalArray;
152    use crate::arrays::DecimalVTable;
153    use crate::serde::ArrayParts;
154    use crate::serde::SerializeOptions;
155    use crate::validity::Validity;
156    use crate::vtable::ArrayVTableExt;
157
158    #[test]
159    fn test_array_serde() {
160        let array = DecimalArray::new(
161            buffer![100i128, 200i128, 300i128, 400i128, 500i128],
162            DecimalDType::new(10, 2),
163            Validity::NonNullable,
164        );
165        let dtype = array.dtype().clone();
166        let ctx = ArrayContext::empty().with(DecimalVTable.as_vtable());
167        let out = array
168            .into_array()
169            .serialize(&ctx, &SerializeOptions::default())
170            .unwrap();
171        // Concat into a single buffer
172        let mut concat = ByteBufferMut::empty();
173        for buf in out {
174            concat.extend_from_slice(buf.as_ref());
175        }
176
177        let concat = concat.freeze();
178
179        let parts = ArrayParts::try_from(concat).unwrap();
180
181        let decoded = parts.decode(&ctx, &dtype, 5).unwrap();
182        assert!(decoded.is::<DecimalVTable>());
183    }
184}