vortex_layout/layouts/dict/writer/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll, ready};
7
8use async_stream::try_stream;
9use async_trait::async_trait;
10use futures::channel::{mpsc, oneshot};
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, SinkExt, Stream, StreamExt, pin_mut, try_join};
13use vortex_array::{Array, ArrayContext, ArrayRef};
14use vortex_btrblocks::BtrBlocksCompressor;
15use vortex_dict::DictEncoding;
16use vortex_dict::builders::{DictConstraints, DictEncoder, dict_encoder};
17use vortex_dtype::{DType, PType};
18use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err};
19
20use super::DictLayout;
21use crate::layouts::chunked::ChunkedLayout;
22use crate::segments::SequenceWriter;
23use crate::sequence::{SequenceId, SequencePointer};
24use crate::{
25    IntoLayout, LayoutRef, LayoutStrategy, OwnedLayoutChildren, SendableSequentialStream,
26    SequentialStreamAdapter, SequentialStreamExt, TaskExecutor, TaskExecutorExt as _,
27};
28
29#[derive(Clone)]
30pub struct DictLayoutOptions {
31    pub constraints: DictConstraints,
32    /// Max number of encoded chunks to keep in memory.
33    pub encoded_buffer_size: usize,
34}
35
36impl Default for DictLayoutOptions {
37    fn default() -> Self {
38        Self {
39            constraints: DictConstraints {
40                max_bytes: 1024 * 1024,
41                max_len: u16::MAX as usize,
42            },
43            encoded_buffer_size: 8,
44        }
45    }
46}
47
48/// A layout strategy that encodes chunk into values and codes, if found
49/// appropriate by the btrblocks compressor. Current implementation only
50/// checks the first chunk to decide whether to apply dict layout and
51/// encodes chunks into dictionaries. When the dict constraints are hit, a
52/// new dictionary is created.
53#[derive(Clone)]
54pub struct DictStrategy<Codes, Values, Fallback> {
55    codes: Codes,
56    values: Values,
57    fallback: Fallback,
58    options: DictLayoutOptions,
59    executor: Arc<dyn TaskExecutor>,
60}
61
62impl<Codes, Values, Fallback> DictStrategy<Codes, Values, Fallback>
63where
64    Codes: LayoutStrategy,
65    Values: LayoutStrategy,
66    Fallback: LayoutStrategy,
67{
68    pub fn new(
69        codes: Codes,
70        values: Values,
71        fallback: Fallback,
72        options: DictLayoutOptions,
73        executor: Arc<dyn TaskExecutor>,
74    ) -> Self {
75        Self {
76            codes,
77            values,
78            fallback,
79            options,
80            executor,
81        }
82    }
83}
84
85#[async_trait]
86impl<Codes, Values, Fallback> LayoutStrategy for DictStrategy<Codes, Values, Fallback>
87where
88    Codes: LayoutStrategy,
89    Values: LayoutStrategy,
90    Fallback: LayoutStrategy,
91{
92    async fn write_stream(
93        &self,
94        ctx: &ArrayContext,
95        sequence_writer: SequenceWriter,
96        stream: SendableSequentialStream,
97    ) -> VortexResult<LayoutRef> {
98        if !dict_layout_supported(stream.dtype()) {
99            return self
100                .fallback
101                .write_stream(ctx, sequence_writer, stream)
102                .await;
103        }
104        let ctx = ctx.clone();
105        let options = self.options.clone();
106        let dtype = stream.dtype().clone();
107        let executor = self.executor.clone();
108        // 0. decide if chunks are eligible for dict encoding
109        let (stream, first_chunk) = peek_first_chunk(stream).await?;
110        let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
111
112        let should_fallback = match first_chunk {
113            None => true, // empty stream
114            Some(chunk) => {
115                let compressed = BtrBlocksCompressor.compress(&chunk)?;
116                !compressed.is_encoding(DictEncoding.id())
117            }
118        };
119        if should_fallback {
120            // first chunk did not compress to dict, or did not exist. Skip dict layout
121            return self
122                .fallback
123                .write_stream(&ctx, sequence_writer.clone(), stream)
124                .await;
125        }
126
127        // 1. from a chunk stream, create a stream that yields codes
128        // followed by a single value chunk when dict constraints are hit.
129        // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
130        let mut dict_stream = dict_encode_stream(stream, options.constraints);
131
132        // 2.a spawn encoding codes
133        let (mut encoded_tx, encoded_rx) = mpsc::channel(options.encoded_buffer_size);
134        let encode_handle = executor.spawn({
135            async move {
136                while let Some(item) = dict_stream.next().await {
137                    encoded_tx
138                        .send(item)
139                        .await
140                        .map_err(|e| vortex_err!("rx dropped: {}", e))?;
141                }
142                Ok(())
143            }
144            .boxed()
145        });
146
147        // 2.b get contiguous runs of codes from the dict stream and
148        // create child dict layouts from them.
149        let dtype_clone = dtype.clone();
150        let child_layouts_fut = async move {
151            let mut children = Vec::new();
152            let mut runs = DictEncodedRuns::new(Box::pin(encoded_rx));
153            while let Some((codes_stream, values_future)) = runs.next_run().await {
154                let (codes_stream, first_chunk) = peek_first_chunk(codes_stream.boxed()).await?;
155                let codes_dtype = match first_chunk {
156                    // codes_stream is empty, this would happen if the parent stream end coincided with a dict run end
157                    None => break,
158                    Some(chunk) => chunk.dtype().clone(),
159                };
160                let codes_layout = self
161                    .codes
162                    .write_stream(
163                        &ctx,
164                        sequence_writer.clone(),
165                        SequentialStreamAdapter::new(codes_dtype, codes_stream).sendable(),
166                    )
167                    .await?;
168                let values_layout = self
169                    .values
170                    .write_stream(
171                        &ctx,
172                        sequence_writer.clone(),
173                        SequentialStreamAdapter::new(dtype_clone.clone(), once(values_future))
174                            .sendable(),
175                    )
176                    .await?;
177                children.push(DictLayout::new(values_layout, codes_layout).into_layout());
178            }
179            Ok(children)
180        };
181
182        // join dict encoding task
183        let (mut children, _) = try_join!(child_layouts_fut, encode_handle)?;
184
185        if children.len() == 1 {
186            return Ok(children.remove(0));
187        }
188
189        let row_count = children.iter().map(|child| child.row_count()).sum();
190        Ok(ChunkedLayout::new(
191            row_count,
192            dtype,
193            OwnedLayoutChildren::layout_children(children),
194        )
195        .into_layout())
196    }
197}
198
199enum DictionaryChunk {
200    Codes((SequenceId, ArrayRef)),
201    Values((SequenceId, ArrayRef)),
202}
203
204type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
205
206fn dict_encode_stream(
207    input: SendableSequentialStream,
208    constraints: DictConstraints,
209) -> DictionaryStream {
210    Box::pin(try_stream! {
211        let mut state = DictStreamState {
212            encoder: None,
213            constraints,
214        };
215        let input = input.peekable();
216        pin_mut!(input);
217        while let Some(item) = input.as_mut().next().await {
218            let (sequence_id, chunk) = item?;
219            // labeler potentially creates sub sequences, we must
220            // create it on both arms to avoid having a SequencePointer
221            // between await points
222            match input.as_mut().peek().await {
223                Some(_) => {
224                    let mut labeler = DictChunkLabeler::new(sequence_id);
225                    let chunks = state.encode(&mut labeler, chunk);
226                    drop(labeler);
227                    for dict_chunk in chunks {
228                        yield dict_chunk?;
229                    }
230                }
231                None => {
232                    // this is the last element, encode and drain chunks
233                    let mut labeler = DictChunkLabeler::new(sequence_id);
234                    let encoded = state.encode(&mut labeler, chunk);
235                    let drained = state.drain_values(&mut labeler);
236                    drop(labeler);
237                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
238                        yield dict_chunk?;
239                    }
240                }
241            }
242        }
243    })
244}
245
246struct DictStreamState {
247    encoder: Option<Box<dyn DictEncoder>>,
248    constraints: DictConstraints,
249}
250
251impl DictStreamState {
252    fn encode(
253        &mut self,
254        labeler: &mut DictChunkLabeler,
255        chunk: ArrayRef,
256    ) -> Vec<VortexResult<DictionaryChunk>> {
257        self.try_encode(labeler, chunk)
258            .unwrap_or_else(|e| vec![Err(e)])
259    }
260
261    fn try_encode(
262        &mut self,
263        labeler: &mut DictChunkLabeler,
264        chunk: ArrayRef,
265    ) -> VortexResult<Vec<VortexResult<DictionaryChunk>>> {
266        let mut res = Vec::new();
267        let mut to_be_encoded = Some(chunk);
268        while let Some(remaining) = to_be_encoded.take() {
269            match self.encoder.take() {
270                None => match start_encoding(&self.constraints, &remaining)? {
271                    EncodingState::Continue((encoder, encoded)) => {
272                        res.push(Ok(labeler.codes(encoded)));
273                        self.encoder = Some(encoder);
274                    }
275                    EncodingState::Done((values, encoded, unencoded)) => {
276                        res.push(Ok(labeler.codes(encoded)));
277                        res.push(Ok(labeler.values(values)));
278                        to_be_encoded = Some(unencoded);
279                    }
280                },
281                Some(encoder) => match encode_chunk(encoder, &remaining)? {
282                    EncodingState::Continue((encoder, encoded)) => {
283                        res.push(Ok(labeler.codes(encoded)));
284                        self.encoder = Some(encoder);
285                    }
286                    EncodingState::Done((values, encoded, unencoded)) => {
287                        res.push(Ok(labeler.codes(encoded)));
288                        res.push(Ok(labeler.values(values)));
289                        to_be_encoded = Some(unencoded);
290                    }
291                },
292            }
293        }
294        Ok(res)
295    }
296
297    fn drain_values(
298        &mut self,
299        labeler: &mut DictChunkLabeler,
300    ) -> Vec<VortexResult<DictionaryChunk>> {
301        match self.encoder.as_mut() {
302            None => Vec::new(),
303            Some(encoder) => vec![encoder.values().map(|val| labeler.values(val))],
304        }
305    }
306}
307
308struct DictChunkLabeler {
309    sequence_pointer: SequencePointer,
310}
311
312impl DictChunkLabeler {
313    fn new(starting_id: SequenceId) -> Self {
314        let sequence_pointer = starting_id.descend();
315        Self { sequence_pointer }
316    }
317
318    fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
319        DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
320    }
321
322    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
323        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
324    }
325}
326
327type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
328
329struct DictEncodedRuns {
330    input: Option<oneshot::Receiver<Option<DictionaryStream>>>,
331}
332
333impl DictEncodedRuns {
334    fn new(input: DictionaryStream) -> Self {
335        let (tx, rx) = oneshot::channel();
336        tx.send(Some(input))
337            .map_err(|_input| vortex_err!("just created rx"))
338            .vortex_unwrap();
339        Self { input: Some(rx) }
340    }
341
342    async fn next_run(
343        &mut self,
344    ) -> Option<(
345        DictEncodedRunStream,
346        impl Future<Output = SequencedChunk> + use<>,
347    )> {
348        // get input to send to the run stream.
349        let Ok(Some(input)) = self.input.take()?.await else {
350            // input exhausted
351            return None;
352        };
353        let (input_tx, input_rx) = oneshot::channel();
354        self.input = Some(input_rx);
355
356        let (values_tx, values_rx) = oneshot::channel();
357        let values_future = async {
358            values_rx
359                .await
360                .unwrap_or_else(|_| vortex_bail!("sender dropped"))
361        };
362
363        let codes_stream = DictEncodedRunStream {
364            input: Some(input),
365            input_tx: Some(input_tx),
366            values_tx: Some(values_tx),
367        };
368
369        Some((codes_stream, values_future))
370    }
371}
372
373struct DictEncodedRunStream {
374    input: Option<DictionaryStream>,
375    input_tx: Option<oneshot::Sender<Option<DictionaryStream>>>,
376    values_tx: Option<oneshot::Sender<SequencedChunk>>,
377}
378
379impl Stream for DictEncodedRunStream {
380    type Item = SequencedChunk;
381
382    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
383        let poll_result = {
384            let Some(stream) = self.input.as_mut() else {
385                return Poll::Ready(None);
386            };
387            ready!(stream.poll_next_unpin(cx))
388        };
389
390        match poll_result {
391            Some(Ok(DictionaryChunk::Codes(item))) => Poll::Ready(Some(Ok(item))),
392            Some(Ok(DictionaryChunk::Values(item))) => {
393                self.send_values(item);
394                self.send_back_input_stream();
395                Poll::Ready(None)
396            }
397            Some(Err(e)) => Poll::Ready(Some(Err(e))),
398            None => {
399                self.send_back_input_stream();
400                Poll::Ready(None)
401            }
402        }
403    }
404}
405
406impl DictEncodedRunStream {
407    fn send_values(&mut self, item: (SequenceId, ArrayRef)) {
408        // ignore receiver drops
409        let _ = self
410            .values_tx
411            .take()
412            .vortex_expect("must not be polled after returning None")
413            .send(Ok(item));
414    }
415
416    fn send_back_input_stream(&mut self) {
417        // ignore receiver drops
418        let _ = self
419            .input_tx
420            .take()
421            .vortex_expect("input already sent")
422            .send(self.input.take());
423    }
424}
425
426impl Drop for DictEncodedRunStream {
427    fn drop(&mut self) {
428        if let Some(tx) = self.input_tx.take() {
429            let _ = tx.send(self.input.take());
430        }
431    }
432}
433
434async fn peek_first_chunk(
435    mut stream: BoxStream<'static, SequencedChunk>,
436) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
437    match stream.next().await {
438        None => Ok((stream.boxed(), None)),
439        Some(Err(e)) => Err(e),
440        Some(Ok((sequence_id, chunk))) => {
441            let chunk_clone = chunk.clone();
442            let reconstructed_stream =
443                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
444            Ok((reconstructed_stream.boxed(), Some(chunk)))
445        }
446    }
447}
448
449pub fn dict_layout_supported(dtype: &DType) -> bool {
450    matches!(
451        dtype,
452        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
453    )
454}
455
456#[derive(prost::Message)]
457pub struct DictLayoutMetadata {
458    #[prost(enumeration = "PType", tag = "1")]
459    // i32 is required for proto, use the generated getter to read this field.
460    codes_ptype: i32,
461}
462
463impl DictLayoutMetadata {
464    pub fn new(codes_ptype: PType) -> Self {
465        let mut metadata = Self::default();
466        metadata.set_codes_ptype(codes_ptype);
467        metadata
468    }
469}
470
471enum EncodingState {
472    Continue((Box<dyn DictEncoder>, ArrayRef)),
473    // (values, encoded, unencoded)
474    Done((ArrayRef, ArrayRef, ArrayRef)),
475}
476
477fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
478    let encoder = dict_encoder(chunk, constraints)?;
479    encode_chunk(encoder, chunk)
480}
481
482fn encode_chunk(
483    mut encoder: Box<dyn DictEncoder>,
484    chunk: &dyn Array,
485) -> VortexResult<EncodingState> {
486    let encoded = encoder.encode(chunk)?;
487    Ok(match remainder(chunk, encoded.len()) {
488        None => EncodingState::Continue((encoder, encoded)),
489        Some(unencoded) => EncodingState::Done((encoder.values()?, encoded, unencoded)),
490    })
491}
492
493fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
494    (encoded_len < array.len()).then(|| array.slice(encoded_len, array.len()))
495}