vortex_layout/layouts/dict/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use async_stream::stream;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::Stream;
14use futures::StreamExt;
15use futures::TryStreamExt;
16use futures::future::BoxFuture;
17use futures::pin_mut;
18use futures::stream::BoxStream;
19use futures::stream::once;
20use futures::try_join;
21use vortex_array::Array;
22use vortex_array::ArrayContext;
23use vortex_array::ArrayRef;
24use vortex_array::arrays::DictVTable;
25use vortex_array::builders::dict::DictConstraints;
26use vortex_array::builders::dict::DictEncoder;
27use vortex_array::builders::dict::dict_encoder;
28use vortex_btrblocks::BtrBlocksCompressor;
29use vortex_dtype::DType;
30use vortex_dtype::Nullability::NonNullable;
31use vortex_dtype::PType;
32use vortex_error::VortexError;
33use vortex_error::VortexResult;
34use vortex_error::vortex_err;
35use vortex_io::kanal_ext::KanalExt;
36use vortex_io::runtime::Handle;
37
38use crate::IntoLayout;
39use crate::LayoutRef;
40use crate::LayoutStrategy;
41use crate::OwnedLayoutChildren;
42use crate::layouts::chunked::ChunkedLayout;
43use crate::layouts::dict::DictLayout;
44use crate::segments::SegmentSinkRef;
45use crate::sequence::SendableSequentialStream;
46use crate::sequence::SequenceId;
47use crate::sequence::SequencePointer;
48use crate::sequence::SequentialStream;
49use crate::sequence::SequentialStreamAdapter;
50use crate::sequence::SequentialStreamExt;
51
52#[derive(Clone)]
53pub struct DictLayoutConstraints {
54    pub max_bytes: usize,
55    // Dict layout codes currently only support u16 codes
56    pub max_len: u16,
57}
58
59impl From<DictLayoutConstraints> for DictConstraints {
60    fn from(value: DictLayoutConstraints) -> Self {
61        DictConstraints {
62            max_bytes: value.max_bytes,
63            max_len: value.max_len as usize,
64        }
65    }
66}
67
68impl Default for DictLayoutConstraints {
69    fn default() -> Self {
70        Self {
71            max_bytes: 1024 * 1024,
72            max_len: u16::MAX,
73        }
74    }
75}
76
77#[derive(Clone, Default)]
78pub struct DictLayoutOptions {
79    pub constraints: DictLayoutConstraints,
80}
81
82/// A layout strategy that encodes chunk into values and codes, if found
83/// appropriate by the btrblocks compressor. Current implementation only
84/// checks the first chunk to decide whether to apply dict layout and
85/// encodes chunks into dictionaries. When the dict constraints are hit, a
86/// new dictionary is created.
87#[derive(Clone)]
88pub struct DictStrategy {
89    codes: Arc<dyn LayoutStrategy>,
90    values: Arc<dyn LayoutStrategy>,
91    fallback: Arc<dyn LayoutStrategy>,
92    options: DictLayoutOptions,
93}
94
95impl DictStrategy {
96    pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
97        codes: Codes,
98        values: Values,
99        fallback: Fallback,
100        options: DictLayoutOptions,
101    ) -> Self {
102        Self {
103            codes: Arc::new(codes),
104            values: Arc::new(values),
105            fallback: Arc::new(fallback),
106            options,
107        }
108    }
109}
110
111#[async_trait]
112impl LayoutStrategy for DictStrategy {
113    async fn write_stream(
114        &self,
115        ctx: ArrayContext,
116        segment_sink: SegmentSinkRef,
117        stream: SendableSequentialStream,
118        mut eof: SequencePointer,
119        handle: Handle,
120    ) -> VortexResult<LayoutRef> {
121        // Fallback if dtype is not supported
122        if !dict_layout_supported(stream.dtype()) {
123            return self
124                .fallback
125                .write_stream(ctx, segment_sink, stream, eof, handle)
126                .await;
127        }
128
129        let options = self.options.clone();
130        let dtype = stream.dtype().clone();
131
132        // 0. decide if chunks are eligible for dict encoding
133        let (stream, first_chunk) = peek_first_chunk(stream).await?;
134        let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
135
136        let should_fallback = match first_chunk {
137            None => true, // empty stream
138            Some(chunk) => {
139                let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
140                !compressed.is::<DictVTable>()
141            }
142        };
143        if should_fallback {
144            // first chunk did not compress to dict, or did not exist. Skip dict layout
145            return self
146                .fallback
147                .write_stream(ctx, segment_sink, stream, eof, handle)
148                .await;
149        }
150
151        // 1. from a chunk stream, create a stream that yields codes
152        // followed by a single value chunk when dict constraints are hit.
153        // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
154        let dict_stream = dict_encode_stream(stream, options.constraints.into());
155
156        // Wrap up the dict stream to yield pairs of (codes_stream, values_future).
157        // Each of these pairs becomes a child dict layout.
158        let runs = DictionaryTransformer::new(dict_stream);
159
160        let dtype2 = dtype.clone();
161        let child_layouts = stream! {
162            pin_mut!(runs);
163
164            while let Some((codes_stream, values_fut)) = runs.next().await {
165                let codes = self.codes.clone();
166                let codes_eof = eof.split_off();
167                let ctx2 = ctx.clone();
168                let segment_sink2 = segment_sink.clone();
169                let codes_fut = handle.spawn_nested(move |h| async move {
170                    codes.write_stream(
171                        ctx2,
172                        segment_sink2,
173                        codes_stream.sendable(),
174                        codes_eof,
175                        h,
176                    ).await
177                });
178
179                let values = self.values.clone();
180                let values_eof = eof.split_off();
181                let ctx2 = ctx.clone();
182                let segment_sink2 = segment_sink.clone();
183                let dtype2 = dtype2.clone();
184                let values_layout = handle.spawn_nested(move |h| async move {
185                    values.write_stream(
186                        ctx2,
187                        segment_sink2,
188                        SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
189                        values_eof,
190                        h,
191                    ).await
192                });
193
194                yield async move {
195                    try_join!(codes_fut, values_layout)
196                }.boxed();
197            }
198        };
199
200        let mut child_layouts = child_layouts
201            .buffered(usize::MAX)
202            .map(|result| {
203                let (codes_layout, values_layout) = result?;
204                // All values are referenced when created via dictionary encoding
205                Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
206            })
207            .try_collect::<Vec<_>>()
208            .await?;
209
210        if child_layouts.len() == 1 {
211            return Ok(child_layouts.remove(0));
212        }
213
214        let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
215        Ok(ChunkedLayout::new(
216            row_count,
217            dtype,
218            OwnedLayoutChildren::layout_children(child_layouts),
219        )
220        .into_layout())
221    }
222
223    fn buffered_bytes(&self) -> u64 {
224        self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
225    }
226}
227
228enum DictionaryChunk {
229    Codes((SequenceId, ArrayRef)),
230    Values((SequenceId, ArrayRef)),
231}
232
233type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
234
235fn dict_encode_stream(
236    input: SendableSequentialStream,
237    constraints: DictConstraints,
238) -> DictionaryStream {
239    Box::pin(try_stream! {
240        let mut state = DictStreamState {
241            encoder: None,
242            constraints,
243        };
244
245        let input = input.peekable();
246        pin_mut!(input);
247
248        while let Some(item) = input.next().await {
249            let (sequence_id, chunk) = item?;
250
251            // labeler potentially creates sub sequences, we must
252            // create it on both arms to avoid having a SequencePointer
253            // between await points
254            match input.as_mut().peek().await {
255                Some(_) => {
256                    let mut labeler = DictChunkLabeler::new(sequence_id);
257                    let chunks = state.encode(&mut labeler, chunk);
258                    drop(labeler);
259                    for dict_chunk in chunks {
260                        yield dict_chunk;
261                    }
262                }
263                None => {
264                    // this is the last element, encode and drain chunks
265                    let mut labeler = DictChunkLabeler::new(sequence_id);
266                    let encoded = state.encode(&mut labeler, chunk);
267                    let drained = state.drain_values(&mut labeler);
268                    drop(labeler);
269                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
270                        yield dict_chunk;
271                    }
272                }
273            }
274        }
275    })
276}
277
278struct DictStreamState {
279    encoder: Option<Box<dyn DictEncoder>>,
280    constraints: DictConstraints,
281}
282
283impl DictStreamState {
284    fn encode(&mut self, labeler: &mut DictChunkLabeler, chunk: ArrayRef) -> Vec<DictionaryChunk> {
285        let mut res = Vec::new();
286        let mut to_be_encoded = Some(chunk);
287        while let Some(remaining) = to_be_encoded.take() {
288            match self.encoder.take() {
289                None => match start_encoding(&self.constraints, &remaining) {
290                    EncodingState::Continue((encoder, encoded)) => {
291                        res.push(labeler.codes(encoded));
292                        self.encoder = Some(encoder);
293                    }
294                    EncodingState::Done((values, encoded, unencoded)) => {
295                        res.push(labeler.codes(encoded));
296                        res.push(labeler.values(values));
297                        to_be_encoded = Some(unencoded);
298                    }
299                },
300                Some(encoder) => match encode_chunk(encoder, &remaining) {
301                    EncodingState::Continue((encoder, encoded)) => {
302                        res.push(labeler.codes(encoded));
303                        self.encoder = Some(encoder);
304                    }
305                    EncodingState::Done((values, encoded, unencoded)) => {
306                        res.push(labeler.codes(encoded));
307                        res.push(labeler.values(values));
308                        to_be_encoded = Some(unencoded);
309                    }
310                },
311            }
312        }
313        res
314    }
315
316    fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
317        match self.encoder.as_mut() {
318            None => Vec::new(),
319            Some(encoder) => vec![labeler.values(encoder.reset())],
320        }
321    }
322}
323
324struct DictChunkLabeler {
325    sequence_pointer: SequencePointer,
326}
327
328impl DictChunkLabeler {
329    fn new(starting_id: SequenceId) -> Self {
330        let sequence_pointer = starting_id.descend();
331        Self { sequence_pointer }
332    }
333
334    fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
335        DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
336    }
337
338    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
339        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
340    }
341}
342
343type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
344
345struct DictionaryTransformer {
346    input: DictionaryStream,
347    active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
348    active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
349    pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
350}
351
352impl DictionaryTransformer {
353    fn new(input: DictionaryStream) -> Self {
354        Self {
355            input,
356            active_codes_tx: None,
357            active_values_tx: None,
358            pending_send: None,
359        }
360    }
361}
362
363impl Stream for DictionaryTransformer {
364    type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
365
366    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
367        loop {
368            // First, try to complete any pending send
369            if let Some(mut send_fut) = self.pending_send.take() {
370                match send_fut.poll_unpin(cx) {
371                    Poll::Ready(Ok(())) => {
372                        // Send completed, continue processing
373                    }
374                    Poll::Ready(Err(_)) => {
375                        // Receiver dropped, close this group
376                        self.active_codes_tx = None;
377                        if let Some(values_tx) = self.active_values_tx.take() {
378                            drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
379                        }
380                    }
381                    Poll::Pending => {
382                        // Still pending, save it and return
383                        self.pending_send = Some(send_fut);
384                        return Poll::Pending;
385                    }
386                }
387            }
388
389            match self.input.poll_next_unpin(cx) {
390                Poll::Ready(Some(Ok(DictionaryChunk::Codes(codes)))) => {
391                    if self.active_codes_tx.is_none() {
392                        // Start a new group
393                        let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
394                        let (values_tx, values_rx) = oneshot::channel();
395
396                        self.active_codes_tx = Some(codes_tx.clone());
397                        self.active_values_tx = Some(values_tx);
398
399                        // Send first codes
400                        self.pending_send =
401                            Some(Box::pin(async move { codes_tx.send(Ok(codes)).await }));
402
403                        // Create output streams
404                        let codes_stream = SequentialStreamAdapter::new(
405                            DType::Primitive(PType::U16, NonNullable),
406                            codes_rx.into_stream().boxed(),
407                        )
408                        .sendable();
409
410                        let values_future = async move {
411                            values_rx
412                                .await
413                                .map_err(|e| vortex_err!("values sender dropped: {}", e))
414                                .flatten()
415                        }
416                        .boxed();
417
418                        return Poll::Ready(Some((codes_stream, values_future)));
419                    } else {
420                        // Continue streaming codes to existing group
421                        if let Some(tx) = &self.active_codes_tx {
422                            let tx = tx.clone();
423                            self.pending_send =
424                                Some(Box::pin(async move { tx.send(Ok(codes)).await }));
425                        }
426                    }
427                }
428                Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
429                    // Complete the current group
430                    if let Some(values_tx) = self.active_values_tx.take() {
431                        drop(values_tx.send(Ok(values)));
432                    }
433                    self.active_codes_tx = None; // Close codes stream
434                }
435                Poll::Ready(Some(Err(e))) => {
436                    // Send error to active channels if any
437                    if let Some(values_tx) = self.active_values_tx.take() {
438                        drop(values_tx.send(Err(e)));
439                    }
440                    self.active_codes_tx = None;
441                    // And terminate the stream
442                    return Poll::Ready(None);
443                }
444                Poll::Ready(None) => {
445                    // Handle any incomplete group
446                    if let Some(values_tx) = self.active_values_tx.take() {
447                        drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
448                    }
449                    self.active_codes_tx = None;
450                    return Poll::Ready(None);
451                }
452                Poll::Pending => return Poll::Pending,
453            }
454        }
455    }
456}
457
458async fn peek_first_chunk(
459    mut stream: BoxStream<'static, SequencedChunk>,
460) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
461    match stream.next().await {
462        None => Ok((stream.boxed(), None)),
463        Some(Err(e)) => Err(e),
464        Some(Ok((sequence_id, chunk))) => {
465            let chunk_clone = chunk.clone();
466            let reconstructed_stream =
467                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
468            Ok((reconstructed_stream.boxed(), Some(chunk)))
469        }
470    }
471}
472
473pub fn dict_layout_supported(dtype: &DType) -> bool {
474    matches!(
475        dtype,
476        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
477    )
478}
479
480#[derive(prost::Message)]
481pub struct DictLayoutMetadata {
482    #[prost(enumeration = "PType", tag = "1")]
483    // i32 is required for proto, use the generated getter to read this field.
484    codes_ptype: i32,
485}
486
487impl DictLayoutMetadata {
488    pub fn new(codes_ptype: PType) -> Self {
489        let mut metadata = Self::default();
490        metadata.set_codes_ptype(codes_ptype);
491        metadata
492    }
493}
494
495enum EncodingState {
496    Continue((Box<dyn DictEncoder>, ArrayRef)),
497    // (values, encoded, unencoded)
498    Done((ArrayRef, ArrayRef, ArrayRef)),
499}
500
501fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> EncodingState {
502    let encoder = dict_encoder(chunk, constraints);
503    encode_chunk(encoder, chunk)
504}
505
506fn encode_chunk(mut encoder: Box<dyn DictEncoder>, chunk: &dyn Array) -> EncodingState {
507    let encoded = encoder.encode(chunk);
508    match remainder(chunk, encoded.len()) {
509        None => EncodingState::Continue((encoder, encoded)),
510        Some(unencoded) => EncodingState::Done((encoder.reset(), encoded, unencoded)),
511    }
512}
513
514fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
515    (encoded_len < array.len()).then(|| array.slice(encoded_len..array.len()))
516}