Skip to main content

vortex_layout/layouts/dict/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use async_stream::stream;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::Stream;
14use futures::StreamExt;
15use futures::TryStreamExt;
16use futures::future::BoxFuture;
17use futures::pin_mut;
18use futures::stream::BoxStream;
19use futures::stream::once;
20use futures::try_join;
21use vortex_array::Array;
22use vortex_array::ArrayContext;
23use vortex_array::ArrayRef;
24use vortex_array::arrays::DictVTable;
25use vortex_array::builders::dict::DictConstraints;
26use vortex_array::builders::dict::DictEncoder;
27use vortex_array::builders::dict::dict_encoder;
28use vortex_btrblocks::BtrBlocksCompressor;
29use vortex_dtype::DType;
30use vortex_dtype::Nullability;
31use vortex_dtype::PType;
32use vortex_error::VortexError;
33use vortex_error::VortexExpect;
34use vortex_error::VortexResult;
35use vortex_error::vortex_err;
36use vortex_io::kanal_ext::KanalExt;
37use vortex_io::runtime::Handle;
38
39use crate::IntoLayout;
40use crate::LayoutRef;
41use crate::LayoutStrategy;
42use crate::OwnedLayoutChildren;
43use crate::layouts::chunked::ChunkedLayout;
44use crate::layouts::dict::DictLayout;
45use crate::segments::SegmentSinkRef;
46use crate::sequence::SendableSequentialStream;
47use crate::sequence::SequenceId;
48use crate::sequence::SequencePointer;
49use crate::sequence::SequentialStream;
50use crate::sequence::SequentialStreamAdapter;
51use crate::sequence::SequentialStreamExt;
52
53/// Constraints for dictionary layout encoding.
54///
55/// Note that [`max_len`](Self::max_len) is limited to `u16` (65,535 entries) by design. Since
56/// layout chunks are typically ~8k elements, having more than 64k unique values in a dictionary
57/// means dictionary encoding provides little compression benefit. If a column has very high
58/// cardinality, the fallback encoding strategy should be used instead.
59#[derive(Clone)]
60pub struct DictLayoutConstraints {
61    /// Maximum size of the dictionary in bytes.
62    pub max_bytes: usize,
63    /// Maximum dictionary length. Limited to `u16` because dictionaries with more than 64k unique
64    /// values provide diminishing compression returns given typical chunk sizes (~8k elements).
65    ///
66    /// The codes dtype is determined upfront from this constraint:
67    /// - [`PType::U8`] when max_len <= 255
68    /// - [`PType::U16`] when max_len > 255
69    ///
70    /// Vortex encoders must always produce unsigned integer codes; signed codes are only accepted for external compatibility.
71    pub max_len: u16,
72}
73
74impl From<DictLayoutConstraints> for DictConstraints {
75    fn from(value: DictLayoutConstraints) -> Self {
76        DictConstraints {
77            max_bytes: value.max_bytes,
78            max_len: value.max_len as usize,
79        }
80    }
81}
82
83impl Default for DictLayoutConstraints {
84    fn default() -> Self {
85        Self {
86            max_bytes: 1024 * 1024,
87            max_len: u16::MAX,
88        }
89    }
90}
91
92#[derive(Clone, Default)]
93pub struct DictLayoutOptions {
94    pub constraints: DictLayoutConstraints,
95}
96
97/// A layout strategy that encodes chunk into values and codes, if found
98/// appropriate by the btrblocks compressor. Current implementation only
99/// checks the first chunk to decide whether to apply dict layout and
100/// encodes chunks into dictionaries. When the dict constraints are hit, a
101/// new dictionary is created.
102#[derive(Clone)]
103pub struct DictStrategy {
104    codes: Arc<dyn LayoutStrategy>,
105    values: Arc<dyn LayoutStrategy>,
106    fallback: Arc<dyn LayoutStrategy>,
107    options: DictLayoutOptions,
108}
109
110impl DictStrategy {
111    pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
112        codes: Codes,
113        values: Values,
114        fallback: Fallback,
115        options: DictLayoutOptions,
116    ) -> Self {
117        Self {
118            codes: Arc::new(codes),
119            values: Arc::new(values),
120            fallback: Arc::new(fallback),
121            options,
122        }
123    }
124}
125
126#[async_trait]
127impl LayoutStrategy for DictStrategy {
128    async fn write_stream(
129        &self,
130        ctx: ArrayContext,
131        segment_sink: SegmentSinkRef,
132        stream: SendableSequentialStream,
133        mut eof: SequencePointer,
134        handle: Handle,
135    ) -> VortexResult<LayoutRef> {
136        // Fallback if dtype is not supported
137        if !dict_layout_supported(stream.dtype()) {
138            return self
139                .fallback
140                .write_stream(ctx, segment_sink, stream, eof, handle)
141                .await;
142        }
143
144        let options = self.options.clone();
145        let dtype = stream.dtype().clone();
146
147        // 0. decide if chunks are eligible for dict encoding
148        let (stream, first_chunk) = peek_first_chunk(stream).await?;
149        let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
150
151        let should_fallback = match first_chunk {
152            None => true, // empty stream
153            Some(chunk) => {
154                let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
155                !compressed.is::<DictVTable>()
156            }
157        };
158        if should_fallback {
159            // first chunk did not compress to dict, or did not exist. Skip dict layout
160            return self
161                .fallback
162                .write_stream(ctx, segment_sink, stream, eof, handle)
163                .await;
164        }
165
166        // 1. from a chunk stream, create a stream that yields codes
167        // followed by a single value chunk when dict constraints are hit.
168        // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
169        let dict_stream = dict_encode_stream(stream, options.constraints.into());
170
171        // Wrap up the dict stream to yield pairs of (codes_stream, values_future).
172        // Each of these pairs becomes a child dict layout.
173        let runs = DictionaryTransformer::new(dict_stream);
174
175        let dtype2 = dtype.clone();
176        let child_layouts = stream! {
177            pin_mut!(runs);
178
179            while let Some((codes_stream, values_fut)) = runs.next().await {
180                let codes = self.codes.clone();
181                let codes_eof = eof.split_off();
182                let ctx2 = ctx.clone();
183                let segment_sink2 = segment_sink.clone();
184                let codes_fut = handle.spawn_nested(move |h| async move {
185                    codes.write_stream(
186                        ctx2,
187                        segment_sink2,
188                        codes_stream.sendable(),
189                        codes_eof,
190                        h,
191                    ).await
192                });
193
194                let values = self.values.clone();
195                let values_eof = eof.split_off();
196                let ctx2 = ctx.clone();
197                let segment_sink2 = segment_sink.clone();
198                let dtype2 = dtype2.clone();
199                let values_layout = handle.spawn_nested(move |h| async move {
200                    values.write_stream(
201                        ctx2,
202                        segment_sink2,
203                        SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
204                        values_eof,
205                        h,
206                    ).await
207                });
208
209                yield async move {
210                    try_join!(codes_fut, values_layout)
211                }.boxed();
212            }
213        };
214
215        let mut child_layouts = child_layouts
216            .buffered(usize::MAX)
217            .map(|result| {
218                let (codes_layout, values_layout) = result?;
219                // All values are referenced when created via dictionary encoding
220                Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
221            })
222            .try_collect::<Vec<_>>()
223            .await?;
224
225        if child_layouts.len() == 1 {
226            return Ok(child_layouts.remove(0));
227        }
228
229        let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
230        Ok(ChunkedLayout::new(
231            row_count,
232            dtype,
233            OwnedLayoutChildren::layout_children(child_layouts),
234        )
235        .into_layout())
236    }
237
238    fn buffered_bytes(&self) -> u64 {
239        self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
240    }
241}
242
243enum DictionaryChunk {
244    Codes {
245        seq_id: SequenceId,
246        codes: ArrayRef,
247        codes_ptype: PType,
248    },
249    Values((SequenceId, ArrayRef)),
250}
251
252type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
253
254fn dict_encode_stream(
255    input: SendableSequentialStream,
256    constraints: DictConstraints,
257) -> DictionaryStream {
258    Box::pin(try_stream! {
259        let mut state = DictStreamState {
260            encoder: None,
261            constraints,
262        };
263
264        let input = input.peekable();
265        pin_mut!(input);
266
267        while let Some(item) = input.next().await {
268            let (sequence_id, chunk) = item?;
269
270            // labeler potentially creates sub sequences, we must
271            // create it on both arms to avoid having a SequencePointer
272            // between await points
273            match input.as_mut().peek().await {
274                Some(_) => {
275                    let mut labeler = DictChunkLabeler::new(sequence_id);
276                    let chunks = state.encode(&mut labeler, chunk)?;
277                    drop(labeler);
278                    for dict_chunk in chunks {
279                        yield dict_chunk;
280                    }
281                }
282                None => {
283                    // this is the last element, encode and drain chunks
284                    let mut labeler = DictChunkLabeler::new(sequence_id);
285                    let encoded = state.encode(&mut labeler, chunk)?;
286                    let drained = state.drain_values(&mut labeler);
287                    drop(labeler);
288                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
289                        yield dict_chunk;
290                    }
291                }
292            }
293        }
294    })
295}
296
297struct DictStreamState {
298    encoder: Option<Box<dyn DictEncoder>>,
299    constraints: DictConstraints,
300}
301
302impl DictStreamState {
303    fn encode(
304        &mut self,
305        labeler: &mut DictChunkLabeler,
306        chunk: ArrayRef,
307    ) -> VortexResult<Vec<DictionaryChunk>> {
308        let mut res = Vec::new();
309        let mut to_be_encoded = Some(chunk);
310        while let Some(remaining) = to_be_encoded.take() {
311            match self.encoder.take() {
312                None => match start_encoding(&self.constraints, &remaining)? {
313                    EncodingState::Continue((encoder, encoded)) => {
314                        let ptype = encoder.codes_ptype();
315                        res.push(labeler.codes(encoded, ptype));
316                        self.encoder = Some(encoder);
317                    }
318                    EncodingState::Done((values, encoded, unencoded)) => {
319                        // Encoder was created and consumed within start_encoding
320                        let ptype = PType::try_from(encoded.dtype())
321                            .vortex_expect("codes should be primitive");
322                        res.push(labeler.codes(encoded, ptype));
323                        res.push(labeler.values(values));
324                        to_be_encoded = Some(unencoded);
325                    }
326                },
327                Some(encoder) => {
328                    let ptype = encoder.codes_ptype();
329                    match encode_chunk(encoder, &remaining)? {
330                        EncodingState::Continue((encoder, encoded)) => {
331                            res.push(labeler.codes(encoded, ptype));
332                            self.encoder = Some(encoder);
333                        }
334                        EncodingState::Done((values, encoded, unencoded)) => {
335                            res.push(labeler.codes(encoded, ptype));
336                            res.push(labeler.values(values));
337                            to_be_encoded = Some(unencoded);
338                        }
339                    }
340                }
341            }
342        }
343        Ok(res)
344    }
345
346    fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
347        match self.encoder.as_mut() {
348            None => Vec::new(),
349            Some(encoder) => vec![labeler.values(encoder.reset())],
350        }
351    }
352}
353
354struct DictChunkLabeler {
355    sequence_pointer: SequencePointer,
356}
357
358impl DictChunkLabeler {
359    fn new(starting_id: SequenceId) -> Self {
360        let sequence_pointer = starting_id.descend();
361        Self { sequence_pointer }
362    }
363
364    fn codes(&mut self, chunk: ArrayRef, ptype: PType) -> DictionaryChunk {
365        DictionaryChunk::Codes {
366            seq_id: self.sequence_pointer.advance(),
367            codes: chunk,
368            codes_ptype: ptype,
369        }
370    }
371
372    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
373        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
374    }
375}
376
377type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
378
379struct DictionaryTransformer {
380    input: DictionaryStream,
381    active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
382    active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
383    pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
384}
385
386impl DictionaryTransformer {
387    fn new(input: DictionaryStream) -> Self {
388        Self {
389            input,
390            active_codes_tx: None,
391            active_values_tx: None,
392            pending_send: None,
393        }
394    }
395}
396
397impl Stream for DictionaryTransformer {
398    type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
399
400    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
401        loop {
402            // First, try to complete any pending send
403            if let Some(mut send_fut) = self.pending_send.take() {
404                match send_fut.poll_unpin(cx) {
405                    Poll::Ready(Ok(())) => {
406                        // Send completed, continue processing
407                    }
408                    Poll::Ready(Err(_)) => {
409                        // Receiver dropped, close this group
410                        self.active_codes_tx = None;
411                        if let Some(values_tx) = self.active_values_tx.take() {
412                            drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
413                        }
414                    }
415                    Poll::Pending => {
416                        // Still pending, save it and return
417                        self.pending_send = Some(send_fut);
418                        return Poll::Pending;
419                    }
420                }
421            }
422
423            match self.input.poll_next_unpin(cx) {
424                Poll::Ready(Some(Ok(DictionaryChunk::Codes {
425                    seq_id,
426                    codes,
427                    codes_ptype,
428                }))) => {
429                    if self.active_codes_tx.is_none() {
430                        // Start a new group
431                        let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
432                        let (values_tx, values_rx) = oneshot::channel();
433
434                        self.active_codes_tx = Some(codes_tx.clone());
435                        self.active_values_tx = Some(values_tx);
436
437                        // Use passed codes_ptype instead of getting from array
438                        let codes_dtype = DType::Primitive(codes_ptype, Nullability::NonNullable);
439
440                        // Send first codes.
441                        self.pending_send =
442                            Some(Box::pin(
443                                async move { codes_tx.send(Ok((seq_id, codes))).await },
444                            ));
445
446                        // Create output streams.
447                        let codes_stream = SequentialStreamAdapter::new(
448                            codes_dtype,
449                            codes_rx.into_stream().boxed(),
450                        )
451                        .sendable();
452
453                        let values_future = async move {
454                            values_rx
455                                .await
456                                .map_err(|e| vortex_err!("values sender dropped: {}", e))
457                                .flatten()
458                        }
459                        .boxed();
460
461                        return Poll::Ready(Some((codes_stream, values_future)));
462                    }
463
464                    // Continue streaming codes to existing group
465                    if let Some(tx) = &self.active_codes_tx {
466                        let tx = tx.clone();
467                        self.pending_send =
468                            Some(Box::pin(async move { tx.send(Ok((seq_id, codes))).await }));
469                    }
470                }
471                Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
472                    // Complete the current group
473                    if let Some(values_tx) = self.active_values_tx.take() {
474                        drop(values_tx.send(Ok(values)));
475                    }
476                    self.active_codes_tx = None; // Close codes stream
477                }
478                Poll::Ready(Some(Err(e))) => {
479                    // Send error to active channels if any
480                    if let Some(values_tx) = self.active_values_tx.take() {
481                        drop(values_tx.send(Err(e)));
482                    }
483                    self.active_codes_tx = None;
484                    // And terminate the stream
485                    return Poll::Ready(None);
486                }
487                Poll::Ready(None) => {
488                    // Handle any incomplete group
489                    if let Some(values_tx) = self.active_values_tx.take() {
490                        drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
491                    }
492                    self.active_codes_tx = None;
493                    return Poll::Ready(None);
494                }
495                Poll::Pending => return Poll::Pending,
496            }
497        }
498    }
499}
500
501async fn peek_first_chunk(
502    mut stream: BoxStream<'static, SequencedChunk>,
503) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
504    match stream.next().await {
505        None => Ok((stream.boxed(), None)),
506        Some(Err(e)) => Err(e),
507        Some(Ok((sequence_id, chunk))) => {
508            let chunk_clone = chunk.clone();
509            let reconstructed_stream =
510                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
511            Ok((reconstructed_stream.boxed(), Some(chunk)))
512        }
513    }
514}
515
516pub fn dict_layout_supported(dtype: &DType) -> bool {
517    matches!(
518        dtype,
519        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
520    )
521}
522
523#[derive(prost::Message)]
524pub struct DictLayoutMetadata {
525    #[prost(enumeration = "PType", tag = "1")]
526    // i32 is required for proto, use the generated getter to read this field.
527    codes_ptype: i32,
528}
529
530impl DictLayoutMetadata {
531    pub fn new(codes_ptype: PType) -> Self {
532        let mut metadata = Self::default();
533        metadata.set_codes_ptype(codes_ptype);
534        metadata
535    }
536}
537
538enum EncodingState {
539    Continue((Box<dyn DictEncoder>, ArrayRef)),
540    // (values, encoded, unencoded)
541    Done((ArrayRef, ArrayRef, ArrayRef)),
542}
543
544fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
545    let encoder = dict_encoder(chunk, constraints);
546    encode_chunk(encoder, chunk)
547}
548
549fn encode_chunk(
550    mut encoder: Box<dyn DictEncoder>,
551    chunk: &dyn Array,
552) -> VortexResult<EncodingState> {
553    let encoded = encoder.encode(chunk);
554    match remainder(chunk, encoded.len())? {
555        None => Ok(EncodingState::Continue((encoder, encoded))),
556        Some(unencoded) => Ok(EncodingState::Done((encoder.reset(), encoded, unencoded))),
557    }
558}
559
560fn remainder(array: &dyn Array, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
561    if encoded_len < array.len() {
562        Ok(Some(array.slice(encoded_len..array.len())?))
563    } else {
564        Ok(None)
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use futures::StreamExt;
571    use vortex_array::IntoArray;
572    use vortex_array::arrays::VarBinArray;
573    use vortex_array::builders::dict::DictConstraints;
574    use vortex_dtype::DType;
575    use vortex_dtype::Nullability::NonNullable;
576    use vortex_dtype::PType;
577
578    use super::DictionaryTransformer;
579    use super::dict_encode_stream;
580    use crate::sequence::SequenceId;
581    use crate::sequence::SequentialStream;
582    use crate::sequence::SequentialStreamAdapter;
583    use crate::sequence::SequentialStreamExt;
584
585    /// Regression test for a bug where the codes stream dtype was hardcoded to U16 instead of
586    /// using the actual codes dtype from the array. When `max_len <= 255`, the dict encoder
587    /// produces U8 codes, but the stream was incorrectly typed as U16, causing a dtype mismatch
588    /// assertion failure in [`SequentialStreamAdapter`].
589    #[tokio::test]
590    async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
591        // Use max_len = 100 to force U8 codes (since 100 <= 255).
592        let constraints = DictConstraints {
593            max_bytes: 1024 * 1024,
594            max_len: 100,
595        };
596
597        // Create a simple string array with a few unique values.
598        let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
599
600        // Wrap into a sequential stream.
601        let mut pointer = SequenceId::root();
602        let input_stream = SequentialStreamAdapter::new(
603            arr.dtype().clone(),
604            futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
605        )
606        .sendable();
607
608        // Encode into dict chunks.
609        let dict_stream = dict_encode_stream(input_stream, constraints);
610
611        // Transform into codes/values streams.
612        let mut transformer = DictionaryTransformer::new(dict_stream);
613
614        // Get the first (and only) run.
615        let (codes_stream, _values_fut) = transformer
616            .next()
617            .await
618            .expect("expected at least one dictionary run");
619
620        // The key assertion: codes stream dtype should be U8, not U16.
621        assert_eq!(
622            codes_stream.dtype(),
623            &DType::Primitive(PType::U8, NonNullable),
624            "codes stream should use U8 dtype for small dictionaries, not U16"
625        );
626    }
627
628    /// Test that the codes stream uses U16 dtype when the dictionary has more than 255 entries.
629    #[tokio::test]
630    async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
631        // Use max_len = 1000 to allow U16 codes (since 1000 > 255).
632        let constraints = DictConstraints {
633            max_bytes: 1024 * 1024,
634            max_len: 1000,
635        };
636
637        // Create an array with more than 255 distinct values to force U16 codes.
638        let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
639        let arr =
640            VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
641
642        // Wrap into a sequential stream.
643        let mut pointer = SequenceId::root();
644        let input_stream = SequentialStreamAdapter::new(
645            arr.dtype().clone(),
646            futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
647        )
648        .sendable();
649
650        // Encode into dict chunks.
651        let dict_stream = dict_encode_stream(input_stream, constraints);
652
653        // Transform into codes/values streams.
654        let mut transformer = DictionaryTransformer::new(dict_stream);
655
656        // Get the first (and only) run.
657        let (codes_stream, _values_fut) = transformer
658            .next()
659            .await
660            .expect("expected at least one dictionary run");
661
662        // Codes stream dtype should be U16 since we have more than 255 distinct values.
663        assert_eq!(
664            codes_stream.dtype(),
665            &DType::Primitive(PType::U16, NonNullable),
666            "codes stream should use U16 dtype for dictionaries with >255 entries"
667        );
668    }
669}