Skip to main content

vortex_array/arrays/decimal/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hasher;
5
6use kernel::PARENT_KERNELS;
7use prost::Message;
8use vortex_buffer::Alignment;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::array::Array;
19use crate::array::ArrayView;
20use crate::array::VTable;
21use crate::arrays::decimal::DecimalData;
22use crate::buffer::BufferHandle;
23use crate::dtype::DType;
24use crate::dtype::DecimalType;
25use crate::dtype::NativeDecimalType;
26use crate::match_each_decimal_value_type;
27use crate::serde::ArrayChildren;
28use crate::validity::Validity;
29mod kernel;
30mod operations;
31mod validity;
32
33use std::hash::Hash;
34
35use crate::Precision;
36use crate::array::ArrayId;
37use crate::arrays::decimal::array::SLOT_NAMES;
38use crate::arrays::decimal::compute::rules::RULES;
39use crate::hash::ArrayEq;
40use crate::hash::ArrayHash;
41/// A [`Decimal`]-encoded Vortex array.
42pub type DecimalArray = Array<Decimal>;
43
44// The type of the values can be determined by looking at the type info...right?
45#[derive(prost::Message)]
46pub struct DecimalMetadata {
47    #[prost(enumeration = "DecimalType", tag = "1")]
48    pub(super) values_type: i32,
49}
50
51impl ArrayHash for DecimalData {
52    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
53        self.values.array_hash(state, precision);
54        std::mem::discriminant(&self.values_type).hash(state);
55    }
56}
57
58impl ArrayEq for DecimalData {
59    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
60        self.values.array_eq(&other.values, precision) && self.values_type == other.values_type
61    }
62}
63
64impl VTable for Decimal {
65    type ArrayData = DecimalData;
66
67    type OperationsVTable = Self;
68    type ValidityVTable = Self;
69
70    fn id(&self) -> ArrayId {
71        Self::ID
72    }
73
74    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
75        1
76    }
77
78    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
79        match idx {
80            0 => array.values.clone(),
81            _ => vortex_panic!("DecimalArray buffer index {idx} out of bounds"),
82        }
83    }
84
85    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
86        match idx {
87            0 => Some("values".to_string()),
88            _ => None,
89        }
90    }
91
92    fn serialize(
93        array: ArrayView<'_, Self>,
94        _session: &VortexSession,
95    ) -> VortexResult<Option<Vec<u8>>> {
96        Ok(Some(
97            DecimalMetadata {
98                values_type: array.values_type() as i32,
99            }
100            .encode_to_vec(),
101        ))
102    }
103
104    fn validate(
105        &self,
106        data: &DecimalData,
107        dtype: &DType,
108        len: usize,
109        slots: &[Option<ArrayRef>],
110    ) -> VortexResult<()> {
111        let DType::Decimal(_, nullability) = dtype else {
112            vortex_bail!("Expected decimal dtype, got {dtype:?}");
113        };
114        vortex_ensure!(
115            data.len() == len,
116            InvalidArgument:
117            "DecimalArray length {} does not match outer length {}",
118            data.len(),
119            len
120        );
121        let validity = crate::array::child_to_validity(&slots[0], *nullability);
122        if let Some(validity_len) = validity.maybe_len() {
123            vortex_ensure!(
124                validity_len == len,
125                InvalidArgument:
126                "DecimalArray validity len {} does not match outer length {}",
127                validity_len,
128                len
129            );
130        }
131
132        Ok(())
133    }
134
135    fn deserialize(
136        &self,
137        dtype: &DType,
138        len: usize,
139        metadata: &[u8],
140
141        buffers: &[BufferHandle],
142        children: &dyn ArrayChildren,
143        _session: &VortexSession,
144    ) -> VortexResult<crate::array::ArrayParts<Self>> {
145        let metadata = DecimalMetadata::decode(metadata)?;
146        if buffers.len() != 1 {
147            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
148        }
149        let values = buffers[0].clone();
150
151        let validity = if children.is_empty() {
152            Validity::from(dtype.nullability())
153        } else if children.len() == 1 {
154            let validity = children.get(0, &Validity::DTYPE, len)?;
155            Validity::Array(validity)
156        } else {
157            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
158        };
159
160        let Some(decimal_dtype) = dtype.as_decimal_opt() else {
161            vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
162        };
163
164        let slots = DecimalData::make_slots(&validity, len);
165        let data = match_each_decimal_value_type!(metadata.values_type(), |D| {
166            // Check and reinterpret-cast the buffer
167            vortex_ensure!(
168                values.is_aligned_to(Alignment::of::<D>()),
169                "DecimalArray buffer not aligned for values type {:?}",
170                D::DECIMAL_TYPE
171            );
172            DecimalData::try_new_handle(values, metadata.values_type(), *decimal_dtype)
173        })?;
174        Ok(crate::array::ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
175    }
176
177    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
178        SLOT_NAMES[idx].to_string()
179    }
180
181    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
182        Ok(ExecutionResult::done(array))
183    }
184
185    fn reduce_parent(
186        array: ArrayView<'_, Self>,
187        parent: &ArrayRef,
188        child_idx: usize,
189    ) -> VortexResult<Option<ArrayRef>> {
190        RULES.evaluate(array, parent, child_idx)
191    }
192
193    fn execute_parent(
194        array: ArrayView<'_, Self>,
195        parent: &ArrayRef,
196        child_idx: usize,
197        ctx: &mut ExecutionCtx,
198    ) -> VortexResult<Option<ArrayRef>> {
199        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
200    }
201}
202
203#[derive(Clone, Debug)]
204pub struct Decimal;
205
206impl Decimal {
207    pub const ID: ArrayId = ArrayId::new_ref("vortex.decimal");
208}
209
210#[cfg(test)]
211mod tests {
212    use vortex_buffer::ByteBufferMut;
213    use vortex_buffer::buffer;
214    use vortex_session::registry::ReadContext;
215
216    use crate::ArrayContext;
217    use crate::IntoArray;
218    use crate::LEGACY_SESSION;
219    use crate::arrays::Decimal;
220    use crate::arrays::DecimalArray;
221    use crate::assert_arrays_eq;
222    use crate::dtype::DecimalDType;
223    use crate::serde::SerializeOptions;
224    use crate::serde::SerializedArray;
225    use crate::validity::Validity;
226
227    #[test]
228    fn test_array_serde() {
229        let array = DecimalArray::new(
230            buffer![100i128, 200i128, 300i128, 400i128, 500i128],
231            DecimalDType::new(10, 2),
232            Validity::NonNullable,
233        );
234        let dtype = array.dtype().clone();
235
236        let ctx = ArrayContext::empty();
237        let out = array
238            .into_array()
239            .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
240            .unwrap();
241        // Concat into a single buffer
242        let mut concat = ByteBufferMut::empty();
243        for buf in out {
244            concat.extend_from_slice(buf.as_ref());
245        }
246
247        let concat = concat.freeze();
248
249        let parts = SerializedArray::try_from(concat).unwrap();
250        let decoded = parts
251            .decode(&dtype, 5, &ReadContext::new(ctx.to_ids()), &LEGACY_SESSION)
252            .unwrap();
253        assert!(decoded.is::<Decimal>());
254    }
255
256    #[test]
257    fn test_nullable_decimal_serde_roundtrip() {
258        let array = DecimalArray::new(
259            buffer![1234567i32, 0i32, -9999999i32],
260            DecimalDType::new(7, 3),
261            Validity::from_iter([true, false, true]),
262        );
263        let dtype = array.dtype().clone();
264        let len = array.len();
265
266        let ctx = ArrayContext::empty();
267        let out = array
268            .clone()
269            .into_array()
270            .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
271            .unwrap();
272        let mut concat = ByteBufferMut::empty();
273        for buf in out {
274            concat.extend_from_slice(buf.as_ref());
275        }
276
277        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
278        let decoded = parts
279            .decode(
280                &dtype,
281                len,
282                &ReadContext::new(ctx.to_ids()),
283                &LEGACY_SESSION,
284            )
285            .unwrap();
286
287        assert_arrays_eq!(decoded, array);
288    }
289}