1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use async_stream::stream;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::Stream;
14use futures::StreamExt;
15use futures::TryStreamExt;
16use futures::future::BoxFuture;
17use futures::pin_mut;
18use futures::stream::BoxStream;
19use futures::stream::once;
20use futures::try_join;
21use vortex_array::ArrayContext;
22use vortex_array::ArrayRef;
23use vortex_array::arrays::Dict;
24use vortex_array::builders::dict::DictConstraints;
25use vortex_array::builders::dict::DictEncoder;
26use vortex_array::builders::dict::dict_encoder;
27use vortex_array::dtype::DType;
28use vortex_array::dtype::Nullability;
29use vortex_array::dtype::PType;
30use vortex_btrblocks::BtrBlocksCompressor;
31use vortex_error::VortexError;
32use vortex_error::VortexExpect;
33use vortex_error::VortexResult;
34use vortex_error::vortex_err;
35use vortex_io::kanal_ext::KanalExt;
36use vortex_io::session::RuntimeSessionExt;
37use vortex_session::VortexSession;
38
39use crate::IntoLayout;
40use crate::LayoutRef;
41use crate::LayoutStrategy;
42use crate::OwnedLayoutChildren;
43use crate::layouts::chunked::ChunkedLayout;
44use crate::layouts::dict::DictLayout;
45use crate::segments::SegmentSinkRef;
46use crate::sequence::SendableSequentialStream;
47use crate::sequence::SequenceId;
48use crate::sequence::SequencePointer;
49use crate::sequence::SequentialStream;
50use crate::sequence::SequentialStreamAdapter;
51use crate::sequence::SequentialStreamExt;
52
53#[derive(Clone)]
60pub struct DictLayoutConstraints {
61 pub max_bytes: usize,
63 pub max_len: u16,
72}
73
74impl From<DictLayoutConstraints> for DictConstraints {
75 fn from(value: DictLayoutConstraints) -> Self {
76 DictConstraints {
77 max_bytes: value.max_bytes,
78 max_len: value.max_len as usize,
79 }
80 }
81}
82
83impl Default for DictLayoutConstraints {
84 fn default() -> Self {
85 Self {
86 max_bytes: 1024 * 1024,
87 max_len: u16::MAX,
88 }
89 }
90}
91
92#[derive(Clone, Default)]
93pub struct DictLayoutOptions {
94 pub constraints: DictLayoutConstraints,
95}
96
97#[derive(Clone)]
103pub struct DictStrategy {
104 codes: Arc<dyn LayoutStrategy>,
105 values: Arc<dyn LayoutStrategy>,
106 fallback: Arc<dyn LayoutStrategy>,
107 options: DictLayoutOptions,
108}
109
110impl DictStrategy {
111 pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
112 codes: Codes,
113 values: Values,
114 fallback: Fallback,
115 options: DictLayoutOptions,
116 ) -> Self {
117 Self {
118 codes: Arc::new(codes),
119 values: Arc::new(values),
120 fallback: Arc::new(fallback),
121 options,
122 }
123 }
124}
125
126#[async_trait]
127impl LayoutStrategy for DictStrategy {
128 async fn write_stream(
129 &self,
130 ctx: ArrayContext,
131 segment_sink: SegmentSinkRef,
132 stream: SendableSequentialStream,
133 mut eof: SequencePointer,
134 session: &VortexSession,
135 ) -> VortexResult<LayoutRef> {
136 if !dict_layout_supported(stream.dtype()) {
138 return self
139 .fallback
140 .write_stream(ctx, segment_sink, stream, eof, session)
141 .await;
142 }
143
144 let options = self.options.clone();
145 let dtype = stream.dtype().clone();
146
147 let (stream, first_chunk) = peek_first_chunk(stream).await?;
149 let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
150
151 let should_fallback = match first_chunk {
152 None => true, Some(chunk) => {
154 let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
155 !compressed.is::<Dict>()
156 }
157 };
158 if should_fallback {
159 return self
161 .fallback
162 .write_stream(ctx, segment_sink, stream, eof, session)
163 .await;
164 }
165
166 let dict_stream = dict_encode_stream(stream, options.constraints.into());
170
171 let runs = DictionaryTransformer::new(dict_stream);
174
175 let handle = session.handle();
176 let dtype2 = dtype.clone();
177 let child_layouts = stream! {
178 pin_mut!(runs);
179
180 while let Some((codes_stream, values_fut)) = runs.next().await {
181 let codes = Arc::clone(&self.codes);
182 let codes_eof = eof.split_off();
183 let ctx2 = ctx.clone();
184 let segment_sink2 = Arc::clone(&segment_sink);
185 let session2 = session.clone();
186 let codes_fut = handle.spawn_nested(move |h| async move {
187 let session2 = session2.with_handle(h);
188 codes.write_stream(
189 ctx2,
190 segment_sink2,
191 codes_stream.sendable(),
192 codes_eof,
193 &session2,
194 ).await
195 });
196
197 let values = Arc::clone(&self.values);
198 let values_eof = eof.split_off();
199 let ctx2 = ctx.clone();
200 let segment_sink2 = Arc::clone(&segment_sink);
201 let dtype2 = dtype2.clone();
202 let session2 = session.clone();
203 let values_layout = handle.spawn_nested(move |h| async move {
204 let session2 = session2.with_handle(h);
205 values.write_stream(
206 ctx2,
207 segment_sink2,
208 SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
209 values_eof,
210 &session2,
211 ).await
212 });
213
214 yield async move {
215 try_join!(codes_fut, values_layout)
216 }.boxed();
217 }
218 };
219
220 let mut child_layouts = child_layouts
221 .buffered(usize::MAX)
222 .map(|result| {
223 let (codes_layout, values_layout) = result?;
224 Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
226 })
227 .try_collect::<Vec<_>>()
228 .await?;
229
230 if child_layouts.len() == 1 {
231 return Ok(child_layouts.remove(0));
232 }
233
234 let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
235 Ok(ChunkedLayout::new(
236 row_count,
237 dtype,
238 OwnedLayoutChildren::layout_children(child_layouts),
239 )
240 .into_layout())
241 }
242
243 fn buffered_bytes(&self) -> u64 {
244 self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
245 }
246}
247
248enum DictionaryChunk {
249 Codes {
250 seq_id: SequenceId,
251 codes: ArrayRef,
252 codes_ptype: PType,
253 },
254 Values((SequenceId, ArrayRef)),
255}
256
257type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
258
259fn dict_encode_stream(
260 input: SendableSequentialStream,
261 constraints: DictConstraints,
262) -> DictionaryStream {
263 Box::pin(try_stream! {
264 let mut state = DictStreamState {
265 encoder: None,
266 constraints,
267 };
268
269 let input = input.peekable();
270 pin_mut!(input);
271
272 while let Some(item) = input.next().await {
273 let (sequence_id, chunk) = item?;
274
275 match input.as_mut().peek().await {
279 Some(_) => {
280 let mut labeler = DictChunkLabeler::new(sequence_id);
281 let chunks = state.encode(&mut labeler, chunk)?;
282 drop(labeler);
283 for dict_chunk in chunks {
284 yield dict_chunk;
285 }
286 }
287 None => {
288 let mut labeler = DictChunkLabeler::new(sequence_id);
290 let encoded = state.encode(&mut labeler, chunk)?;
291 let drained = state.drain_values(&mut labeler);
292 drop(labeler);
293 for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
294 yield dict_chunk;
295 }
296 }
297 }
298 }
299 })
300}
301
302struct DictStreamState {
303 encoder: Option<Box<dyn DictEncoder>>,
304 constraints: DictConstraints,
305}
306
307impl DictStreamState {
308 fn encode(
309 &mut self,
310 labeler: &mut DictChunkLabeler,
311 chunk: ArrayRef,
312 ) -> VortexResult<Vec<DictionaryChunk>> {
313 let mut res = Vec::new();
314 let mut to_be_encoded = Some(chunk);
315 while let Some(remaining) = to_be_encoded.take() {
316 match self.encoder.take() {
317 None => match start_encoding(&self.constraints, &remaining)? {
318 EncodingState::Continue((encoder, encoded)) => {
319 let ptype = encoder.codes_ptype();
320 res.push(labeler.codes(encoded, ptype));
321 self.encoder = Some(encoder);
322 }
323 EncodingState::Done((values, encoded, unencoded)) => {
324 let ptype = PType::try_from(encoded.dtype())
326 .vortex_expect("codes should be primitive");
327 res.push(labeler.codes(encoded, ptype));
328 res.push(labeler.values(values));
329 to_be_encoded = Some(unencoded);
330 }
331 },
332 Some(encoder) => {
333 let ptype = encoder.codes_ptype();
334 match encode_chunk(encoder, &remaining)? {
335 EncodingState::Continue((encoder, encoded)) => {
336 res.push(labeler.codes(encoded, ptype));
337 self.encoder = Some(encoder);
338 }
339 EncodingState::Done((values, encoded, unencoded)) => {
340 res.push(labeler.codes(encoded, ptype));
341 res.push(labeler.values(values));
342 to_be_encoded = Some(unencoded);
343 }
344 }
345 }
346 }
347 }
348 Ok(res)
349 }
350
351 fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
352 match self.encoder.as_mut() {
353 None => Vec::new(),
354 Some(encoder) => vec![labeler.values(encoder.reset())],
355 }
356 }
357}
358
359struct DictChunkLabeler {
360 sequence_pointer: SequencePointer,
361}
362
363impl DictChunkLabeler {
364 fn new(starting_id: SequenceId) -> Self {
365 let sequence_pointer = starting_id.descend();
366 Self { sequence_pointer }
367 }
368
369 fn codes(&mut self, chunk: ArrayRef, ptype: PType) -> DictionaryChunk {
370 DictionaryChunk::Codes {
371 seq_id: self.sequence_pointer.advance(),
372 codes: chunk,
373 codes_ptype: ptype,
374 }
375 }
376
377 fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
378 DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
379 }
380}
381
382type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
383
384struct DictionaryTransformer {
385 input: DictionaryStream,
386 active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
387 active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
388 pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
389}
390
391impl DictionaryTransformer {
392 fn new(input: DictionaryStream) -> Self {
393 Self {
394 input,
395 active_codes_tx: None,
396 active_values_tx: None,
397 pending_send: None,
398 }
399 }
400}
401
402impl Stream for DictionaryTransformer {
403 type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
404
405 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
406 loop {
407 if let Some(mut send_fut) = self.pending_send.take() {
409 match send_fut.poll_unpin(cx) {
410 Poll::Ready(Ok(())) => {
411 }
413 Poll::Ready(Err(_)) => {
414 self.active_codes_tx = None;
416 if let Some(values_tx) = self.active_values_tx.take() {
417 drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
418 }
419 }
420 Poll::Pending => {
421 self.pending_send = Some(send_fut);
423 return Poll::Pending;
424 }
425 }
426 }
427
428 match self.input.poll_next_unpin(cx) {
429 Poll::Ready(Some(Ok(DictionaryChunk::Codes {
430 seq_id,
431 codes,
432 codes_ptype,
433 }))) => {
434 if self.active_codes_tx.is_none() {
435 let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
437 let (values_tx, values_rx) = oneshot::channel();
438
439 self.active_codes_tx = Some(codes_tx.clone());
440 self.active_values_tx = Some(values_tx);
441
442 let codes_dtype = DType::Primitive(codes_ptype, Nullability::NonNullable);
444
445 self.pending_send =
447 Some(Box::pin(
448 async move { codes_tx.send(Ok((seq_id, codes))).await },
449 ));
450
451 let codes_stream = SequentialStreamAdapter::new(
453 codes_dtype,
454 codes_rx.into_stream().boxed(),
455 )
456 .sendable();
457
458 let values_future = async move {
459 values_rx
460 .await
461 .map_err(|e| vortex_err!("values sender dropped: {}", e))
462 .flatten()
463 }
464 .boxed();
465
466 return Poll::Ready(Some((codes_stream, values_future)));
467 }
468
469 if let Some(tx) = &self.active_codes_tx {
471 let tx = tx.clone();
472 self.pending_send =
473 Some(Box::pin(async move { tx.send(Ok((seq_id, codes))).await }));
474 }
475 }
476 Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
477 if let Some(values_tx) = self.active_values_tx.take() {
479 drop(values_tx.send(Ok(values)));
480 }
481 self.active_codes_tx = None; }
483 Poll::Ready(Some(Err(e))) => {
484 if let Some(values_tx) = self.active_values_tx.take() {
486 drop(values_tx.send(Err(e)));
487 }
488 self.active_codes_tx = None;
489 return Poll::Ready(None);
491 }
492 Poll::Ready(None) => {
493 if let Some(values_tx) = self.active_values_tx.take() {
495 drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
496 }
497 self.active_codes_tx = None;
498 return Poll::Ready(None);
499 }
500 Poll::Pending => return Poll::Pending,
501 }
502 }
503 }
504}
505
506async fn peek_first_chunk(
507 mut stream: BoxStream<'static, SequencedChunk>,
508) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
509 match stream.next().await {
510 None => Ok((stream.boxed(), None)),
511 Some(Err(e)) => Err(e),
512 Some(Ok((sequence_id, chunk))) => {
513 let chunk_clone = chunk.clone();
514 let reconstructed_stream =
515 once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
516 Ok((reconstructed_stream.boxed(), Some(chunk)))
517 }
518 }
519}
520
521pub fn dict_layout_supported(dtype: &DType) -> bool {
522 matches!(
523 dtype,
524 DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
525 )
526}
527
528#[derive(prost::Message)]
529pub struct DictLayoutMetadata {
530 #[prost(enumeration = "PType", tag = "1")]
531 codes_ptype: i32,
533}
534
535impl DictLayoutMetadata {
536 pub fn new(codes_ptype: PType) -> Self {
537 let mut metadata = Self::default();
538 metadata.set_codes_ptype(codes_ptype);
539 metadata
540 }
541}
542
543enum EncodingState {
544 Continue((Box<dyn DictEncoder>, ArrayRef)),
545 Done((ArrayRef, ArrayRef, ArrayRef)),
547}
548
549fn start_encoding(constraints: &DictConstraints, chunk: &ArrayRef) -> VortexResult<EncodingState> {
550 let encoder = dict_encoder(chunk, constraints);
551 encode_chunk(encoder, chunk)
552}
553
554fn encode_chunk(
555 mut encoder: Box<dyn DictEncoder>,
556 chunk: &ArrayRef,
557) -> VortexResult<EncodingState> {
558 let encoded = encoder.encode(chunk);
559 match remainder(chunk, encoded.len())? {
560 None => Ok(EncodingState::Continue((encoder, encoded))),
561 Some(unencoded) => Ok(EncodingState::Done((encoder.reset(), encoded, unencoded))),
562 }
563}
564
565fn remainder(array: &ArrayRef, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
566 if encoded_len < array.len() {
567 Ok(Some(array.slice(encoded_len..array.len())?))
568 } else {
569 Ok(None)
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use futures::StreamExt;
576 use vortex_array::IntoArray;
577 use vortex_array::arrays::VarBinArray;
578 use vortex_array::builders::dict::DictConstraints;
579 use vortex_array::dtype::DType;
580 use vortex_array::dtype::Nullability::NonNullable;
581 use vortex_array::dtype::PType;
582
583 use super::DictionaryTransformer;
584 use super::dict_encode_stream;
585 use crate::sequence::SequenceId;
586 use crate::sequence::SequentialStream;
587 use crate::sequence::SequentialStreamAdapter;
588 use crate::sequence::SequentialStreamExt;
589
590 #[tokio::test]
595 async fn test_dict_transformer_uses_u8_for_small_dictionaries() {
596 let constraints = DictConstraints {
598 max_bytes: 1024 * 1024,
599 max_len: 100,
600 };
601
602 let arr = VarBinArray::from(vec!["hello", "world", "hello", "world"]).into_array();
604
605 let mut pointer = SequenceId::root();
607 let input_stream = SequentialStreamAdapter::new(
608 arr.dtype().clone(),
609 futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
610 )
611 .sendable();
612
613 let dict_stream = dict_encode_stream(input_stream, constraints);
615
616 let mut transformer = DictionaryTransformer::new(dict_stream);
618
619 let (codes_stream, _values_fut) = transformer
621 .next()
622 .await
623 .expect("expected at least one dictionary run");
624
625 assert_eq!(
627 codes_stream.dtype(),
628 &DType::Primitive(PType::U8, NonNullable),
629 "codes stream should use U8 dtype for small dictionaries, not U16"
630 );
631 }
632
633 #[tokio::test]
635 async fn test_dict_transformer_uses_u16_for_large_dictionaries() {
636 let constraints = DictConstraints {
638 max_bytes: 1024 * 1024,
639 max_len: 1000,
640 };
641
642 let values: Vec<String> = (0..300).map(|i| format!("value_{i}")).collect();
644 let arr =
645 VarBinArray::from(values.iter().map(|s| s.as_str()).collect::<Vec<_>>()).into_array();
646
647 let mut pointer = SequenceId::root();
649 let input_stream = SequentialStreamAdapter::new(
650 arr.dtype().clone(),
651 futures::stream::once(async move { Ok((pointer.advance(), arr)) }),
652 )
653 .sendable();
654
655 let dict_stream = dict_encode_stream(input_stream, constraints);
657
658 let mut transformer = DictionaryTransformer::new(dict_stream);
660
661 let (codes_stream, _values_fut) = transformer
663 .next()
664 .await
665 .expect("expected at least one dictionary run");
666
667 assert_eq!(
669 codes_stream.dtype(),
670 &DType::Primitive(PType::U16, NonNullable),
671 "codes stream should use U16 dtype for dictionaries with >255 entries"
672 );
673 }
674}