vortex_array/arrays/decimal/vtable/
mod.rs1use std::hash::Hasher;
5
6use kernel::PARENT_KERNELS;
7use prost::Message;
8use vortex_buffer::Alignment;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::array::Array;
19use crate::array::ArrayView;
20use crate::array::VTable;
21use crate::arrays::decimal::DecimalData;
22use crate::buffer::BufferHandle;
23use crate::dtype::DType;
24use crate::dtype::DecimalType;
25use crate::dtype::NativeDecimalType;
26use crate::match_each_decimal_value_type;
27use crate::serde::ArrayChildren;
28use crate::validity::Validity;
29mod kernel;
30mod operations;
31mod validity;
32
33use std::hash::Hash;
34
35use vortex_session::registry::CachedId;
36
37use crate::Precision;
38use crate::array::ArrayId;
39use crate::arrays::decimal::array::SLOT_NAMES;
40use crate::arrays::decimal::compute::rules::RULES;
41use crate::hash::ArrayEq;
42use crate::hash::ArrayHash;
43pub type DecimalArray = Array<Decimal>;
45
46#[derive(prost::Message)]
48pub struct DecimalMetadata {
49 #[prost(enumeration = "DecimalType", tag = "1")]
50 pub(super) values_type: i32,
51}
52
53impl ArrayHash for DecimalData {
54 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
55 self.values.array_hash(state, precision);
56 std::mem::discriminant(&self.values_type).hash(state);
57 }
58}
59
60impl ArrayEq for DecimalData {
61 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
62 self.values.array_eq(&other.values, precision) && self.values_type == other.values_type
63 }
64}
65
66impl VTable for Decimal {
67 type ArrayData = DecimalData;
68
69 type OperationsVTable = Self;
70 type ValidityVTable = Self;
71
72 fn id(&self) -> ArrayId {
73 static ID: CachedId = CachedId::new("vortex.decimal");
74 *ID
75 }
76
77 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
78 1
79 }
80
81 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
82 match idx {
83 0 => array.values.clone(),
84 _ => vortex_panic!("DecimalArray buffer index {idx} out of bounds"),
85 }
86 }
87
88 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
89 match idx {
90 0 => Some("values".to_string()),
91 _ => None,
92 }
93 }
94
95 fn serialize(
96 array: ArrayView<'_, Self>,
97 _session: &VortexSession,
98 ) -> VortexResult<Option<Vec<u8>>> {
99 Ok(Some(
100 DecimalMetadata {
101 values_type: array.values_type() as i32,
102 }
103 .encode_to_vec(),
104 ))
105 }
106
107 fn validate(
108 &self,
109 data: &DecimalData,
110 dtype: &DType,
111 len: usize,
112 slots: &[Option<ArrayRef>],
113 ) -> VortexResult<()> {
114 let DType::Decimal(_, nullability) = dtype else {
115 vortex_bail!("Expected decimal dtype, got {dtype:?}");
116 };
117 vortex_ensure!(
118 data.len() == len,
119 InvalidArgument:
120 "DecimalArray length {} does not match outer length {}",
121 data.len(),
122 len
123 );
124 let validity = crate::array::child_to_validity(&slots[0], *nullability);
125 if let Some(validity_len) = validity.maybe_len() {
126 vortex_ensure!(
127 validity_len == len,
128 InvalidArgument:
129 "DecimalArray validity len {} does not match outer length {}",
130 validity_len,
131 len
132 );
133 }
134
135 Ok(())
136 }
137
138 fn deserialize(
139 &self,
140 dtype: &DType,
141 len: usize,
142 metadata: &[u8],
143
144 buffers: &[BufferHandle],
145 children: &dyn ArrayChildren,
146 _session: &VortexSession,
147 ) -> VortexResult<crate::array::ArrayParts<Self>> {
148 let metadata = DecimalMetadata::decode(metadata)?;
149 if buffers.len() != 1 {
150 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
151 }
152 let values = buffers[0].clone();
153
154 let validity = if children.is_empty() {
155 Validity::from(dtype.nullability())
156 } else if children.len() == 1 {
157 let validity = children.get(0, &Validity::DTYPE, len)?;
158 Validity::Array(validity)
159 } else {
160 vortex_bail!("Expected 0 or 1 child, got {}", children.len());
161 };
162
163 let Some(decimal_dtype) = dtype.as_decimal_opt() else {
164 vortex_bail!("Expected Decimal dtype, got {:?}", dtype)
165 };
166
167 let slots = DecimalData::make_slots(&validity, len);
168 let data = match_each_decimal_value_type!(metadata.values_type(), |D| {
169 vortex_ensure!(
171 values.is_aligned_to(Alignment::of::<D>()),
172 "DecimalArray buffer not aligned for values type {:?}",
173 D::DECIMAL_TYPE
174 );
175 DecimalData::try_new_handle(values, metadata.values_type(), *decimal_dtype)
176 })?;
177 Ok(crate::array::ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
178 }
179
180 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
181 SLOT_NAMES[idx].to_string()
182 }
183
184 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
185 Ok(ExecutionResult::done(array))
186 }
187
188 fn reduce_parent(
189 array: ArrayView<'_, Self>,
190 parent: &ArrayRef,
191 child_idx: usize,
192 ) -> VortexResult<Option<ArrayRef>> {
193 RULES.evaluate(array, parent, child_idx)
194 }
195
196 fn execute_parent(
197 array: ArrayView<'_, Self>,
198 parent: &ArrayRef,
199 child_idx: usize,
200 ctx: &mut ExecutionCtx,
201 ) -> VortexResult<Option<ArrayRef>> {
202 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
203 }
204}
205
206#[derive(Clone, Debug)]
207pub struct Decimal;
208
209#[cfg(test)]
210mod tests {
211 use vortex_buffer::ByteBufferMut;
212 use vortex_buffer::buffer;
213 use vortex_session::registry::ReadContext;
214
215 use crate::ArrayContext;
216 use crate::IntoArray;
217 use crate::LEGACY_SESSION;
218 use crate::arrays::Decimal;
219 use crate::arrays::DecimalArray;
220 use crate::assert_arrays_eq;
221 use crate::dtype::DecimalDType;
222 use crate::serde::SerializeOptions;
223 use crate::serde::SerializedArray;
224 use crate::validity::Validity;
225
226 #[test]
227 fn test_array_serde() {
228 let array = DecimalArray::new(
229 buffer![100i128, 200i128, 300i128, 400i128, 500i128],
230 DecimalDType::new(10, 2),
231 Validity::NonNullable,
232 );
233 let dtype = array.dtype().clone();
234
235 let ctx = ArrayContext::empty();
236 let out = array
237 .into_array()
238 .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
239 .unwrap();
240 let mut concat = ByteBufferMut::empty();
242 for buf in out {
243 concat.extend_from_slice(buf.as_ref());
244 }
245
246 let concat = concat.freeze();
247
248 let parts = SerializedArray::try_from(concat).unwrap();
249 let decoded = parts
250 .decode(&dtype, 5, &ReadContext::new(ctx.to_ids()), &LEGACY_SESSION)
251 .unwrap();
252 assert!(decoded.is::<Decimal>());
253 }
254
255 #[test]
256 fn test_nullable_decimal_serde_roundtrip() {
257 let array = DecimalArray::new(
258 buffer![1234567i32, 0i32, -9999999i32],
259 DecimalDType::new(7, 3),
260 Validity::from_iter([true, false, true]),
261 );
262 let dtype = array.dtype().clone();
263 let len = array.len();
264
265 let ctx = ArrayContext::empty();
266 let out = array
267 .clone()
268 .into_array()
269 .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
270 .unwrap();
271 let mut concat = ByteBufferMut::empty();
272 for buf in out {
273 concat.extend_from_slice(buf.as_ref());
274 }
275
276 let parts = SerializedArray::try_from(concat.freeze()).unwrap();
277 let decoded = parts
278 .decode(
279 &dtype,
280 len,
281 &ReadContext::new(ctx.to_ids()),
282 &LEGACY_SESSION,
283 )
284 .unwrap();
285
286 assert_arrays_eq!(decoded, array);
287 }
288}