1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use async_stream::stream;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::Stream;
14use futures::StreamExt;
15use futures::TryStreamExt;
16use futures::future::BoxFuture;
17use futures::pin_mut;
18use futures::stream::BoxStream;
19use futures::stream::once;
20use futures::try_join;
21use vortex_array::Array;
22use vortex_array::ArrayContext;
23use vortex_array::ArrayRef;
24use vortex_array::arrays::DictVTable;
25use vortex_array::builders::dict::DictConstraints;
26use vortex_array::builders::dict::DictEncoder;
27use vortex_array::builders::dict::dict_encoder;
28use vortex_btrblocks::BtrBlocksCompressor;
29use vortex_dtype::DType;
30use vortex_dtype::Nullability;
31use vortex_dtype::PType;
32use vortex_error::VortexError;
33use vortex_error::VortexExpect;
34use vortex_error::VortexResult;
35use vortex_error::vortex_err;
36use vortex_io::kanal_ext::KanalExt;
37use vortex_io::runtime::Handle;
38
39use crate::IntoLayout;
40use crate::LayoutRef;
41use crate::LayoutStrategy;
42use crate::OwnedLayoutChildren;
43use crate::layouts::chunked::ChunkedLayout;
44use crate::layouts::dict::DictLayout;
45use crate::segments::SegmentSinkRef;
46use crate::sequence::SendableSequentialStream;
47use crate::sequence::SequenceId;
48use crate::sequence::SequencePointer;
49use crate::sequence::SequentialStream;
50use crate::sequence::SequentialStreamAdapter;
51use crate::sequence::SequentialStreamExt;
52
53#[derive(Clone)]
60pub struct DictLayoutConstraints {
61 pub max_bytes: usize,
63 pub max_len: u16,
72}
73
74impl From<DictLayoutConstraints> for DictConstraints {
75 fn from(value: DictLayoutConstraints) -> Self {
76 DictConstraints {
77 max_bytes: value.max_bytes,
78 max_len: value.max_len as usize,
79 }
80 }
81}
82
83impl Default for DictLayoutConstraints {
84 fn default() -> Self {
85 Self {
86 max_bytes: 1024 * 1024,
87 max_len: u16::MAX,
88 }
89 }
90}
91
92#[derive(Clone, Default)]
93pub struct DictLayoutOptions {
94 pub constraints: DictLayoutConstraints,
95}
96
97#[derive(Clone)]
103pub struct DictStrategy {
104 codes: Arc<dyn LayoutStrategy>,
105 values: Arc<dyn LayoutStrategy>,
106 fallback: Arc<dyn LayoutStrategy>,
107 options: DictLayoutOptions,
108}
109
110impl DictStrategy {
111 pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
112 codes: Codes,
113 values: Values,
114 fallback: Fallback,
115 options: DictLayoutOptions,
116 ) -> Self {
117 Self {
118 codes: Arc::new(codes),
119 values: Arc::new(values),
120 fallback: Arc::new(fallback),
121 options,
122 }
123 }
124}
125
126#[async_trait]
127impl LayoutStrategy for DictStrategy {
128 async fn write_stream(
129 &self,
130 ctx: ArrayContext,
131 segment_sink: SegmentSinkRef,
132 stream: SendableSequentialStream,
133 mut eof: SequencePointer,
134 handle: Handle,
135 ) -> VortexResult<LayoutRef> {
136 if !dict_layout_supported(stream.dtype()) {
138 return self
139 .fallback
140 .write_stream(ctx, segment_sink, stream, eof, handle)
141 .await;
142 }
143
144 let options = self.options.clone();
145 let dtype = stream.dtype().clone();
146
147 let (stream, first_chunk) = peek_first_chunk(stream).await?;
149 let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
150
151 let should_fallback = match first_chunk {
152 None => true, Some(chunk) => {
154 let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
155 !compressed.is::<DictVTable>()
156 }
157 };
158 if should_fallback {
159 return self
161 .fallback
162 .write_stream(ctx, segment_sink, stream, eof, handle)
163 .await;
164 }
165
166 let dict_stream = dict_encode_stream(stream, options.constraints.into());
170
171 let runs = DictionaryTransformer::new(dict_stream);
174
175 let dtype2 = dtype.clone();
176 let child_layouts = stream! {
177 pin_mut!(runs);
178
179 while let Some((codes_stream, values_fut)) = runs.next().await {
180 let codes = self.codes.clone();
181 let codes_eof = eof.split_off();
182 let ctx2 = ctx.clone();
183 let segment_sink2 = segment_sink.clone();
184 let codes_fut = handle.spawn_nested(move |h| async move {
185 codes.write_stream(
186 ctx2,
187 segment_sink2,
188 codes_stream.sendable(),
189 codes_eof,
190 h,
191 ).await
192 });
193
194 let values = self.values.clone();
195 let values_eof = eof.split_off();
196 let ctx2 = ctx.clone();
197 let segment_sink2 = segment_sink.clone();
198 let dtype2 = dtype2.clone();
199 let values_layout = handle.spawn_nested(move |h| async move {
200 values.write_stream(
201 ctx2,
202 segment_sink2,
203 SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
204 values_eof,
205 h,
206 ).await
207 });
208
209 yield async move {
210 try_join!(codes_fut, values_layout)
211 }.boxed();
212 }
213 };
214
215 let mut child_layouts = child_layouts
216 .buffered(usize::MAX)
217 .map(|result| {
218 let (codes_layout, values_layout) = result?;
219 Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
221 })
222 .try_collect::<Vec<_>>()
223 .await?;
224
225 if child_layouts.len() == 1 {
226 return Ok(child_layouts.remove(0));
227 }
228
229 let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
230 Ok(ChunkedLayout::new(
231 row_count,
232 dtype,
233 OwnedLayoutChildren::layout_children(child_layouts),
234 )
235 .into_layout())
236 }
237
238 fn buffered_bytes(&self) -> u64 {
239 self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
240 }
241}
242
243enum DictionaryChunk {
244 Codes {
245 seq_id: SequenceId,
246 codes: ArrayRef,
247 codes_ptype: PType,
248 },
249 Values((SequenceId, ArrayRef)),
250}
251
252type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
253
254fn dict_encode_stream(
255 input: SendableSequentialStream,
256 constraints: DictConstraints,
257) -> DictionaryStream {
258 Box::pin(try_stream! {
259 let mut state = DictStreamState {
260 encoder: None,
261 constraints,
262 };
263
264 let input = input.peekable();
265 pin_mut!(input);
266
267 while let Some(item) = input.next().await {
268 let (sequence_id, chunk) = item?;
269
270 match input.as_mut().peek().await {
274 Some(_) => {
275 let mut labeler = DictChunkLabeler::new(sequence_id);
276 let chunks = state.encode(&mut labeler, chunk)?;
277 drop(labeler);
278 for dict_chunk in chunks {
279 yield dict_chunk;
280 }
281 }
282 None => {
283 let mut labeler = DictChunkLabeler::new(sequence_id);
285 let encoded = state.encode(&mut labeler, chunk)?;
286 let drained = state.drain_values(&mut labeler);
287 drop(labeler);
288 for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
289 yield dict_chunk;
290 }
291 }
292 }
293 }
294 })
295}
296
297struct DictStreamState {
298 encoder: Option<Box<dyn DictEncoder>>,
299 constraints: DictConstraints,
300}
301
302impl DictStreamState {
303 fn encode(
304 &mut self,
305 labeler: &mut DictChunkLabeler,
306 chunk: ArrayRef,
307 ) -> VortexResult<Vec<DictionaryChunk>> {
308 let mut res = Vec::new();
309 let mut to_be_encoded = Some(chunk);
310 while let Some(remaining) = to_be_encoded.take() {
311 match self.encoder.take() {
312 None => match start_encoding(&self.constraints, &remaining)? {
313 EncodingState::Continue((encoder, encoded)) => {
314 let ptype = encoder.codes_ptype();
315 res.push(labeler.codes(encoded, ptype));
316 self.encoder = Some(encoder);
317 }
318 EncodingState::Done((values, encoded, unencoded)) => {
319 let ptype = PType::try_from(encoded.dtype())
321 .vortex_expect("codes should be primitive");
322 res.push(labeler.codes(encoded, ptype));
323 res.push(labeler.values(values));
324 to_be_encoded = Some(unencoded);
325 }
326 },
327 Some(encoder) => {
328 let ptype = encoder.codes_ptype();
329 match encode_chunk(encoder, &remaining)? {
330 EncodingState::Continue((encoder, encoded)) => {
331 res.push(labeler.codes(encoded, ptype));
332 self.encoder = Some(encoder);
333 }
334 EncodingState::Done((values, encoded, unencoded)) => {
335 res.push(labeler.codes(encoded, ptype));
336 res.push(labeler.values(values));
337 to_be_encoded = Some(unencoded);
338 }
339 }
340 }
341 }
342 }
343 Ok(res)
344 }
345
346 fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
347 match self.encoder.as_mut() {
348 None => Vec::new(),
349 Some(encoder) => vec![labeler.values(encoder.reset())],
350 }
351 }
352}
353
354struct DictChunkLabeler {
355 sequence_pointer: SequencePointer,
356}
357
358impl DictChunkLabeler {
359 fn new(starting_id: SequenceId) -> Self {
360 let sequence_pointer = starting_id.descend();
361 Self { sequence_pointer }
362 }
363
364 fn codes(&mut self, chunk: ArrayRef, ptype: PType) -> DictionaryChunk {
365 DictionaryChunk::Codes {
366 seq_id: self.sequence_pointer.advance(),
367 codes: chunk,
368 codes_ptype: ptype,
369 }
370 }
371
372 fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
373 DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
374 }
375}
376
377type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
378
379struct DictionaryTransformer {
380 input: DictionaryStream,
381 active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
382 active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
383 pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
384}
385
386impl DictionaryTransformer {
387 fn new(input: DictionaryStream) -> Self {
388 Self {
389 input,
390 active_codes_tx: None,
391 active_values_tx: None,
392 pending_send: None,
393 }
394 }
395}
396
397impl Stream for DictionaryTransformer {
398 type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
399
400 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
401 loop {
402 if let Some(mut send_fut) = self.pending_send.take() {
404 match send_fut.poll_unpin(cx) {
405 Poll::Ready(Ok(())) => {
406 }
408 Poll::Ready(Err(_)) => {
409 self.active_codes_tx = None;
411 if let Some(values_tx) = self.active_values_tx.take() {
412 drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
413 }
414 }
415 Poll::Pending => {
416 self.pending_send = Some(send_fut);
418 return Poll::Pending;
419 }
420 }
421 }
422
423 match self.input.poll_next_unpin(cx) {
424 Poll::Ready(Some(Ok(DictionaryChunk::Codes {
425 seq_id,
426 codes,
427 codes_ptype,
428 }))) => {
429 if self.active_codes_tx.is_none() {
430 let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
432 let (values_tx, values_rx) = oneshot::channel();
433
434 self.active_codes_tx = Some(codes_tx.clone());
435 self.active_values_tx = Some(values_tx);
436
437 let codes_dtype = DType::Primitive(codes_ptype, Nullability::NonNullable);
439
440 self.pending_send =
442 Some(Box::pin(
443 async move { codes_tx.send(Ok((seq_id, codes))).await },
444 ));
445
446 let codes_stream = SequentialStreamAdapter::new(
448 codes_dtype,
449 codes_rx.into_stream().boxed(),
450 )
451 .sendable();
452
453 let values_future = async move {
454 values_rx
455 .await
456 .map_err(|e| vortex_err!("values sender dropped: {}", e))
457 .flatten()
458 }
459 .boxed();
460
461 return Poll::Ready(Some((codes_stream, values_future)));
462 }
463
464 if let Some(tx) = &self.active_codes_tx {
466 let tx = tx.clone();
467 self.pending_send =
468 Some(Box::pin(async move { tx.send(Ok((seq_id, codes))).await }));
469 }
470 }
471 Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
472 if let Some(values_tx) = self.active_values_tx.take() {
474 drop(values_tx.send(Ok(values)));
475 }
476 self.active_codes_tx = None; }
478 Poll::Ready(Some(Err(e))) => {
479 if let Some(values_tx) = self.active_values_tx.take() {
481 drop(values_tx.send(Err(e)));
482 }
483 self.active_codes_tx = None;
484 return Poll::Ready(None);
486 }
487 Poll::Ready(None) => {
488 if let Some(values_tx) = self.active_values_tx.take() {
490 drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
491 }
492 self.active_codes_tx = None;
493 return Poll::Ready(None);
494 }
495 Poll::Pending => return Poll::Pending,
496 }
497 }
498 }
499}
500
501async fn peek_first_chunk(
502 mut stream: BoxStream<'static, SequencedChunk>,
503) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
504 match stream.next().await {
505 None => Ok((stream.boxed(), None)),
506 Some(Err(e)) => Err(e),
507 Some(Ok((sequence_id, chunk))) => {
508 let chunk_clone = chunk.clone();
509 let reconstructed_stream =
510 once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
511 Ok((reconstructed_stream.boxed(), Some(chunk)))
512 }
513 }
514}
515
516pub fn dict_layout_supported(dtype: &DType) -> bool {
517 matches!(
518 dtype,
519 DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
520 )
521}
522
523#[derive(prost::Message)]
524pub struct DictLayoutMetadata {
525 #[prost(enumeration = "PType", tag = "1")]
526 codes_ptype: i32,
528}
529
530impl DictLayoutMetadata {
531 pub fn new(codes_ptype: PType) -> Self {
532 let mut metadata = Self::default();
533 metadata.set_codes_ptype(codes_ptype);
534 metadata
535 }
536}
537
538enum EncodingState {
539 Continue((Box<dyn DictEncoder>, ArrayRef)),
540 Done((ArrayRef, ArrayRef, ArrayRef)),
542}
543
544fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
545 let encoder = dict_encoder(chunk, constraints);
546 encode_chunk(encoder, chunk)
547}
548
549fn encode_chunk(
550 mut encoder: Box<dyn DictEncoder>,
551 chunk: &dyn Array,
552) -> VortexResult<EncodingState> {
553 let encoded = encoder.encode(chunk);
554 match remainder(chunk, encoded.len())? {
555 None => Ok(EncodingState::Continue((encoder, encoded))),
556 Some(unencoded) => Ok(EncodingState::Done((encoder.reset(), encoded, unencoded))),
557 }
558}
559
560fn remainder(array: &dyn Array, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
561 if encoded_len < array.len() {
562 Ok(Some(array.slice(encoded_len..array.len())?))
563 } else {
564 Ok(None)
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use futures::StreamExt;
571 use vortex_array::IntoArray;
572 use vortex_array::arrays::VarBinArray;
573 use vortex_array::builders::dict::DictConstraints;
574 use vortex_dtype::DType;
575 use vortex_dtype::Nullability::NonNullable;
576 use vortex_dtype::PType;
577
578 use super::DictionaryTransformer;
579 use super::dict_encode_stream;
580 use crate::sequence::SequenceId;
581 use crate::sequence::SequentialStream;
582 use crate::sequence::SequentialStreamAdapter;
583 use crate::sequence::SequentialStreamExt;
584
585 #[tokio::test]
590 async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
591 let constraints = DictConstraints {
593 max_bytes: 1024 * 1024,
594 max_len: 100,
595 };
596
597 let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
599
600 let mut pointer = SequenceId::root();
602 let input_stream = SequentialStreamAdapter::new(
603 arr.dtype().clone(),
604 futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
605 )
606 .sendable();
607
608 let dict_stream = dict_encode_stream(input_stream, constraints);
610
611 let mut transformer = DictionaryTransformer::new(dict_stream);
613
614 let (codes_stream, _values_fut) = transformer
616 .next()
617 .await
618 .expect("expected at least one dictionary run");
619
620 assert_eq!(
622 codes_stream.dtype(),
623 &DType::Primitive(PType::U8, NonNullable),
624 "codes stream should use U8 dtype for small dictionaries, not U16"
625 );
626 }
627
628 #[tokio::test]
630 async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
631 let constraints = DictConstraints {
633 max_bytes: 1024 * 1024,
634 max_len: 1000,
635 };
636
637 let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
639 let arr =
640 VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
641
642 let mut pointer = SequenceId::root();
644 let input_stream = SequentialStreamAdapter::new(
645 arr.dtype().clone(),
646 futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
647 )
648 .sendable();
649
650 let dict_stream = dict_encode_stream(input_stream, constraints);
652
653 let mut transformer = DictionaryTransformer::new(dict_stream);
655
656 let (codes_stream, _values_fut) = transformer
658 .next()
659 .await
660 .expect("expected at least one dictionary run");
661
662 assert_eq!(
664 codes_stream.dtype(),
665 &DType::Primitive(PType::U16, NonNullable),
666 "codes stream should use U16 dtype for dictionaries with >255 entries"
667 );
668 }
669}