1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll, ready};
7
8use arcref::ArcRef;
9use async_stream::try_stream;
10use futures::channel::{mpsc, oneshot};
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, SinkExt, Stream, StreamExt, pin_mut, try_join};
13use vortex_array::{Array, ArrayContext, ArrayRef};
14use vortex_btrblocks::BtrBlocksCompressor;
15use vortex_dict::DictEncoding;
16use vortex_dict::builders::{DictConstraints, DictEncoder, dict_encoder};
17use vortex_dtype::{DType, PType};
18use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err};
19
20use super::DictLayout;
21use crate::layouts::chunked::ChunkedLayout;
22use crate::segments::SequenceWriter;
23use crate::sequence::{SequenceId, SequencePointer};
24use crate::{
25 IntoLayout, LayoutStrategy, OwnedLayoutChildren, SendableLayoutFuture,
26 SendableSequentialStream, SequentialStreamAdapter, SequentialStreamExt, TaskExecutor,
27 TaskExecutorExt as _,
28};
29
30#[derive(Clone)]
31pub struct DictLayoutOptions {
32 pub constraints: DictConstraints,
33 pub encoded_buffer_size: usize,
35}
36
37impl Default for DictLayoutOptions {
38 fn default() -> Self {
39 Self {
40 constraints: DictConstraints {
41 max_bytes: 1024 * 1024,
42 max_len: u16::MAX as usize,
43 },
44 encoded_buffer_size: 8,
45 }
46 }
47}
48
49pub struct DictStrategy {
55 codes: ArcRef<dyn LayoutStrategy>,
56 values: ArcRef<dyn LayoutStrategy>,
57 fallback: ArcRef<dyn LayoutStrategy>,
58 options: DictLayoutOptions,
59 executor: Arc<dyn TaskExecutor>,
60}
61
62impl DictStrategy {
63 pub fn new(
64 codes: ArcRef<dyn LayoutStrategy>,
65 values: ArcRef<dyn LayoutStrategy>,
66 fallback: ArcRef<dyn LayoutStrategy>,
67 options: DictLayoutOptions,
68 executor: Arc<dyn TaskExecutor>,
69 ) -> Self {
70 Self {
71 codes,
72 values,
73 fallback,
74 options,
75 executor,
76 }
77 }
78}
79
80impl LayoutStrategy for DictStrategy {
81 fn write_stream(
82 &self,
83 ctx: &ArrayContext,
84 sequence_writer: SequenceWriter,
85 stream: SendableSequentialStream,
86 ) -> SendableLayoutFuture {
87 if !dict_layout_supported(stream.dtype()) {
88 return self.fallback.write_stream(ctx, sequence_writer, stream);
89 }
90 let codes = self.codes.clone();
91 let values = self.values.clone();
92 let fallback = self.fallback.clone();
93 let ctx = ctx.clone();
94 let options = self.options.clone();
95 let dtype = stream.dtype().clone();
96 let executor = self.executor.clone();
97 Box::pin(async move {
98 let (stream, first_chunk) = peek_first_chunk(stream).await?;
100 let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
101
102 let should_fallback = match first_chunk {
103 None => true, Some(chunk) => {
105 let compressed = BtrBlocksCompressor.compress(&chunk)?;
106 !compressed.is_encoding(DictEncoding.id())
107 }
108 };
109 if should_fallback {
110 return fallback
112 .write_stream(&ctx, sequence_writer.clone(), stream)
113 .await;
114 }
115
116 let mut dict_stream = dict_encode_stream(stream, options.constraints);
120
121 let (mut encoded_tx, encoded_rx) = mpsc::channel(options.encoded_buffer_size);
123 let encode_handle = executor.spawn({
124 async move {
125 while let Some(item) = dict_stream.next().await {
126 encoded_tx
127 .send(item)
128 .await
129 .map_err(|e| vortex_err!("rx dropped: {}", e))?;
130 }
131 Ok(())
132 }
133 .boxed()
134 });
135
136 let dtype_clone = dtype.clone();
139 let child_layouts_fut = async move {
140 let mut children = Vec::new();
141 let mut runs = DictEncodedRuns::new(Box::pin(encoded_rx));
142 while let Some((codes_stream, values_future)) = runs.next_run().await {
143 let (codes_stream, first_chunk) =
144 peek_first_chunk(codes_stream.boxed()).await?;
145 let codes_dtype = match first_chunk {
146 None => break,
148 Some(chunk) => chunk.dtype().clone(),
149 };
150 let codes_layout = codes
151 .write_stream(
152 &ctx,
153 sequence_writer.clone(),
154 SequentialStreamAdapter::new(codes_dtype, codes_stream).sendable(),
155 )
156 .await?;
157 let values_layout = values
158 .write_stream(
159 &ctx,
160 sequence_writer.clone(),
161 SequentialStreamAdapter::new(dtype_clone.clone(), once(values_future))
162 .sendable(),
163 )
164 .await?;
165 children.push(DictLayout::new(values_layout, codes_layout).into_layout());
166 }
167 Ok(children)
168 };
169
170 let (mut children, _) = try_join!(child_layouts_fut, encode_handle)?;
172
173 if children.len() == 1 {
174 return Ok(children.remove(0));
175 }
176
177 let row_count = children.iter().map(|child| child.row_count()).sum();
178 Ok(ChunkedLayout::new(
179 row_count,
180 dtype,
181 OwnedLayoutChildren::layout_children(children),
182 )
183 .into_layout())
184 })
185 }
186}
187
188enum DictionaryChunk {
189 Codes((SequenceId, ArrayRef)),
190 Values((SequenceId, ArrayRef)),
191}
192
193type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
194
195fn dict_encode_stream(
196 input: SendableSequentialStream,
197 constraints: DictConstraints,
198) -> DictionaryStream {
199 Box::pin(try_stream! {
200 let mut state = DictStreamState {
201 encoder: None,
202 constraints,
203 };
204 let input = input.peekable();
205 pin_mut!(input);
206 while let Some(item) = input.as_mut().next().await {
207 let (sequence_id, chunk) = item?;
208 match input.as_mut().peek().await {
212 Some(_) => {
213 let mut labeler = DictChunkLabeler::new(sequence_id);
214 let chunks = state.encode(&mut labeler, chunk);
215 drop(labeler);
216 for dict_chunk in chunks {
217 yield dict_chunk?;
218 }
219 }
220 None => {
221 let mut labeler = DictChunkLabeler::new(sequence_id);
223 let encoded = state.encode(&mut labeler, chunk);
224 let drained = state.drain_values(&mut labeler);
225 drop(labeler);
226 for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
227 yield dict_chunk?;
228 }
229 }
230 }
231 }
232 })
233}
234
235struct DictStreamState {
236 encoder: Option<Box<dyn DictEncoder>>,
237 constraints: DictConstraints,
238}
239
240impl DictStreamState {
241 fn encode(
242 &mut self,
243 labeler: &mut DictChunkLabeler,
244 chunk: ArrayRef,
245 ) -> Vec<VortexResult<DictionaryChunk>> {
246 self.try_encode(labeler, chunk)
247 .unwrap_or_else(|e| vec![Err(e)])
248 }
249
250 fn try_encode(
251 &mut self,
252 labeler: &mut DictChunkLabeler,
253 chunk: ArrayRef,
254 ) -> VortexResult<Vec<VortexResult<DictionaryChunk>>> {
255 let mut res = Vec::new();
256 let mut to_be_encoded = Some(chunk);
257 while let Some(remaining) = to_be_encoded.take() {
258 match self.encoder.take() {
259 None => match start_encoding(&self.constraints, &remaining)? {
260 EncodingState::Continue((encoder, encoded)) => {
261 res.push(Ok(labeler.codes(encoded)));
262 self.encoder = Some(encoder);
263 }
264 EncodingState::Done((values, encoded, unencoded)) => {
265 res.push(Ok(labeler.codes(encoded)));
266 res.push(Ok(labeler.values(values)));
267 to_be_encoded = Some(unencoded);
268 }
269 },
270 Some(encoder) => match encode_chunk(encoder, &remaining)? {
271 EncodingState::Continue((encoder, encoded)) => {
272 res.push(Ok(labeler.codes(encoded)));
273 self.encoder = Some(encoder);
274 }
275 EncodingState::Done((values, encoded, unencoded)) => {
276 res.push(Ok(labeler.codes(encoded)));
277 res.push(Ok(labeler.values(values)));
278 to_be_encoded = Some(unencoded);
279 }
280 },
281 }
282 }
283 Ok(res)
284 }
285
286 fn drain_values(
287 &mut self,
288 labeler: &mut DictChunkLabeler,
289 ) -> Vec<VortexResult<DictionaryChunk>> {
290 match self.encoder.as_mut() {
291 None => Vec::new(),
292 Some(encoder) => vec![encoder.values().map(|val| labeler.values(val))],
293 }
294 }
295}
296
297struct DictChunkLabeler {
298 sequence_pointer: SequencePointer,
299}
300
301impl DictChunkLabeler {
302 fn new(starting_id: SequenceId) -> Self {
303 let sequence_pointer = starting_id.descend();
304 Self { sequence_pointer }
305 }
306
307 fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
308 DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
309 }
310
311 fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
312 DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
313 }
314}
315
316type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
317
318struct DictEncodedRuns {
319 input: Option<oneshot::Receiver<Option<DictionaryStream>>>,
320}
321
322impl DictEncodedRuns {
323 fn new(input: DictionaryStream) -> Self {
324 let (tx, rx) = oneshot::channel();
325 tx.send(Some(input))
326 .map_err(|_input| vortex_err!("just created rx"))
327 .vortex_unwrap();
328 Self { input: Some(rx) }
329 }
330
331 async fn next_run(
332 &mut self,
333 ) -> Option<(
334 DictEncodedRunStream,
335 impl Future<Output = SequencedChunk> + use<>,
336 )> {
337 let Ok(Some(input)) = self.input.take()?.await else {
339 return None;
341 };
342 let (input_tx, input_rx) = oneshot::channel();
343 self.input = Some(input_rx);
344
345 let (values_tx, values_rx) = oneshot::channel();
346 let values_future = async {
347 values_rx
348 .await
349 .unwrap_or_else(|_| vortex_bail!("sender dropped"))
350 };
351
352 let codes_stream = DictEncodedRunStream {
353 input: Some(input),
354 input_tx: Some(input_tx),
355 values_tx: Some(values_tx),
356 };
357
358 Some((codes_stream, values_future))
359 }
360}
361
362struct DictEncodedRunStream {
363 input: Option<DictionaryStream>,
364 input_tx: Option<oneshot::Sender<Option<DictionaryStream>>>,
365 values_tx: Option<oneshot::Sender<SequencedChunk>>,
366}
367
368impl Stream for DictEncodedRunStream {
369 type Item = SequencedChunk;
370
371 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372 let poll_result = {
373 let Some(stream) = self.input.as_mut() else {
374 return Poll::Ready(None);
375 };
376 ready!(stream.poll_next_unpin(cx))
377 };
378
379 match poll_result {
380 Some(Ok(DictionaryChunk::Codes(item))) => Poll::Ready(Some(Ok(item))),
381 Some(Ok(DictionaryChunk::Values(item))) => {
382 self.send_values(item);
383 self.send_back_input_stream();
384 Poll::Ready(None)
385 }
386 Some(Err(e)) => Poll::Ready(Some(Err(e))),
387 None => {
388 self.send_back_input_stream();
389 Poll::Ready(None)
390 }
391 }
392 }
393}
394
395impl DictEncodedRunStream {
396 fn send_values(&mut self, item: (SequenceId, ArrayRef)) {
397 let _ = self
399 .values_tx
400 .take()
401 .vortex_expect("must not be polled after returning None")
402 .send(Ok(item));
403 }
404
405 fn send_back_input_stream(&mut self) {
406 let _ = self
408 .input_tx
409 .take()
410 .vortex_expect("input already sent")
411 .send(self.input.take());
412 }
413}
414
415impl Drop for DictEncodedRunStream {
416 fn drop(&mut self) {
417 if let Some(tx) = self.input_tx.take() {
418 let _ = tx.send(self.input.take());
419 }
420 }
421}
422
423async fn peek_first_chunk(
424 mut stream: BoxStream<'static, SequencedChunk>,
425) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
426 match stream.next().await {
427 None => Ok((stream.boxed(), None)),
428 Some(Err(e)) => Err(e),
429 Some(Ok((sequence_id, chunk))) => {
430 let chunk_clone = chunk.clone();
431 let reconstructed_stream =
432 once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
433 Ok((reconstructed_stream.boxed(), Some(chunk)))
434 }
435 }
436}
437
438pub fn dict_layout_supported(dtype: &DType) -> bool {
439 matches!(
440 dtype,
441 DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
442 )
443}
444
445#[derive(prost::Message)]
446pub struct DictLayoutMetadata {
447 #[prost(enumeration = "PType", tag = "1")]
448 codes_ptype: i32,
450}
451
452impl DictLayoutMetadata {
453 pub fn new(codes_ptype: PType) -> Self {
454 let mut metadata = Self::default();
455 metadata.set_codes_ptype(codes_ptype);
456 metadata
457 }
458}
459
460enum EncodingState {
461 Continue((Box<dyn DictEncoder>, ArrayRef)),
462 Done((ArrayRef, ArrayRef, ArrayRef)),
464}
465
466fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
467 let encoder = dict_encoder(chunk, constraints)?;
468 encode_chunk(encoder, chunk)
469}
470
471fn encode_chunk(
472 mut encoder: Box<dyn DictEncoder>,
473 chunk: &dyn Array,
474) -> VortexResult<EncodingState> {
475 let encoded = encoder.encode(chunk)?;
476 Ok(match remainder(chunk, encoded.len())? {
477 None => EncodingState::Continue((encoder, encoded)),
478 Some(unencoded) => EncodingState::Done((encoder.values()?, encoded, unencoded)),
479 })
480}
481
482fn remainder(array: &dyn Array, encoded_len: usize) -> VortexResult<Option<ArrayRef>> {
483 (encoded_len < array.len())
484 .then(|| array.slice(encoded_len, array.len()))
485 .transpose()
486}