vortex_array/arrays/decimal/vtable/
mod.rs1use std::hash::Hasher;
5
6use prost::Message;
7use vortex_buffer::Alignment;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_session::VortexSession;
13
14use crate::ArrayParts;
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::array::Array;
19use crate::array::ArrayView;
20use crate::array::VTable;
21use crate::arrays::decimal::DecimalData;
22use crate::buffer::BufferHandle;
23use crate::dtype::DType;
24use crate::dtype::DecimalType;
25use crate::dtype::NativeDecimalType;
26use crate::match_each_decimal_value_type;
27use crate::serde::ArrayChildren;
28use crate::validity::Validity;
29mod kernel;
30mod operations;
31mod validity;
32
33use std::hash::Hash;
34
35use vortex_session::registry::CachedId;
36
37use crate::EqMode;
38use crate::array::ArrayId;
39use crate::arrays::decimal::array::SLOT_NAMES;
40use crate::arrays::decimal::compute::rules::RULES;
41use crate::hash::ArrayEq;
42use crate::hash::ArrayHash;
43pub type DecimalArray = Array<Decimal>;
45
46pub(crate) fn initialize(session: &VortexSession) {
47 kernel::initialize(session);
48}
49
50#[derive(prost::Message)]
52pub struct DecimalMetadata {
53 #[prost(enumeration = "DecimalType", tag = "1")]
54 pub(super) values_type: i32,
55}
56
57impl ArrayHash for DecimalData {
58 fn array_hash<H: Hasher>(&self, state: &mut H, accuracy: EqMode) {
59 self.values.array_hash(state, accuracy);
60 std::mem::discriminant(&self.values_type).hash(state);
61 }
62}
63
64impl ArrayEq for DecimalData {
65 fn array_eq(&self, other: &Self, accuracy: EqMode) -> bool {
66 self.values.array_eq(&other.values, accuracy) && self.values_type == other.values_type
67 }
68}
69
70impl VTable for Decimal {
71 type TypedArrayData = DecimalData;
72
73 type OperationsVTable = Self;
74 type ValidityVTable = Self;
75
76 fn id(&self) -> ArrayId {
77 static ID: CachedId = CachedId::new("vortex.decimal");
78 *ID
79 }
80
81 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
82 1
83 }
84
85 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
86 match idx {
87 0 => array.values.clone(),
88 _ => vortex_panic!("DecimalArray buffer index {idx} out of bounds"),
89 }
90 }
91
92 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
93 match idx {
94 0 => Some("values".to_string()),
95 _ => None,
96 }
97 }
98
99 fn with_buffers(
100 &self,
101 array: ArrayView<'_, Self>,
102 buffers: &[BufferHandle],
103 ) -> VortexResult<ArrayParts<Self>> {
104 vortex_ensure!(
105 buffers.len() == 1,
106 "Expected 1 buffer, got {}",
107 buffers.len()
108 );
109 let mut data = array.data().clone();
110 data.values = buffers[0].clone();
111 Ok(
112 ArrayParts::new(self.clone(), array.dtype().clone(), array.len(), data)
113 .with_slots(array.slots().iter().cloned().collect()),
114 )
115 }
116
117 fn serialize(
118 array: ArrayView<'_, Self>,
119 _session: &VortexSession,
120 ) -> VortexResult<Option<Vec<u8>>> {
121 Ok(Some(
122 DecimalMetadata {
123 values_type: array.values_type() as i32,
124 }
125 .encode_to_vec(),
126 ))
127 }
128
129 fn validate(
130 &self,
131 data: &DecimalData,
132 dtype: &DType,
133 len: usize,
134 slots: &[Option<ArrayRef>],
135 ) -> VortexResult<()> {
136 let DType::Decimal(_, nullability) = dtype else {
137 vortex_bail!("Expected decimal dtype, got {dtype:?}");
138 };
139 vortex_ensure!(
140 data.len() == len,
141 InvalidArgument:
142 "DecimalArray length {} does not match outer length {}",
143 data.len(),
144 len
145 );
146 let validity = crate::array::child_to_validity(slots[0].as_ref(), *nullability);
147 if let Some(validity_len) = validity.maybe_len() {
148 vortex_ensure!(
149 validity_len == len,
150 InvalidArgument:
151 "DecimalArray validity len {} does not match outer length {}",
152 validity_len,
153 len
154 );
155 }
156
157 Ok(())
158 }
159
160 fn deserialize(
161 &self,
162 dtype: &DType,
163 len: usize,
164 metadata: &[u8],
165
166 buffers: &[BufferHandle],
167 children: &dyn ArrayChildren,
168 _session: &VortexSession,
169 ) -> VortexResult<ArrayParts<Self>> {
170 let metadata = DecimalMetadata::decode(metadata)?;
171 if buffers.len() != 1 {
172 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
173 }
174 let values = buffers[0].clone();
175
176 let validity = if children.is_empty() {
177 Validity::from(dtype.nullability())
178 } else if children.len() == 1 {
179 let validity = children.get(0, &Validity::DTYPE, len)?;
180 Validity::Array(validity)
181 } else {
182 vortex_bail!("Expected 0 or 1 child, got {}", children.len());
183 };
184
185 let Some(decimal_dtype) = dtype.as_decimal_opt() else {
186 vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
187 };
188
189 let slots = DecimalData::make_slots(&validity, len);
190 let data = match_each_decimal_value_type!(metadata.values_type(), |D| {
191 vortex_ensure!(
193 values.is_aligned_to(Alignment::of::<D>()),
194 "DecimalArray buffer not aligned for values type {:?}",
195 D::DECIMAL_TYPE
196 );
197 DecimalData::try_new_handle(values, metadata.values_type(), *decimal_dtype)
198 })?;
199 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
200 }
201
202 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
203 SLOT_NAMES[idx].to_string()
204 }
205
206 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
207 Ok(ExecutionResult::done(array))
208 }
209
210 fn reduce_parent(
211 array: ArrayView<'_, Self>,
212 parent: &ArrayRef,
213 child_idx: usize,
214 ) -> VortexResult<Option<ArrayRef>> {
215 RULES.evaluate(array, parent, child_idx)
216 }
217}
218
219#[derive(Clone, Debug)]
220pub struct Decimal;
221
222#[cfg(test)]
223mod tests {
224 use vortex_buffer::ByteBufferMut;
225 use vortex_buffer::buffer;
226 use vortex_session::registry::ReadContext;
227
228 use crate::ArrayContext;
229 use crate::IntoArray;
230 use crate::VortexSessionExecute;
231 use crate::array_session;
232 use crate::arrays::Decimal;
233 use crate::arrays::DecimalArray;
234 use crate::assert_arrays_eq;
235 use crate::dtype::DecimalDType;
236 use crate::serde::SerializeOptions;
237 use crate::serde::SerializedArray;
238 use crate::validity::Validity;
239
240 #[test]
241 fn test_array_serde() {
242 let session = array_session();
243 let array = DecimalArray::new(
244 buffer![100i128, 200i128, 300i128, 400i128, 500i128],
245 DecimalDType::new(10, 2),
246 Validity::NonNullable,
247 );
248 let dtype = array.dtype().clone();
249
250 let array_ctx = ArrayContext::empty();
251 let out = array
252 .into_array()
253 .serialize(&array_ctx, &session, &SerializeOptions::default())
254 .unwrap();
255 let mut concat = ByteBufferMut::empty();
257 for buf in out {
258 concat.extend_from_slice(buf.as_ref());
259 }
260
261 let concat = concat.freeze();
262
263 let parts = SerializedArray::try_from(concat).unwrap();
264 let decoded = parts
265 .decode(&dtype, 5, &ReadContext::new(array_ctx.to_ids()), &session)
266 .unwrap();
267 assert!(decoded.is::<Decimal>());
268 }
269
270 #[test]
271 fn test_nullable_decimal_serde_roundtrip() {
272 let session = array_session();
273 let mut ctx = session.create_execution_ctx();
274 let array = DecimalArray::new(
275 buffer![1234567i32, 0i32, -9999999i32],
276 DecimalDType::new(7, 3),
277 Validity::from_iter([true, false, true]),
278 );
279 let dtype = array.dtype().clone();
280 let len = array.len();
281
282 let array_ctx = ArrayContext::empty();
283 let out = array
284 .clone()
285 .into_array()
286 .serialize(&array_ctx, &session, &SerializeOptions::default())
287 .unwrap();
288 let mut concat = ByteBufferMut::empty();
289 for buf in out {
290 concat.extend_from_slice(buf.as_ref());
291 }
292
293 let parts = SerializedArray::try_from(concat.freeze()).unwrap();
294 let decoded = parts
295 .decode(&dtype, len, &ReadContext::new(array_ctx.to_ids()), &session)
296 .unwrap();
297
298 assert_arrays_eq!(decoded, array, &mut ctx);
299 }
300}