1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use async_stream::{stream, try_stream};
9use async_trait::async_trait;
10use futures::future::BoxFuture;
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, Stream, StreamExt, TryStreamExt, pin_mut, try_join};
13use vortex_array::arrays::DictEncoding;
14use vortex_array::builders::dict::{DictConstraints, DictEncoder, dict_encoder};
15use vortex_array::{Array, ArrayContext, ArrayRef};
16use vortex_btrblocks::BtrBlocksCompressor;
17use vortex_dtype::Nullability::NonNullable;
18use vortex_dtype::{DType, PType};
19use vortex_error::{VortexError, VortexResult, vortex_err};
20use vortex_io::kanal_ext::KanalExt;
21use vortex_io::runtime::Handle;
22
23use crate::layouts::chunked::ChunkedLayout;
24use crate::layouts::dict::DictLayout;
25use crate::segments::SegmentSinkRef;
26use crate::sequence::{
27 SendableSequentialStream, SequenceId, SequencePointer, SequentialStream,
28 SequentialStreamAdapter, SequentialStreamExt,
29};
30use crate::{IntoLayout, LayoutRef, LayoutStrategy, OwnedLayoutChildren};
31
32#[derive(Clone)]
33pub struct DictLayoutConstraints {
34 pub max_bytes: usize,
35 pub max_len: u16,
37}
38
39impl From<DictLayoutConstraints> for DictConstraints {
40 fn from(value: DictLayoutConstraints) -> Self {
41 DictConstraints {
42 max_bytes: value.max_bytes,
43 max_len: value.max_len as usize,
44 }
45 }
46}
47
48impl Default for DictLayoutConstraints {
49 fn default() -> Self {
50 Self {
51 max_bytes: 1024 * 1024,
52 max_len: u16::MAX,
53 }
54 }
55}
56
57#[derive(Clone, Default)]
58pub struct DictLayoutOptions {
59 pub constraints: DictLayoutConstraints,
60}
61
62#[derive(Clone)]
68pub struct DictStrategy {
69 codes: Arc<dyn LayoutStrategy>,
70 values: Arc<dyn LayoutStrategy>,
71 fallback: Arc<dyn LayoutStrategy>,
72 options: DictLayoutOptions,
73}
74
75impl DictStrategy {
76 pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
77 codes: Codes,
78 values: Values,
79 fallback: Fallback,
80 options: DictLayoutOptions,
81 ) -> Self {
82 Self {
83 codes: Arc::new(codes),
84 values: Arc::new(values),
85 fallback: Arc::new(fallback),
86 options,
87 }
88 }
89}
90
91#[async_trait]
92impl LayoutStrategy for DictStrategy {
93 async fn write_stream(
94 &self,
95 ctx: ArrayContext,
96 segment_sink: SegmentSinkRef,
97 stream: SendableSequentialStream,
98 mut eof: SequencePointer,
99 handle: Handle,
100 ) -> VortexResult<LayoutRef> {
101 if !dict_layout_supported(stream.dtype()) {
103 return self
104 .fallback
105 .write_stream(ctx, segment_sink, stream, eof, handle)
106 .await;
107 }
108
109 let options = self.options.clone();
110 let dtype = stream.dtype().clone();
111
112 let (stream, first_chunk) = peek_first_chunk(stream).await?;
114 let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
115
116 let should_fallback = match first_chunk {
117 None => true, Some(chunk) => {
119 let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
120 !compressed.is_encoding(DictEncoding.id())
121 }
122 };
123 if should_fallback {
124 return self
126 .fallback
127 .write_stream(ctx, segment_sink, stream, eof, handle)
128 .await;
129 }
130
131 let dict_stream = dict_encode_stream(stream, options.constraints.into());
135
136 let runs = DictionaryTransformer::new(dict_stream);
139
140 let dtype2 = dtype.clone();
141 let child_layouts = stream! {
142 pin_mut!(runs);
143
144 while let Some((codes_stream, values_fut)) = runs.next().await {
145 let codes = self.codes.clone();
146 let codes_eof = eof.split_off();
147 let ctx2 = ctx.clone();
148 let segment_sink2 = segment_sink.clone();
149 let codes_fut = handle.spawn_nested(move |h| async move {
150 codes.write_stream(
151 ctx2,
152 segment_sink2,
153 codes_stream.sendable(),
154 codes_eof,
155 h,
156 ).await
157 });
158
159 let values = self.values.clone();
160 let values_eof = eof.split_off();
161 let ctx2 = ctx.clone();
162 let segment_sink2 = segment_sink.clone();
163 let dtype2 = dtype2.clone();
164 let values_layout = handle.spawn_nested(move |h| async move {
165 values.write_stream(
166 ctx2,
167 segment_sink2,
168 SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
169 values_eof,
170 h,
171 ).await
172 });
173
174 yield async move {
175 try_join!(codes_fut, values_layout)
176 }.boxed();
177 }
178 };
179
180 let mut child_layouts = child_layouts
181 .buffered(usize::MAX)
182 .map(|result| {
183 let (codes_layout, values_layout) = result?;
184 Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
185 })
186 .try_collect::<Vec<_>>()
187 .await?;
188
189 if child_layouts.len() == 1 {
190 return Ok(child_layouts.remove(0));
191 }
192
193 let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
194 Ok(ChunkedLayout::new(
195 row_count,
196 dtype,
197 OwnedLayoutChildren::layout_children(child_layouts),
198 )
199 .into_layout())
200 }
201
202 fn buffered_bytes(&self) -> u64 {
203 self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
204 }
205}
206
207enum DictionaryChunk {
208 Codes((SequenceId, ArrayRef)),
209 Values((SequenceId, ArrayRef)),
210}
211
212type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
213
214fn dict_encode_stream(
215 input: SendableSequentialStream,
216 constraints: DictConstraints,
217) -> DictionaryStream {
218 Box::pin(try_stream! {
219 let mut state = DictStreamState {
220 encoder: None,
221 constraints,
222 };
223
224 let input = input.peekable();
225 pin_mut!(input);
226
227 while let Some(item) = input.next().await {
228 let (sequence_id, chunk) = item?;
229
230 match input.as_mut().peek().await {
234 Some(_) => {
235 let mut labeler = DictChunkLabeler::new(sequence_id);
236 let chunks = state.encode(&mut labeler, chunk);
237 drop(labeler);
238 for dict_chunk in chunks {
239 yield dict_chunk;
240 }
241 }
242 None => {
243 let mut labeler = DictChunkLabeler::new(sequence_id);
245 let encoded = state.encode(&mut labeler, chunk);
246 let drained = state.drain_values(&mut labeler);
247 drop(labeler);
248 for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
249 yield dict_chunk;
250 }
251 }
252 }
253 }
254 })
255}
256
257struct DictStreamState {
258 encoder: Option<Box<dyn DictEncoder>>,
259 constraints: DictConstraints,
260}
261
262impl DictStreamState {
263 fn encode(&mut self, labeler: &mut DictChunkLabeler, chunk: ArrayRef) -> Vec<DictionaryChunk> {
264 let mut res = Vec::new();
265 let mut to_be_encoded = Some(chunk);
266 while let Some(remaining) = to_be_encoded.take() {
267 match self.encoder.take() {
268 None => match start_encoding(&self.constraints, &remaining) {
269 EncodingState::Continue((encoder, encoded)) => {
270 res.push(labeler.codes(encoded));
271 self.encoder = Some(encoder);
272 }
273 EncodingState::Done((values, encoded, unencoded)) => {
274 res.push(labeler.codes(encoded));
275 res.push(labeler.values(values));
276 to_be_encoded = Some(unencoded);
277 }
278 },
279 Some(encoder) => match encode_chunk(encoder, &remaining) {
280 EncodingState::Continue((encoder, encoded)) => {
281 res.push(labeler.codes(encoded));
282 self.encoder = Some(encoder);
283 }
284 EncodingState::Done((values, encoded, unencoded)) => {
285 res.push(labeler.codes(encoded));
286 res.push(labeler.values(values));
287 to_be_encoded = Some(unencoded);
288 }
289 },
290 }
291 }
292 res
293 }
294
295 fn drain_values(&mut self, labeler: &mut DictChunkLabeler) -> Vec<DictionaryChunk> {
296 match self.encoder.as_mut() {
297 None => Vec::new(),
298 Some(encoder) => vec![labeler.values(encoder.reset())],
299 }
300 }
301}
302
303struct DictChunkLabeler {
304 sequence_pointer: SequencePointer,
305}
306
307impl DictChunkLabeler {
308 fn new(starting_id: SequenceId) -> Self {
309 let sequence_pointer = starting_id.descend();
310 Self { sequence_pointer }
311 }
312
313 fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
314 DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
315 }
316
317 fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
318 DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
319 }
320}
321
322type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
323
324struct DictionaryTransformer {
325 input: DictionaryStream,
326 active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
327 active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
328 pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
329}
330
331impl DictionaryTransformer {
332 fn new(input: DictionaryStream) -> Self {
333 Self {
334 input,
335 active_codes_tx: None,
336 active_values_tx: None,
337 pending_send: None,
338 }
339 }
340}
341
342impl Stream for DictionaryTransformer {
343 type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
344
345 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
346 loop {
347 if let Some(mut send_fut) = self.pending_send.take() {
349 match send_fut.poll_unpin(cx) {
350 Poll::Ready(Ok(())) => {
351 }
353 Poll::Ready(Err(_)) => {
354 self.active_codes_tx = None;
356 if let Some(values_tx) = self.active_values_tx.take() {
357 drop(values_tx.send(Err(vortex_err!("values receiver dropped"))));
358 }
359 }
360 Poll::Pending => {
361 self.pending_send = Some(send_fut);
363 return Poll::Pending;
364 }
365 }
366 }
367
368 match self.input.poll_next_unpin(cx) {
369 Poll::Ready(Some(Ok(DictionaryChunk::Codes(codes)))) => {
370 if self.active_codes_tx.is_none() {
371 let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
373 let (values_tx, values_rx) = oneshot::channel();
374
375 self.active_codes_tx = Some(codes_tx.clone());
376 self.active_values_tx = Some(values_tx);
377
378 self.pending_send =
380 Some(Box::pin(async move { codes_tx.send(Ok(codes)).await }));
381
382 let codes_stream = SequentialStreamAdapter::new(
384 DType::Primitive(PType::U16, NonNullable),
385 codes_rx.into_stream().boxed(),
386 )
387 .sendable();
388
389 let values_future = async move {
390 values_rx
391 .await
392 .map_err(|e| vortex_err!("values sender dropped: {}", e))
393 .flatten()
394 }
395 .boxed();
396
397 return Poll::Ready(Some((codes_stream, values_future)));
398 } else {
399 if let Some(tx) = &self.active_codes_tx {
401 let tx = tx.clone();
402 self.pending_send =
403 Some(Box::pin(async move { tx.send(Ok(codes)).await }));
404 }
405 }
406 }
407 Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
408 if let Some(values_tx) = self.active_values_tx.take() {
410 drop(values_tx.send(Ok(values)));
411 }
412 self.active_codes_tx = None; }
414 Poll::Ready(Some(Err(e))) => {
415 if let Some(values_tx) = self.active_values_tx.take() {
417 drop(values_tx.send(Err(e)));
418 }
419 self.active_codes_tx = None;
420 return Poll::Ready(None);
422 }
423 Poll::Ready(None) => {
424 if let Some(values_tx) = self.active_values_tx.take() {
426 drop(values_tx.send(Err(vortex_err!("Incomplete dictionary group"))));
427 }
428 self.active_codes_tx = None;
429 return Poll::Ready(None);
430 }
431 Poll::Pending => return Poll::Pending,
432 }
433 }
434 }
435}
436
437async fn peek_first_chunk(
438 mut stream: BoxStream<'static, SequencedChunk>,
439) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
440 match stream.next().await {
441 None => Ok((stream.boxed(), None)),
442 Some(Err(e)) => Err(e),
443 Some(Ok((sequence_id, chunk))) => {
444 let chunk_clone = chunk.clone();
445 let reconstructed_stream =
446 once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
447 Ok((reconstructed_stream.boxed(), Some(chunk)))
448 }
449 }
450}
451
452pub fn dict_layout_supported(dtype: &DType) -> bool {
453 matches!(
454 dtype,
455 DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
456 )
457}
458
459#[derive(prost::Message)]
460pub struct DictLayoutMetadata {
461 #[prost(enumeration = "PType", tag = "1")]
462 codes_ptype: i32,
464}
465
466impl DictLayoutMetadata {
467 pub fn new(codes_ptype: PType) -> Self {
468 let mut metadata = Self::default();
469 metadata.set_codes_ptype(codes_ptype);
470 metadata
471 }
472}
473
474enum EncodingState {
475 Continue((Box<dyn DictEncoder>, ArrayRef)),
476 Done((ArrayRef, ArrayRef, ArrayRef)),
478}
479
480fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> EncodingState {
481 let encoder = dict_encoder(chunk, constraints);
482 encode_chunk(encoder, chunk)
483}
484
485fn encode_chunk(mut encoder: Box<dyn DictEncoder>, chunk: &dyn Array) -> EncodingState {
486 let encoded = encoder.encode(chunk);
487 match remainder(chunk, encoded.len()) {
488 None => EncodingState::Continue((encoder, encoded)),
489 Some(unencoded) => EncodingState::Done((encoder.reset(), encoded, unencoded)),
490 }
491}
492
493fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
494 (encoded_len < array.len()).then(|| array.slice(encoded_len..array.len()))
495}