Skip to main content

vortex_layout/layouts/dict/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use async_stream::stream;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::Stream;
14use futures::StreamExt;
15use futures::TryStreamExt;
16use futures::future::BoxFuture;
17use futures::pin_mut;
18use futures::stream::BoxStream;
19use futures::stream::once;
20use futures::try_join;
21use vortex_array::ArrayContext;
22use vortex_array::ArrayRef;
23use vortex_array::arrays::Dict;
24use vortex_array::builders::dict::DictConstraints;
25use vortex_array::builders::dict::DictEncoder;
26use vortex_array::builders::dict::dict_encoder;
27use vortex_array::dtype::DType;
28use vortex_array::dtype::Nullability;
29use vortex_array::dtype::PType;
30use vortex_btrblocks::BtrBlocksCompressor;
31use vortex_error::VortexError;
32use vortex_error::VortexExpect;
33use vortex_error::VortexResult;
34use vortex_error::vortex_err;
35use vortex_io::kanal_ext::KanalExt;
36use vortex_io::session::RuntimeSessionExt;
37use vortex_session::VortexSession;
38
39use crate::IntoLayout;
40use crate::LayoutRef;
41use crate::LayoutStrategy;
42use crate::OwnedLayoutChildren;
43use crate::layouts::chunked::ChunkedLayout;
44use crate::layouts::dict::DictLayout;
45use crate::segments::SegmentSinkRef;
46use crate::sequence::SendableSequentialStream;
47use crate::sequence::SequenceId;
48use crate::sequence::SequencePointer;
49use crate::sequence::SequentialStream;
50use crate::sequence::SequentialStreamAdapter;
51use crate::sequence::SequentialStreamExt;
52
53/// Constraints for dictionary layout encoding.
54///
55/// Note that [`max_len`](Self::max_len) is limited to `u16` (65,535 entries) by design. Since
56/// layout chunks are typically ~8k elements, having more than 64k unique values in a dictionary
57/// means dictionary encoding provides little compression benefit. If a column has very high
58/// cardinality, the fallback encoding strategy should be used instead.
59#[derive(Clone)]
60pub struct DictLayoutConstraints {
61    /// Maximum size of the dictionary in bytes.
62    pub max_bytes: usize,
63    /// Maximum dictionary length. Limited to `u16` because dictionaries with more than 64k unique
64    /// values provide diminishing compression returns given typical chunk sizes (~8k elements).
65    ///
66    /// The codes dtype is determined upfront from this constraint:
67    /// - [`PType::U8`] when max_len <= 255
68    /// - [`PType::U16`] when max_len > 255
69    ///
70    /// Vortex encoders must always produce unsigned integer codes; signed codes are only accepted for external compatibility.
71    pub max_len: u16,
72}
73
74impl From<DictLayoutConstraints> for DictConstraints {
75    fn from(value: DictLayoutConstraints) -> Self {
76        DictConstraints {
77            max_bytes: value.max_bytes,
78            max_len: value.max_len as usize,
79        }
80    }
81}
82
83impl Default for DictLayoutConstraints {
84    fn default() -> Self {
85        Self {
86            max_bytes: 1024 * 1024,
87            max_len: u16::MAX,
88        }
89    }
90}
91
92#[derive(Clone, Default)]
93pub struct DictLayoutOptions {
94    pub constraints: DictLayoutConstraints,
95}
96
97/// A layout strategy that encodes chunk into values and codes, if found
98/// appropriate by the btrblocks compressor. Current implementation only
99/// checks the first chunk to decide whether to apply dict layout and
100/// encodes chunks into dictionaries. When the dict constraints are hit, a
101/// new dictionary is created.
102#[derive(Clone)]
103pub struct DictStrategy {
104    codes: Arc<dyn LayoutStrategy>,
105    values: Arc<dyn LayoutStrategy>,
106    fallback: Arc<dyn LayoutStrategy>,
107    options: DictLayoutOptions,
108}
109
110impl DictStrategy {
111    pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
112        codes: Codes,
113        values: Values,
114        fallback: Fallback,
115        options: DictLayoutOptions,
116    ) -> Self {
117        Self {
118            codes: Arc::new(codes),
119            values: Arc::new(values),
120            fallback: Arc::new(fallback),
121            options,
122        }
123    }
124}
125
126#[async_trait]
127impl LayoutStrategy for DictStrategy {
128    async fn write_stream(
129        &self,
130        ctx: ArrayContext,
131        segment_sink: SegmentSinkRef,
132        stream: SendableSequentialStream,
133        mut eof: SequencePointer,
134        session: &VortexSession,
135    ) -> VortexResult<LayoutRef> {
136        // Fallback if dtype is not supported
137        if !dict_layout_supported(stream.dtype()) {
138            return self
139                .fallback
140                .write_stream(ctx, segment_sink, stream, eof, session)
141                .await;
142        }
143
144        let options = self.options.clone();
145        let dtype = stream.dtype().clone();
146
147        // 0. decide if chunks are eligible for dict encoding
148        let (stream, first_chunk) = peek_first_chunk(stream).await?;
149        let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
150
151        let should_fallback = match first_chunk {
152            None => true, // empty stream
153            Some(chunk) => {
154                let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
155                !compressed.is::<Dict>()
156            }
157        };
158        if should_fallback {
159            // first chunk did not compress to dict, or did not exist. Skip dict layout
160            return self
161                .fallback
162                .write_stream(ctx, segment_sink, stream, eof, session)
163                .await;
164        }
165
166        // 1. from a chunk stream, create a stream that yields codes
167        // followed by a single value chunk when dict constraints are hit.
168        // (a1, a2) -> (code(c1), code(c2), values(v1), code(c3), ...)
169        let dict_stream = dict_encode_stream(stream, options.constraints.into());
170
171        // Wrap up the dict stream to yield pairs of (codes_stream, values_future).
172        // Each of these pairs becomes a child dict layout.
173        let runs = DictionaryTransformer::new(dict_stream);
174
175        let handle = session.handle();
176        let dtype2 = dtype.clone();
177        let child_layouts = stream! {
178            pin_mut!(runs);
179
180            while let Some((codes_stream, values_fut)) = runs.next().await {
181                let codes = Arc::clone(&self.codes);
182                let codes_eof = eof.split_off();
183                let ctx2 = ctx.clone();
184                let segment_sink2 = Arc::clone(&segment_sink);
185                let session2 = session.clone();
186                let codes_fut = handle.spawn_nested(move |h| async move {
187                    let session2 = session2.with_handle(h);
188                    codes.write_stream(
189                        ctx2,
190                        segment_sink2,
191                        codes_stream.sendable(),
192                        codes_eof,
193                        &session2,
194                    ).await
195                });
196
197                let values = Arc::clone(&self.values);
198                let values_eof = eof.split_off();
199                let ctx2 = ctx.clone();
200                let segment_sink2 = Arc::clone(&segment_sink);
201                let dtype2 = dtype2.clone();
202                let session2 = session.clone();
203                let values_layout = handle.spawn_nested(move |h| async move {
204                    let session2 = session2.with_handle(h);
205                    values.write_stream(
206                        ctx2,
207                        segment_sink2,
208                        SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
209                        values_eof,
210                        &session2,
211                    ).await
212                });
213
214                yield async move {
215                    try_join!(codes_fut, values_layout)
216                }.boxed();
217            }
218        };
219
220        let mut child_layouts = child_layouts
221            .buffered(usize::MAX)
222            .map(|result| {
223                let (codes_layout, values_layout) = result?;
224                // All values are referenced when created via dictionary encoding
225                Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
226            })
227            .try_collect::<Vec<_>>()
228            .await?;
229
230        if child_layouts.len() == 1 {
231            return Ok(child_layouts.remove(0));
232        }
233
234        let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
235        Ok(ChunkedLayout::new(
236            row_count,
237            dtype,
238            OwnedLayoutChildren::layout_children(child_layouts),
239        )
240        .into_layout())
241    }
242
243    fn buffered_bytes(&self) -> u64 {
244        self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
245    }
246}
247
248enum DictionaryChunk {
249    Codes {
250        seq_id: SequenceId,
251        codes: ArrayRef,
252        codes_ptype: PType,
253    },
254    Values((SequenceId, ArrayRef)),
255}
256
257type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
258
259fn dict_encode_stream(
260    input: SendableSequentialStream,
261    constraints: DictConstraints,
262) -> DictionaryStream {
263    Box::pin(try_stream! {
264        let mut state = DictStreamState {
265            encoder: None,
266            constraints,
267        };
268
269        let input = input.peekable();
270        pin_mut!(input);
271
272        while let Some(item) = input.next().await {
273            let (sequence_id, chunk) = item?;
274
275            // labeler potentially creates sub sequences, we must
276            // create it on both arms to avoid having a SequencePointer
277            // between await points
278            match input.as_mut().peek().await {
279                Some(_) => {
280                    let mut labeler = DictChunkLabeler::new(sequence_id);
281                    let chunks = state.encode(&mut labeler, chunk)?;
282                    drop(labeler);
283                    for dict_chunk in chunks {
284                        yield dict_chunk;
285                    }
286                }
287                None => {
288                    // this is the last element, encode and drain chunks
289                    let mut labeler = DictChunkLabeler::new(sequence_id);
290                    let encoded = state.encode(&mut labeler, chunk)?;
291                    let drained = state.drain_values(&mut labeler);
292                    drop(labeler);
293                    for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
294                        yield dict_chunk;
295                    }
296                }
297            }
298        }
299    })
300}
301
302struct DictStreamState {
303    encoder: Option<Box<dyn DictEncoder>>,
304    constraints: DictConstraints,
305}
306
307impl DictStreamState {
308    fn encode(
309        &mut self,
310        labeler: &mut DictChunkLabeler,
311        chunk: ArrayRef,
312    ) -> VortexResult<Vec<DictionaryChunk>> {
313        let mut res = Vec::new();
314        let mut to_be_encoded = Some(chunk);
315        while let Some(remaining) = to_be_encoded.take() {
316            match self.encoder.take() {
317                None => match start_encoding(&self.constraints, &remaining)? {
318                    EncodingState::Continue((encoder, encoded)) => {
319                        let ptype = encoder.codes_ptype();
320                        res.push(labeler.codes(encoded, ptype));
321                        self.encoder = Some(encoder);
322                    }
323                    EncodingState::Done((values, encoded, unencoded)) => {
324                        // Encoder was created and consumed within start_encoding
325                        let ptype = PType::try_from(encoded.dtype())
326                            .vortex_expect("codes should be primitive");
327                        res.push(labeler.codes(encoded, ptype));
328                        res.push(labeler.values(values));
329                        to_be_encoded = Some(unencoded);
330                    }
331                },
332                Some(encoder) => {
333                    let ptype = encoder.codes_ptype();
334                    match encode_chunk(encoder, &remaining)? {
335                        EncodingState::Continue((encoder, encoded)) => {
336                            res.push(labeler.codes(encoded, ptype));
337                            self.encoder = Some(encoder);
338                        }
339                        EncodingState::Done((values, encoded, unencoded)) => {
340                            res.push(labeler.codes(encoded, ptype));
341                            res.push(labeler.values(values));
342                            to_be_encoded = Some(unencoded);
343                        }
344                    }
345                }
346            }
347        }
348        Ok(res)
349    }
350
351    fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
352        match self.encoder.as_mut() {
353            None => Vec::new(),
354            Some(encoder) => vec![labeler.values(encoder.reset())],
355        }
356    }
357}
358
359struct DictChunkLabeler {
360    sequence_pointer: SequencePointer,
361}
362
363impl DictChunkLabeler {
364    fn new(starting_id: SequenceId) -> Self {
365        let sequence_pointer = starting_id.descend();
366        Self { sequence_pointer }
367    }
368
369    fn codes(&mut self, chunk: ArrayRef, ptype: PType) -> DictionaryChunk {
370        DictionaryChunk::Codes {
371            seq_id: self.sequence_pointer.advance(),
372            codes: chunk,
373            codes_ptype: ptype,
374        }
375    }
376
377    fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
378        DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
379    }
380}
381
382type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
383
384struct DictionaryTransformer {
385    input: DictionaryStream,
386    active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
387    active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
388    pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
389}
390
391impl DictionaryTransformer {
392    fn new(input: DictionaryStream) -> Self {
393        Self {
394            input,
395            active_codes_tx: None,
396            active_values_tx: None,
397            pending_send: None,
398        }
399    }
400}
401
402impl Stream for DictionaryTransformer {
403    type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
404
405    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
406        loop {
407            // First, try to complete any pending send
408            if let Some(mut send_fut) = self.pending_send.take() {
409                match send_fut.poll_unpin(cx) {
410                    Poll::Ready(Ok(())) => {
411                        // Send completed, continue processing
412                    }
413                    Poll::Ready(Err(_)) => {
414                        // Receiver dropped, close this group
415                        self.active_codes_tx = None;
416                        if let Some(values_tx) = self.active_values_tx.take() {
417                            drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
418                        }
419                    }
420                    Poll::Pending => {
421                        // Still pending, save it and return
422                        self.pending_send = Some(send_fut);
423                        return Poll::Pending;
424                    }
425                }
426            }
427
428            match self.input.poll_next_unpin(cx) {
429                Poll::Ready(Some(Ok(DictionaryChunk::Codes {
430                    seq_id,
431                    codes,
432                    codes_ptype,
433                }))) => {
434                    if self.active_codes_tx.is_none() {
435                        // Start a new group
436                        let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
437                        let (values_tx, values_rx) = oneshot::channel();
438
439                        self.active_codes_tx = Some(codes_tx.clone());
440                        self.active_values_tx = Some(values_tx);
441
442                        // Use passed codes_ptype instead of getting from array
443                        let codes_dtype = DType::Primitive(codes_ptype, Nullability::NonNullable);
444
445                        // Send first codes.
446                        self.pending_send =
447                            Some(Box::pin(
448                                async move { codes_tx.send(Ok((seq_id, codes))).await },
449                            ));
450
451                        // Create output streams.
452                        let codes_stream = SequentialStreamAdapter::new(
453                            codes_dtype,
454                            codes_rx.into_stream().boxed(),
455                        )
456                        .sendable();
457
458                        let values_future = async move {
459                            values_rx
460                                .await
461                                .map_err(|e| vortex_err!("values sender dropped: {}", e))
462                                .flatten()
463                        }
464                        .boxed();
465
466                        return Poll::Ready(Some((codes_stream, values_future)));
467                    }
468
469                    // Continue streaming codes to existing group
470                    if let Some(tx) = &self.active_codes_tx {
471                        let tx = tx.clone();
472                        self.pending_send =
473                            Some(Box::pin(async move { tx.send(Ok((seq_id, codes))).await }));
474                    }
475                }
476                Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
477                    // Complete the current group
478                    if let Some(values_tx) = self.active_values_tx.take() {
479                        drop(values_tx.send(Ok(values)));
480                    }
481                    self.active_codes_tx = None; // Close codes stream
482                }
483                Poll::Ready(Some(Err(e))) => {
484                    // Send error to active channels if any
485                    if let Some(values_tx) = self.active_values_tx.take() {
486                        drop(values_tx.send(Err(e)));
487                    }
488                    self.active_codes_tx = None;
489                    // And terminate the stream
490                    return Poll::Ready(None);
491                }
492                Poll::Ready(None) => {
493                    // Handle any incomplete group
494                    if let Some(values_tx) = self.active_values_tx.take() {
495                        drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
496                    }
497                    self.active_codes_tx = None;
498                    return Poll::Ready(None);
499                }
500                Poll::Pending => return Poll::Pending,
501            }
502        }
503    }
504}
505
506async fn peek_first_chunk(
507    mut stream: BoxStream<'static, SequencedChunk>,
508) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
509    match stream.next().await {
510        None => Ok((stream.boxed(), None)),
511        Some(Err(e)) => Err(e),
512        Some(Ok((sequence_id, chunk))) => {
513            let chunk_clone = chunk.clone();
514            let reconstructed_stream =
515                once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
516            Ok((reconstructed_stream.boxed(), Some(chunk)))
517        }
518    }
519}
520
521pub fn dict_layout_supported(dtype: &DType) -> bool {
522    matches!(
523        dtype,
524        DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
525    )
526}
527
528#[derive(prost::Message)]
529pub struct DictLayoutMetadata {
530    #[prost(enumeration = "PType", tag = "1")]
531    // i32 is required for proto, use the generated getter to read this field.
532    codes_ptype: i32,
533}
534
535impl DictLayoutMetadata {
536    pub fn new(codes_ptype: PType) -> Self {
537        let mut metadata = Self::default();
538        metadata.set_codes_ptype(codes_ptype);
539        metadata
540    }
541}
542
543enum EncodingState {
544    Continue((Box<dyn DictEncoder>, ArrayRef)),
545    // (values, encoded, unencoded)
546    Done((ArrayRef, ArrayRef, ArrayRef)),
547}
548
549fn start_encoding(constraints: &DictConstraints, chunk: &ArrayRef) -> VortexResult<EncodingState> {
550    let encoder = dict_encoder(chunk, constraints);
551    encode_chunk(encoder, chunk)
552}
553
554fn encode_chunk(
555    mut encoder: Box<dyn DictEncoder>,
556    chunk: &ArrayRef,
557) -> VortexResult<EncodingState> {
558    let encoded = encoder.encode(chunk);
559    match remainder(chunk, encoded.len())? {
560        None => Ok(EncodingState::Continue((encoder, encoded))),
561        Some(unencoded) => Ok(EncodingState::Done((encoder.reset(), encoded, unencoded))),
562    }
563}
564
565fn remainder(array: &ArrayRef, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
566    if encoded_len < array.len() {
567        Ok(Some(array.slice(encoded_len..array.len())?))
568    } else {
569        Ok(None)
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use futures::StreamExt;
576    use vortex_array::IntoArray;
577    use vortex_array::arrays::VarBinArray;
578    use vortex_array::builders::dict::DictConstraints;
579    use vortex_array::dtype::DType;
580    use vortex_array::dtype::Nullability::NonNullable;
581    use vortex_array::dtype::PType;
582
583    use super::DictionaryTransformer;
584    use super::dict_encode_stream;
585    use crate::sequence::SequenceId;
586    use crate::sequence::SequentialStream;
587    use crate::sequence::SequentialStreamAdapter;
588    use crate::sequence::SequentialStreamExt;
589
590    /// Regression test for a bug where the codes stream dtype was hardcoded to U16 instead of
591    /// using the actual codes dtype from the array. When `max_len <= 255`, the dict encoder
592    /// produces U8 codes, but the stream was incorrectly typed as U16, causing a dtype mismatch
593    /// assertion failure in [`SequentialStreamAdapter`].
594    #[tokio::test]
595    async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
596        // Use max_len = 100 to force U8 codes (since 100 <= 255).
597        let constraints = DictConstraints {
598            max_bytes: 1024 * 1024,
599            max_len: 100,
600        };
601
602        // Create a simple string array with a few unique values.
603        let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
604
605        // Wrap into a sequential stream.
606        let mut pointer = SequenceId::root();
607        let input_stream = SequentialStreamAdapter::new(
608            arr.dtype().clone(),
609            futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
610        )
611        .sendable();
612
613        // Encode into dict chunks.
614        let dict_stream = dict_encode_stream(input_stream, constraints);
615
616        // Transform into codes/values streams.
617        let mut transformer = DictionaryTransformer::new(dict_stream);
618
619        // Get the first (and only) run.
620        let (codes_stream, _values_fut) = transformer
621            .next()
622            .await
623            .expect("expected at least one dictionary run");
624
625        // The key assertion: codes stream dtype should be U8, not U16.
626        assert_eq!(
627            codes_stream.dtype(),
628            &DType::Primitive(PType::U8, NonNullable),
629            "codes stream should use U8 dtype for small dictionaries, not U16"
630        );
631    }
632
633    /// Test that the codes stream uses U16 dtype when the dictionary has more than 255 entries.
634    #[tokio::test]
635    async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
636        // Use max_len = 1000 to allow U16 codes (since 1000 > 255).
637        let constraints = DictConstraints {
638            max_bytes: 1024 * 1024,
639            max_len: 1000,
640        };
641
642        // Create an array with more than 255 distinct values to force U16 codes.
643        let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
644        let arr =
645            VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
646
647        // Wrap into a sequential stream.
648        let mut pointer = SequenceId::root();
649        let input_stream = SequentialStreamAdapter::new(
650            arr.dtype().clone(),
651            futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
652        )
653        .sendable();
654
655        // Encode into dict chunks.
656        let dict_stream = dict_encode_stream(input_stream, constraints);
657
658        // Transform into codes/values streams.
659        let mut transformer = DictionaryTransformer::new(dict_stream);
660
661        // Get the first (and only) run.
662        let (codes_stream, _values_fut) = transformer
663            .next()
664            .await
665            .expect("expected at least one dictionary run");
666
667        // Codes stream dtype should be U16 since we have more than 255 distinct values.
668        assert_eq!(
669            codes_stream.dtype(),
670            &DType::Primitive(PType::U16, NonNullable),
671            "codes stream should use U16 dtype for dictionaries with >255 entries"
672        );
673    }
674}