Skip to main content

vortex_layout/layouts/
repartition.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::collections::VecDeque;
5use std::sync::Arc;
6
7use async_stream::try_stream;
8use async_trait::async_trait;
9use futures::StreamExt as _;
10use futures::pin_mut;
11use vortex_array::ArrayContext;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::IntoArray;
15use vortex_array::VortexSessionExecute;
16use vortex_array::arrays::ChunkedArray;
17use vortex_array::dtype::DType;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_session::VortexSession;
21
22use crate::LayoutRef;
23use crate::LayoutStrategy;
24use crate::segments::SegmentSinkRef;
25use crate::sequence::SendableSequentialStream;
26use crate::sequence::SequencePointer;
27use crate::sequence::SequentialStreamAdapter;
28use crate::sequence::SequentialStreamExt;
29
30#[derive(Clone)]
31pub struct RepartitionWriterOptions {
32    /// The minimum uncompressed size in bytes for a block.
33    pub block_size_minimum: u64,
34    /// The multiple of the number of rows in each block.
35    pub block_len_multiple: usize,
36    /// Optional target uncompressed size in bytes for a block.
37    ///
38    /// The repartition writer attempts to produce partitions with this uncompressed size. This is
39    /// only a best effort attempt: the partitions may be arbitrarily larger or smaller. Reasons for
40    /// this include:
41    ///
42    /// 1. The size of one element may not perfectly divide the target size, resulting in blocks
43    ///    that are either too large or too small.
44    ///
45    /// 2. Variable length types are expensive to pack due to the need to read each element length.
46    ///
47    /// 3. View types are expensive to pack due to each view sharing an arbitrary slice of data.
48    pub block_size_target: Option<u64>,
49    pub canonicalize: bool,
50}
51
52impl RepartitionWriterOptions {
53    /// Compute the effective block length for a given [`DType`].
54    ///
55    /// For fixed-width types where [`DType::element_size`] is known and large enough that
56    /// `element_size * block_len_multiple` would exceed `block_size_target`, this reduces the
57    /// block length so each block stays close to the target byte size.
58    fn effective_block_len(&self, dtype: &DType) -> usize {
59        let Some(block_size_target) = self.block_size_target else {
60            return self.block_len_multiple;
61        };
62        match dtype.element_size() {
63            Some(elem_size) if elem_size > 0 => {
64                // `div_ceil` ensures we overshoot the block_size_target; therefore preventing
65                // `write_stream` from combining adjacent 0.9 MiB chunks into one 1.8 MiB chunk.
66                let max_rows = usize::try_from(block_size_target.div_ceil(elem_size as u64))
67                    .unwrap_or(usize::MAX);
68                self.block_len_multiple.min(max_rows).max(1)
69            }
70            _ => self.block_len_multiple,
71        }
72    }
73}
74
75/// Repartition a stream of arrays into blocks.
76///
77/// Each emitted block (except the last) is at least `block_size_minimum` bytes and contains a
78/// multiple of `block_len_multiple` rows.
79#[derive(Clone)]
80pub struct RepartitionStrategy {
81    child: Arc<dyn LayoutStrategy>,
82    options: RepartitionWriterOptions,
83}
84
85impl RepartitionStrategy {
86    pub fn new<S: LayoutStrategy>(child: S, options: RepartitionWriterOptions) -> Self {
87        Self {
88            child: Arc::new(child),
89            options,
90        }
91    }
92}
93
94#[async_trait]
95impl LayoutStrategy for RepartitionStrategy {
96    async fn write_stream(
97        &self,
98        ctx: ArrayContext,
99        segment_sink: SegmentSinkRef,
100        stream: SendableSequentialStream,
101        eof: SequencePointer,
102        session: &VortexSession,
103    ) -> VortexResult<LayoutRef> {
104        // TODO(os): spawn stream below like:
105        // canon_stream = stream.map(async {to_canonical}).map(spawn).buffered(parallelism)
106        let dtype = stream.dtype().clone();
107        let stream = if self.options.canonicalize {
108            let canonicalize_session = session.clone();
109            SequentialStreamAdapter::new(
110                dtype.clone(),
111                stream.map(move |chunk| {
112                    let (sequence_id, chunk) = chunk?;
113                    let mut ctx = canonicalize_session.create_execution_ctx();
114                    let canonical = chunk.execute::<Canonical>(&mut ctx)?.into_array();
115                    VortexResult::Ok((sequence_id, canonical))
116                }),
117            )
118            .sendable()
119        } else {
120            stream
121        };
122
123        let dtype_clone = dtype.clone();
124        let options = self.options.clone();
125
126        // For fixed-width types with large per-element sizes, reduce the block_len_multiple
127        // so that each block targets block_size_target bytes rather than producing oversized
128        // segments.
129        let block_len = options.effective_block_len(&dtype);
130        let block_size_minimum = options.block_size_minimum;
131        let repartition_session = session.clone();
132
133        let repartitioned_stream = try_stream! {
134            let canonical_stream = stream.peekable();
135            pin_mut!(canonical_stream);
136
137            let mut ctx = repartition_session.create_execution_ctx();
138            let mut chunks = ChunksBuffer::new(block_size_minimum, block_len);
139            while let Some(chunk) = canonical_stream.as_mut().next().await {
140                let (sequence_id, chunk) = chunk?;
141                let mut sequence_pointer = sequence_id.descend();
142                let mut offset = 0;
143                while offset < chunk.len() {
144                    let end = (offset + block_len).min(chunk.len());
145                    let sliced = chunk.slice(offset..end)?;
146                    chunks.push_back(sliced);
147                    offset = end;
148
149                    if chunks.have_enough() {
150                        let output_chunks = chunks.collect_exact_blocks()?;
151                        assert!(!output_chunks.is_empty());
152                        let chunked =
153                            ChunkedArray::try_new(output_chunks, dtype_clone.clone())?;
154                        if !chunked.is_empty() {
155                            let canonical = chunked.into_array().execute::<Canonical>(&mut ctx)?.into_array();
156                            yield (
157                                sequence_pointer.advance(),
158                                canonical,
159                            )
160                        }
161                    }
162                }
163                if canonical_stream.as_mut().peek().await.is_none() {
164                    let to_flush = ChunkedArray::try_new(
165                        chunks.data.drain(..).map(|(arr, _)| arr).collect(),
166                        dtype_clone.clone(),
167                    )?;
168                    if !to_flush.is_empty() {
169                        let canonical = to_flush.into_array().execute::<Canonical>(&mut ctx)?.into_array();
170                        yield (
171                            sequence_pointer.advance(),
172                            canonical,
173                        )
174                    }
175                }
176            }
177        };
178
179        self.child
180            .write_stream(
181                ctx,
182                segment_sink,
183                SequentialStreamAdapter::new(dtype, repartitioned_stream).sendable(),
184                eof,
185                session,
186            )
187            .await
188    }
189
190    fn buffered_bytes(&self) -> u64 {
191        // TODO(os): we should probably add the buffered bytes from this strategy on top,
192        // it is currently better to not add it at all because these buffered arrays are
193        // potentially sliced and uncompressed. They would overestimate the actual bytes
194        // that will end up in the file when flushed.
195        self.child.buffered_bytes()
196    }
197}
198
199struct ChunksBuffer {
200    /// Each entry stores the chunk and the `nbytes()` snapshot taken at push time.
201    /// This avoids accounting mismatches when interior-mutable arrays (e.g. `SharedArray`)
202    /// change their reported size after being pushed.
203    data: VecDeque<(ArrayRef, u64)>,
204    row_count: usize,
205    nbytes: u64,
206    block_size_minimum: u64,
207    block_len_multiple: usize,
208}
209
210impl ChunksBuffer {
211    fn new(block_size_minimum: u64, block_len_multiple: usize) -> Self {
212        Self {
213            data: Default::default(),
214            row_count: 0,
215            nbytes: 0,
216            block_size_minimum,
217            block_len_multiple,
218        }
219    }
220
221    fn have_enough(&self) -> bool {
222        self.nbytes >= self.block_size_minimum && self.row_count >= self.block_len_multiple
223    }
224
225    fn collect_exact_blocks(&mut self) -> VortexResult<Vec<ArrayRef>> {
226        let nblocks = self.row_count / self.block_len_multiple;
227        let mut res = Vec::with_capacity(self.data.len());
228        let mut remaining = nblocks * self.block_len_multiple;
229        while remaining > 0 {
230            let (chunk, _) = self
231                .pop_front()
232                .vortex_expect("must have at least one chunk");
233            let len = chunk.len();
234
235            if len > remaining {
236                let left = chunk.slice(0..remaining)?;
237                let right = chunk.slice(remaining..len)?;
238                self.push_front(right);
239                res.push(left);
240                remaining = 0;
241            } else {
242                res.push(chunk);
243                remaining -= len;
244            }
245        }
246        Ok(res)
247    }
248
249    fn push_back(&mut self, chunk: ArrayRef) {
250        let nb = chunk.nbytes();
251        self.row_count += chunk.len();
252        self.nbytes += nb;
253        self.data.push_back((chunk, nb));
254    }
255
256    fn push_front(&mut self, chunk: ArrayRef) {
257        let nb = chunk.nbytes();
258        self.row_count += chunk.len();
259        self.nbytes += nb;
260        self.data.push_front((chunk, nb));
261    }
262
263    fn pop_front(&mut self) -> Option<(ArrayRef, u64)> {
264        let res = self.data.pop_front();
265        if let Some((chunk, nb)) = res.as_ref() {
266            self.row_count -= chunk.len();
267            self.nbytes -= nb;
268        }
269        res
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use std::sync::Arc;
276
277    use vortex_array::ArrayContext;
278    use vortex_array::IntoArray;
279    use vortex_array::LEGACY_SESSION;
280    use vortex_array::arrays::ConstantArray;
281    use vortex_array::arrays::FixedSizeListArray;
282    use vortex_array::arrays::PrimitiveArray;
283    use vortex_array::arrays::SharedArray;
284    use vortex_array::dtype::DType;
285    use vortex_array::dtype::Nullability::NonNullable;
286    use vortex_array::dtype::PType;
287    use vortex_array::validity::Validity;
288    use vortex_error::VortexResult;
289    use vortex_io::runtime::single::block_on;
290    use vortex_io::session::RuntimeSessionExt;
291
292    use super::*;
293    use crate::LayoutStrategy;
294    use crate::layouts::chunked::writer::ChunkedLayoutStrategy;
295    use crate::layouts::flat::writer::FlatLayoutStrategy;
296    use crate::segments::TestSegments;
297    use crate::sequence::SequenceId;
298    use crate::sequence::SequentialArrayStreamExt;
299    use crate::test::SESSION;
300
301    const ONE_MEG: u64 = 1 << 20;
302
303    #[test]
304    fn effective_block_len_small_elements() {
305        // f64 = 8 bytes/element. 8192 * 8 = 64 KiB << 1 MiB, so no reduction.
306        let dtype = DType::Primitive(PType::F64, NonNullable);
307        let options = RepartitionWriterOptions {
308            block_size_minimum: 0,
309            block_len_multiple: 8192,
310            block_size_target: Some(ONE_MEG),
311            canonicalize: false,
312        };
313        assert_eq!(options.effective_block_len(&dtype), 8192);
314    }
315
316    #[test]
317    fn effective_block_len_large_elements() {
318        // FixedSizeList(f64, 1000) = 8000 bytes/element.
319        // div_ceil(1 MiB, 8000) = 132, so effective block len = min(8192, 132) = 132.
320        let dtype = DType::FixedSizeList(
321            Arc::new(DType::Primitive(PType::F64, NonNullable)),
322            1000,
323            NonNullable,
324        );
325        let options = RepartitionWriterOptions {
326            block_size_minimum: 0,
327            block_len_multiple: 8192,
328            block_size_target: Some(ONE_MEG),
329            canonicalize: false,
330        };
331        assert_eq!(options.effective_block_len(&dtype), 132);
332    }
333
334    #[test]
335    fn effective_block_len_variable_width() {
336        // Utf8 has no known element_size, so block_len_multiple is unchanged.
337        let dtype = DType::Utf8(NonNullable);
338        let options = RepartitionWriterOptions {
339            block_size_minimum: 0,
340            block_len_multiple: 8192,
341            block_size_target: Some(ONE_MEG),
342            canonicalize: false,
343        };
344        assert_eq!(options.effective_block_len(&dtype), 8192);
345    }
346
347    #[test]
348    fn effective_block_len_very_large_elements() {
349        // FixedSizeList(f64, 1_000_000) = 8_000_000 bytes/element.
350        // 1 MiB / 8_000_000 = 0, clamped to max(1) = 1.
351        let dtype = DType::FixedSizeList(
352            Arc::new(DType::Primitive(PType::F64, NonNullable)),
353            1_000_000,
354            NonNullable,
355        );
356        let options = RepartitionWriterOptions {
357            block_size_minimum: 0,
358            block_len_multiple: 8192,
359            block_size_target: Some(ONE_MEG),
360            canonicalize: false,
361        };
362        assert_eq!(options.effective_block_len(&dtype), 1);
363    }
364
365    #[test]
366    fn repartition_large_element_type_produces_small_blocks() -> VortexResult<()> {
367        // Create a FixedSizeList(f64, 1000) array with 1000 lists.
368        // Each list is 8000 bytes, so 1000 lists = 8 MiB total.
369        // With block_size_target = 1 MiB, effective block_len = 133.
370        // We expect the repartition to produce blocks of 132 rows each.
371        let list_size: u32 = 1000;
372        let num_lists: usize = 1000;
373        let total_elements = list_size as usize * num_lists;
374
375        let elements = PrimitiveArray::from_iter((0..total_elements).map(|i| i as f64));
376        let fsl = FixedSizeListArray::new(
377            elements.into_array(),
378            list_size,
379            Validity::NonNullable,
380            num_lists,
381        );
382
383        let ctx = ArrayContext::empty();
384        let segments = Arc::new(TestSegments::default());
385        let (ptr, eof) = SequenceId::root().split();
386
387        let child = ChunkedLayoutStrategy::new(FlatLayoutStrategy::default());
388        let strategy = RepartitionStrategy::new(
389            child,
390            RepartitionWriterOptions {
391                block_size_minimum: 0,
392                block_len_multiple: 8192,
393                block_size_target: Some(ONE_MEG),
394                canonicalize: false,
395            },
396        );
397
398        let stream = fsl.into_array().to_array_stream().sequenced(ptr);
399        let layout = block_on(|handle| async move {
400            let session = SESSION.clone().with_handle(handle);
401            strategy
402                .write_stream(
403                    ctx,
404                    Arc::<TestSegments>::clone(&segments),
405                    stream,
406                    eof,
407                    &session,
408                )
409                .await
410        })?;
411
412        // The layout should be a ChunkedLayout with multiple children.
413        // With 1000 rows and effective block_len = 132:
414        //   - 7 full blocks of 132 rows = 924 rows
415        //   - 1 remainder block of 76 rows
416        //   - Total: 8 blocks, 1000 rows
417        assert_eq!(layout.row_count(), num_lists as u64);
418
419        // All non-last children should have 131 rows.
420        let nchildren = layout.nchildren();
421        assert!(nchildren > 1, "expected multiple chunks, got {nchildren}");
422
423        for i in 0..nchildren - 1 {
424            let child = layout.child(i)?;
425            assert_eq!(
426                child.row_count(),
427                132,
428                "chunk {i} has {} rows, expected 131",
429                child.row_count()
430            );
431        }
432
433        // Last child gets the remainder.
434        let last = layout.child(nchildren - 1)?;
435        assert_eq!(last.row_count(), 1000 - 132 * (nchildren as u64 - 1));
436
437        Ok(())
438    }
439
440    #[test]
441    fn repartition_small_element_type_unchanged() -> VortexResult<()> {
442        // For f64 (8 bytes/element), effective block_len stays at 8192.
443        // With 10000 elements and block_size_minimum=0, we get one block of 8192
444        // and one remainder of 1808.
445        let num_elements: usize = 10000;
446        let elements = PrimitiveArray::from_iter((0..num_elements).map(|i| i as f64));
447
448        let ctx = ArrayContext::empty();
449        let segments = Arc::new(TestSegments::default());
450        let (ptr, eof) = SequenceId::root().split();
451
452        let child = ChunkedLayoutStrategy::new(FlatLayoutStrategy::default());
453        let strategy = RepartitionStrategy::new(
454            child,
455            RepartitionWriterOptions {
456                block_size_minimum: 0,
457                block_len_multiple: 8192,
458                block_size_target: Some(ONE_MEG),
459                canonicalize: false,
460            },
461        );
462
463        let stream = elements.into_array().to_array_stream().sequenced(ptr);
464        let layout = block_on(|handle| async move {
465            let session = SESSION.clone().with_handle(handle);
466            strategy
467                .write_stream(
468                    ctx,
469                    Arc::<TestSegments>::clone(&segments),
470                    stream,
471                    eof,
472                    &session,
473                )
474                .await
475        })?;
476
477        assert_eq!(layout.row_count(), num_elements as u64);
478        assert_eq!(layout.nchildren(), 2);
479        assert_eq!(layout.child(0)?.row_count(), 8192);
480        assert_eq!(layout.child(1)?.row_count(), 1808);
481
482        Ok(())
483    }
484
485    /// Regression test: `SharedArray` slices sharing an `Arc<Mutex<SharedState>>` can
486    /// transition from Source to Cached when any one of them is canonicalized. This caused
487    /// `pop_front` to panic with `attempt to subtract with overflow` because the buffer's
488    /// running `nbytes` total was accumulated with the smaller Source-era values while
489    /// `pop_front` subtracted the larger Cached-era values.
490    #[test]
491    fn chunks_buffer_pop_front_no_panic_after_shared_execution() -> VortexResult<()> {
492        let mut ctx = LEGACY_SESSION.create_execution_ctx();
493        let n = 20_000usize;
494        let block_len = 10_000usize;
495
496        let constant = ConstantArray::new(42i64, n);
497        let shared = SharedArray::new(constant.into_array());
498        let shared_handle = shared.clone();
499        let arr = shared.into_array();
500
501        let s1 = arr.slice(0..block_len)?;
502        let s2 = arr.slice(block_len..n)?;
503
504        let mut buf = ChunksBuffer::new(0, block_len);
505        buf.push_back(s1);
506        buf.push_back(s2);
507
508        let _output = buf.pop_front().unwrap();
509
510        // Transition SharedState from Source to Cached for ALL slices sharing this Arc.
511        use vortex_array::arrays::shared::SharedArrayExt;
512        let _canonical =
513            shared_handle.get_or_compute(|source| source.clone().execute::<Canonical>(&mut ctx))?;
514
515        // Before the fix this panicked with "attempt to subtract with overflow".
516        let _s2 = buf.pop_front().unwrap();
517        assert_eq!(buf.nbytes, 0);
518        assert_eq!(buf.row_count, 0);
519
520        Ok(())
521    }
522}