Skip to main content

vortex_array/arrays/decimal/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hasher;
5
6use prost::Message;
7use vortex_buffer::Alignment;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_session::VortexSession;
13
14use crate::ArrayParts;
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::array::Array;
19use crate::array::ArrayView;
20use crate::array::VTable;
21use crate::arrays::decimal::DecimalData;
22use crate::buffer::BufferHandle;
23use crate::dtype::DType;
24use crate::dtype::DecimalType;
25use crate::dtype::NativeDecimalType;
26use crate::match_each_decimal_value_type;
27use crate::serde::ArrayChildren;
28use crate::validity::Validity;
29mod kernel;
30mod operations;
31mod validity;
32
33use std::hash::Hash;
34
35use vortex_session::registry::CachedId;
36
37use crate::EqMode;
38use crate::array::ArrayId;
39use crate::arrays::decimal::array::SLOT_NAMES;
40use crate::arrays::decimal::compute::rules::RULES;
41use crate::hash::ArrayEq;
42use crate::hash::ArrayHash;
43/// A [`Decimal`]-encoded Vortex array.
44pub type DecimalArray = Array<Decimal>;
45
46pub(crate) fn initialize(session: &VortexSession) {
47    kernel::initialize(session);
48}
49
50// The type of the values can be determined by looking at the type info...right?
51#[derive(prost::Message)]
52pub struct DecimalMetadata {
53    #[prost(enumeration = "DecimalType", tag = "1")]
54    pub(super) values_type: i32,
55}
56
57impl ArrayHash for DecimalData {
58    fn array_hash<H: Hasher>(&self, state: &mut H, accuracy: EqMode) {
59        self.values.array_hash(state, accuracy);
60        std::mem::discriminant(&self.values_type).hash(state);
61    }
62}
63
64impl ArrayEq for DecimalData {
65    fn array_eq(&self, other: &Self, accuracy: EqMode) -> bool {
66        self.values.array_eq(&other.values, accuracy) && self.values_type == other.values_type
67    }
68}
69
70impl VTable for Decimal {
71    type TypedArrayData = DecimalData;
72
73    type OperationsVTable = Self;
74    type ValidityVTable = Self;
75
76    fn id(&self) -> ArrayId {
77        static ID: CachedId = CachedId::new("vortex.decimal");
78        *ID
79    }
80
81    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
82        1
83    }
84
85    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
86        match idx {
87            0 => array.values.clone(),
88            _ => vortex_panic!("DecimalArray buffer index {idx} out of bounds"),
89        }
90    }
91
92    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
93        match idx {
94            0 => Some("values".to_string()),
95            _ => None,
96        }
97    }
98
99    fn with_buffers(
100        &self,
101        array: ArrayView<'_, Self>,
102        buffers: &[BufferHandle],
103    ) -> VortexResult<ArrayParts<Self>> {
104        vortex_ensure!(
105            buffers.len() == 1,
106            "Expected 1 buffer, got {}",
107            buffers.len()
108        );
109        let mut data = array.data().clone();
110        data.values = buffers[0].clone();
111        Ok(
112            ArrayParts::new(self.clone(), array.dtype().clone(), array.len(), data)
113                .with_slots(array.slots().iter().cloned().collect()),
114        )
115    }
116
117    fn serialize(
118        array: ArrayView<'_, Self>,
119        _session: &VortexSession,
120    ) -> VortexResult<Option<Vec<u8>>> {
121        Ok(Some(
122            DecimalMetadata {
123                values_type: array.values_type() as i32,
124            }
125            .encode_to_vec(),
126        ))
127    }
128
129    fn validate(
130        &self,
131        data: &DecimalData,
132        dtype: &DType,
133        len: usize,
134        slots: &[Option<ArrayRef>],
135    ) -> VortexResult<()> {
136        let DType::Decimal(_, nullability) = dtype else {
137            vortex_bail!("Expected decimal dtype, got {dtype:?}");
138        };
139        vortex_ensure!(
140            data.len() == len,
141            InvalidArgument:
142            "DecimalArray length {} does not match outer length {}",
143            data.len(),
144            len
145        );
146        let validity = crate::array::child_to_validity(slots[0].as_ref(), *nullability);
147        if let Some(validity_len) = validity.maybe_len() {
148            vortex_ensure!(
149                validity_len == len,
150                InvalidArgument:
151                "DecimalArray validity len {} does not match outer length {}",
152                validity_len,
153                len
154            );
155        }
156
157        Ok(())
158    }
159
160    fn deserialize(
161        &self,
162        dtype: &DType,
163        len: usize,
164        metadata: &[u8],
165
166        buffers: &[BufferHandle],
167        children: &dyn ArrayChildren,
168        _session: &VortexSession,
169    ) -> VortexResult<ArrayParts<Self>> {
170        let metadata = DecimalMetadata::decode(metadata)?;
171        if buffers.len() != 1 {
172            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
173        }
174        let values = buffers[0].clone();
175
176        let validity = if children.is_empty() {
177            Validity::from(dtype.nullability())
178        } else if children.len() == 1 {
179            let validity = children.get(0, &Validity::DTYPE, len)?;
180            Validity::Array(validity)
181        } else {
182            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
183        };
184
185        let Some(decimal_dtype) = dtype.as_decimal_opt() else {
186            vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
187        };
188
189        let slots = DecimalData::make_slots(&validity, len);
190        let data = match_each_decimal_value_type!(metadata.values_type(), |D| {
191            // Check and reinterpret-cast the buffer
192            vortex_ensure!(
193                values.is_aligned_to(Alignment::of::<D>()),
194                "DecimalArray buffer not aligned for values type {:?}",
195                D::DECIMAL_TYPE
196            );
197            DecimalData::try_new_handle(values, metadata.values_type(), *decimal_dtype)
198        })?;
199        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
200    }
201
202    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
203        SLOT_NAMES[idx].to_string()
204    }
205
206    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
207        Ok(ExecutionResult::done(array))
208    }
209
210    fn reduce_parent(
211        array: ArrayView<'_, Self>,
212        parent: &ArrayRef,
213        child_idx: usize,
214    ) -> VortexResult<Option<ArrayRef>> {
215        RULES.evaluate(array, parent, child_idx)
216    }
217}
218
219#[derive(Clone, Debug)]
220pub struct Decimal;
221
222#[cfg(test)]
223mod tests {
224    use vortex_buffer::ByteBufferMut;
225    use vortex_buffer::buffer;
226    use vortex_session::registry::ReadContext;
227
228    use crate::ArrayContext;
229    use crate::IntoArray;
230    use crate::VortexSessionExecute;
231    use crate::array_session;
232    use crate::arrays::Decimal;
233    use crate::arrays::DecimalArray;
234    use crate::assert_arrays_eq;
235    use crate::dtype::DecimalDType;
236    use crate::serde::SerializeOptions;
237    use crate::serde::SerializedArray;
238    use crate::validity::Validity;
239
240    #[test]
241    fn test_array_serde() {
242        let session = array_session();
243        let array = DecimalArray::new(
244            buffer![100i128, 200i128, 300i128, 400i128, 500i128],
245            DecimalDType::new(10, 2),
246            Validity::NonNullable,
247        );
248        let dtype = array.dtype().clone();
249
250        let array_ctx = ArrayContext::empty();
251        let out = array
252            .into_array()
253            .serialize(&array_ctx, &session, &SerializeOptions::default())
254            .unwrap();
255        // Concat into a single buffer
256        let mut concat = ByteBufferMut::empty();
257        for buf in out {
258            concat.extend_from_slice(buf.as_ref());
259        }
260
261        let concat = concat.freeze();
262
263        let parts = SerializedArray::try_from(concat).unwrap();
264        let decoded = parts
265            .decode(&dtype, 5, &ReadContext::new(array_ctx.to_ids()), &session)
266            .unwrap();
267        assert!(decoded.is::<Decimal>());
268    }
269
270    #[test]
271    fn test_nullable_decimal_serde_roundtrip() {
272        let session = array_session();
273        let mut ctx = session.create_execution_ctx();
274        let array = DecimalArray::new(
275            buffer![1234567i32, 0i32, -9999999i32],
276            DecimalDType::new(7, 3),
277            Validity::from_iter([true, false, true]),
278        );
279        let dtype = array.dtype().clone();
280        let len = array.len();
281
282        let array_ctx = ArrayContext::empty();
283        let out = array
284            .clone()
285            .into_array()
286            .serialize(&array_ctx, &session, &SerializeOptions::default())
287            .unwrap();
288        let mut concat = ByteBufferMut::empty();
289        for buf in out {
290            concat.extend_from_slice(buf.as_ref());
291        }
292
293        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
294        let decoded = parts
295            .decode(&dtype, len, &ReadContext::new(array_ctx.to_ids()), &session)
296            .unwrap();
297
298        assert_arrays_eq!(decoded, array, &mut ctx);
299    }
300}