vortex_layout/layouts/dict/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use async_stream::{stream, try_stream};
9use async_trait::async_trait;
10use futures::future::BoxFuture;
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, Stream, StreamExt, TryStreamExt, pin_mut, try_join};
13use vortex_array::{Array, ArrayContext, ArrayRef};
14use vortex_btrblocks::BtrBlocksCompressor;
15use vortex_dict::DictEncoding;
16use vortex_dict::builders::{DictConstraints, DictEncoder, dict_encoder};
17use vortex_dtype::Nullability::NonNullable;
18use vortex_dtype::{DType, PType};
19use vortex_error::{VortexError, VortexResult, vortex_err};
20use vortex_io::kanal_ext::KanalExt;
21use vortex_io::runtime::Handle;
22
23use crate::layouts::chunked::ChunkedLayout;
24use crate::layouts::dict::DictLayout;
25use crate::segments::SegmentSinkRef;
26use crate::sequence::{
27    SendableSequentialStream, SequenceId, SequencePointer, SequentialStream,
28    SequentialStreamAdapter, SequentialStreamExt,
29};
30use crate::{IntoLayout, LayoutRef, LayoutStrategy, OwnedLayoutChildren};
31
32#[derive(Clone)]
33pub struct DictLayoutConstraints {
34    pub max_bytes: usize,
35    // Dict layout codes currently only support u16 codes
36    pub max_len: u16,
37}
38
39impl From<DictLayoutConstraints> for DictConstraints {
40    fn from(value: DictLayoutConstraints) -> Self {
41        DictConstraints {
42            max_bytes: value.max_bytes,
43            max_len: value.max_len as usize,
44        }
45    }
46}
47
48impl Default for DictLayoutConstraints {
49    fn default() -> Self {
50        Self {
51            max_bytes: 1024 * 1024,
52            max_len: u16::MAX,
53        }
54    }
55}
56
57#[derive(Clone, Default)]
58pub struct DictLayoutOptions {
59    pub constraints: DictLayoutConstraints,
60}
61
62/// A layout strategy that encodes chunk into values and codes, if found
63/// appropriate by the btrblocks compressor. Current implementation only
64/// checks the first chunk to decide whether to apply dict layout and
65/// encodes chunks into dictionaries. When the dict constraints are hit, a
66/// new dictionary is created.
67#[derive(Clone)]
68pub struct DictStrategy {
69    codes: Arc<dyn LayoutStrategy>,
70    values: Arc<dyn LayoutStrategy>,
71    fallback: Arc<dyn LayoutStrategy>,
72    options: DictLayoutOptions,
73}
74
75impl DictStrategy {
76    pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
77        codes: Codes,
78        values: Values,
79        fallback: Fallback,
80        options: DictLayoutOptions,
81    ) -> Self {
82        Self {
83            codes: Arc::new(codes),
84            values: Arc::new(values),
85            fallback: Arc::new(fallback),
86            options,
87        }
88    }
89}
90
91#[async_trait]
92impl LayoutStrategy for DictStrategy {
93    async fn write_stream(
94        &self,
95        ctx: ArrayContext,
96        segment_sink: SegmentSinkRef,
97        stream: SendableSequentialStream,
98        mut eof: SequencePointer,
99        handle: Handle,
100    ) -> VortexResult<LayoutRef> {
101        // Fallback if dtype is not supported
102        if !dict_layout_supported(stream.dtype()) {
103            return self
104                .fallback
105                .write_stream(ctx, segment_sink, stream, eof, handle)
106                .await;
107        }
108
109        let options = self.options.clone();
110        let dtype = stream.dtype().clone();
111
112        // 0. decide if chunks are eligible for dict encoding
113        let (stream, first_chunk) = peek_first_chunk(stream).await?;
114        let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
115
116        let should_fallback = match first_chunk {
117            None => true, // empty stream
118            Some(chunk) => {
119                let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
120                !compressed.is_encoding(DictEncoding.id())
121            }
122        };
123        if should_fallback {
124            // first chunk did not compress to dict, or did not exist. Skip dict layout
125            return self
126                .fallback
127                .write_stream(ctx, segment_sink, stream, eof, handle)
128                .await;
129        }
130
131        // 1. from a chunk stream, create a stream that yields codes
132        // followed by a single value chunk when dict constraints are hit.
133        // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
134        let dict_stream = dict_encode_stream(stream, options.constraints.into());
135
136        // Wrap up the dict stream to yield pairs of (codes_stream, values_future).
137        // Each of these pairs becomes a child dict layout.
138        let runs = DictionaryTransformer::new(dict_stream);
139
140        let dtype2 = dtype.clone();
141        let child_layouts = stream! {
142            pin_mut!(runs);
143
144            while let Some((codes_stream, values_fut)) = runs.next().await {
145                let codes = self.codes.clone();
146                let codes_eof = eof.split_off();
147                let ctx2 = ctx.clone();
148                let segment_sink2 = segment_sink.clone();
149                let codes_fut = handle.spawn_nested(move |h| async move {
150                    codes.write_stream(
151                        ctx2,
152                        segment_sink2,
153                        codes_stream.sendable(),
154                        codes_eof,
155                        h,
156                    ).await
157                });
158
159                let values = self.values.clone();
160                let values_eof = eof.split_off();
161                let ctx2 = ctx.clone();
162                let segment_sink2 = segment_sink.clone();
163                let dtype2 = dtype2.clone();
164                let values_layout = handle.spawn_nested(move |h| async move {
165                    values.write_stream(
166                        ctx2,
167                        segment_sink2,
168                        SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
169                        values_eof,
170                        h,
171                    ).await
172                });
173
174                yield async move {
175                    try_join!(codes_fut, values_layout)
176                }.boxed();
177            }
178        };
179
180        let mut child_layouts = child_layouts
181            .buffered(usize::MAX)
182            .map(|result| {
183                let (codes_layout, values_layout) = result?;
184                Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
185            })
186            .try_collect::<Vec<_>>()
187            .await?;
188
189        if child_layouts.len() == 1 {
190            return Ok(child_layouts.remove(0));
191        }
192
193        let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
194        Ok(ChunkedLayout::new(
195            row_count,
196            dtype,
197            OwnedLayoutChildren::layout_children(child_layouts),
198        )
199        .into_layout())
200    }
201
202    fn buffered_bytes(&self) -> u64 {
203        self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
204    }
205}
206
207enum DictionaryChunk {
208    Codes((SequenceId, ArrayRef)),
209    Values((SequenceId, ArrayRef)),
210}
211
212type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
213
214fn dict_encode_stream(
215    input: SendableSequentialStream,
216    constraints: DictConstraints,
217) -> DictionaryStream {
218    Box::pin(try_stream! {
219        let mut state = DictStreamState {
220            encoder: None,
221            constraints,
222        };
223
224        let input = input.peekable();
225        pin_mut!(input);
226
227        while let Some(item) = input.next().await {
228            let (sequence_id, chunk) = item?;
229
230            // labeler potentially creates sub sequences, we must
231            // create it on both arms to avoid having a SequencePointer
232            // between await points
233            match input.as_mut().peek().await {
234                Some(_) => {
235                    let mut labeler = DictChunkLabeler::new(sequence_id);
236                    let chunks = state.encode(&mut labeler, chunk);
237                    drop(labeler);
238                    for dict_chunk in chunks {
239                        yield dict_chunk?;
240                    }
241                }
242                None => {
243                    // this is the last element, encode and drain chunks
244                    let mut labeler = DictChunkLabeler::new(sequence_id);
245                    let encoded = state.encode(&mut labeler, chunk);
246                    let drained = state.drain_values(&mut labeler);
247                    drop(labeler);
248                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
249                        yield dict_chunk?;
250                    }
251                }
252            }
253        }
254    })
255}
256
257struct DictStreamState {
258    encoder: Option<Box<dyn DictEncoder>>,
259    constraints: DictConstraints,
260}
261
262impl DictStreamState {
263    fn encode(
264        &mut self,
265        labeler: &mut DictChunkLabeler,
266        chunk: ArrayRef,
267    ) -> Vec<VortexResult<DictionaryChunk>> {
268        self.try_encode(labeler, chunk)
269            .unwrap_or_else(|e| vec![Err(e)])
270    }
271
272    fn try_encode(
273        &mut self,
274        labeler: &mut DictChunkLabeler,
275        chunk: ArrayRef,
276    ) -> VortexResult<Vec<VortexResult<DictionaryChunk>>> {
277        let mut res = Vec::new();
278        let mut to_be_encoded = Some(chunk);
279        while let Some(remaining) = to_be_encoded.take() {
280            match self.encoder.take() {
281                None => match start_encoding(&self.constraints, &remaining)? {
282                    EncodingState::Continue((encoder, encoded)) => {
283                        res.push(Ok(labeler.codes(encoded)));
284                        self.encoder = Some(encoder);
285                    }
286                    EncodingState::Done((values, encoded, unencoded)) => {
287                        res.push(Ok(labeler.codes(encoded)));
288                        res.push(Ok(labeler.values(values)));
289                        to_be_encoded = Some(unencoded);
290                    }
291                },
292                Some(encoder) => match encode_chunk(encoder, &remaining)? {
293                    EncodingState::Continue((encoder, encoded)) => {
294                        res.push(Ok(labeler.codes(encoded)));
295                        self.encoder = Some(encoder);
296                    }
297                    EncodingState::Done((values, encoded, unencoded)) => {
298                        res.push(Ok(labeler.codes(encoded)));
299                        res.push(Ok(labeler.values(values)));
300                        to_be_encoded = Some(unencoded);
301                    }
302                },
303            }
304        }
305        Ok(res)
306    }
307
308    fn drain_values(
309        &mut self,
310        labeler: &mut DictChunkLabeler,
311    ) -> Vec<VortexResult<DictionaryChunk>> {
312        match self.encoder.as_mut() {
313            None => Vec::new(),
314            Some(encoder) => vec![encoder.values().map(|val| labeler.values(val))],
315        }
316    }
317}
318
319struct DictChunkLabeler {
320    sequence_pointer: SequencePointer,
321}
322
323impl DictChunkLabeler {
324    fn new(starting_id: SequenceId) -> Self {
325        let sequence_pointer = starting_id.descend();
326        Self { sequence_pointer }
327    }
328
329    fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
330        DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
331    }
332
333    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
334        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
335    }
336}
337
338type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
339
340struct DictionaryTransformer {
341    input: DictionaryStream,
342    active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
343    active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
344    pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
345}
346
347impl DictionaryTransformer {
348    fn new(input: DictionaryStream) -> Self {
349        Self {
350            input,
351            active_codes_tx: None,
352            active_values_tx: None,
353            pending_send: None,
354        }
355    }
356}
357
358impl Stream for DictionaryTransformer {
359    type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
360
361    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362        loop {
363            // First, try to complete any pending send
364            if let Some(mut send_fut) = self.pending_send.take() {
365                match send_fut.poll_unpin(cx) {
366                    Poll::Ready(Ok(())) => {
367                        // Send completed, continue processing
368                    }
369                    Poll::Ready(Err(_)) => {
370                        // Receiver dropped, close this group
371                        self.active_codes_tx = None;
372                        if let Some(values_tx) = self.active_values_tx.take() {
373                            let _ = values_tx.send(Err(vortex_err!("values receiver dropped")));
374                        }
375                    }
376                    Poll::Pending => {
377                        // Still pending, save it and return
378                        self.pending_send = Some(send_fut);
379                        return Poll::Pending;
380                    }
381                }
382            }
383
384            match self.input.poll_next_unpin(cx) {
385                Poll::Ready(Some(Ok(DictionaryChunk::Codes(codes)))) => {
386                    if self.active_codes_tx.is_none() {
387                        // Start a new group
388                        let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
389                        let (values_tx, values_rx) = oneshot::channel();
390
391                        self.active_codes_tx = Some(codes_tx.clone());
392                        self.active_values_tx = Some(values_tx);
393
394                        // Send first codes
395                        self.pending_send =
396                            Some(Box::pin(async move { codes_tx.send(Ok(codes)).await }));
397
398                        // Create output streams
399                        let codes_stream = SequentialStreamAdapter::new(
400                            DType::Primitive(PType::U16, NonNullable),
401                            codes_rx.into_stream().boxed(),
402                        )
403                        .sendable();
404
405                        let values_future = async move {
406                            values_rx
407                                .await
408                                .map_err(|e| vortex_err!("values sender dropped: {}", e))
409                                .flatten()
410                        }
411                        .boxed();
412
413                        return Poll::Ready(Some((codes_stream, values_future)));
414                    } else {
415                        // Continue streaming codes to existing group
416                        if let Some(tx) = &self.active_codes_tx {
417                            let tx = tx.clone();
418                            self.pending_send =
419                                Some(Box::pin(async move { tx.send(Ok(codes)).await }));
420                        }
421                    }
422                }
423                Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
424                    // Complete the current group
425                    if let Some(values_tx) = self.active_values_tx.take() {
426                        let _ = values_tx.send(Ok(values));
427                    }
428                    self.active_codes_tx = None; // Close codes stream
429                }
430                Poll::Ready(Some(Err(e))) => {
431                    // Send error to active channels if any
432                    if let Some(values_tx) = self.active_values_tx.take() {
433                        let _ = values_tx.send(Err(e));
434                    }
435                    self.active_codes_tx = None;
436                    // And terminate the stream
437                    return Poll::Ready(None);
438                }
439                Poll::Ready(None) => {
440                    // Handle any incomplete group
441                    if let Some(values_tx) = self.active_values_tx.take() {
442                        let _ = values_tx.send(Err(vortex_err!("Incomplete dictionary group")));
443                    }
444                    self.active_codes_tx = None;
445                    return Poll::Ready(None);
446                }
447                Poll::Pending => return Poll::Pending,
448            }
449        }
450    }
451}
452
453async fn peek_first_chunk(
454    mut stream: BoxStream<'static, SequencedChunk>,
455) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
456    match stream.next().await {
457        None => Ok((stream.boxed(), None)),
458        Some(Err(e)) => Err(e),
459        Some(Ok((sequence_id, chunk))) => {
460            let chunk_clone = chunk.clone();
461            let reconstructed_stream =
462                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
463            Ok((reconstructed_stream.boxed(), Some(chunk)))
464        }
465    }
466}
467
468pub fn dict_layout_supported(dtype: &DType) -> bool {
469    matches!(
470        dtype,
471        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
472    )
473}
474
475#[derive(prost::Message)]
476pub struct DictLayoutMetadata {
477    #[prost(enumeration = "PType", tag = "1")]
478    // i32 is required for proto, use the generated getter to read this field.
479    codes_ptype: i32,
480}
481
482impl DictLayoutMetadata {
483    pub fn new(codes_ptype: PType) -> Self {
484        let mut metadata = Self::default();
485        metadata.set_codes_ptype(codes_ptype);
486        metadata
487    }
488}
489
490enum EncodingState {
491    Continue((Box<dyn DictEncoder>, ArrayRef)),
492    // (values, encoded, unencoded)
493    Done((ArrayRef, ArrayRef, ArrayRef)),
494}
495
496fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
497    let encoder = dict_encoder(chunk, constraints)?;
498    encode_chunk(encoder, chunk)
499}
500
501fn encode_chunk(
502    mut encoder: Box<dyn DictEncoder>,
503    chunk: &dyn Array,
504) -> VortexResult<EncodingState> {
505    let encoded = encoder.encode(chunk)?;
506    Ok(match remainder(chunk, encoded.len()) {
507        None => EncodingState::Continue((encoder, encoded)),
508        Some(unencoded) => EncodingState::Done((encoder.values()?, encoded, unencoded)),
509    })
510}
511
512fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
513    (encoded_len < array.len()).then(|| array.slice(encoded_len..array.len()))
514}