1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use prost::Message as _;
10use vortex_array::Array;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayId;
14use vortex_array::ArrayParts;
15use vortex_array::ArrayRef;
16use vortex_array::ArrayView;
17use vortex_array::Canonical;
18use vortex_array::ExecutionCtx;
19use vortex_array::ExecutionResult;
20use vortex_array::IntoArray;
21use vortex_array::Precision;
22use vortex_array::array_slots;
23use vortex_array::buffer::BufferHandle;
24use vortex_array::builders::ArrayBuilder;
25use vortex_array::builders::VarBinViewBuilder;
26use vortex_array::dtype::DType;
27use vortex_array::dtype::Nullability;
28use vortex_array::dtype::PType;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::validity::Validity;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityVTable;
33use vortex_array::vtable::child_to_validity;
34use vortex_array::vtable::validity_to_child;
35use vortex_buffer::ByteBuffer;
36use vortex_error::VortexResult;
37use vortex_error::vortex_bail;
38use vortex_error::vortex_ensure;
39use vortex_error::vortex_err;
40use vortex_error::vortex_panic;
41use vortex_session::VortexSession;
42use vortex_session::registry::CachedId;
43
44use crate::canonical::canonicalize_onpair;
45use crate::canonical::onpair_decode_views;
46use crate::kernel::PARENT_KERNELS;
47use crate::rules::RULES;
48
49pub type OnPairArray = Array<OnPair>;
51
52#[derive(Clone, prost::Message)]
66pub struct OnPairMetadata {
67 #[prost(enumeration = "PType", tag = "1")]
69 pub uncompressed_lengths_ptype: i32,
70 #[prost(uint32, tag = "2")]
73 pub bits: u32,
74 #[prost(uint32, tag = "3")]
77 pub dict_size: u32,
78 #[prost(uint64, tag = "4")]
81 pub total_tokens: u64,
82 #[prost(enumeration = "PType", tag = "5")]
85 pub dict_offsets_ptype: i32,
86 #[prost(enumeration = "PType", tag = "6")]
89 pub codes_ptype: i32,
90 #[prost(enumeration = "PType", tag = "7")]
92 pub codes_offsets_ptype: i32,
93}
94
95impl OnPairMetadata {
96 pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
97 PType::try_from(self.uncompressed_lengths_ptype)
98 .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
99 }
100}
101
102#[array_slots(OnPair)]
103pub struct OnPairSlots {
104 pub dict_offsets: ArrayRef,
107 pub codes: ArrayRef,
111 pub codes_offsets: ArrayRef,
114 pub uncompressed_lengths: ArrayRef,
117 pub validity: Option<ArrayRef>,
119}
120
121#[derive(Clone)]
128pub struct OnPairData {
129 dict_bytes: BufferHandle,
142 bits: u32,
143 len: usize,
144}
145
146impl OnPairData {
147 pub fn new(dict_bytes: BufferHandle, bits: u32, len: usize) -> Self {
148 Self {
149 dict_bytes,
150 bits,
151 len,
152 }
153 }
154
155 pub fn len(&self) -> usize {
156 self.len
157 }
158
159 pub fn is_empty(&self) -> bool {
160 self.len == 0
161 }
162
163 pub fn bits(&self) -> u32 {
164 self.bits
165 }
166
167 pub fn dict_bytes(&self) -> &ByteBuffer {
168 self.dict_bytes.as_host()
169 }
170
171 pub fn dict_bytes_handle(&self) -> &BufferHandle {
172 &self.dict_bytes
173 }
174}
175
176impl Display for OnPairData {
177 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
178 write!(
179 f,
180 "len: {}, bits: {}, dict_bytes_len: {}",
181 self.len,
182 self.bits,
183 self.dict_bytes.len()
184 )
185 }
186}
187
188impl Debug for OnPairData {
189 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("OnPairData")
191 .field("len", &self.len)
192 .field("bits", &self.bits)
193 .field("dict_bytes_len", &self.dict_bytes.len())
194 .finish()
195 }
196}
197
198impl ArrayHash for OnPairData {
199 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
200 self.dict_bytes.as_host().array_hash(state, precision);
201 state.write_u32(self.bits);
202 }
203}
204
205impl ArrayEq for OnPairData {
206 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
207 self.bits == other.bits
208 && self
209 .dict_bytes
210 .as_host()
211 .array_eq(other.dict_bytes.as_host(), precision)
212 }
213}
214
215#[derive(Clone, Debug)]
217pub struct OnPair;
218
219impl OnPair {
220 #[expect(clippy::too_many_arguments, reason = "every child is a real input")]
222 pub fn try_new(
223 dtype: DType,
224 dict_bytes: BufferHandle,
225 dict_offsets: ArrayRef,
226 codes: ArrayRef,
227 codes_offsets: ArrayRef,
228 uncompressed_lengths: ArrayRef,
229 validity: Validity,
230 bits: u32,
231 ) -> VortexResult<OnPairArray> {
232 validate_parts(
233 &dtype,
234 &dict_offsets,
235 &codes,
236 &codes_offsets,
237 &uncompressed_lengths,
238 bits,
239 )?;
240 let len = uncompressed_lengths.len();
241 let data = OnPairData::new(dict_bytes, bits, len);
242 let slots = OnPairSlots {
243 dict_offsets,
244 codes,
245 codes_offsets,
246 uncompressed_lengths,
247 validity: validity_to_child(&validity, len),
248 }
249 .into_slots();
250 Ok(unsafe {
251 Array::from_parts_unchecked(ArrayParts::new(OnPair, dtype, len, data).with_slots(slots))
252 })
253 }
254
255 #[expect(clippy::too_many_arguments, reason = "every child is a real input")]
256 pub(crate) unsafe fn new_unchecked(
257 dtype: DType,
258 dict_bytes: BufferHandle,
259 dict_offsets: ArrayRef,
260 codes: ArrayRef,
261 codes_offsets: ArrayRef,
262 uncompressed_lengths: ArrayRef,
263 validity: Validity,
264 bits: u32,
265 ) -> OnPairArray {
266 let len = uncompressed_lengths.len();
267 let data = OnPairData::new(dict_bytes, bits, len);
268 let slots = OnPairSlots {
269 dict_offsets,
270 codes,
271 codes_offsets,
272 uncompressed_lengths,
273 validity: validity_to_child(&validity, len),
274 }
275 .into_slots();
276 unsafe {
277 Array::from_parts_unchecked(ArrayParts::new(OnPair, dtype, len, data).with_slots(slots))
278 }
279 }
280}
281
282fn validate_parts(
283 dtype: &DType,
284 dict_offsets: &ArrayRef,
285 codes: &ArrayRef,
286 codes_offsets: &ArrayRef,
287 uncompressed_lengths: &ArrayRef,
288 bits: u32,
289) -> VortexResult<()> {
290 vortex_ensure!(
291 matches!(dtype, DType::Binary(_) | DType::Utf8(_)),
292 "OnPair arrays must be Binary or Utf8, found {dtype}"
293 );
294 vortex_ensure!((9..=16).contains(&bits), "bits {bits} out of range [9, 16]");
295
296 if !dict_offsets.dtype().is_int() || dict_offsets.dtype().is_nullable() {
297 vortex_bail!(InvalidArgument: "dict_offsets must be non-nullable integer");
298 }
299 if !codes.dtype().is_int() || codes.dtype().is_nullable() {
300 vortex_bail!(InvalidArgument: "codes must be non-nullable integer");
301 }
302 if !codes_offsets.dtype().is_int() || codes_offsets.dtype().is_nullable() {
303 vortex_bail!(InvalidArgument: "codes_offsets must be non-nullable integer");
304 }
305 if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
306 vortex_bail!(InvalidArgument: "uncompressed_lengths must be non-nullable integer");
307 }
308 if codes_offsets.len() != uncompressed_lengths.len() + 1 {
309 vortex_bail!(InvalidArgument:
310 "codes_offsets.len ({}) != uncompressed_lengths.len + 1 ({})",
311 codes_offsets.len(),
312 uncompressed_lengths.len() + 1
313 );
314 }
315 Ok(())
316}
317
318impl VTable for OnPair {
319 type TypedArrayData = OnPairData;
320 type OperationsVTable = Self;
321 type ValidityVTable = Self;
322
323 fn id(&self) -> ArrayId {
324 static ID: CachedId = CachedId::new("vortex.onpair");
325 *ID
326 }
327
328 fn validate(
329 &self,
330 data: &Self::TypedArrayData,
331 dtype: &DType,
332 len: usize,
333 slots: &[Option<ArrayRef>],
334 ) -> VortexResult<()> {
335 let s = OnPairSlotsView::from_slots(slots);
336 validate_parts(
337 dtype,
338 s.dict_offsets,
339 s.codes,
340 s.codes_offsets,
341 s.uncompressed_lengths,
342 data.bits,
343 )?;
344 if s.uncompressed_lengths.len() != len {
345 vortex_bail!(InvalidArgument: "uncompressed_lengths must have same len as outer array");
346 }
347 if data.len != len {
348 vortex_bail!(InvalidArgument: "OnPairData len {} != outer len {}", data.len, len);
349 }
350 Ok(())
351 }
352
353 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
354 1
355 }
356
357 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
358 match idx {
359 0 => array.dict_bytes_handle().clone(),
360 _ => vortex_panic!("OnPairArray buffer index {idx} out of bounds"),
361 }
362 }
363
364 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
365 match idx {
366 0 => Some("dict_bytes".to_string()),
367 _ => vortex_panic!("OnPairArray buffer_name index {idx} out of bounds"),
368 }
369 }
370
371 fn serialize(
372 array: ArrayView<'_, Self>,
373 _session: &VortexSession,
374 ) -> VortexResult<Option<Vec<u8>>> {
375 let dict_size = u32::try_from(array.dict_offsets().len().saturating_sub(1))
376 .map_err(|_| vortex_err!("OnPair dict_size exceeds u32"))?;
377 let total_tokens = array.codes().len() as u64;
378 Ok(Some(
379 OnPairMetadata {
380 uncompressed_lengths_ptype: array.uncompressed_lengths().dtype().as_ptype().into(),
381 bits: array.bits(),
382 dict_size,
383 total_tokens,
384 dict_offsets_ptype: array.dict_offsets().dtype().as_ptype().into(),
385 codes_ptype: array.codes().dtype().as_ptype().into(),
386 codes_offsets_ptype: array.codes_offsets().dtype().as_ptype().into(),
387 }
388 .encode_to_vec(),
389 ))
390 }
391
392 fn deserialize(
393 &self,
394 dtype: &DType,
395 len: usize,
396 metadata: &[u8],
397 buffers: &[BufferHandle],
398 children: &dyn ArrayChildren,
399 _session: &VortexSession,
400 ) -> VortexResult<ArrayParts<Self>> {
401 if buffers.len() != 1 {
402 vortex_bail!(InvalidArgument: "Expected 1 buffer, got {}", buffers.len());
403 }
404 let metadata = OnPairMetadata::decode(metadata)?;
405 let uncompressed_ptype = metadata.get_uncompressed_lengths_ptype()?;
406
407 let dict_offsets_len = metadata.dict_size as usize + 1;
411 let total_tokens = usize::try_from(metadata.total_tokens)
412 .map_err(|_| vortex_err!("total_tokens {} overflows usize", metadata.total_tokens))?;
413 let dict_offsets_ptype = PType::try_from(metadata.dict_offsets_ptype).map_err(|_| {
417 vortex_err!("invalid dict_offsets_ptype {}", metadata.dict_offsets_ptype)
418 })?;
419 let codes_ptype = PType::try_from(metadata.codes_ptype)
420 .map_err(|_| vortex_err!("invalid codes_ptype {}", metadata.codes_ptype))?;
421 let codes_offsets_ptype = PType::try_from(metadata.codes_offsets_ptype).map_err(|_| {
422 vortex_err!(
423 "invalid codes_offsets_ptype {}",
424 metadata.codes_offsets_ptype
425 )
426 })?;
427 let dict_offsets = children.get(
428 0,
429 &DType::Primitive(dict_offsets_ptype, Nullability::NonNullable),
430 dict_offsets_len,
431 )?;
432 let codes = children.get(
433 1,
434 &DType::Primitive(codes_ptype, Nullability::NonNullable),
435 total_tokens,
436 )?;
437 let codes_offsets = children.get(
438 2,
439 &DType::Primitive(codes_offsets_ptype, Nullability::NonNullable),
440 len + 1,
441 )?;
442 let uncompressed_lengths = children.get(
443 3,
444 &DType::Primitive(uncompressed_ptype, Nullability::NonNullable),
445 len,
446 )?;
447 let validity = match children.len() {
448 4 => Validity::from(dtype.nullability()),
449 5 => Validity::Array(children.get(4, &Validity::DTYPE, len)?),
450 other => vortex_bail!(InvalidArgument: "Expected 4 or 5 children, got {other}"),
451 };
452
453 let data = OnPairData::new(buffers[0].clone(), metadata.bits, len);
454 let slots = OnPairSlots {
455 dict_offsets,
456 codes,
457 codes_offsets,
458 uncompressed_lengths,
459 validity: validity_to_child(&validity, len),
460 }
461 .into_slots();
462 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
463 }
464
465 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
466 OnPairSlots::NAMES[idx].to_string()
467 }
468
469 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
470 canonicalize_onpair(array.as_view(), ctx).map(ExecutionResult::done)
471 }
472
473 fn append_to_builder(
474 array: ArrayView<'_, Self>,
475 builder: &mut dyn ArrayBuilder,
476 ctx: &mut ExecutionCtx,
477 ) -> VortexResult<()> {
478 let Some(builder) = builder.as_any_mut().downcast_mut::<VarBinViewBuilder>() else {
479 builder.extend_from_array(
480 &array
481 .array()
482 .clone()
483 .execute::<Canonical>(ctx)?
484 .into_array(),
485 );
486 return Ok(());
487 };
488
489 let next_buffer_index = builder.completed_block_count() + u32::from(builder.in_progress());
490 let (buffers, views) = onpair_decode_views(array, next_buffer_index, ctx)?;
491 builder.push_buffer_and_adjusted_views(
492 &buffers,
493 &views,
494 array
495 .array()
496 .validity()?
497 .execute_mask(array.array().len(), ctx)?,
498 );
499 Ok(())
500 }
501
502 fn execute_parent(
503 array: ArrayView<'_, Self>,
504 parent: &ArrayRef,
505 child_idx: usize,
506 ctx: &mut ExecutionCtx,
507 ) -> VortexResult<Option<ArrayRef>> {
508 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
509 }
510
511 fn reduce_parent(
512 array: ArrayView<'_, Self>,
513 parent: &ArrayRef,
514 child_idx: usize,
515 ) -> VortexResult<Option<ArrayRef>> {
516 RULES.evaluate(array, parent, child_idx)
517 }
518}
519
520impl ValidityVTable<OnPair> for OnPair {
521 fn validity(array: ArrayView<'_, OnPair>) -> VortexResult<Validity> {
522 Ok(child_to_validity(
523 array.slots()[OnPairSlots::VALIDITY].as_ref(),
524 array.dtype().nullability(),
525 ))
526 }
527}
528
529pub trait OnPairArrayExt: OnPairArraySlotsExt {
531 fn array_validity(&self) -> Validity {
532 child_to_validity(
533 self.as_ref().slots()[OnPairSlots::VALIDITY].as_ref(),
534 self.as_ref().dtype().nullability(),
535 )
536 }
537}
538
539impl<T: OnPairArraySlotsExt> OnPairArrayExt for T {}