1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use async_stream::stream;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::Stream;
14use futures::StreamExt;
15use futures::TryStreamExt;
16use futures::future::BoxFuture;
17use futures::pin_mut;
18use futures::stream::BoxStream;
19use futures::stream::once;
20use futures::try_join;
21use vortex_array::ArrayContext;
22use vortex_array::ArrayRef;
23use vortex_array::VortexSessionExecute;
24use vortex_array::arrays::Dict;
25use vortex_array::builders::dict::DictConstraints;
26use vortex_array::builders::dict::DictEncoder;
27use vortex_array::builders::dict::dict_encoder;
28use vortex_array::dtype::DType;
29use vortex_array::dtype::Nullability;
30use vortex_array::dtype::PType;
31use vortex_btrblocks::BtrBlocksCompressor;
32use vortex_error::VortexError;
33use vortex_error::VortexExpect;
34use vortex_error::VortexResult;
35use vortex_error::vortex_err;
36use vortex_io::kanal_ext::KanalExt;
37use vortex_io::session::RuntimeSessionExt;
38use vortex_session::VortexSession;
39
40use crate::IntoLayout;
41use crate::LayoutRef;
42use crate::LayoutStrategy;
43use crate::OwnedLayoutChildren;
44use crate::layouts::chunked::ChunkedLayout;
45use crate::layouts::dict::DictLayout;
46use crate::segments::SegmentSinkRef;
47use crate::sequence::SendableSequentialStream;
48use crate::sequence::SequenceId;
49use crate::sequence::SequencePointer;
50use crate::sequence::SequentialStream;
51use crate::sequence::SequentialStreamAdapter;
52use crate::sequence::SequentialStreamExt;
53
54#[derive(Clone)]
61pub struct DictLayoutConstraints {
62 pub max_bytes: usize,
64 pub max_len: u16,
73}
74
75impl From<DictLayoutConstraints> for DictConstraints {
76 fn from(value: DictLayoutConstraints) -> Self {
77 DictConstraints {
78 max_bytes: value.max_bytes,
79 max_len: value.max_len as usize,
80 }
81 }
82}
83
84impl Default for DictLayoutConstraints {
85 fn default() -> Self {
86 Self {
87 max_bytes: 1024 * 1024,
88 max_len: u16::MAX,
89 }
90 }
91}
92
93#[derive(Clone, Default)]
94pub struct DictLayoutOptions {
95 pub constraints: DictLayoutConstraints,
96}
97
98#[derive(Clone)]
104pub struct DictStrategy {
105 codes: Arc<dyn LayoutStrategy>,
106 values: Arc<dyn LayoutStrategy>,
107 fallback: Arc<dyn LayoutStrategy>,
108 options: DictLayoutOptions,
109}
110
111impl DictStrategy {
112 pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
113 codes: Codes,
114 values: Values,
115 fallback: Fallback,
116 options: DictLayoutOptions,
117 ) -> Self {
118 Self {
119 codes: Arc::new(codes),
120 values: Arc::new(values),
121 fallback: Arc::new(fallback),
122 options,
123 }
124 }
125}
126
127#[async_trait]
128impl LayoutStrategy for DictStrategy {
129 async fn write_stream(
130 &self,
131 ctx: ArrayContext,
132 segment_sink: SegmentSinkRef,
133 stream: SendableSequentialStream,
134 mut eof: SequencePointer,
135 session: &VortexSession,
136 ) -> VortexResult<LayoutRef> {
137 if !dict_layout_supported(stream.dtype()) {
139 return self
140 .fallback
141 .write_stream(ctx, segment_sink, stream, eof, session)
142 .await;
143 }
144
145 let options = self.options.clone();
146 let dtype = stream.dtype().clone();
147
148 let (stream, first_chunk) = peek_first_chunk(stream).await?;
150 let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
151
152 let should_fallback = match first_chunk {
153 None => true, Some(chunk) => {
155 let mut exec_ctx = session.create_execution_ctx();
156 let compressed = BtrBlocksCompressor::default().compress(&chunk, &mut exec_ctx)?;
157 !compressed.is::<Dict>()
158 }
159 };
160 if should_fallback {
161 return self
163 .fallback
164 .write_stream(ctx, segment_sink, stream, eof, session)
165 .await;
166 }
167
168 let dict_stream = dict_encode_stream(stream, options.constraints.into());
172
173 let runs = DictionaryTransformer::new(dict_stream);
176
177 let handle = session.handle();
178 let dtype2 = dtype.clone();
179 let child_layouts = stream! {
180 pin_mut!(runs);
181
182 while let Some((codes_stream, values_fut)) = runs.next().await {
183 let codes = Arc::clone(&self.codes);
184 let codes_eof = eof.split_off();
185 let ctx2 = ctx.clone();
186 let segment_sink2 = Arc::clone(&segment_sink);
187 let session2 = session.clone();
188 let codes_fut = handle.spawn_nested(move |h| async move {
189 let session2 = session2.with_handle(h);
190 codes.write_stream(
191 ctx2,
192 segment_sink2,
193 codes_stream.sendable(),
194 codes_eof,
195 &session2,
196 ).await
197 });
198
199 let values = Arc::clone(&self.values);
200 let values_eof = eof.split_off();
201 let ctx2 = ctx.clone();
202 let segment_sink2 = Arc::clone(&segment_sink);
203 let dtype2 = dtype2.clone();
204 let session2 = session.clone();
205 let values_layout = handle.spawn_nested(move |h| async move {
206 let session2 = session2.with_handle(h);
207 values.write_stream(
208 ctx2,
209 segment_sink2,
210 SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
211 values_eof,
212 &session2,
213 ).await
214 });
215
216 yield async move {
217 try_join!(codes_fut, values_layout)
218 }.boxed();
219 }
220 };
221
222 let mut child_layouts = child_layouts
223 .buffered(usize::MAX)
224 .map(|result| {
225 let (codes_layout, values_layout) = result?;
226 Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
228 })
229 .try_collect::<Vec<_>>()
230 .await?;
231
232 if child_layouts.len() == 1 {
233 return Ok(child_layouts.remove(0));
234 }
235
236 let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
237 Ok(ChunkedLayout::new(
238 row_count,
239 dtype,
240 OwnedLayoutChildren::layout_children(child_layouts),
241 )
242 .into_layout())
243 }
244
245 fn buffered_bytes(&self) -> u64 {
246 self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
247 }
248}
249
250enum DictionaryChunk {
251 Codes {
252 seq_id: SequenceId,
253 codes: ArrayRef,
254 codes_ptype: PType,
255 },
256 Values((SequenceId, ArrayRef)),
257}
258
259type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
260
261fn dict_encode_stream(
262 input: SendableSequentialStream,
263 constraints: DictConstraints,
264) -> DictionaryStream {
265 Box::pin(try_stream! {
266 let mut state = DictStreamState {
267 encoder: None,
268 constraints,
269 };
270
271 let input = input.peekable();
272 pin_mut!(input);
273
274 while let Some(item) = input.next().await {
275 let (sequence_id, chunk) = item?;
276
277 match input.as_mut().peek().await {
281 Some(_) => {
282 let mut labeler = DictChunkLabeler::new(sequence_id);
283 let chunks = state.encode(&mut labeler, chunk)?;
284 drop(labeler);
285 for dict_chunk in chunks {
286 yield dict_chunk;
287 }
288 }
289 None => {
290 let mut labeler = DictChunkLabeler::new(sequence_id);
292 let encoded = state.encode(&mut labeler, chunk)?;
293 let drained = state.drain_values(&mut labeler);
294 drop(labeler);
295 for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
296 yield dict_chunk;
297 }
298 }
299 }
300 }
301 })
302}
303
304struct DictStreamState {
305 encoder: Option<Box<dyn DictEncoder>>,
306 constraints: DictConstraints,
307}
308
309impl DictStreamState {
310 fn encode(
311 &mut self,
312 labeler: &mut DictChunkLabeler,
313 chunk: ArrayRef,
314 ) -> VortexResult<Vec<DictionaryChunk>> {
315 let mut res = Vec::new();
316 let mut to_be_encoded = Some(chunk);
317 while let Some(remaining) = to_be_encoded.take() {
318 match self.encoder.take() {
319 None => match start_encoding(&self.constraints, &remaining)? {
320 EncodingState::Continue((encoder, encoded)) => {
321 let ptype = encoder.codes_ptype();
322 res.push(labeler.codes(encoded, ptype));
323 self.encoder = Some(encoder);
324 }
325 EncodingState::Done((values, encoded, unencoded)) => {
326 let ptype = PType::try_from(encoded.dtype())
328 .vortex_expect("codes should be primitive");
329 res.push(labeler.codes(encoded, ptype));
330 res.push(labeler.values(values));
331 to_be_encoded = Some(unencoded);
332 }
333 },
334 Some(encoder) => {
335 let ptype = encoder.codes_ptype();
336 match encode_chunk(encoder, &remaining)? {
337 EncodingState::Continue((encoder, encoded)) => {
338 res.push(labeler.codes(encoded, ptype));
339 self.encoder = Some(encoder);
340 }
341 EncodingState::Done((values, encoded, unencoded)) => {
342 res.push(labeler.codes(encoded, ptype));
343 res.push(labeler.values(values));
344 to_be_encoded = Some(unencoded);
345 }
346 }
347 }
348 }
349 }
350 Ok(res)
351 }
352
353 fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
354 match self.encoder.as_mut() {
355 None => Vec::new(),
356 Some(encoder) => vec![labeler.values(encoder.reset())],
357 }
358 }
359}
360
361struct DictChunkLabeler {
362 sequence_pointer: SequencePointer,
363}
364
365impl DictChunkLabeler {
366 fn new(starting_id: SequenceId) -> Self {
367 let sequence_pointer = starting_id.descend();
368 Self { sequence_pointer }
369 }
370
371 fn codes(&mut self, chunk: ArrayRef, ptype: PType) -> DictionaryChunk {
372 DictionaryChunk::Codes {
373 seq_id: self.sequence_pointer.advance(),
374 codes: chunk,
375 codes_ptype: ptype,
376 }
377 }
378
379 fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
380 DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
381 }
382}
383
384type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
385
386struct DictionaryTransformer {
387 input: DictionaryStream,
388 active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
389 active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
390 pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
391}
392
393impl DictionaryTransformer {
394 fn new(input: DictionaryStream) -> Self {
395 Self {
396 input,
397 active_codes_tx: None,
398 active_values_tx: None,
399 pending_send: None,
400 }
401 }
402}
403
404impl Stream for DictionaryTransformer {
405 type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
406
407 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
408 loop {
409 if let Some(mut send_fut) = self.pending_send.take() {
411 match send_fut.poll_unpin(cx) {
412 Poll::Ready(Ok(())) => {
413 }
415 Poll::Ready(Err(_)) => {
416 self.active_codes_tx = None;
418 if let Some(values_tx) = self.active_values_tx.take() {
419 drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
420 }
421 }
422 Poll::Pending => {
423 self.pending_send = Some(send_fut);
425 return Poll::Pending;
426 }
427 }
428 }
429
430 match self.input.poll_next_unpin(cx) {
431 Poll::Ready(Some(Ok(DictionaryChunk::Codes {
432 seq_id,
433 codes,
434 codes_ptype,
435 }))) => {
436 if self.active_codes_tx.is_none() {
437 let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
439 let (values_tx, values_rx) = oneshot::channel();
440
441 self.active_codes_tx = Some(codes_tx.clone());
442 self.active_values_tx = Some(values_tx);
443
444 let codes_dtype = DType::Primitive(codes_ptype, Nullability::NonNullable);
446
447 self.pending_send =
449 Some(Box::pin(
450 async move { codes_tx.send(Ok((seq_id, codes))).await },
451 ));
452
453 let codes_stream = SequentialStreamAdapter::new(
455 codes_dtype,
456 codes_rx.into_stream().boxed(),
457 )
458 .sendable();
459
460 let values_future = async move {
461 values_rx
462 .await
463 .map_err(|e| vortex_err!("values sender dropped: {}", e))
464 .flatten()
465 }
466 .boxed();
467
468 return Poll::Ready(Some((codes_stream, values_future)));
469 }
470
471 if let Some(tx) = &self.active_codes_tx {
473 let tx = tx.clone();
474 self.pending_send =
475 Some(Box::pin(async move { tx.send(Ok((seq_id, codes))).await }));
476 }
477 }
478 Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
479 if let Some(values_tx) = self.active_values_tx.take() {
481 drop(values_tx.send(Ok(values)));
482 }
483 self.active_codes_tx = None; }
485 Poll::Ready(Some(Err(e))) => {
486 if let Some(values_tx) = self.active_values_tx.take() {
488 drop(values_tx.send(Err(e)));
489 }
490 self.active_codes_tx = None;
491 return Poll::Ready(None);
493 }
494 Poll::Ready(None) => {
495 if let Some(values_tx) = self.active_values_tx.take() {
497 drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
498 }
499 self.active_codes_tx = None;
500 return Poll::Ready(None);
501 }
502 Poll::Pending => return Poll::Pending,
503 }
504 }
505 }
506}
507
508async fn peek_first_chunk(
509 mut stream: BoxStream<'static, SequencedChunk>,
510) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
511 match stream.next().await {
512 None => Ok((stream.boxed(), None)),
513 Some(Err(e)) => Err(e),
514 Some(Ok((sequence_id, chunk))) => {
515 let chunk_clone = chunk.clone();
516 let reconstructed_stream =
517 once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
518 Ok((reconstructed_stream.boxed(), Some(chunk)))
519 }
520 }
521}
522
523pub fn dict_layout_supported(dtype: &DType) -> bool {
524 matches!(
525 dtype,
526 DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
527 )
528}
529
530#[derive(prost::Message)]
531pub struct DictLayoutMetadata {
532 #[prost(enumeration = "PType", tag = "1")]
533 codes_ptype: i32,
535}
536
537impl DictLayoutMetadata {
538 pub fn new(codes_ptype: PType) -> Self {
539 let mut metadata = Self::default();
540 metadata.set_codes_ptype(codes_ptype);
541 metadata
542 }
543}
544
545enum EncodingState {
546 Continue((Box<dyn DictEncoder>, ArrayRef)),
547 Done((ArrayRef, ArrayRef, ArrayRef)),
549}
550
551fn start_encoding(constraints: &DictConstraints, chunk: &ArrayRef) -> VortexResult<EncodingState> {
552 let encoder = dict_encoder(chunk, constraints);
553 encode_chunk(encoder, chunk)
554}
555
556fn encode_chunk(
557 mut encoder: Box<dyn DictEncoder>,
558 chunk: &ArrayRef,
559) -> VortexResult<EncodingState> {
560 let encoded = encoder.encode(chunk);
561 match remainder(chunk, encoded.len())? {
562 None => Ok(EncodingState::Continue((encoder, encoded))),
563 Some(unencoded) => Ok(EncodingState::Done((encoder.reset(), encoded, unencoded))),
564 }
565}
566
567fn remainder(array: &ArrayRef, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
568 if encoded_len < array.len() {
569 Ok(Some(array.slice(encoded_len..array.len())?))
570 } else {
571 Ok(None)
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use futures::StreamExt;
578 use vortex_array::IntoArray;
579 use vortex_array::arrays::VarBinArray;
580 use vortex_array::builders::dict::DictConstraints;
581 use vortex_array::dtype::DType;
582 use vortex_array::dtype::Nullability::NonNullable;
583 use vortex_array::dtype::PType;
584
585 use super::DictionaryTransformer;
586 use super::dict_encode_stream;
587 use crate::sequence::SequenceId;
588 use crate::sequence::SequentialStream;
589 use crate::sequence::SequentialStreamAdapter;
590 use crate::sequence::SequentialStreamExt;
591
592 #[tokio::test]
597 async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
598 let constraints = DictConstraints {
600 max_bytes: 1024 * 1024,
601 max_len: 100,
602 };
603
604 let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
606
607 let mut pointer = SequenceId::root();
609 let input_stream = SequentialStreamAdapter::new(
610 arr.dtype().clone(),
611 futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
612 )
613 .sendable();
614
615 let dict_stream = dict_encode_stream(input_stream, constraints);
617
618 let mut transformer = DictionaryTransformer::new(dict_stream);
620
621 let (codes_stream, _values_fut) = transformer
623 .next()
624 .await
625 .expect("expected at least one dictionary run");
626
627 assert_eq!(
629 codes_stream.dtype(),
630 &DType::Primitive(PType::U8, NonNullable),
631 "codes stream should use U8 dtype for small dictionaries, not U16"
632 );
633 }
634
635 #[tokio::test]
637 async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
638 let constraints = DictConstraints {
640 max_bytes: 1024 * 1024,
641 max_len: 1000,
642 };
643
644 let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
646 let arr =
647 VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
648
649 let mut pointer = SequenceId::root();
651 let input_stream = SequentialStreamAdapter::new(
652 arr.dtype().clone(),
653 futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
654 )
655 .sendable();
656
657 let dict_stream = dict_encode_stream(input_stream, constraints);
659
660 let mut transformer = DictionaryTransformer::new(dict_stream);
662
663 let (codes_stream, _values_fut) = transformer
665 .next()
666 .await
667 .expect("expected at least one dictionary run");
668
669 assert_eq!(
671 codes_stream.dtype(),
672 &DType::Primitive(PType::U16, NonNullable),
673 "codes stream should use U16 dtype for dictionaries with >255 entries"
674 );
675 }
676}