vortex_layout/layouts/dict/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use async_stream::{stream, try_stream};
9use async_trait::async_trait;
10use futures::future::BoxFuture;
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, Stream, StreamExt, TryStreamExt, pin_mut, try_join};
13use vortex_array::arrays::DictEncoding;
14use vortex_array::builders::dict::{DictConstraints, DictEncoder, dict_encoder};
15use vortex_array::{Array, ArrayContext, ArrayRef};
16use vortex_btrblocks::BtrBlocksCompressor;
17use vortex_dtype::Nullability::NonNullable;
18use vortex_dtype::{DType, PType};
19use vortex_error::{VortexError, VortexResult, vortex_err};
20use vortex_io::kanal_ext::KanalExt;
21use vortex_io::runtime::Handle;
22
23use crate::layouts::chunked::ChunkedLayout;
24use crate::layouts::dict::DictLayout;
25use crate::segments::SegmentSinkRef;
26use crate::sequence::{
27    SendableSequentialStream, SequenceId, SequencePointer, SequentialStream,
28    SequentialStreamAdapter, SequentialStreamExt,
29};
30use crate::{IntoLayout, LayoutRef, LayoutStrategy, OwnedLayoutChildren};
31
32#[derive(Clone)]
33pub struct DictLayoutConstraints {
34    pub max_bytes: usize,
35    // Dict layout codes currently only support u16 codes
36    pub max_len: u16,
37}
38
39impl From<DictLayoutConstraints> for DictConstraints {
40    fn from(value: DictLayoutConstraints) -> Self {
41        DictConstraints {
42            max_bytes: value.max_bytes,
43            max_len: value.max_len as usize,
44        }
45    }
46}
47
48impl Default for DictLayoutConstraints {
49    fn default() -> Self {
50        Self {
51            max_bytes: 1024 * 1024,
52            max_len: u16::MAX,
53        }
54    }
55}
56
57#[derive(Clone, Default)]
58pub struct DictLayoutOptions {
59    pub constraints: DictLayoutConstraints,
60}
61
62/// A layout strategy that encodes chunk into values and codes, if found
63/// appropriate by the btrblocks compressor. Current implementation only
64/// checks the first chunk to decide whether to apply dict layout and
65/// encodes chunks into dictionaries. When the dict constraints are hit, a
66/// new dictionary is created.
67#[derive(Clone)]
68pub struct DictStrategy {
69    codes: Arc<dyn LayoutStrategy>,
70    values: Arc<dyn LayoutStrategy>,
71    fallback: Arc<dyn LayoutStrategy>,
72    options: DictLayoutOptions,
73}
74
75impl DictStrategy {
76    pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
77        codes: Codes,
78        values: Values,
79        fallback: Fallback,
80        options: DictLayoutOptions,
81    ) -> Self {
82        Self {
83            codes: Arc::new(codes),
84            values: Arc::new(values),
85            fallback: Arc::new(fallback),
86            options,
87        }
88    }
89}
90
91#[async_trait]
92impl LayoutStrategy for DictStrategy {
93    async fn write_stream(
94        &self,
95        ctx: ArrayContext,
96        segment_sink: SegmentSinkRef,
97        stream: SendableSequentialStream,
98        mut eof: SequencePointer,
99        handle: Handle,
100    ) -> VortexResult<LayoutRef> {
101        // Fallback if dtype is not supported
102        if !dict_layout_supported(stream.dtype()) {
103            return self
104                .fallback
105                .write_stream(ctx, segment_sink, stream, eof, handle)
106                .await;
107        }
108
109        let options = self.options.clone();
110        let dtype = stream.dtype().clone();
111
112        // 0. decide if chunks are eligible for dict encoding
113        let (stream, first_chunk) = peek_first_chunk(stream).await?;
114        let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
115
116        let should_fallback = match first_chunk {
117            None => true, // empty stream
118            Some(chunk) => {
119                let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
120                !compressed.is_encoding(DictEncoding.id())
121            }
122        };
123        if should_fallback {
124            // first chunk did not compress to dict, or did not exist. Skip dict layout
125            return self
126                .fallback
127                .write_stream(ctx, segment_sink, stream, eof, handle)
128                .await;
129        }
130
131        // 1. from a chunk stream, create a stream that yields codes
132        // followed by a single value chunk when dict constraints are hit.
133        // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
134        let dict_stream = dict_encode_stream(stream, options.constraints.into());
135
136        // Wrap up the dict stream to yield pairs of (codes_stream, values_future).
137        // Each of these pairs becomes a child dict layout.
138        let runs = DictionaryTransformer::new(dict_stream);
139
140        let dtype2 = dtype.clone();
141        let child_layouts = stream! {
142            pin_mut!(runs);
143
144            while let Some((codes_stream, values_fut)) = runs.next().await {
145                let codes = self.codes.clone();
146                let codes_eof = eof.split_off();
147                let ctx2 = ctx.clone();
148                let segment_sink2 = segment_sink.clone();
149                let codes_fut = handle.spawn_nested(move |h| async move {
150                    codes.write_stream(
151                        ctx2,
152                        segment_sink2,
153                        codes_stream.sendable(),
154                        codes_eof,
155                        h,
156                    ).await
157                });
158
159                let values = self.values.clone();
160                let values_eof = eof.split_off();
161                let ctx2 = ctx.clone();
162                let segment_sink2 = segment_sink.clone();
163                let dtype2 = dtype2.clone();
164                let values_layout = handle.spawn_nested(move |h| async move {
165                    values.write_stream(
166                        ctx2,
167                        segment_sink2,
168                        SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
169                        values_eof,
170                        h,
171                    ).await
172                });
173
174                yield async move {
175                    try_join!(codes_fut, values_layout)
176                }.boxed();
177            }
178        };
179
180        let mut child_layouts = child_layouts
181            .buffered(usize::MAX)
182            .map(|result| {
183                let (codes_layout, values_layout) = result?;
184                Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
185            })
186            .try_collect::<Vec<_>>()
187            .await?;
188
189        if child_layouts.len() == 1 {
190            return Ok(child_layouts.remove(0));
191        }
192
193        let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
194        Ok(ChunkedLayout::new(
195            row_count,
196            dtype,
197            OwnedLayoutChildren::layout_children(child_layouts),
198        )
199        .into_layout())
200    }
201
202    fn buffered_bytes(&self) -> u64 {
203        self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
204    }
205}
206
207enum DictionaryChunk {
208    Codes((SequenceId, ArrayRef)),
209    Values((SequenceId, ArrayRef)),
210}
211
212type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
213
214fn dict_encode_stream(
215    input: SendableSequentialStream,
216    constraints: DictConstraints,
217) -> DictionaryStream {
218    Box::pin(try_stream! {
219        let mut state = DictStreamState {
220            encoder: None,
221            constraints,
222        };
223
224        let input = input.peekable();
225        pin_mut!(input);
226
227        while let Some(item) = input.next().await {
228            let (sequence_id, chunk) = item?;
229
230            // labeler potentially creates sub sequences, we must
231            // create it on both arms to avoid having a SequencePointer
232            // between await points
233            match input.as_mut().peek().await {
234                Some(_) => {
235                    let mut labeler = DictChunkLabeler::new(sequence_id);
236                    let chunks = state.encode(&mut labeler, chunk);
237                    drop(labeler);
238                    for dict_chunk in chunks {
239                        yield dict_chunk;
240                    }
241                }
242                None => {
243                    // this is the last element, encode and drain chunks
244                    let mut labeler = DictChunkLabeler::new(sequence_id);
245                    let encoded = state.encode(&mut labeler, chunk);
246                    let drained = state.drain_values(&mut labeler);
247                    drop(labeler);
248                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
249                        yield dict_chunk;
250                    }
251                }
252            }
253        }
254    })
255}
256
257struct DictStreamState {
258    encoder: Option<Box<dyn DictEncoder>>,
259    constraints: DictConstraints,
260}
261
262impl DictStreamState {
263    fn encode(&mut self, labeler: &mut DictChunkLabeler, chunk: ArrayRef) -> Vec<DictionaryChunk> {
264        let mut res = Vec::new();
265        let mut to_be_encoded = Some(chunk);
266        while let Some(remaining) = to_be_encoded.take() {
267            match self.encoder.take() {
268                None => match start_encoding(&self.constraints, &remaining) {
269                    EncodingState::Continue((encoder, encoded)) => {
270                        res.push(labeler.codes(encoded));
271                        self.encoder = Some(encoder);
272                    }
273                    EncodingState::Done((values, encoded, unencoded)) => {
274                        res.push(labeler.codes(encoded));
275                        res.push(labeler.values(values));
276                        to_be_encoded = Some(unencoded);
277                    }
278                },
279                Some(encoder) => match encode_chunk(encoder, &remaining) {
280                    EncodingState::Continue((encoder, encoded)) => {
281                        res.push(labeler.codes(encoded));
282                        self.encoder = Some(encoder);
283                    }
284                    EncodingState::Done((values, encoded, unencoded)) => {
285                        res.push(labeler.codes(encoded));
286                        res.push(labeler.values(values));
287                        to_be_encoded = Some(unencoded);
288                    }
289                },
290            }
291        }
292        res
293    }
294
295    fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
296        match self.encoder.as_mut() {
297            None => Vec::new(),
298            Some(encoder) => vec![labeler.values(encoder.reset())],
299        }
300    }
301}
302
303struct DictChunkLabeler {
304    sequence_pointer: SequencePointer,
305}
306
307impl DictChunkLabeler {
308    fn new(starting_id: SequenceId) -> Self {
309        let sequence_pointer = starting_id.descend();
310        Self { sequence_pointer }
311    }
312
313    fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
314        DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
315    }
316
317    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
318        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
319    }
320}
321
322type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
323
324struct DictionaryTransformer {
325    input: DictionaryStream,
326    active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
327    active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
328    pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
329}
330
331impl DictionaryTransformer {
332    fn new(input: DictionaryStream) -> Self {
333        Self {
334            input,
335            active_codes_tx: None,
336            active_values_tx: None,
337            pending_send: None,
338        }
339    }
340}
341
342impl Stream for DictionaryTransformer {
343    type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
344
345    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
346        loop {
347            // First, try to complete any pending send
348            if let Some(mut send_fut) = self.pending_send.take() {
349                match send_fut.poll_unpin(cx) {
350                    Poll::Ready(Ok(())) => {
351                        // Send completed, continue processing
352                    }
353                    Poll::Ready(Err(_)) => {
354                        // Receiver dropped, close this group
355                        self.active_codes_tx = None;
356                        if let Some(values_tx) = self.active_values_tx.take() {
357                            drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
358                        }
359                    }
360                    Poll::Pending => {
361                        // Still pending, save it and return
362                        self.pending_send = Some(send_fut);
363                        return Poll::Pending;
364                    }
365                }
366            }
367
368            match self.input.poll_next_unpin(cx) {
369                Poll::Ready(Some(Ok(DictionaryChunk::Codes(codes)))) => {
370                    if self.active_codes_tx.is_none() {
371                        // Start a new group
372                        let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
373                        let (values_tx, values_rx) = oneshot::channel();
374
375                        self.active_codes_tx = Some(codes_tx.clone());
376                        self.active_values_tx = Some(values_tx);
377
378                        // Send first codes
379                        self.pending_send =
380                            Some(Box::pin(async move { codes_tx.send(Ok(codes)).await }));
381
382                        // Create output streams
383                        let codes_stream = SequentialStreamAdapter::new(
384                            DType::Primitive(PType::U16, NonNullable),
385                            codes_rx.into_stream().boxed(),
386                        )
387                        .sendable();
388
389                        let values_future = async move {
390                            values_rx
391                                .await
392                                .map_err(|e| vortex_err!("values sender dropped: {}", e))
393                                .flatten()
394                        }
395                        .boxed();
396
397                        return Poll::Ready(Some((codes_stream, values_future)));
398                    } else {
399                        // Continue streaming codes to existing group
400                        if let Some(tx) = &self.active_codes_tx {
401                            let tx = tx.clone();
402                            self.pending_send =
403                                Some(Box::pin(async move { tx.send(Ok(codes)).await }));
404                        }
405                    }
406                }
407                Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
408                    // Complete the current group
409                    if let Some(values_tx) = self.active_values_tx.take() {
410                        drop(values_tx.send(Ok(values)));
411                    }
412                    self.active_codes_tx = None; // Close codes stream
413                }
414                Poll::Ready(Some(Err(e))) => {
415                    // Send error to active channels if any
416                    if let Some(values_tx) = self.active_values_tx.take() {
417                        drop(values_tx.send(Err(e)));
418                    }
419                    self.active_codes_tx = None;
420                    // And terminate the stream
421                    return Poll::Ready(None);
422                }
423                Poll::Ready(None) => {
424                    // Handle any incomplete group
425                    if let Some(values_tx) = self.active_values_tx.take() {
426                        drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
427                    }
428                    self.active_codes_tx = None;
429                    return Poll::Ready(None);
430                }
431                Poll::Pending => return Poll::Pending,
432            }
433        }
434    }
435}
436
437async fn peek_first_chunk(
438    mut stream: BoxStream<'static, SequencedChunk>,
439) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
440    match stream.next().await {
441        None => Ok((stream.boxed(), None)),
442        Some(Err(e)) => Err(e),
443        Some(Ok((sequence_id, chunk))) => {
444            let chunk_clone = chunk.clone();
445            let reconstructed_stream =
446                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
447            Ok((reconstructed_stream.boxed(), Some(chunk)))
448        }
449    }
450}
451
452pub fn dict_layout_supported(dtype: &DType) -> bool {
453    matches!(
454        dtype,
455        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
456    )
457}
458
459#[derive(prost::Message)]
460pub struct DictLayoutMetadata {
461    #[prost(enumeration = "PType", tag = "1")]
462    // i32 is required for proto, use the generated getter to read this field.
463    codes_ptype: i32,
464}
465
466impl DictLayoutMetadata {
467    pub fn new(codes_ptype: PType) -> Self {
468        let mut metadata = Self::default();
469        metadata.set_codes_ptype(codes_ptype);
470        metadata
471    }
472}
473
474enum EncodingState {
475    Continue((Box<dyn DictEncoder>, ArrayRef)),
476    // (values, encoded, unencoded)
477    Done((ArrayRef, ArrayRef, ArrayRef)),
478}
479
480fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> EncodingState {
481    let encoder = dict_encoder(chunk, constraints);
482    encode_chunk(encoder, chunk)
483}
484
485fn encode_chunk(mut encoder: Box<dyn DictEncoder>, chunk: &dyn Array) -> EncodingState {
486    let encoded = encoder.encode(chunk);
487    match remainder(chunk, encoded.len()) {
488        None => EncodingState::Continue((encoder, encoded)),
489        Some(unencoded) => EncodingState::Done((encoder.reset(), encoded, unencoded)),
490    }
491}
492
493fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
494    (encoded_len < array.len()).then(|| array.slice(encoded_len..array.len()))
495}