vortex_layout/layouts/dict/writer/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll, ready};
7
8use arcref::ArcRef;
9use async_stream::try_stream;
10use futures::channel::{mpsc, oneshot};
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, SinkExt, Stream, StreamExt, pin_mut, try_join};
13use vortex_array::{Array, ArrayContext, ArrayRef};
14use vortex_btrblocks::BtrBlocksCompressor;
15use vortex_dict::DictEncoding;
16use vortex_dict::builders::{DictConstraints, DictEncoder, dict_encoder};
17use vortex_dtype::{DType, PType};
18use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err};
19
20use super::DictLayout;
21use crate::layouts::chunked::ChunkedLayout;
22use crate::segments::SequenceWriter;
23use crate::sequence::{SequenceId, SequencePointer};
24use crate::{
25    IntoLayout, LayoutStrategy, OwnedLayoutChildren, SendableLayoutFuture,
26    SendableSequentialStream, SequentialStreamAdapter, SequentialStreamExt, TaskExecutor,
27    TaskExecutorExt as _,
28};
29
30#[derive(Clone)]
31pub struct DictLayoutOptions {
32    pub constraints: DictConstraints,
33    /// Max number of encoded chunks to keep in memory.
34    pub encoded_buffer_size: usize,
35}
36
37impl Default for DictLayoutOptions {
38    fn default() -> Self {
39        Self {
40            constraints: DictConstraints {
41                max_bytes: 1024 * 1024,
42                max_len: u16::MAX as usize,
43            },
44            encoded_buffer_size: 8,
45        }
46    }
47}
48
49/// A layout strategy that encodes chunk into values and codes, if found
50/// appropriate by the btrblocks compressor. Current implementation only
51/// checks the first chunk to decide whether to apply dict layout and
52/// encodes chunks into dictionaries. When the dict constraints are hit, a
53/// new dictionary is created.
54pub struct DictStrategy {
55    codes: ArcRef<dyn LayoutStrategy>,
56    values: ArcRef<dyn LayoutStrategy>,
57    fallback: ArcRef<dyn LayoutStrategy>,
58    options: DictLayoutOptions,
59    executor: Arc<dyn TaskExecutor>,
60}
61
62impl DictStrategy {
63    pub fn new(
64        codes: ArcRef<dyn LayoutStrategy>,
65        values: ArcRef<dyn LayoutStrategy>,
66        fallback: ArcRef<dyn LayoutStrategy>,
67        options: DictLayoutOptions,
68        executor: Arc<dyn TaskExecutor>,
69    ) -> Self {
70        Self {
71            codes,
72            values,
73            fallback,
74            options,
75            executor,
76        }
77    }
78}
79
80impl LayoutStrategy for DictStrategy {
81    fn write_stream(
82        &self,
83        ctx: &ArrayContext,
84        sequence_writer: SequenceWriter,
85        stream: SendableSequentialStream,
86    ) -> SendableLayoutFuture {
87        if !dict_layout_supported(stream.dtype()) {
88            return self.fallback.write_stream(ctx, sequence_writer, stream);
89        }
90        let codes = self.codes.clone();
91        let values = self.values.clone();
92        let fallback = self.fallback.clone();
93        let ctx = ctx.clone();
94        let options = self.options.clone();
95        let dtype = stream.dtype().clone();
96        let executor = self.executor.clone();
97        Box::pin(async move {
98            // 0. decide if chunks are eligible for dict encoding
99            let (stream, first_chunk) = peek_first_chunk(stream).await?;
100            let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
101
102            let should_fallback = match first_chunk {
103                None => true, // empty stream
104                Some(chunk) => {
105                    let compressed = BtrBlocksCompressor.compress(&chunk)?;
106                    !compressed.is_encoding(DictEncoding.id())
107                }
108            };
109            if should_fallback {
110                // first chunk did not compress to dict, or did not exist. Skip dict layout
111                return fallback
112                    .write_stream(&ctx, sequence_writer.clone(), stream)
113                    .await;
114            }
115
116            // 1. from a chunk stream, create a stream that yields codes
117            // followed by a single value chunk when dict constraints are hit.
118            // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
119            let mut dict_stream = dict_encode_stream(stream, options.constraints);
120
121            // 2.a spawn encoding codes
122            let (mut encoded_tx, encoded_rx) = mpsc::channel(options.encoded_buffer_size);
123            let encode_handle = executor.spawn({
124                async move {
125                    while let Some(item) = dict_stream.next().await {
126                        encoded_tx
127                            .send(item)
128                            .await
129                            .map_err(|e| vortex_err!("rx dropped: {}", e))?;
130                    }
131                    Ok(())
132                }
133                .boxed()
134            });
135
136            // 2.b get contiguous runs of codes from the dict stream and
137            // create child dict layouts from them.
138            let dtype_clone = dtype.clone();
139            let child_layouts_fut = async move {
140                let mut children = Vec::new();
141                let mut runs = DictEncodedRuns::new(Box::pin(encoded_rx));
142                while let Some((codes_stream, values_future)) = runs.next_run().await {
143                    let (codes_stream, first_chunk) =
144                        peek_first_chunk(codes_stream.boxed()).await?;
145                    let codes_dtype = match first_chunk {
146                        // codes_stream is empty, this would happen if the parent stream end coincided with a dict run end
147                        None => break,
148                        Some(chunk) => chunk.dtype().clone(),
149                    };
150                    let codes_layout = codes
151                        .write_stream(
152                            &ctx,
153                            sequence_writer.clone(),
154                            SequentialStreamAdapter::new(codes_dtype, codes_stream).sendable(),
155                        )
156                        .await?;
157                    let values_layout = values
158                        .write_stream(
159                            &ctx,
160                            sequence_writer.clone(),
161                            SequentialStreamAdapter::new(dtype_clone.clone(), once(values_future))
162                                .sendable(),
163                        )
164                        .await?;
165                    children.push(DictLayout::new(values_layout, codes_layout).into_layout());
166                }
167                Ok(children)
168            };
169
170            // join dict encoding task
171            let (mut children, _) = try_join!(child_layouts_fut, encode_handle)?;
172
173            if children.len() == 1 {
174                return Ok(children.remove(0));
175            }
176
177            let row_count = children.iter().map(|child| child.row_count()).sum();
178            Ok(ChunkedLayout::new(
179                row_count,
180                dtype,
181                OwnedLayoutChildren::layout_children(children),
182            )
183            .into_layout())
184        })
185    }
186}
187
188enum DictionaryChunk {
189    Codes((SequenceId, ArrayRef)),
190    Values((SequenceId, ArrayRef)),
191}
192
193type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
194
195fn dict_encode_stream(
196    input: SendableSequentialStream,
197    constraints: DictConstraints,
198) -> DictionaryStream {
199    Box::pin(try_stream! {
200        let mut state = DictStreamState {
201            encoder: None,
202            constraints,
203        };
204        let input = input.peekable();
205        pin_mut!(input);
206        while let Some(item) = input.as_mut().next().await {
207            let (sequence_id, chunk) = item?;
208            // labeler potentially creates sub sequences, we must
209            // create it on both arms to avoid having a SequencePointer
210            // between await points
211            match input.as_mut().peek().await {
212                Some(_) => {
213                    let mut labeler = DictChunkLabeler::new(sequence_id);
214                    let chunks = state.encode(&mut labeler, chunk);
215                    drop(labeler);
216                    for dict_chunk in chunks {
217                        yield dict_chunk?;
218                    }
219                }
220                None => {
221                    // this is the last element, encode and drain chunks
222                    let mut labeler = DictChunkLabeler::new(sequence_id);
223                    let encoded = state.encode(&mut labeler, chunk);
224                    let drained = state.drain_values(&mut labeler);
225                    drop(labeler);
226                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
227                        yield dict_chunk?;
228                    }
229                }
230            }
231        }
232    })
233}
234
235struct DictStreamState {
236    encoder: Option<Box<dyn DictEncoder>>,
237    constraints: DictConstraints,
238}
239
240impl DictStreamState {
241    fn encode(
242        &mut self,
243        labeler: &mut DictChunkLabeler,
244        chunk: ArrayRef,
245    ) -> Vec<VortexResult<DictionaryChunk>> {
246        self.try_encode(labeler, chunk)
247            .unwrap_or_else(|e| vec![Err(e)])
248    }
249
250    fn try_encode(
251        &mut self,
252        labeler: &mut DictChunkLabeler,
253        chunk: ArrayRef,
254    ) -> VortexResult<Vec<VortexResult<DictionaryChunk>>> {
255        let mut res = Vec::new();
256        let mut to_be_encoded = Some(chunk);
257        while let Some(remaining) = to_be_encoded.take() {
258            match self.encoder.take() {
259                None => match start_encoding(&self.constraints, &remaining)? {
260                    EncodingState::Continue((encoder, encoded)) => {
261                        res.push(Ok(labeler.codes(encoded)));
262                        self.encoder = Some(encoder);
263                    }
264                    EncodingState::Done((values, encoded, unencoded)) => {
265                        res.push(Ok(labeler.codes(encoded)));
266                        res.push(Ok(labeler.values(values)));
267                        to_be_encoded = Some(unencoded);
268                    }
269                },
270                Some(encoder) => match encode_chunk(encoder, &remaining)? {
271                    EncodingState::Continue((encoder, encoded)) => {
272                        res.push(Ok(labeler.codes(encoded)));
273                        self.encoder = Some(encoder);
274                    }
275                    EncodingState::Done((values, encoded, unencoded)) => {
276                        res.push(Ok(labeler.codes(encoded)));
277                        res.push(Ok(labeler.values(values)));
278                        to_be_encoded = Some(unencoded);
279                    }
280                },
281            }
282        }
283        Ok(res)
284    }
285
286    fn drain_values(
287        &mut self,
288        labeler: &mut DictChunkLabeler,
289    ) -> Vec<VortexResult<DictionaryChunk>> {
290        match self.encoder.as_mut() {
291            None => Vec::new(),
292            Some(encoder) => vec![encoder.values().map(|val| labeler.values(val))],
293        }
294    }
295}
296
297struct DictChunkLabeler {
298    sequence_pointer: SequencePointer,
299}
300
301impl DictChunkLabeler {
302    fn new(starting_id: SequenceId) -> Self {
303        let sequence_pointer = starting_id.descend();
304        Self { sequence_pointer }
305    }
306
307    fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
308        DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
309    }
310
311    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
312        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
313    }
314}
315
316type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
317
318struct DictEncodedRuns {
319    input: Option<oneshot::Receiver<Option<DictionaryStream>>>,
320}
321
322impl DictEncodedRuns {
323    fn new(input: DictionaryStream) -> Self {
324        let (tx, rx) = oneshot::channel();
325        tx.send(Some(input))
326            .map_err(|_input| vortex_err!("just created rx"))
327            .vortex_unwrap();
328        Self { input: Some(rx) }
329    }
330
331    async fn next_run(
332        &mut self,
333    ) -> Option<(
334        DictEncodedRunStream,
335        impl Future<Output = SequencedChunk> + use<>,
336    )> {
337        // get input to send to the run stream.
338        let Ok(Some(input)) = self.input.take()?.await else {
339            // input exhausted
340            return None;
341        };
342        let (input_tx, input_rx) = oneshot::channel();
343        self.input = Some(input_rx);
344
345        let (values_tx, values_rx) = oneshot::channel();
346        let values_future = async {
347            values_rx
348                .await
349                .unwrap_or_else(|_| vortex_bail!("sender dropped"))
350        };
351
352        let codes_stream = DictEncodedRunStream {
353            input: Some(input),
354            input_tx: Some(input_tx),
355            values_tx: Some(values_tx),
356        };
357
358        Some((codes_stream, values_future))
359    }
360}
361
362struct DictEncodedRunStream {
363    input: Option<DictionaryStream>,
364    input_tx: Option<oneshot::Sender<Option<DictionaryStream>>>,
365    values_tx: Option<oneshot::Sender<SequencedChunk>>,
366}
367
368impl Stream for DictEncodedRunStream {
369    type Item = SequencedChunk;
370
371    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372        let poll_result = {
373            let Some(stream) = self.input.as_mut() else {
374                return Poll::Ready(None);
375            };
376            ready!(stream.poll_next_unpin(cx))
377        };
378
379        match poll_result {
380            Some(Ok(DictionaryChunk::Codes(item))) => Poll::Ready(Some(Ok(item))),
381            Some(Ok(DictionaryChunk::Values(item))) => {
382                self.send_values(item);
383                self.send_back_input_stream();
384                Poll::Ready(None)
385            }
386            Some(Err(e)) => Poll::Ready(Some(Err(e))),
387            None => {
388                self.send_back_input_stream();
389                Poll::Ready(None)
390            }
391        }
392    }
393}
394
395impl DictEncodedRunStream {
396    fn send_values(&mut self, item: (SequenceId, ArrayRef)) {
397        // ignore receiver drops
398        let _ = self
399            .values_tx
400            .take()
401            .vortex_expect("must not be polled after returning None")
402            .send(Ok(item));
403    }
404
405    fn send_back_input_stream(&mut self) {
406        // ignore receiver drops
407        let _ = self
408            .input_tx
409            .take()
410            .vortex_expect("input already sent")
411            .send(self.input.take());
412    }
413}
414
415impl Drop for DictEncodedRunStream {
416    fn drop(&mut self) {
417        if let Some(tx) = self.input_tx.take() {
418            let _ = tx.send(self.input.take());
419        }
420    }
421}
422
423async fn peek_first_chunk(
424    mut stream: BoxStream<'static, SequencedChunk>,
425) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
426    match stream.next().await {
427        None => Ok((stream.boxed(), None)),
428        Some(Err(e)) => Err(e),
429        Some(Ok((sequence_id, chunk))) => {
430            let chunk_clone = chunk.clone();
431            let reconstructed_stream =
432                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
433            Ok((reconstructed_stream.boxed(), Some(chunk)))
434        }
435    }
436}
437
438pub fn dict_layout_supported(dtype: &DType) -> bool {
439    matches!(
440        dtype,
441        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
442    )
443}
444
445#[derive(prost::Message)]
446pub struct DictLayoutMetadata {
447    #[prost(enumeration = "PType", tag = "1")]
448    // i32 is required for proto, use the generated getter to read this field.
449    codes_ptype: i32,
450}
451
452impl DictLayoutMetadata {
453    pub fn new(codes_ptype: PType) -> Self {
454        let mut metadata = Self::default();
455        metadata.set_codes_ptype(codes_ptype);
456        metadata
457    }
458}
459
460enum EncodingState {
461    Continue((Box<dyn DictEncoder>, ArrayRef)),
462    // (values, encoded, unencoded)
463    Done((ArrayRef, ArrayRef, ArrayRef)),
464}
465
466fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
467    let encoder = dict_encoder(chunk, constraints)?;
468    encode_chunk(encoder, chunk)
469}
470
471fn encode_chunk(
472    mut encoder: Box<dyn DictEncoder>,
473    chunk: &dyn Array,
474) -> VortexResult<EncodingState> {
475    let encoded = encoder.encode(chunk)?;
476    Ok(match remainder(chunk, encoded.len())? {
477        None => EncodingState::Continue((encoder, encoded)),
478        Some(unencoded) => EncodingState::Done((encoder.values()?, encoded, unencoded)),
479    })
480}
481
482fn remainder(array: &dyn Array, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
483    (encoded_len < array.len())
484        .then(|| array.slice(encoded_len, array.len()))
485        .transpose()
486}