1use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll, ready};
7
8use async_stream::try_stream;
9use async_trait::async_trait;
10use futures::channel::{mpsc, oneshot};
11use futures::stream::{BoxStream, once};
12use futures::{FutureExt, SinkExt, Stream, StreamExt, pin_mut, try_join};
13use vortex_array::{Array, ArrayContext, ArrayRef};
14use vortex_btrblocks::BtrBlocksCompressor;
15use vortex_dict::DictEncoding;
16use vortex_dict::builders::{DictConstraints, DictEncoder, dict_encoder};
17use vortex_dtype::{DType, PType};
18use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err};
19
20use super::DictLayout;
21use crate::layouts::chunked::ChunkedLayout;
22use crate::segments::SequenceWriter;
23use crate::sequence::{SequenceId, SequencePointer};
24use crate::{
25 IntoLayout, LayoutRef, LayoutStrategy, OwnedLayoutChildren, SendableSequentialStream,
26 SequentialStreamAdapter, SequentialStreamExt, TaskExecutor, TaskExecutorExt as _,
27};
28
29#[derive(Clone)]
30pub struct DictLayoutOptions {
31 pub constraints: DictConstraints,
32 pub encoded_buffer_size: usize,
34}
35
36impl Default for DictLayoutOptions {
37 fn default() -> Self {
38 Self {
39 constraints: DictConstraints {
40 max_bytes: 1024 * 1024,
41 max_len: u16::MAX as usize,
42 },
43 encoded_buffer_size: 8,
44 }
45 }
46}
47
48#[derive(Clone)]
54pub struct DictStrategy<Codes, Values, Fallback> {
55 codes: Codes,
56 values: Values,
57 fallback: Fallback,
58 options: DictLayoutOptions,
59 executor: Arc<dyn TaskExecutor>,
60}
61
62impl<Codes, Values, Fallback> DictStrategy<Codes, Values, Fallback>
63where
64 Codes: LayoutStrategy,
65 Values: LayoutStrategy,
66 Fallback: LayoutStrategy,
67{
68 pub fn new(
69 codes: Codes,
70 values: Values,
71 fallback: Fallback,
72 options: DictLayoutOptions,
73 executor: Arc<dyn TaskExecutor>,
74 ) -> Self {
75 Self {
76 codes,
77 values,
78 fallback,
79 options,
80 executor,
81 }
82 }
83}
84
85#[async_trait]
86impl<Codes, Values, Fallback> LayoutStrategy for DictStrategy<Codes, Values, Fallback>
87where
88 Codes: LayoutStrategy,
89 Values: LayoutStrategy,
90 Fallback: LayoutStrategy,
91{
92 async fn write_stream(
93 &self,
94 ctx: &ArrayContext,
95 sequence_writer: SequenceWriter,
96 stream: SendableSequentialStream,
97 ) -> VortexResult<LayoutRef> {
98 if !dict_layout_supported(stream.dtype()) {
99 return self
100 .fallback
101 .write_stream(ctx, sequence_writer, stream)
102 .await;
103 }
104 let ctx = ctx.clone();
105 let options = self.options.clone();
106 let dtype = stream.dtype().clone();
107 let executor = self.executor.clone();
108 let (stream, first_chunk) = peek_first_chunk(stream).await?;
110 let stream = SequentialStreamAdapter::new(dtype.clone(), stream).sendable();
111
112 let should_fallback = match first_chunk {
113 None => true, Some(chunk) => {
115 let compressed = BtrBlocksCompressor.compress(&chunk)?;
116 !compressed.is_encoding(DictEncoding.id())
117 }
118 };
119 if should_fallback {
120 return self
122 .fallback
123 .write_stream(&ctx, sequence_writer.clone(), stream)
124 .await;
125 }
126
127 let mut dict_stream = dict_encode_stream(stream, options.constraints);
131
132 let (mut encoded_tx, encoded_rx) = mpsc::channel(options.encoded_buffer_size);
134 let encode_handle = executor.spawn({
135 async move {
136 while let Some(item) = dict_stream.next().await {
137 encoded_tx
138 .send(item)
139 .await
140 .map_err(|e| vortex_err!("rx dropped: {}", e))?;
141 }
142 Ok(())
143 }
144 .boxed()
145 });
146
147 let dtype_clone = dtype.clone();
150 let child_layouts_fut = async move {
151 let mut children = Vec::new();
152 let mut runs = DictEncodedRuns::new(Box::pin(encoded_rx));
153 while let Some((codes_stream, values_future)) = runs.next_run().await {
154 let (codes_stream, first_chunk) = peek_first_chunk(codes_stream.boxed()).await?;
155 let codes_dtype = match first_chunk {
156 None => break,
158 Some(chunk) => chunk.dtype().clone(),
159 };
160 let codes_layout = self
161 .codes
162 .write_stream(
163 &ctx,
164 sequence_writer.clone(),
165 SequentialStreamAdapter::new(codes_dtype, codes_stream).sendable(),
166 )
167 .await?;
168 let values_layout = self
169 .values
170 .write_stream(
171 &ctx,
172 sequence_writer.clone(),
173 SequentialStreamAdapter::new(dtype_clone.clone(), once(values_future))
174 .sendable(),
175 )
176 .await?;
177 children.push(DictLayout::new(values_layout, codes_layout).into_layout());
178 }
179 Ok(children)
180 };
181
182 let (mut children, _) = try_join!(child_layouts_fut, encode_handle)?;
184
185 if children.len() == 1 {
186 return Ok(children.remove(0));
187 }
188
189 let row_count = children.iter().map(|child| child.row_count()).sum();
190 Ok(ChunkedLayout::new(
191 row_count,
192 dtype,
193 OwnedLayoutChildren::layout_children(children),
194 )
195 .into_layout())
196 }
197}
198
199enum DictionaryChunk {
200 Codes((SequenceId, ArrayRef)),
201 Values((SequenceId, ArrayRef)),
202}
203
204type DictionaryStream = BoxStream<'static, VortexResult<DictionaryChunk>>;
205
206fn dict_encode_stream(
207 input: SendableSequentialStream,
208 constraints: DictConstraints,
209) -> DictionaryStream {
210 Box::pin(try_stream! {
211 let mut state = DictStreamState {
212 encoder: None,
213 constraints,
214 };
215 let input = input.peekable();
216 pin_mut!(input);
217 while let Some(item) = input.as_mut().next().await {
218 let (sequence_id, chunk) = item?;
219 match input.as_mut().peek().await {
223 Some(_) => {
224 let mut labeler = DictChunkLabeler::new(sequence_id);
225 let chunks = state.encode(&mut labeler, chunk);
226 drop(labeler);
227 for dict_chunk in chunks {
228 yield dict_chunk?;
229 }
230 }
231 None => {
232 let mut labeler = DictChunkLabeler::new(sequence_id);
234 let encoded = state.encode(&mut labeler, chunk);
235 let drained = state.drain_values(&mut labeler);
236 drop(labeler);
237 for dict_chunk in encoded.into_iter().chain(drained.into_iter()) {
238 yield dict_chunk?;
239 }
240 }
241 }
242 }
243 })
244}
245
246struct DictStreamState {
247 encoder: Option<Box<dyn DictEncoder>>,
248 constraints: DictConstraints,
249}
250
251impl DictStreamState {
252 fn encode(
253 &mut self,
254 labeler: &mut DictChunkLabeler,
255 chunk: ArrayRef,
256 ) -> Vec<VortexResult<DictionaryChunk>> {
257 self.try_encode(labeler, chunk)
258 .unwrap_or_else(|e| vec![Err(e)])
259 }
260
261 fn try_encode(
262 &mut self,
263 labeler: &mut DictChunkLabeler,
264 chunk: ArrayRef,
265 ) -> VortexResult<Vec<VortexResult<DictionaryChunk>>> {
266 let mut res = Vec::new();
267 let mut to_be_encoded = Some(chunk);
268 while let Some(remaining) = to_be_encoded.take() {
269 match self.encoder.take() {
270 None => match start_encoding(&self.constraints, &remaining)? {
271 EncodingState::Continue((encoder, encoded)) => {
272 res.push(Ok(labeler.codes(encoded)));
273 self.encoder = Some(encoder);
274 }
275 EncodingState::Done((values, encoded, unencoded)) => {
276 res.push(Ok(labeler.codes(encoded)));
277 res.push(Ok(labeler.values(values)));
278 to_be_encoded = Some(unencoded);
279 }
280 },
281 Some(encoder) => match encode_chunk(encoder, &remaining)? {
282 EncodingState::Continue((encoder, encoded)) => {
283 res.push(Ok(labeler.codes(encoded)));
284 self.encoder = Some(encoder);
285 }
286 EncodingState::Done((values, encoded, unencoded)) => {
287 res.push(Ok(labeler.codes(encoded)));
288 res.push(Ok(labeler.values(values)));
289 to_be_encoded = Some(unencoded);
290 }
291 },
292 }
293 }
294 Ok(res)
295 }
296
297 fn drain_values(
298 &mut self,
299 labeler: &mut DictChunkLabeler,
300 ) -> Vec<VortexResult<DictionaryChunk>> {
301 match self.encoder.as_mut() {
302 None => Vec::new(),
303 Some(encoder) => vec![encoder.values().map(|val| labeler.values(val))],
304 }
305 }
306}
307
308struct DictChunkLabeler {
309 sequence_pointer: SequencePointer,
310}
311
312impl DictChunkLabeler {
313 fn new(starting_id: SequenceId) -> Self {
314 let sequence_pointer = starting_id.descend();
315 Self { sequence_pointer }
316 }
317
318 fn codes(&mut self, chunk: ArrayRef) -> DictionaryChunk {
319 DictionaryChunk::Codes((self.sequence_pointer.advance(), chunk))
320 }
321
322 fn values(&mut self, chunk: ArrayRef) -> DictionaryChunk {
323 DictionaryChunk::Values((self.sequence_pointer.advance(), chunk))
324 }
325}
326
327type SequencedChunk = VortexResult<(SequenceId, ArrayRef)>;
328
329struct DictEncodedRuns {
330 input: Option<oneshot::Receiver<Option<DictionaryStream>>>,
331}
332
333impl DictEncodedRuns {
334 fn new(input: DictionaryStream) -> Self {
335 let (tx, rx) = oneshot::channel();
336 tx.send(Some(input))
337 .map_err(|_input| vortex_err!("just created rx"))
338 .vortex_unwrap();
339 Self { input: Some(rx) }
340 }
341
342 async fn next_run(
343 &mut self,
344 ) -> Option<(
345 DictEncodedRunStream,
346 impl Future<Output = SequencedChunk> + use<>,
347 )> {
348 let Ok(Some(input)) = self.input.take()?.await else {
350 return None;
352 };
353 let (input_tx, input_rx) = oneshot::channel();
354 self.input = Some(input_rx);
355
356 let (values_tx, values_rx) = oneshot::channel();
357 let values_future = async {
358 values_rx
359 .await
360 .unwrap_or_else(|_| vortex_bail!("sender dropped"))
361 };
362
363 let codes_stream = DictEncodedRunStream {
364 input: Some(input),
365 input_tx: Some(input_tx),
366 values_tx: Some(values_tx),
367 };
368
369 Some((codes_stream, values_future))
370 }
371}
372
373struct DictEncodedRunStream {
374 input: Option<DictionaryStream>,
375 input_tx: Option<oneshot::Sender<Option<DictionaryStream>>>,
376 values_tx: Option<oneshot::Sender<SequencedChunk>>,
377}
378
379impl Stream for DictEncodedRunStream {
380 type Item = SequencedChunk;
381
382 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
383 let poll_result = {
384 let Some(stream) = self.input.as_mut() else {
385 return Poll::Ready(None);
386 };
387 ready!(stream.poll_next_unpin(cx))
388 };
389
390 match poll_result {
391 Some(Ok(DictionaryChunk::Codes(item))) => Poll::Ready(Some(Ok(item))),
392 Some(Ok(DictionaryChunk::Values(item))) => {
393 self.send_values(item);
394 self.send_back_input_stream();
395 Poll::Ready(None)
396 }
397 Some(Err(e)) => Poll::Ready(Some(Err(e))),
398 None => {
399 self.send_back_input_stream();
400 Poll::Ready(None)
401 }
402 }
403 }
404}
405
406impl DictEncodedRunStream {
407 fn send_values(&mut self, item: (SequenceId, ArrayRef)) {
408 let _ = self
410 .values_tx
411 .take()
412 .vortex_expect("must not be polled after returning None")
413 .send(Ok(item));
414 }
415
416 fn send_back_input_stream(&mut self) {
417 let _ = self
419 .input_tx
420 .take()
421 .vortex_expect("input already sent")
422 .send(self.input.take());
423 }
424}
425
426impl Drop for DictEncodedRunStream {
427 fn drop(&mut self) {
428 if let Some(tx) = self.input_tx.take() {
429 let _ = tx.send(self.input.take());
430 }
431 }
432}
433
434async fn peek_first_chunk(
435 mut stream: BoxStream<'static, SequencedChunk>,
436) -> VortexResult<(BoxStream<'static, SequencedChunk>, Option<ArrayRef>)> {
437 match stream.next().await {
438 None => Ok((stream.boxed(), None)),
439 Some(Err(e)) => Err(e),
440 Some(Ok((sequence_id, chunk))) => {
441 let chunk_clone = chunk.clone();
442 let reconstructed_stream =
443 once(async move { Ok((sequence_id, chunk_clone)) }).chain(stream);
444 Ok((reconstructed_stream.boxed(), Some(chunk)))
445 }
446 }
447}
448
449pub fn dict_layout_supported(dtype: &DType) -> bool {
450 matches!(
451 dtype,
452 DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_)
453 )
454}
455
456#[derive(prost::Message)]
457pub struct DictLayoutMetadata {
458 #[prost(enumeration = "PType", tag = "1")]
459 codes_ptype: i32,
461}
462
463impl DictLayoutMetadata {
464 pub fn new(codes_ptype: PType) -> Self {
465 let mut metadata = Self::default();
466 metadata.set_codes_ptype(codes_ptype);
467 metadata
468 }
469}
470
471enum EncodingState {
472 Continue((Box<dyn DictEncoder>, ArrayRef)),
473 Done((ArrayRef, ArrayRef, ArrayRef)),
475}
476
477fn start_encoding(constraints: &DictConstraints, chunk: &dyn Array) -> VortexResult<EncodingState> {
478 let encoder = dict_encoder(chunk, constraints)?;
479 encode_chunk(encoder, chunk)
480}
481
482fn encode_chunk(
483 mut encoder: Box<dyn DictEncoder>,
484 chunk: &dyn Array,
485) -> VortexResult<EncodingState> {
486 let encoded = encoder.encode(chunk)?;
487 Ok(match remainder(chunk, encoded.len()) {
488 None => EncodingState::Continue((encoder, encoded)),
489 Some(unencoded) => EncodingState::Done((encoder.values()?, encoded, unencoded)),
490 })
491}
492
493fn remainder(array: &dyn Array, encoded_len: usize) -> Option<ArrayRef> {
494 (encoded_len < array.len()).then(|| array.slice(encoded_len, array.len()))
495}