Skip to main content

vortex_onpair/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use prost::Message as _;
10use vortex_array::Array;
11use vortex_array::ArrayEq;
12use vortex_array::ArrayHash;
13use vortex_array::ArrayId;
14use vortex_array::ArrayParts;
15use vortex_array::ArrayRef;
16use vortex_array::ArrayView;
17use vortex_array::Canonical;
18use vortex_array::ExecutionCtx;
19use vortex_array::ExecutionResult;
20use vortex_array::IntoArray;
21use vortex_array::Precision;
22use vortex_array::array_slots;
23use vortex_array::buffer::BufferHandle;
24use vortex_array::builders::ArrayBuilder;
25use vortex_array::builders::VarBinViewBuilder;
26use vortex_array::dtype::DType;
27use vortex_array::dtype::Nullability;
28use vortex_array::dtype::PType;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::validity::Validity;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityVTable;
33use vortex_array::vtable::child_to_validity;
34use vortex_array::vtable::validity_to_child;
35use vortex_buffer::ByteBuffer;
36use vortex_error::VortexResult;
37use vortex_error::vortex_bail;
38use vortex_error::vortex_ensure;
39use vortex_error::vortex_err;
40use vortex_error::vortex_panic;
41use vortex_session::VortexSession;
42use vortex_session::registry::CachedId;
43
44use crate::canonical::canonicalize_onpair;
45use crate::canonical::onpair_decode_views;
46use crate::kernel::PARENT_KERNELS;
47use crate::rules::RULES;
48
49/// An [`OnPair`]-encoded Vortex array.
50pub type OnPairArray = Array<OnPair>;
51
52/// Wire-format metadata persisted alongside the OnPair buffer + slot children.
53///
54/// On disk the layout is FSST-shape:
55///
56/// * Buffer 0 — `dict_bytes`: the dictionary blob built by the OnPair trainer,
57///   padded with `onpair::MAX_TOKEN_SIZE` trailing zero
58///   bytes so the over-copy decoder can read 16 bytes past the last token.
59/// * Slots — see [`OnPairSlots`].
60///
61/// The four integer slot children flow through the standard `compress_child`
62/// pipeline (see `vortex-btrblocks::schemes::string::OnPairScheme`), so any
63/// encoding registered with the compressor can re-encode them — exactly the
64/// same shape as FSST's `codes` `VarBinArray`.
65#[derive(Clone, prost::Message)]
66pub struct OnPairMetadata {
67    /// Width of the per-row primitive `uncompressed_lengths` child.
68    #[prost(enumeration = "PType", tag = "1")]
69    pub uncompressed_lengths_ptype: i32,
70    /// Bits-per-token the column was compressed with (9..=16). Every value
71    /// in the `codes` child only uses its low `bits` bits.
72    #[prost(uint32, tag = "2")]
73    pub bits: u32,
74    /// Number of dictionary tokens. `dict_offsets` has length `dict_size + 1`.
75    /// Bounded by `2^bits ≤ 2^16 = 65_536`, so `u32` is comfortably wide.
76    #[prost(uint32, tag = "3")]
77    pub dict_size: u32,
78    /// Total number of tokens across all rows. `codes` has this length;
79    /// `codes_offsets.last() == total_tokens`.
80    #[prost(uint64, tag = "4")]
81    pub total_tokens: u64,
82    /// PType of the `dict_offsets` slot child (defaults to U32, may be
83    /// narrowed to U16/U8 by the cascading compressor when values fit).
84    #[prost(enumeration = "PType", tag = "5")]
85    pub dict_offsets_ptype: i32,
86    /// PType of the `codes` slot child (typically U16, may be narrowed to U8
87    /// when `bits <= 8`).
88    #[prost(enumeration = "PType", tag = "6")]
89    pub codes_ptype: i32,
90    /// PType of the `codes_offsets` slot child.
91    #[prost(enumeration = "PType", tag = "7")]
92    pub codes_offsets_ptype: i32,
93}
94
95impl OnPairMetadata {
96    pub fn get_uncompressed_lengths_ptype(&self) -> VortexResult<PType> {
97        PType::try_from(self.uncompressed_lengths_ptype)
98            .map_err(|_| vortex_err!("Invalid PType {}", self.uncompressed_lengths_ptype))
99    }
100}
101
102#[array_slots(OnPair)]
103pub struct OnPairSlots {
104    /// `PrimitiveArray<u32>`, length `dict_size + 1`. Cascading compressor may
105    /// narrow the ptype to U16/U8.
106    pub dict_offsets: ArrayRef,
107    /// `PrimitiveArray<u16>`. Each value only uses its low `bits` bits;
108    /// downstream `FastLanes::BitPacking` losslessly shrinks the child to
109    /// exactly `bits`-bit codes on disk.
110    pub codes: ArrayRef,
111    /// `PrimitiveArray<u32>`, length `num_rows + 1`. FoR / RunEnd / etc. apply
112    /// naturally via the cascading compressor.
113    pub codes_offsets: ArrayRef,
114    /// Integer `PrimitiveArray`, length `num_rows`. Used to size the canonical
115    /// output buffer.
116    pub uncompressed_lengths: ArrayRef,
117    /// Optional validity child for the outer string column.
118    pub validity: Option<ArrayRef>,
119}
120
121/// Inner data for an OnPair-encoded array.
122///
123/// Holds only the dictionary blob (buffer 0). Every other piece —
124/// `dict_offsets`, the per-token `codes`, the per-row `codes_offsets`, the
125/// per-row `uncompressed_lengths`, and the optional validity child — is a
126/// Vortex slot child so it can be re-encoded by the cascading compressor.
127#[derive(Clone)]
128pub struct OnPairData {
129    /// The dictionary blob (buffer 0).
130    ///
131    /// INVARIANT: this buffer must be over-padded past its logical end
132    /// (`dict_offsets.last()`) by the decoder's fixed token read width,
133    /// `onpair::MAX_TOKEN_SIZE`. The over-copy decoder reads
134    /// every dictionary entry with one fixed-width load and then advances the
135    /// cursor by the token's true length, so the load for the final, shortest
136    /// token over-reads past the logical end of the dictionary. This is the
137    /// same over-read the decoder accounts for on the final few codes; the
138    /// trailing padding absorbs it so that any entry can be read in bounds.
139    /// `onpair_compress` establishes this padding (see `parts_to_children`);
140    /// the over-copy decoder lives in the `onpair` crate.
141    dict_bytes: BufferHandle,
142    bits: u32,
143    len: usize,
144}
145
146impl OnPairData {
147    pub fn new(dict_bytes: BufferHandle, bits: u32, len: usize) -> Self {
148        Self {
149            dict_bytes,
150            bits,
151            len,
152        }
153    }
154
155    pub fn len(&self) -> usize {
156        self.len
157    }
158
159    pub fn is_empty(&self) -> bool {
160        self.len == 0
161    }
162
163    pub fn bits(&self) -> u32 {
164        self.bits
165    }
166
167    pub fn dict_bytes(&self) -> &ByteBuffer {
168        self.dict_bytes.as_host()
169    }
170
171    pub fn dict_bytes_handle(&self) -> &BufferHandle {
172        &self.dict_bytes
173    }
174}
175
176impl Display for OnPairData {
177    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
178        write!(
179            f,
180            "len: {}, bits: {}, dict_bytes_len: {}",
181            self.len,
182            self.bits,
183            self.dict_bytes.len()
184        )
185    }
186}
187
188impl Debug for OnPairData {
189    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("OnPairData")
191            .field("len", &self.len)
192            .field("bits", &self.bits)
193            .field("dict_bytes_len", &self.dict_bytes.len())
194            .finish()
195    }
196}
197
198impl ArrayHash for OnPairData {
199    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
200        self.dict_bytes.as_host().array_hash(state, precision);
201        state.write_u32(self.bits);
202    }
203}
204
205impl ArrayEq for OnPairData {
206    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
207        self.bits == other.bits
208            && self
209                .dict_bytes
210                .as_host()
211                .array_eq(other.dict_bytes.as_host(), precision)
212    }
213}
214
215/// Zero-sized VTable marker for the OnPair encoding.
216#[derive(Clone, Debug)]
217pub struct OnPair;
218
219impl OnPair {
220    /// Build an [`OnPairArray`] from already-materialised parts.
221    #[expect(clippy::too_many_arguments, reason = "every child is a real input")]
222    pub fn try_new(
223        dtype: DType,
224        dict_bytes: BufferHandle,
225        dict_offsets: ArrayRef,
226        codes: ArrayRef,
227        codes_offsets: ArrayRef,
228        uncompressed_lengths: ArrayRef,
229        validity: Validity,
230        bits: u32,
231    ) -> VortexResult<OnPairArray> {
232        validate_parts(
233            &dtype,
234            &dict_offsets,
235            &codes,
236            &codes_offsets,
237            &uncompressed_lengths,
238            bits,
239        )?;
240        let len = uncompressed_lengths.len();
241        let data = OnPairData::new(dict_bytes, bits, len);
242        let slots = OnPairSlots {
243            dict_offsets,
244            codes,
245            codes_offsets,
246            uncompressed_lengths,
247            validity: validity_to_child(&validity, len),
248        }
249        .into_slots();
250        Ok(unsafe {
251            Array::from_parts_unchecked(ArrayParts::new(OnPair, dtype, len, data).with_slots(slots))
252        })
253    }
254
255    #[expect(clippy::too_many_arguments, reason = "every child is a real input")]
256    pub(crate) unsafe fn new_unchecked(
257        dtype: DType,
258        dict_bytes: BufferHandle,
259        dict_offsets: ArrayRef,
260        codes: ArrayRef,
261        codes_offsets: ArrayRef,
262        uncompressed_lengths: ArrayRef,
263        validity: Validity,
264        bits: u32,
265    ) -> OnPairArray {
266        let len = uncompressed_lengths.len();
267        let data = OnPairData::new(dict_bytes, bits, len);
268        let slots = OnPairSlots {
269            dict_offsets,
270            codes,
271            codes_offsets,
272            uncompressed_lengths,
273            validity: validity_to_child(&validity, len),
274        }
275        .into_slots();
276        unsafe {
277            Array::from_parts_unchecked(ArrayParts::new(OnPair, dtype, len, data).with_slots(slots))
278        }
279    }
280}
281
282fn validate_parts(
283    dtype: &DType,
284    dict_offsets: &ArrayRef,
285    codes: &ArrayRef,
286    codes_offsets: &ArrayRef,
287    uncompressed_lengths: &ArrayRef,
288    bits: u32,
289) -> VortexResult<()> {
290    vortex_ensure!(
291        matches!(dtype, DType::Binary(_) | DType::Utf8(_)),
292        "OnPair arrays must be Binary or Utf8, found {dtype}"
293    );
294    vortex_ensure!((9..=16).contains(&bits), "bits {bits} out of range [9, 16]");
295
296    if !dict_offsets.dtype().is_int() || dict_offsets.dtype().is_nullable() {
297        vortex_bail!(InvalidArgument: "dict_offsets must be non-nullable integer");
298    }
299    if !codes.dtype().is_int() || codes.dtype().is_nullable() {
300        vortex_bail!(InvalidArgument: "codes must be non-nullable integer");
301    }
302    if !codes_offsets.dtype().is_int() || codes_offsets.dtype().is_nullable() {
303        vortex_bail!(InvalidArgument: "codes_offsets must be non-nullable integer");
304    }
305    if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
306        vortex_bail!(InvalidArgument: "uncompressed_lengths must be non-nullable integer");
307    }
308    if codes_offsets.len() != uncompressed_lengths.len() + 1 {
309        vortex_bail!(InvalidArgument:
310            "codes_offsets.len ({}) != uncompressed_lengths.len + 1 ({})",
311            codes_offsets.len(),
312            uncompressed_lengths.len() + 1
313        );
314    }
315    Ok(())
316}
317
318impl VTable for OnPair {
319    type TypedArrayData = OnPairData;
320    type OperationsVTable = Self;
321    type ValidityVTable = Self;
322
323    fn id(&self) -> ArrayId {
324        static ID: CachedId = CachedId::new("vortex.onpair");
325        *ID
326    }
327
328    fn validate(
329        &self,
330        data: &Self::TypedArrayData,
331        dtype: &DType,
332        len: usize,
333        slots: &[Option<ArrayRef>],
334    ) -> VortexResult<()> {
335        let s = OnPairSlotsView::from_slots(slots);
336        validate_parts(
337            dtype,
338            s.dict_offsets,
339            s.codes,
340            s.codes_offsets,
341            s.uncompressed_lengths,
342            data.bits,
343        )?;
344        if s.uncompressed_lengths.len() != len {
345            vortex_bail!(InvalidArgument: "uncompressed_lengths must have same len as outer array");
346        }
347        if data.len != len {
348            vortex_bail!(InvalidArgument: "OnPairData len {} != outer len {}", data.len, len);
349        }
350        Ok(())
351    }
352
353    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
354        1
355    }
356
357    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
358        match idx {
359            0 => array.dict_bytes_handle().clone(),
360            _ => vortex_panic!("OnPairArray buffer index {idx} out of bounds"),
361        }
362    }
363
364    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
365        match idx {
366            0 => Some("dict_bytes".to_string()),
367            _ => vortex_panic!("OnPairArray buffer_name index {idx} out of bounds"),
368        }
369    }
370
371    fn serialize(
372        array: ArrayView<'_, Self>,
373        _session: &VortexSession,
374    ) -> VortexResult<Option<Vec<u8>>> {
375        let dict_size = u32::try_from(array.dict_offsets().len().saturating_sub(1))
376            .map_err(|_| vortex_err!("OnPair dict_size exceeds u32"))?;
377        let total_tokens = array.codes().len() as u64;
378        Ok(Some(
379            OnPairMetadata {
380                uncompressed_lengths_ptype: array.uncompressed_lengths().dtype().as_ptype().into(),
381                bits: array.bits(),
382                dict_size,
383                total_tokens,
384                dict_offsets_ptype: array.dict_offsets().dtype().as_ptype().into(),
385                codes_ptype: array.codes().dtype().as_ptype().into(),
386                codes_offsets_ptype: array.codes_offsets().dtype().as_ptype().into(),
387            }
388            .encode_to_vec(),
389        ))
390    }
391
392    fn deserialize(
393        &self,
394        dtype: &DType,
395        len: usize,
396        metadata: &[u8],
397        buffers: &[BufferHandle],
398        children: &dyn ArrayChildren,
399        _session: &VortexSession,
400    ) -> VortexResult<ArrayParts<Self>> {
401        if buffers.len() != 1 {
402            vortex_bail!(InvalidArgument: "Expected 1 buffer, got {}", buffers.len());
403        }
404        let metadata = OnPairMetadata::decode(metadata)?;
405        let uncompressed_ptype = metadata.get_uncompressed_lengths_ptype()?;
406
407        // Slot children. We pass `usize::MAX` for slots whose length we
408        // don't know up front (`dict_offsets` and `codes`). `codes_offsets`
409        // has known length `len + 1`.
410        let dict_offsets_len = metadata.dict_size as usize + 1;
411        let total_tokens = usize::try_from(metadata.total_tokens)
412            .map_err(|_| vortex_err!("total_tokens {} overflows usize", metadata.total_tokens))?;
413        // The cascading compressor may have narrowed any of these integer
414        // children to a tighter ptype; the recorded ptype tells the framework
415        // exactly which dtype to materialise as.
416        let dict_offsets_ptype = PType::try_from(metadata.dict_offsets_ptype).map_err(|_| {
417            vortex_err!("invalid dict_offsets_ptype {}", metadata.dict_offsets_ptype)
418        })?;
419        let codes_ptype = PType::try_from(metadata.codes_ptype)
420            .map_err(|_| vortex_err!("invalid codes_ptype {}", metadata.codes_ptype))?;
421        let codes_offsets_ptype = PType::try_from(metadata.codes_offsets_ptype).map_err(|_| {
422            vortex_err!(
423                "invalid codes_offsets_ptype {}",
424                metadata.codes_offsets_ptype
425            )
426        })?;
427        let dict_offsets = children.get(
428            0,
429            &DType::Primitive(dict_offsets_ptype, Nullability::NonNullable),
430            dict_offsets_len,
431        )?;
432        let codes = children.get(
433            1,
434            &DType::Primitive(codes_ptype, Nullability::NonNullable),
435            total_tokens,
436        )?;
437        let codes_offsets = children.get(
438            2,
439            &DType::Primitive(codes_offsets_ptype, Nullability::NonNullable),
440            len + 1,
441        )?;
442        let uncompressed_lengths = children.get(
443            3,
444            &DType::Primitive(uncompressed_ptype, Nullability::NonNullable),
445            len,
446        )?;
447        let validity = match children.len() {
448            4 => Validity::from(dtype.nullability()),
449            5 => Validity::Array(children.get(4, &Validity::DTYPE, len)?),
450            other => vortex_bail!(InvalidArgument: "Expected 4 or 5 children, got {other}"),
451        };
452
453        let data = OnPairData::new(buffers[0].clone(), metadata.bits, len);
454        let slots = OnPairSlots {
455            dict_offsets,
456            codes,
457            codes_offsets,
458            uncompressed_lengths,
459            validity: validity_to_child(&validity, len),
460        }
461        .into_slots();
462        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
463    }
464
465    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
466        OnPairSlots::NAMES[idx].to_string()
467    }
468
469    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
470        canonicalize_onpair(array.as_view(), ctx).map(ExecutionResult::done)
471    }
472
473    fn append_to_builder(
474        array: ArrayView<'_, Self>,
475        builder: &mut dyn ArrayBuilder,
476        ctx: &mut ExecutionCtx,
477    ) -> VortexResult<()> {
478        let Some(builder) = builder.as_any_mut().downcast_mut::<VarBinViewBuilder>() else {
479            builder.extend_from_array(
480                &array
481                    .array()
482                    .clone()
483                    .execute::<Canonical>(ctx)?
484                    .into_array(),
485            );
486            return Ok(());
487        };
488
489        let next_buffer_index = builder.completed_block_count() + u32::from(builder.in_progress());
490        let (buffers, views) = onpair_decode_views(array, next_buffer_index, ctx)?;
491        builder.push_buffer_and_adjusted_views(
492            &buffers,
493            &views,
494            array
495                .array()
496                .validity()?
497                .execute_mask(array.array().len(), ctx)?,
498        );
499        Ok(())
500    }
501
502    fn execute_parent(
503        array: ArrayView<'_, Self>,
504        parent: &ArrayRef,
505        child_idx: usize,
506        ctx: &mut ExecutionCtx,
507    ) -> VortexResult<Option<ArrayRef>> {
508        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
509    }
510
511    fn reduce_parent(
512        array: ArrayView<'_, Self>,
513        parent: &ArrayRef,
514        child_idx: usize,
515    ) -> VortexResult<Option<ArrayRef>> {
516        RULES.evaluate(array, parent, child_idx)
517    }
518}
519
520impl ValidityVTable<OnPair> for OnPair {
521    fn validity(array: ArrayView<'_, OnPair>) -> VortexResult<Validity> {
522        Ok(child_to_validity(
523            array.slots()[OnPairSlots::VALIDITY].as_ref(),
524            array.dtype().nullability(),
525        ))
526    }
527}
528
529/// Convenience methods on top of the macro-generated [`OnPairArraySlotsExt`].
530pub trait OnPairArrayExt: OnPairArraySlotsExt {
531    fn array_validity(&self) -> Validity {
532        child_to_validity(
533            self.as_ref().slots()[OnPairSlots::VALIDITY].as_ref(),
534            self.as_ref().dtype().nullability(),
535        )
536    }
537}
538
539impl<T: OnPairArraySlotsExt> OnPairArrayExt for T {}