Skip to main content

vortex_layout/layouts/
repartition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::collections::VecDeque;
5use std::sync::Arc;
6
7use async_stream::try_stream;
8use async_trait::async_trait;
9use futures::StreamExt as _;
10use futures::pin_mut;
11use vortex_array::Array;
12use vortex_array::ArrayContext;
13use vortex_array::ArrayRef;
14use vortex_array::IntoArray;
15use vortex_array::arrays::ChunkedArray;
16use vortex_dtype::DType;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_io::runtime::Handle;
20
21use crate::LayoutRef;
22use crate::LayoutStrategy;
23use crate::segments::SegmentSinkRef;
24use crate::sequence::SendableSequentialStream;
25use crate::sequence::SequencePointer;
26use crate::sequence::SequentialStreamAdapter;
27use crate::sequence::SequentialStreamExt;
28
29#[derive(Clone)]
30pub struct RepartitionWriterOptions {
31    /// The minimum uncompressed size in bytes for a block.
32    pub block_size_minimum: u64,
33    /// The multiple of the number of rows in each block.
34    pub block_len_multiple: usize,
35    /// Optional target uncompressed size in bytes for a block.
36    ///
37    /// The repartition writer attempts to produce partitions with this uncompressed size. This is
38    /// only a best effort attempt: the partitions may be arbitrarily larger or smaller. Reasons for
39    /// this include:
40    ///
41    /// 1. The size of one element may not perfectly divide the target size, resulting in blocks
42    ///    that are either too large or too small.
43    ///
44    /// 2. Variable length types are expensive to pack due to the need to read each element length.
45    ///
46    /// 3. View types are expensive to pack due to each view sharing an arbitrary slice of data.
47    pub block_size_target: Option<u64>,
48    pub canonicalize: bool,
49}
50
51impl RepartitionWriterOptions {
52    /// Compute the effective block length for a given [`DType`].
53    ///
54    /// For fixed-width types where [`DType::element_size`] is known and large enough that
55    /// `element_size * block_len_multiple` would exceed `block_size_target`, this reduces the
56    /// block length so each block stays close to the target byte size.
57    fn effective_block_len(&self, dtype: &DType) -> usize {
58        let Some(block_size_target) = self.block_size_target else {
59            return self.block_len_multiple;
60        };
61        match dtype.element_size() {
62            Some(elem_size) if elem_size > 0 => {
63                // `div_ceil` ensures we overshoot the block_size_target; therefore preventing
64                // `write_stream` from combining adjacent 0.9 MiB chunks into one 1.8 MiB chunk.
65                let max_rows = usize::try_from(block_size_target.div_ceil(elem_size as u64))
66                    .unwrap_or(usize::MAX);
67                self.block_len_multiple.min(max_rows).max(1)
68            }
69            _ => self.block_len_multiple,
70        }
71    }
72}
73
74/// Repartition a stream of arrays into blocks.
75///
76/// Each emitted block (except the last) is at least `block_size_minimum` bytes and contains a
77/// multiple of `block_len_multiple` rows.
78#[derive(Clone)]
79pub struct RepartitionStrategy {
80    child: Arc<dyn LayoutStrategy>,
81    options: RepartitionWriterOptions,
82}
83
84impl RepartitionStrategy {
85    pub fn new<S: LayoutStrategy>(child: S, options: RepartitionWriterOptions) -> Self {
86        Self {
87            child: Arc::new(child),
88            options,
89        }
90    }
91}
92
93#[async_trait]
94impl LayoutStrategy for RepartitionStrategy {
95    async fn write_stream(
96        &self,
97        ctx: ArrayContext,
98        segment_sink: SegmentSinkRef,
99        stream: SendableSequentialStream,
100        eof: SequencePointer,
101        handle: Handle,
102    ) -> VortexResult<LayoutRef> {
103        // TODO(os): spawn stream below like:
104        // canon_stream = stream.map(async {to_canonical}).map(spawn).buffered(parallelism)
105        let dtype = stream.dtype().clone();
106        let stream = if self.options.canonicalize {
107            SequentialStreamAdapter::new(
108                dtype.clone(),
109                stream.map(|chunk| {
110                    let (sequence_id, chunk) = chunk?;
111                    VortexResult::Ok((sequence_id, chunk.to_canonical()?.into_array()))
112                }),
113            )
114            .sendable()
115        } else {
116            stream
117        };
118
119        let dtype_clone = dtype.clone();
120        let options = self.options.clone();
121
122        // For fixed-width types with large per-element sizes, reduce the block_len_multiple
123        // so that each block targets block_size_target bytes rather than producing oversized
124        // segments.
125        let block_len = options.effective_block_len(&dtype);
126        let block_size_minimum = options.block_size_minimum;
127
128        let repartitioned_stream = try_stream! {
129            let canonical_stream = stream.peekable();
130            pin_mut!(canonical_stream);
131
132            let mut chunks = ChunksBuffer::new(block_size_minimum, block_len);
133            while let Some(chunk) = canonical_stream.as_mut().next().await {
134                let (sequence_id, chunk) = chunk?;
135                let mut sequence_pointer = sequence_id.descend();
136                let mut offset = 0;
137                while offset < chunk.len() {
138                    let end = (offset + block_len).min(chunk.len());
139                    let sliced = chunk.slice(offset..end)?;
140                    chunks.push_back(sliced);
141                    offset = end;
142
143                    if chunks.have_enough() {
144                        let output_chunks = chunks.collect_exact_blocks()?;
145                        assert!(!output_chunks.is_empty());
146                        let chunked =
147                            ChunkedArray::try_new(output_chunks, dtype_clone.clone())?;
148                        if !chunked.is_empty() {
149                            yield (
150                                sequence_pointer.advance(),
151                                chunked.to_canonical()?.into_array(),
152                            )
153                        }
154                    }
155                }
156                if canonical_stream.as_mut().peek().await.is_none() {
157                    let to_flush = ChunkedArray::try_new(
158                        chunks.data.drain(..).map(|(arr, _)| arr).collect(),
159                        dtype_clone.clone(),
160                    )?;
161                    if !to_flush.is_empty() {
162                        yield (
163                            sequence_pointer.advance(),
164                            to_flush.to_canonical()?.into_array(),
165                        )
166                    }
167                }
168            }
169        };
170
171        self.child
172            .write_stream(
173                ctx,
174                segment_sink,
175                SequentialStreamAdapter::new(dtype, repartitioned_stream).sendable(),
176                eof,
177                handle,
178            )
179            .await
180    }
181
182    fn buffered_bytes(&self) -> u64 {
183        // TODO(os): we should probably add the buffered bytes from this strategy on top,
184        // it is currently better to not add it at all because these buffered arrays are
185        // potentially sliced and uncompressed. They would overestimate the actual bytes
186        // that will end up in the file when flushed.
187        self.child.buffered_bytes()
188    }
189}
190
191struct ChunksBuffer {
192    /// Each entry stores the chunk and the `nbytes()` snapshot taken at push time.
193    /// This avoids accounting mismatches when interior-mutable arrays (e.g. `SharedArray`)
194    /// change their reported size after being pushed.
195    data: VecDeque<(ArrayRef, u64)>,
196    row_count: usize,
197    nbytes: u64,
198    block_size_minimum: u64,
199    block_len_multiple: usize,
200}
201
202impl ChunksBuffer {
203    fn new(block_size_minimum: u64, block_len_multiple: usize) -> Self {
204        Self {
205            data: Default::default(),
206            row_count: 0,
207            nbytes: 0,
208            block_size_minimum,
209            block_len_multiple,
210        }
211    }
212
213    fn have_enough(&self) -> bool {
214        self.nbytes >= self.block_size_minimum && self.row_count >= self.block_len_multiple
215    }
216
217    fn collect_exact_blocks(&mut self) -> VortexResult<Vec<ArrayRef>> {
218        let nblocks = self.row_count / self.block_len_multiple;
219        let mut res = Vec::with_capacity(self.data.len());
220        let mut remaining = nblocks * self.block_len_multiple;
221        while remaining > 0 {
222            let (chunk, _) = self
223                .pop_front()
224                .vortex_expect("must have at least one chunk");
225            let len = chunk.len();
226
227            if len > remaining {
228                let left = chunk.slice(0..remaining)?;
229                let right = chunk.slice(remaining..len)?;
230                self.push_front(right);
231                res.push(left);
232                remaining = 0;
233            } else {
234                res.push(chunk);
235                remaining -= len;
236            }
237        }
238        Ok(res)
239    }
240
241    fn push_back(&mut self, chunk: ArrayRef) {
242        let nb = chunk.nbytes();
243        self.row_count += chunk.len();
244        self.nbytes += nb;
245        self.data.push_back((chunk, nb));
246    }
247
248    fn push_front(&mut self, chunk: ArrayRef) {
249        let nb = chunk.nbytes();
250        self.row_count += chunk.len();
251        self.nbytes += nb;
252        self.data.push_front((chunk, nb));
253    }
254
255    fn pop_front(&mut self) -> Option<(ArrayRef, u64)> {
256        let res = self.data.pop_front();
257        if let Some((chunk, nb)) = res.as_ref() {
258            self.row_count -= chunk.len();
259            self.nbytes -= nb;
260        }
261        res
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use std::sync::Arc;
268
269    use vortex_array::Array;
270    use vortex_array::ArrayContext;
271    use vortex_array::IntoArray;
272    use vortex_array::arrays::ConstantArray;
273    use vortex_array::arrays::FixedSizeListArray;
274    use vortex_array::arrays::PrimitiveArray;
275    use vortex_array::arrays::SharedArray;
276    use vortex_array::validity::Validity;
277    use vortex_dtype::DType;
278    use vortex_dtype::Nullability::NonNullable;
279    use vortex_dtype::PType;
280    use vortex_error::VortexResult;
281    use vortex_io::runtime::single::block_on;
282
283    use super::*;
284    use crate::LayoutStrategy;
285    use crate::layouts::chunked::writer::ChunkedLayoutStrategy;
286    use crate::layouts::flat::writer::FlatLayoutStrategy;
287    use crate::segments::TestSegments;
288    use crate::sequence::SequenceId;
289    use crate::sequence::SequentialArrayStreamExt;
290
291    const ONE_MEG: u64 = 1 << 20;
292
293    #[test]
294    fn effective_block_len_small_elements() {
295        // f64 = 8 bytes/element. 8192 * 8 = 64 KiB << 1 MiB, so no reduction.
296        let dtype = DType::Primitive(PType::F64, NonNullable);
297        let options = RepartitionWriterOptions {
298            block_size_minimum: 0,
299            block_len_multiple: 8192,
300            block_size_target: Some(ONE_MEG),
301            canonicalize: false,
302        };
303        assert_eq!(options.effective_block_len(&dtype), 8192);
304    }
305
306    #[test]
307    fn effective_block_len_large_elements() {
308        // FixedSizeList(f64, 1000) = 8000 bytes/element.
309        // div_ceil(1 MiB, 8000) = 132, so effective block len = min(8192, 132) = 132.
310        let dtype = DType::FixedSizeList(
311            Arc::new(DType::Primitive(PType::F64, NonNullable)),
312            1000,
313            NonNullable,
314        );
315        let options = RepartitionWriterOptions {
316            block_size_minimum: 0,
317            block_len_multiple: 8192,
318            block_size_target: Some(ONE_MEG),
319            canonicalize: false,
320        };
321        assert_eq!(options.effective_block_len(&dtype), 132);
322    }
323
324    #[test]
325    fn effective_block_len_variable_width() {
326        // Utf8 has no known element_size, so block_len_multiple is unchanged.
327        let dtype = DType::Utf8(NonNullable);
328        let options = RepartitionWriterOptions {
329            block_size_minimum: 0,
330            block_len_multiple: 8192,
331            block_size_target: Some(ONE_MEG),
332            canonicalize: false,
333        };
334        assert_eq!(options.effective_block_len(&dtype), 8192);
335    }
336
337    #[test]
338    fn effective_block_len_very_large_elements() {
339        // FixedSizeList(f64, 1_000_000) = 8_000_000 bytes/element.
340        // 1 MiB / 8_000_000 = 0, clamped to max(1) = 1.
341        let dtype = DType::FixedSizeList(
342            Arc::new(DType::Primitive(PType::F64, NonNullable)),
343            1_000_000,
344            NonNullable,
345        );
346        let options = RepartitionWriterOptions {
347            block_size_minimum: 0,
348            block_len_multiple: 8192,
349            block_size_target: Some(ONE_MEG),
350            canonicalize: false,
351        };
352        assert_eq!(options.effective_block_len(&dtype), 1);
353    }
354
355    #[test]
356    fn repartition_large_element_type_produces_small_blocks() -> VortexResult<()> {
357        // Create a FixedSizeList(f64, 1000) array with 1000 lists.
358        // Each list is 8000 bytes, so 1000 lists = 8 MiB total.
359        // With block_size_target = 1 MiB, effective block_len = 133.
360        // We expect the repartition to produce blocks of 132 rows each.
361        let list_size: u32 = 1000;
362        let num_lists: usize = 1000;
363        let total_elements = list_size as usize * num_lists;
364
365        let elements = PrimitiveArray::from_iter((0..total_elements).map(|i| i as f64));
366        let fsl = FixedSizeListArray::new(
367            elements.into_array(),
368            list_size,
369            Validity::NonNullable,
370            num_lists,
371        );
372
373        let ctx = ArrayContext::empty();
374        let segments = Arc::new(TestSegments::default());
375        let (ptr, eof) = SequenceId::root().split();
376
377        let child = ChunkedLayoutStrategy::new(FlatLayoutStrategy::default());
378        let strategy = RepartitionStrategy::new(
379            child,
380            RepartitionWriterOptions {
381                block_size_minimum: 0,
382                block_len_multiple: 8192,
383                block_size_target: Some(ONE_MEG),
384                canonicalize: false,
385            },
386        );
387
388        let stream = fsl.into_array().to_array_stream().sequenced(ptr);
389        let layout =
390            block_on(|handle| strategy.write_stream(ctx, segments.clone(), stream, eof, handle))?;
391
392        // The layout should be a ChunkedLayout with multiple children.
393        // With 1000 rows and effective block_len = 132:
394        //   - 7 full blocks of 132 rows = 924 rows
395        //   - 1 remainder block of 76 rows
396        //   - Total: 8 blocks, 1000 rows
397        assert_eq!(layout.row_count(), num_lists as u64);
398
399        // All non-last children should have 131 rows.
400        let nchildren = layout.nchildren();
401        assert!(nchildren > 1, "expected multiple chunks, got {nchildren}");
402
403        for i in 0..nchildren - 1 {
404            let child = layout.child(i)?;
405            assert_eq!(
406                child.row_count(),
407                132,
408                "chunk {i} has {} rows, expected 131",
409                child.row_count()
410            );
411        }
412
413        // Last child gets the remainder.
414        let last = layout.child(nchildren - 1)?;
415        assert_eq!(last.row_count(), 1000 - 132 * (nchildren as u64 - 1));
416
417        Ok(())
418    }
419
420    #[test]
421    fn repartition_small_element_type_unchanged() -> VortexResult<()> {
422        // For f64 (8 bytes/element), effective block_len stays at 8192.
423        // With 10000 elements and block_size_minimum=0, we get one block of 8192
424        // and one remainder of 1808.
425        let num_elements: usize = 10000;
426        let elements = PrimitiveArray::from_iter((0..num_elements).map(|i| i as f64));
427
428        let ctx = ArrayContext::empty();
429        let segments = Arc::new(TestSegments::default());
430        let (ptr, eof) = SequenceId::root().split();
431
432        let child = ChunkedLayoutStrategy::new(FlatLayoutStrategy::default());
433        let strategy = RepartitionStrategy::new(
434            child,
435            RepartitionWriterOptions {
436                block_size_minimum: 0,
437                block_len_multiple: 8192,
438                block_size_target: Some(ONE_MEG),
439                canonicalize: false,
440            },
441        );
442
443        let stream = elements.into_array().to_array_stream().sequenced(ptr);
444        let layout =
445            block_on(|handle| strategy.write_stream(ctx, segments.clone(), stream, eof, handle))?;
446
447        assert_eq!(layout.row_count(), num_elements as u64);
448        assert_eq!(layout.nchildren(), 2);
449        assert_eq!(layout.child(0)?.row_count(), 8192);
450        assert_eq!(layout.child(1)?.row_count(), 1808);
451
452        Ok(())
453    }
454
455    /// Regression test: `SharedArray` slices sharing an `Arc<Mutex<SharedState>>` can
456    /// transition from Source to Cached when any one of them is canonicalized. This caused
457    /// `pop_front` to panic with `attempt to subtract with overflow` because the buffer's
458    /// running `nbytes` total was accumulated with the smaller Source-era values while
459    /// `pop_front` subtracted the larger Cached-era values.
460    #[test]
461    fn chunks_buffer_pop_front_no_panic_after_shared_execution() -> VortexResult<()> {
462        let n = 20_000usize;
463        let block_len = 10_000usize;
464
465        let constant = ConstantArray::new(42i64, n);
466        let shared = SharedArray::new(constant.into_array());
467        let shared_handle = shared.clone();
468        let arr = shared.into_array();
469
470        let s1 = arr.slice(0..block_len)?;
471        let s2 = arr.slice(block_len..n)?;
472
473        let mut buf = ChunksBuffer::new(0, block_len);
474        buf.push_back(s1);
475        buf.push_back(s2);
476
477        let _output = buf.pop_front().unwrap();
478
479        // Transition SharedState from Source to Cached for ALL slices sharing this Arc.
480        shared_handle.get_or_compute(|source| source.to_canonical())?;
481
482        // Before the fix this panicked with "attempt to subtract with overflow".
483        let _s2 = buf.pop_front().unwrap();
484        assert_eq!(buf.nbytes, 0);
485        assert_eq!(buf.row_count, 0);
486
487        Ok(())
488    }
489}