Skip to main content

vortex_layout/layouts/dict/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use async_stream::stream;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::Stream;
14use futures::StreamExt;
15use futures::TryStreamExt;
16use futures::future::BoxFuture;
17use futures::pin_mut;
18use futures::stream::BoxStream;
19use futures::stream::once;
20use futures::try_join;
21use vortex_array::ArrayContext;
22use vortex_array::ArrayRef;
23use vortex_array::VortexSessionExecute;
24use vortex_array::arrays::Dict;
25use vortex_array::builders::dict::DictConstraints;
26use vortex_array::builders::dict::DictEncoder;
27use vortex_array::builders::dict::dict_encoder;
28use vortex_array::dtype::DType;
29use vortex_array::dtype::Nullability;
30use vortex_array::dtype::PType;
31use vortex_btrblocks::BtrBlocksCompressor;
32use vortex_error::VortexError;
33use vortex_error::VortexExpect;
34use vortex_error::VortexResult;
35use vortex_error::vortex_err;
36use vortex_io::kanal_ext::KanalExt;
37use vortex_io::session::RuntimeSessionExt;
38use vortex_session::VortexSession;
39
40use crate::IntoLayout;
41use crate::LayoutRef;
42use crate::LayoutStrategy;
43use crate::OwnedLayoutChildren;
44use crate::layouts::chunked::ChunkedLayout;
45use crate::layouts::dict::DictLayout;
46use crate::segments::SegmentSinkRef;
47use crate::sequence::SendableSequentialStream;
48use crate::sequence::SequenceId;
49use crate::sequence::SequencePointer;
50use crate::sequence::SequentialStream;
51use crate::sequence::SequentialStreamAdapter;
52use crate::sequence::SequentialStreamExt;
53
54/// Constraints for dictionary layout encoding.
55///
56/// Note that [`max_len`](Self::max_len) is limited to `u16` (65,535 entries) by design. Since
57/// layout chunks are typically ~8k elements, having more than 64k unique values in a dictionary
58/// means dictionary encoding provides little compression benefit. If a column has very high
59/// cardinality, the fallback encoding strategy should be used instead.
60#[derive(Clone)]
61pub struct DictLayoutConstraints {
62    /// Maximum size of the dictionary in bytes.
63    pub max_bytes: usize,
64    /// Maximum dictionary length. Limited to `u16` because dictionaries with more than 64k unique
65    /// values provide diminishing compression returns given typical chunk sizes (~8k elements).
66    ///
67    /// The codes dtype is determined upfront from this constraint:
68    /// - [`PType::U8`] when max_len <= 255
69    /// - [`PType::U16`] when max_len > 255
70    ///
71    /// Vortex encoders must always produce unsigned integer codes; signed codes are only accepted for external compatibility.
72    pub max_len: u16,
73}
74
75impl From<DictLayoutConstraints> for DictConstraints {
76    fn from(value: DictLayoutConstraints) -> Self {
77        DictConstraints {
78            max_bytes: value.max_bytes,
79            max_len: value.max_len as usize,
80        }
81    }
82}
83
84impl Default for DictLayoutConstraints {
85    fn default() -> Self {
86        Self {
87            max_bytes: 1024 * 1024,
88            max_len: u16::MAX,
89        }
90    }
91}
92
93#[derive(Clone, Default)]
94pub struct DictLayoutOptions {
95    pub constraints: DictLayoutConstraints,
96}
97
98/// A layout strategy that encodes chunk into values and codes, if found
99/// appropriate by the btrblocks compressor. Current implementation only
100/// checks the first chunk to decide whether to apply dict layout and
101/// encodes chunks into dictionaries. When the dict constraints are hit, a
102/// new dictionary is created.
103#[derive(Clone)]
104pub struct DictStrategy {
105    codes: Arc<dyn LayoutStrategy>,
106    values: Arc<dyn LayoutStrategy>,
107    fallback: Arc<dyn LayoutStrategy>,
108    options: DictLayoutOptions,
109}
110
111impl DictStrategy {
112    pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
113        codes: Codes,
114        values: Values,
115        fallback: Fallback,
116        options: DictLayoutOptions,
117    ) -> Self {
118        Self {
119            codes: Arc::new(codes),
120            values: Arc::new(values),
121            fallback: Arc::new(fallback),
122            options,
123        }
124    }
125}
126
127#[async_trait]
128impl LayoutStrategy for DictStrategy {
129    async fn write_stream(
130        &self,
131        ctx: ArrayContext,
132        segment_sink: SegmentSinkRef,
133        stream: SendableSequentialStream,
134        mut eof: SequencePointer,
135        session: &VortexSession,
136    ) -> VortexResult<LayoutRef> {
137        // Fallback if dtype is not supported
138        if !dict_layout_supported(stream.dtype()) {
139            return self
140                .fallback
141                .write_stream(ctx, segment_sink, stream, eof, session)
142                .await;
143        }
144
145        let options = self.options.clone();
146        let dtype = stream.dtype().clone();
147
148        // 0. decide if chunks are eligible for dict encoding
149        let (stream, first_chunk) = peek_first_chunk(stream).await?;
150        let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
151
152        let should_fallback = match first_chunk {
153            None => true, // empty stream
154            Some(chunk) => {
155                let mut exec_ctx = session.create_execution_ctx();
156                let compressed = BtrBlocksCompressor::default().compress(&chunk, &mut exec_ctx)?;
157                !compressed.is::<Dict>()
158            }
159        };
160        if should_fallback {
161            // first chunk did not compress to dict, or did not exist. Skip dict layout
162            return self
163                .fallback
164                .write_stream(ctx, segment_sink, stream, eof, session)
165                .await;
166        }
167
168        // 1. from a chunk stream, create a stream that yields codes
169        // followed by a single value chunk when dict constraints are hit.
170        // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
171        let dict_stream = dict_encode_stream(stream, options.constraints.into());
172
173        // Wrap up the dict stream to yield pairs of (codes_stream, values_future).
174        // Each of these pairs becomes a child dict layout.
175        let runs = DictionaryTransformer::new(dict_stream);
176
177        let handle = session.handle();
178        let dtype2 = dtype.clone();
179        let child_layouts = stream! {
180            pin_mut!(runs);
181
182            while let Some((codes_stream, values_fut)) = runs.next().await {
183                let codes = Arc::clone(&self.codes);
184                let codes_eof = eof.split_off();
185                let ctx2 = ctx.clone();
186                let segment_sink2 = Arc::clone(&segment_sink);
187                let session2 = session.clone();
188                let codes_fut = handle.spawn_nested(move |h| async move {
189                    let session2 = session2.with_handle(h);
190                    codes.write_stream(
191                        ctx2,
192                        segment_sink2,
193                        codes_stream.sendable(),
194                        codes_eof,
195                        &session2,
196                    ).await
197                });
198
199                let values = Arc::clone(&self.values);
200                let values_eof = eof.split_off();
201                let ctx2 = ctx.clone();
202                let segment_sink2 = Arc::clone(&segment_sink);
203                let dtype2 = dtype2.clone();
204                let session2 = session.clone();
205                let values_layout = handle.spawn_nested(move |h| async move {
206                    let session2 = session2.with_handle(h);
207                    values.write_stream(
208                        ctx2,
209                        segment_sink2,
210                        SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
211                        values_eof,
212                        &session2,
213                    ).await
214                });
215
216                yield async move {
217                    try_join!(codes_fut, values_layout)
218                }.boxed();
219            }
220        };
221
222        let mut child_layouts = child_layouts
223            .buffered(usize::MAX)
224            .map(|result| {
225                let (codes_layout, values_layout) = result?;
226                // All values are referenced when created via dictionary encoding
227                Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
228            })
229            .try_collect::<Vec<_>>()
230            .await?;
231
232        if child_layouts.len() == 1 {
233            return Ok(child_layouts.remove(0));
234        }
235
236        let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
237        Ok(ChunkedLayout::new(
238            row_count,
239            dtype,
240            OwnedLayoutChildren::layout_children(child_layouts),
241        )
242        .into_layout())
243    }
244
245    fn buffered_bytes(&self) -> u64 {
246        self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
247    }
248}
249
250enum DictionaryChunk {
251    Codes {
252        seq_id: SequenceId,
253        codes: ArrayRef,
254        codes_ptype: PType,
255    },
256    Values((SequenceId, ArrayRef)),
257}
258
259type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
260
261fn dict_encode_stream(
262    input: SendableSequentialStream,
263    constraints: DictConstraints,
264) -> DictionaryStream {
265    Box::pin(try_stream! {
266        let mut state = DictStreamState {
267            encoder: None,
268            constraints,
269        };
270
271        let input = input.peekable();
272        pin_mut!(input);
273
274        while let Some(item) = input.next().await {
275            let (sequence_id, chunk) = item?;
276
277            // labeler potentially creates sub sequences, we must
278            // create it on both arms to avoid having a SequencePointer
279            // between await points
280            match input.as_mut().peek().await {
281                Some(_) => {
282                    let mut labeler = DictChunkLabeler::new(sequence_id);
283                    let chunks = state.encode(&mut labeler, chunk)?;
284                    drop(labeler);
285                    for dict_chunk in chunks {
286                        yield dict_chunk;
287                    }
288                }
289                None => {
290                    // this is the last element, encode and drain chunks
291                    let mut labeler = DictChunkLabeler::new(sequence_id);
292                    let encoded = state.encode(&mut labeler, chunk)?;
293                    let drained = state.drain_values(&mut labeler);
294                    drop(labeler);
295                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
296                        yield dict_chunk;
297                    }
298                }
299            }
300        }
301    })
302}
303
304struct DictStreamState {
305    encoder: Option<Box<dyn DictEncoder>>,
306    constraints: DictConstraints,
307}
308
309impl DictStreamState {
310    fn encode(
311        &mut self,
312        labeler: &mut DictChunkLabeler,
313        chunk: ArrayRef,
314    ) -> VortexResult<Vec<DictionaryChunk>> {
315        let mut res = Vec::new();
316        let mut to_be_encoded = Some(chunk);
317        while let Some(remaining) = to_be_encoded.take() {
318            match self.encoder.take() {
319                None => match start_encoding(&self.constraints, &remaining)? {
320                    EncodingState::Continue((encoder, encoded)) => {
321                        let ptype = encoder.codes_ptype();
322                        res.push(labeler.codes(encoded, ptype));
323                        self.encoder = Some(encoder);
324                    }
325                    EncodingState::Done((values, encoded, unencoded)) => {
326                        // Encoder was created and consumed within start_encoding
327                        let ptype = PType::try_from(encoded.dtype())
328                            .vortex_expect("codes should be primitive");
329                        res.push(labeler.codes(encoded, ptype));
330                        res.push(labeler.values(values));
331                        to_be_encoded = Some(unencoded);
332                    }
333                },
334                Some(encoder) => {
335                    let ptype = encoder.codes_ptype();
336                    match encode_chunk(encoder, &remaining)? {
337                        EncodingState::Continue((encoder, encoded)) => {
338                            res.push(labeler.codes(encoded, ptype));
339                            self.encoder = Some(encoder);
340                        }
341                        EncodingState::Done((values, encoded, unencoded)) => {
342                            res.push(labeler.codes(encoded, ptype));
343                            res.push(labeler.values(values));
344                            to_be_encoded = Some(unencoded);
345                        }
346                    }
347                }
348            }
349        }
350        Ok(res)
351    }
352
353    fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
354        match self.encoder.as_mut() {
355            None => Vec::new(),
356            Some(encoder) => vec![labeler.values(encoder.reset())],
357        }
358    }
359}
360
361struct DictChunkLabeler {
362    sequence_pointer: SequencePointer,
363}
364
365impl DictChunkLabeler {
366    fn new(starting_id: SequenceId) -> Self {
367        let sequence_pointer = starting_id.descend();
368        Self { sequence_pointer }
369    }
370
371    fn codes(&mut self, chunk: ArrayRef, ptype: PType) -> DictionaryChunk {
372        DictionaryChunk::Codes {
373            seq_id: self.sequence_pointer.advance(),
374            codes: chunk,
375            codes_ptype: ptype,
376        }
377    }
378
379    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
380        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
381    }
382}
383
384type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
385
386struct DictionaryTransformer {
387    input: DictionaryStream,
388    active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
389    active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
390    pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
391}
392
393impl DictionaryTransformer {
394    fn new(input: DictionaryStream) -> Self {
395        Self {
396            input,
397            active_codes_tx: None,
398            active_values_tx: None,
399            pending_send: None,
400        }
401    }
402}
403
404impl Stream for DictionaryTransformer {
405    type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
406
407    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
408        loop {
409            // First, try to complete any pending send
410            if let Some(mut send_fut) = self.pending_send.take() {
411                match send_fut.poll_unpin(cx) {
412                    Poll::Ready(Ok(())) => {
413                        // Send completed, continue processing
414                    }
415                    Poll::Ready(Err(_)) => {
416                        // Receiver dropped, close this group
417                        self.active_codes_tx = None;
418                        if let Some(values_tx) = self.active_values_tx.take() {
419                            drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
420                        }
421                    }
422                    Poll::Pending => {
423                        // Still pending, save it and return
424                        self.pending_send = Some(send_fut);
425                        return Poll::Pending;
426                    }
427                }
428            }
429
430            match self.input.poll_next_unpin(cx) {
431                Poll::Ready(Some(Ok(DictionaryChunk::Codes {
432                    seq_id,
433                    codes,
434                    codes_ptype,
435                }))) => {
436                    if self.active_codes_tx.is_none() {
437                        // Start a new group
438                        let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
439                        let (values_tx, values_rx) = oneshot::channel();
440
441                        self.active_codes_tx = Some(codes_tx.clone());
442                        self.active_values_tx = Some(values_tx);
443
444                        // Use passed codes_ptype instead of getting from array
445                        let codes_dtype = DType::Primitive(codes_ptype, Nullability::NonNullable);
446
447                        // Send first codes.
448                        self.pending_send =
449                            Some(Box::pin(
450                                async move { codes_tx.send(Ok((seq_id, codes))).await },
451                            ));
452
453                        // Create output streams.
454                        let codes_stream = SequentialStreamAdapter::new(
455                            codes_dtype,
456                            codes_rx.into_stream().boxed(),
457                        )
458                        .sendable();
459
460                        let values_future = async move {
461                            values_rx
462                                .await
463                                .map_err(|e| vortex_err!("values sender dropped: {}", e))
464                                .flatten()
465                        }
466                        .boxed();
467
468                        return Poll::Ready(Some((codes_stream, values_future)));
469                    }
470
471                    // Continue streaming codes to existing group
472                    if let Some(tx) = &self.active_codes_tx {
473                        let tx = tx.clone();
474                        self.pending_send =
475                            Some(Box::pin(async move { tx.send(Ok((seq_id, codes))).await }));
476                    }
477                }
478                Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
479                    // Complete the current group
480                    if let Some(values_tx) = self.active_values_tx.take() {
481                        drop(values_tx.send(Ok(values)));
482                    }
483                    self.active_codes_tx = None; // Close codes stream
484                }
485                Poll::Ready(Some(Err(e))) => {
486                    // Send error to active channels if any
487                    if let Some(values_tx) = self.active_values_tx.take() {
488                        drop(values_tx.send(Err(e)));
489                    }
490                    self.active_codes_tx = None;
491                    // And terminate the stream
492                    return Poll::Ready(None);
493                }
494                Poll::Ready(None) => {
495                    // Handle any incomplete group
496                    if let Some(values_tx) = self.active_values_tx.take() {
497                        drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
498                    }
499                    self.active_codes_tx = None;
500                    return Poll::Ready(None);
501                }
502                Poll::Pending => return Poll::Pending,
503            }
504        }
505    }
506}
507
508async fn peek_first_chunk(
509    mut stream: BoxStream<'static, SequencedChunk>,
510) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
511    match stream.next().await {
512        None => Ok((stream.boxed(), None)),
513        Some(Err(e)) => Err(e),
514        Some(Ok((sequence_id, chunk))) => {
515            let chunk_clone = chunk.clone();
516            let reconstructed_stream =
517                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
518            Ok((reconstructed_stream.boxed(), Some(chunk)))
519        }
520    }
521}
522
523pub fn dict_layout_supported(dtype: &DType) -> bool {
524    matches!(
525        dtype,
526        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
527    )
528}
529
530#[derive(prost::Message)]
531pub struct DictLayoutMetadata {
532    #[prost(enumeration = "PType", tag = "1")]
533    // i32 is required for proto, use the generated getter to read this field.
534    codes_ptype: i32,
535}
536
537impl DictLayoutMetadata {
538    pub fn new(codes_ptype: PType) -> Self {
539        let mut metadata = Self::default();
540        metadata.set_codes_ptype(codes_ptype);
541        metadata
542    }
543}
544
545enum EncodingState {
546    Continue((Box<dyn DictEncoder>, ArrayRef)),
547    // (values, encoded, unencoded)
548    Done((ArrayRef, ArrayRef, ArrayRef)),
549}
550
551fn start_encoding(constraints: &DictConstraints, chunk: &ArrayRef) -> VortexResult<EncodingState> {
552    let encoder = dict_encoder(chunk, constraints);
553    encode_chunk(encoder, chunk)
554}
555
556fn encode_chunk(
557    mut encoder: Box<dyn DictEncoder>,
558    chunk: &ArrayRef,
559) -> VortexResult<EncodingState> {
560    let encoded = encoder.encode(chunk);
561    match remainder(chunk, encoded.len())? {
562        None => Ok(EncodingState::Continue((encoder, encoded))),
563        Some(unencoded) => Ok(EncodingState::Done((encoder.reset(), encoded, unencoded))),
564    }
565}
566
567fn remainder(array: &ArrayRef, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
568    if encoded_len < array.len() {
569        Ok(Some(array.slice(encoded_len..array.len())?))
570    } else {
571        Ok(None)
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use futures::StreamExt;
578    use vortex_array::IntoArray;
579    use vortex_array::arrays::VarBinArray;
580    use vortex_array::builders::dict::DictConstraints;
581    use vortex_array::dtype::DType;
582    use vortex_array::dtype::Nullability::NonNullable;
583    use vortex_array::dtype::PType;
584
585    use super::DictionaryTransformer;
586    use super::dict_encode_stream;
587    use crate::sequence::SequenceId;
588    use crate::sequence::SequentialStream;
589    use crate::sequence::SequentialStreamAdapter;
590    use crate::sequence::SequentialStreamExt;
591
592    /// Regression test for a bug where the codes stream dtype was hardcoded to U16 instead of
593    /// using the actual codes dtype from the array. When `max_len <= 255`, the dict encoder
594    /// produces U8 codes, but the stream was incorrectly typed as U16, causing a dtype mismatch
595    /// assertion failure in [`SequentialStreamAdapter`].
596    #[tokio::test]
597    async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
598        // Use max_len = 100 to force U8 codes (since 100 <= 255).
599        let constraints = DictConstraints {
600            max_bytes: 1024 * 1024,
601            max_len: 100,
602        };
603
604        // Create a simple string array with a few unique values.
605        let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
606
607        // Wrap into a sequential stream.
608        let mut pointer = SequenceId::root();
609        let input_stream = SequentialStreamAdapter::new(
610            arr.dtype().clone(),
611            futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
612        )
613        .sendable();
614
615        // Encode into dict chunks.
616        let dict_stream = dict_encode_stream(input_stream, constraints);
617
618        // Transform into codes/values streams.
619        let mut transformer = DictionaryTransformer::new(dict_stream);
620
621        // Get the first (and only) run.
622        let (codes_stream, _values_fut) = transformer
623            .next()
624            .await
625            .expect("expected at least one dictionary run");
626
627        // The key assertion: codes stream dtype should be U8, not U16.
628        assert_eq!(
629            codes_stream.dtype(),
630            &DType::Primitive(PType::U8, NonNullable),
631            "codes stream should use U8 dtype for small dictionaries, not U16"
632        );
633    }
634
635    /// Test that the codes stream uses U16 dtype when the dictionary has more than 255 entries.
636    #[tokio::test]
637    async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
638        // Use max_len = 1000 to allow U16 codes (since 1000 > 255).
639        let constraints = DictConstraints {
640            max_bytes: 1024 * 1024,
641            max_len: 1000,
642        };
643
644        // Create an array with more than 255 distinct values to force U16 codes.
645        let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
646        let arr =
647            VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
648
649        // Wrap into a sequential stream.
650        let mut pointer = SequenceId::root();
651        let input_stream = SequentialStreamAdapter::new(
652            arr.dtype().clone(),
653            futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
654        )
655        .sendable();
656
657        // Encode into dict chunks.
658        let dict_stream = dict_encode_stream(input_stream, constraints);
659
660        // Transform into codes/values streams.
661        let mut transformer = DictionaryTransformer::new(dict_stream);
662
663        // Get the first (and only) run.
664        let (codes_stream, _values_fut) = transformer
665            .next()
666            .await
667            .expect("expected at least one dictionary run");
668
669        // Codes stream dtype should be U16 since we have more than 255 distinct values.
670        assert_eq!(
671            codes_stream.dtype(),
672            &DType::Primitive(PType::U16, NonNullable),
673            "codes stream should use U16 dtype for dictionaries with >255 entries"
674        );
675    }
676}