1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use async_stream::{stream, try_stream};
9use async_trait::async_trait;
10use futures::future::BoxFuture;
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, Stream, StreamExt, TryStreamExt, pin_mut, try_join};
13use vortex_array::{Array, ArrayContext, ArrayRef};
14use vortex_btrblocks::BtrBlocksCompressor;
15use vortex_dict::DictEncoding;
16use vortex_dict::builders::{DictConstraints, DictEncoder, dict_encoder};
17use vortex_dtype::Nullability::NonNullable;
18use vortex_dtype::{DType, PType};
19use vortex_error::{VortexError, VortexResult, vortex_err};
20use vortex_io::kanal_ext::KanalExt;
21use vortex_io::runtime::Handle;
22
23use crate::layouts::chunked::ChunkedLayout;
24use crate::layouts::dict::DictLayout;
25use crate::segments::SegmentSinkRef;
26use crate::sequence::{
27 SendableSequentialStream, SequenceId, SequencePointer, SequentialStream,
28 SequentialStreamAdapter, SequentialStreamExt,
29};
30use crate::{IntoLayout, LayoutRef, LayoutStrategy, OwnedLayoutChildren};
31
32#[derive(Clone)]
33pub struct DictLayoutConstraints {
34 pub max_bytes: usize,
35 pub max_len: u16,
37}
38
39impl From<DictLayoutConstraints> for DictConstraints {
40 fn from(value: DictLayoutConstraints) -> Self {
41 DictConstraints {
42 max_bytes: value.max_bytes,
43 max_len: value.max_len as usize,
44 }
45 }
46}
47
48impl Default for DictLayoutConstraints {
49 fn default() -> Self {
50 Self {
51 max_bytes: 1024 * 1024,
52 max_len: u16::MAX,
53 }
54 }
55}
56
57#[derive(Clone, Default)]
58pub struct DictLayoutOptions {
59 pub constraints: DictLayoutConstraints,
60}
61
62#[derive(Clone)]
68pub struct DictStrategy {
69 codes: Arc<dyn LayoutStrategy>,
70 values: Arc<dyn LayoutStrategy>,
71 fallback: Arc<dyn LayoutStrategy>,
72 options: DictLayoutOptions,
73}
74
75impl DictStrategy {
76 pub fn new<Codes: LayoutStrategy, Values: LayoutStrategy, Fallback: LayoutStrategy>(
77 codes: Codes,
78 values: Values,
79 fallback: Fallback,
80 options: DictLayoutOptions,
81 ) -> Self {
82 Self {
83 codes: Arc::new(codes),
84 values: Arc::new(values),
85 fallback: Arc::new(fallback),
86 options,
87 }
88 }
89}
90
91#[async_trait]
92impl LayoutStrategy for DictStrategy {
93 async fn write_stream(
94 &self,
95 ctx: ArrayContext,
96 segment_sink: SegmentSinkRef,
97 stream: SendableSequentialStream,
98 mut eof: SequencePointer,
99 handle: Handle,
100 ) -> VortexResult<LayoutRef> {
101 if !dict_layout_supported(stream.dtype()) {
103 return self
104 .fallback
105 .write_stream(ctx, segment_sink, stream, eof, handle)
106 .await;
107 }
108
109 let options = self.options.clone();
110 let dtype = stream.dtype().clone();
111
112 let (stream, first_chunk) = peek_first_chunk(stream).await?;
114 let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
115
116 let should_fallback = match first_chunk {
117 None => true, Some(chunk) => {
119 let compressed = BtrBlocksCompressor::default().compress(&chunk)?;
120 !compressed.is_encoding(DictEncoding.id())
121 }
122 };
123 if should_fallback {
124 return self
126 .fallback
127 .write_stream(ctx, segment_sink, stream, eof, handle)
128 .await;
129 }
130
131 let dict_stream = dict_encode_stream(stream, options.constraints.into());
135
136 let runs = DictionaryTransformer::new(dict_stream);
139
140 let dtype2 = dtype.clone();
141 let child_layouts = stream! {
142 pin_mut!(runs);
143
144 while let Some((codes_stream, values_fut)) = runs.next().await {
145 let codes = self.codes.clone();
146 let codes_eof = eof.split_off();
147 let ctx2 = ctx.clone();
148 let segment_sink2 = segment_sink.clone();
149 let codes_fut = handle.spawn_nested(move |h| async move {
150 codes.write_stream(
151 ctx2,
152 segment_sink2,
153 codes_stream.sendable(),
154 codes_eof,
155 h,
156 ).await
157 });
158
159 let values = self.values.clone();
160 let values_eof = eof.split_off();
161 let ctx2 = ctx.clone();
162 let segment_sink2 = segment_sink.clone();
163 let dtype2 = dtype2.clone();
164 let values_layout = handle.spawn_nested(move |h| async move {
165 values.write_stream(
166 ctx2,
167 segment_sink2,
168 SequentialStreamAdapter::new(dtype2, once(values_fut)).sendable(),
169 values_eof,
170 h,
171 ).await
172 });
173
174 yield async move {
175 try_join!(codes_fut, values_layout)
176 }.boxed();
177 }
178 };
179
180 let mut child_layouts = child_layouts
181 .buffered(usize::MAX)
182 .map(|result| {
183 let (codes_layout, values_layout) = result?;
184 Ok::<_, VortexError>(DictLayout::new(values_layout, codes_layout).into_layout())
185 })
186 .try_collect::<Vec<_>>()
187 .await?;
188
189 if child_layouts.len() == 1 {
190 return Ok(child_layouts.remove(0));
191 }
192
193 let row_count = child_layouts.iter().map(|child| child.row_count()).sum();
194 Ok(ChunkedLayout::new(
195 row_count,
196 dtype,
197 OwnedLayoutChildren::layout_children(child_layouts),
198 )
199 .into_layout())
200 }
201
202 fn buffered_bytes(&self) -> u64 {
203 self.codes.buffered_bytes() + self.values.buffered_bytes() + self.fallback.buffered_bytes()
204 }
205}
206
207enum DictionaryChunk {
208 Codes((SequenceId, ArrayRef)),
209 Values((SequenceId, ArrayRef)),
210}
211
212type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
213
214fn dict_encode_stream(
215 input: SendableSequentialStream,
216 constraints: DictConstraints,
217) -> DictionaryStream {
218 Box::pin(try_stream! {
219 let mut state = DictStreamState {
220 encoder: None,
221 constraints,
222 };
223
224 let input = input.peekable();
225 pin_mut!(input);
226
227 while let Some(item) = input.next().await {
228 let (sequence_id, chunk) = item?;
229
230 match input.as_mut().peek().await {
234 Some(_) => {
235 let mut labeler = DictChunkLabeler::new(sequence_id);
236 let chunks = state.encode(&mut labeler, chunk);
237 drop(labeler);
238 for dict_chunk in chunks {
239 yield dict_chunk?;
240 }
241 }
242 None => {
243 let mut labeler = DictChunkLabeler::new(sequence_id);
245 let encoded = state.encode(&mut labeler, chunk);
246 let drained = state.drain_values(&mut labeler);
247 drop(labeler);
248 for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
249 yield dict_chunk?;
250 }
251 }
252 }
253 }
254 })
255}
256
257struct DictStreamState {
258 encoder: Option<Box<dyn DictEncoder>>,
259 constraints: DictConstraints,
260}
261
262impl DictStreamState {
263 fn encode(
264 &mut self,
265 labeler: &mut DictChunkLabeler,
266 chunk: ArrayRef,
267 ) -> Vec<VortexResult<DictionaryChunk>> {
268 self.try_encode(labeler, chunk)
269 .unwrap_or_else(|e| vec![Err(e)])
270 }
271
272 fn try_encode(
273 &mut self,
274 labeler: &mut DictChunkLabeler,
275 chunk: ArrayRef,
276 ) -> VortexResult<Vec<VortexResult<DictionaryChunk>>> {
277 let mut res = Vec::new();
278 let mut to_be_encoded = Some(chunk);
279 while let Some(remaining) = to_be_encoded.take() {
280 match self.encoder.take() {
281 None => match start_encoding(&self.constraints, &remaining)? {
282 EncodingState::Continue((encoder, encoded)) => {
283 res.push(Ok(labeler.codes(encoded)));
284 self.encoder = Some(encoder);
285 }
286 EncodingState::Done((values, encoded, unencoded)) => {
287 res.push(Ok(labeler.codes(encoded)));
288 res.push(Ok(labeler.values(values)));
289 to_be_encoded = Some(unencoded);
290 }
291 },
292 Some(encoder) => match encode_chunk(encoder, &remaining)? {
293 EncodingState::Continue((encoder, encoded)) => {
294 res.push(Ok(labeler.codes(encoded)));
295 self.encoder = Some(encoder);
296 }
297 EncodingState::Done((values, encoded, unencoded)) => {
298 res.push(Ok(labeler.codes(encoded)));
299 res.push(Ok(labeler.values(values)));
300 to_be_encoded = Some(unencoded);
301 }
302 },
303 }
304 }
305 Ok(res)
306 }
307
308 fn drain_values(
309 &mut self,
310 labeler: &mut DictChunkLabeler,
311 ) -> Vec<VortexResult<DictionaryChunk>> {
312 match self.encoder.as_mut() {
313 None => Vec::new(),
314 Some(encoder) => vec![encoder.values().map(|val| labeler.values(val))],
315 }
316 }
317}
318
319struct DictChunkLabeler {
320 sequence_pointer: SequencePointer,
321}
322
323impl DictChunkLabeler {
324 fn new(starting_id: SequenceId) -> Self {
325 let sequence_pointer = starting_id.descend();
326 Self { sequence_pointer }
327 }
328
329 fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
330 DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
331 }
332
333 fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
334 DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
335 }
336}
337
338type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
339
340struct DictionaryTransformer {
341 input: DictionaryStream,
342 active_codes_tx: Option<kanal::AsyncSender<SequencedChunk>>,
343 active_values_tx: Option<oneshot::Sender<SequencedChunk>>,
344 pending_send: Option<BoxFuture<'static, Result<(), kanal::SendError>>>,
345}
346
347impl DictionaryTransformer {
348 fn new(input: DictionaryStream) -> Self {
349 Self {
350 input,
351 active_codes_tx: None,
352 active_values_tx: None,
353 pending_send: None,
354 }
355 }
356}
357
358impl Stream for DictionaryTransformer {
359 type Item = (SendableSequentialStream, BoxFuture<'static, SequencedChunk>);
360
361 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362 loop {
363 if let Some(mut send_fut) = self.pending_send.take() {
365 match send_fut.poll_unpin(cx) {
366 Poll::Ready(Ok(())) => {
367 }
369 Poll::Ready(Err(_)) => {
370 self.active_codes_tx = None;
372 if let Some(values_tx) = self.active_values_tx.take() {
373 let _ = values_tx.send(Err(vortex_err!("values receiver dropped")));
374 }
375 }
376 Poll::Pending => {
377 self.pending_send = Some(send_fut);
379 return Poll::Pending;
380 }
381 }
382 }
383
384 match self.input.poll_next_unpin(cx) {
385 Poll::Ready(Some(Ok(DictionaryChunk::Codes(codes)))) => {
386 if self.active_codes_tx.is_none() {
387 let (codes_tx, codes_rx) = kanal::bounded_async::<SequencedChunk>(1);
389 let (values_tx, values_rx) = oneshot::channel();
390
391 self.active_codes_tx = Some(codes_tx.clone());
392 self.active_values_tx = Some(values_tx);
393
394 self.pending_send =
396 Some(Box::pin(async move { codes_tx.send(Ok(codes)).await }));
397
398 let codes_stream = SequentialStreamAdapter::new(
400 DType::Primitive(PType::U16, NonNullable),
401 codes_rx.into_stream().boxed(),
402 )
403 .sendable();
404
405 let values_future = async move {
406 values_rx
407 .await
408 .map_err(|e| vortex_err!("values sender dropped: {}", e))
409 .flatten()
410 }
411 .boxed();
412
413 return Poll::Ready(Some((codes_stream, values_future)));
414 } else {
415 if let Some(tx) = &self.active_codes_tx {
417 let tx = tx.clone();
418 self.pending_send =
419 Some(Box::pin(async move { tx.send(Ok(codes)).await }));
420 }
421 }
422 }
423 Poll::Ready(Some(Ok(DictionaryChunk::Values(values)))) => {
424 if let Some(values_tx) = self.active_values_tx.take() {
426 let _ = values_tx.send(Ok(values));
427 }
428 self.active_codes_tx = None; }
430 Poll::Ready(Some(Err(e))) => {
431 if let Some(values_tx) = self.active_values_tx.take() {
433 let _ = values_tx.send(Err(e));
434 }
435 self.active_codes_tx = None;
436 return Poll::Ready(None);
438 }
439 Poll::Ready(None) => {
440 if let Some(values_tx) = self.active_values_tx.take() {
442 let _ = values_tx.send(Err(vortex_err!("Incomplete dictionary group")));
443 }
444 self.active_codes_tx = None;
445 return Poll::Ready(None);
446 }
447 Poll::Pending => return Poll::Pending,
448 }
449 }
450 }
451}
452
453async fn peek_first_chunk(
454 mut stream: BoxStream<'static, SequencedChunk>,
455) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
456 match stream.next().await {
457 None => Ok((stream.boxed(), None)),
458 Some(Err(e)) => Err(e),
459 Some(Ok((sequence_id, chunk))) => {
460 let chunk_clone = chunk.clone();
461 let reconstructed_stream =
462 once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
463 Ok((reconstructed_stream.boxed(), Some(chunk)))
464 }
465 }
466}
467
468pub fn dict_layout_supported(dtype: &DType) -> bool {
469 matches!(
470 dtype,
471 DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
472 )
473}
474
475#[derive(prost::Message)]
476pub struct DictLayoutMetadata {
477 #[prost(enumeration = "PType", tag = "1")]
478 codes_ptype: i32,
480}
481
482impl DictLayoutMetadata {
483 pub fn new(codes_ptype: PType) -> Self {
484 let mut metadata = Self::default();
485 metadata.set_codes_ptype(codes_ptype);
486 metadata
487 }
488}
489
490enum EncodingState {
491 Continue((Box<dyn DictEncoder>, ArrayRef)),
492 Done((ArrayRef, ArrayRef, ArrayRef)),
494}
495
496fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
497 let encoder = dict_encoder(chunk, constraints)?;
498 encode_chunk(encoder, chunk)
499}
500
501fn encode_chunk(
502 mut encoder: Box<dyn DictEncoder>,
503 chunk: &dyn Array,
504) -> VortexResult<EncodingState> {
505 let encoded = encoder.encode(chunk)?;
506 Ok(match remainder(chunk, encoded.len()) {
507 None => EncodingState::Continue((encoder, encoded)),
508 Some(unencoded) => EncodingState::Done((encoder.values()?, encoded, unencoded)),
509 })
510}
511
512fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
513 (encoded_len < array.len()).then(|| array.slice(encoded_len..array.len()))
514}