vortex_layout/layouts/dict/writer/
mod.rs

1use bytes::Bytes;
2use vortex_array::arcref::ArcRef;
3use vortex_array::compute::slice;
4use vortex_array::vtable::EncodingVTable as _;
5use vortex_array::{Array, ArrayContext, ArrayRef, ProstMetadata, SerializeMetadata};
6use vortex_btrblocks::BtrBlocksCompressor;
7use vortex_dict::DictEncoding;
8use vortex_dict::builders::{DictConstraints, DictEncoder, dict_encoder};
9use vortex_dtype::proto::dtype as pb;
10use vortex_dtype::{DType, PType};
11use vortex_error::{VortexResult, vortex_bail, vortex_err};
12
13mod repeating;
14
15use crate::layouts::dict::DictLayout;
16use crate::{Layout, LayoutStrategy, LayoutVTableRef, LayoutWriter, LayoutWriterExt};
17
18#[derive(Clone)]
19pub struct DictLayoutOptions {
20    pub constraints: DictConstraints,
21}
22
23impl Default for DictLayoutOptions {
24    fn default() -> Self {
25        Self {
26            constraints: DictConstraints {
27                max_bytes: 1024 * 1024,
28                max_len: u16::MAX as usize,
29            },
30        }
31    }
32}
33
34/// A layout strategy that encodes chunk into values and codes, if found
35/// appropriate by the btrblocks compressor. Current implementation only
36/// checks the first chunk to decide whether to apply dict layout and
37/// encodes chunks into dictionaries. When the dict constraints are hit, a
38/// new dictionary is created.
39#[derive(Clone)]
40pub struct DictStrategy {
41    pub options: DictLayoutOptions,
42    pub codes: ArcRef<dyn LayoutStrategy>,
43    pub values: ArcRef<dyn LayoutStrategy>,
44    pub fallback: ArcRef<dyn LayoutStrategy>,
45}
46
47impl LayoutStrategy for DictStrategy {
48    fn new_writer(&self, ctx: &ArrayContext, dtype: &DType) -> VortexResult<Box<dyn LayoutWriter>> {
49        if !dict_layout_supported(dtype) {
50            return self.fallback.new_writer(ctx, dtype);
51        }
52        Ok(DelegatingDictLayoutWriter {
53            ctx: ctx.clone(),
54            strategy: self.clone(),
55            dtype: dtype.clone(),
56            writer: None,
57        }
58        .boxed())
59    }
60}
61
62pub fn dict_layout_supported(dtype: &DType) -> bool {
63    matches!(
64        dtype,
65        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
66    )
67}
68
69struct DelegatingDictLayoutWriter {
70    ctx: ArrayContext,
71    strategy: DictStrategy,
72    dtype: DType,
73    writer: Option<Box<dyn LayoutWriter>>,
74}
75
76impl LayoutWriter for DelegatingDictLayoutWriter {
77    fn push_chunk(
78        &mut self,
79        segment_writer: &mut dyn crate::segments::SegmentWriter,
80        chunk: ArrayRef,
81    ) -> VortexResult<()> {
82        match self.writer.as_mut() {
83            Some(writer) => writer.push_chunk(segment_writer, chunk),
84            None => {
85                let compressed = BtrBlocksCompressor.compress(&chunk)?;
86                let mut writer = if !compressed.is_encoding(DictEncoding.id()) {
87                    self.strategy.fallback.new_writer(&self.ctx, &self.dtype)?
88                } else {
89                    repeating::DictLayoutWriter::new(
90                        self.ctx.clone(),
91                        &self.dtype,
92                        self.strategy.clone(),
93                    )
94                    .boxed()
95                };
96                writer.push_chunk(segment_writer, chunk)?;
97                self.writer = Some(writer);
98                Ok(())
99            }
100        }
101    }
102
103    fn flush(
104        &mut self,
105        segment_writer: &mut dyn crate::segments::SegmentWriter,
106    ) -> VortexResult<()> {
107        match self.writer.as_mut() {
108            None => vortex_bail!("flush called before push_chunk"),
109            Some(writer) => writer.flush(segment_writer),
110        }
111    }
112
113    fn finish(
114        &mut self,
115        segment_writer: &mut dyn crate::segments::SegmentWriter,
116    ) -> VortexResult<Layout> {
117        match self.writer.as_mut() {
118            None => vortex_bail!("finish called before push_chunk"),
119            Some(writer) => writer.finish(segment_writer),
120        }
121    }
122}
123
124#[derive(prost::Message)]
125pub struct DictLayoutMetadata {
126    #[prost(enumeration = "pb::PType", tag = "1")]
127    // i32 is required for proto, use the generated getter to read this field.
128    codes_ptype: i32,
129}
130
131impl DictLayoutMetadata {
132    pub fn new(codes_ptype: PType) -> Self {
133        let mut metadata = Self::default();
134        metadata.set_codes_ptype(codes_ptype.into());
135        metadata
136    }
137}
138
139fn dict_layout(values: Layout, codes: Layout) -> VortexResult<Layout> {
140    let metadata = Bytes::copy_from_slice(
141        &ProstMetadata(DictLayoutMetadata::new(codes.dtype().try_into()?))
142            .serialize()
143            .ok_or_else(|| vortex_err!("could not serialize dict layout metadata"))?,
144    );
145    Ok(Layout::new_owned(
146        "dict".into(),
147        LayoutVTableRef::new_ref(&DictLayout),
148        values.dtype().clone(),
149        codes.row_count(),
150        vec![],
151        vec![values, codes],
152        Some(metadata),
153    ))
154}
155
156enum EncodingState {
157    Continue((Box<dyn DictEncoder>, ArrayRef)),
158    // (values, encoded, unencoded)
159    Done((ArrayRef, ArrayRef, ArrayRef)),
160}
161
162fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
163    let encoder = dict_encoder(chunk, constraints)?;
164    encode_chunk(encoder, chunk)
165}
166
167fn encode_chunk(
168    mut encoder: Box<dyn DictEncoder>,
169    chunk: &dyn Array,
170) -> VortexResult<EncodingState> {
171    let encoded = encoder.encode(chunk)?;
172    Ok(match remainder(chunk, encoded.len())? {
173        None => EncodingState::Continue((encoder, encoded)),
174        Some(unencoded) => EncodingState::Done((encoder.values()?, encoded, unencoded)),
175    })
176}
177
178fn remainder(array: &dyn Array, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
179    (encoded_len < array.len())
180        .then(|| slice(array, encoded_len, array.len()))
181        .transpose()
182}