Skip to main content

xet_data/file_reconstruction/data_writer/
unordered_writer.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4use bytes::Bytes;
5use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
6use tokio::task::JoinSet;
7use xet_client::cas_types::FileRange;
8use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphorePermit;
9
10use super::super::data_writer::{DataFuture, DataWriter};
11use super::super::run_state::RunState;
12use super::super::{FileReconstructionError, Result};
13
14/// A completed term ready for consumption. Contains the byte range indicating
15/// where this data belongs in the output file, the actual data bytes, and an
16/// optional semaphore permit for backpressure control.
17pub(crate) struct CompletedTerm {
18    pub byte_range: FileRange,
19    pub data: Bytes,
20    pub permit: Option<AdjustableSemaphorePermit>,
21}
22
23/// Atomic progress counters shared between the writer, its spawned tasks,
24/// and the consumer stream. Wrapped in an `Arc` so each party can read/update
25/// counters without holding a reference to the full `UnorderedWriter`.
26pub(crate) struct UnorderedWriterProgress {
27    pub terms_in_progress: AtomicU64,
28    pub bytes_in_progress: AtomicU64,
29}
30
31impl UnorderedWriterProgress {
32    pub fn terms_in_progress(&self) -> u64 {
33        self.terms_in_progress.load(Ordering::Acquire)
34    }
35
36    pub fn bytes_in_progress(&self) -> u64 {
37        self.bytes_in_progress.load(Ordering::Relaxed)
38    }
39}
40
41/// Writer that delivers completed data terms in arbitrary order.
42///
43/// Each call to [`set_next_term_data_source`](DataWriter::set_next_term_data_source)
44/// spawns a task (tracked via a [`JoinSet`]) that resolves the data future and
45/// sends the result through an [`mpsc`](tokio::sync::mpsc) channel. The consumer
46/// (typically an [`UnorderedDownloadStream`](super::unordered_download_stream::UnorderedDownloadStream))
47/// reads from the receiver end and gets items in whatever order tasks complete.
48///
49/// The consumer stream holds only `Arc<UnorderedWriterProgress>`, not the writer
50/// itself, so the writer's channel sender is dropped naturally when the
51/// reconstruction task finishes and consumes the writer via
52/// [`finish()`](DataWriter::finish).
53pub struct UnorderedWriter {
54    result_tx: UnboundedSender<Result<CompletedTerm>>,
55    run_state: Arc<RunState>,
56    progress: Arc<UnorderedWriterProgress>,
57    task_set: JoinSet<Result<u64>>,
58    total_bytes_sent: u64,
59    finished: bool,
60}
61
62impl Drop for UnorderedWriter {
63    fn drop(&mut self) {
64        if !self.finished {
65            self.run_state.cancel();
66        }
67    }
68}
69
70#[async_trait::async_trait]
71impl DataWriter for UnorderedWriter {
72    async fn set_next_term_data_source(
73        &mut self,
74        byte_range: FileRange,
75        permit: Option<AdjustableSemaphorePermit>,
76        data_future: DataFuture,
77    ) -> Result<()> {
78        self.run_state.check_error()?;
79
80        while let Some(result) = self.task_set.try_join_next() {
81            self.total_bytes_sent +=
82                result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
83        }
84
85        if self.finished {
86            return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
87        }
88
89        let expected_size = byte_range.end - byte_range.start;
90        self.progress.terms_in_progress.fetch_add(1, Ordering::Relaxed);
91        self.progress.bytes_in_progress.fetch_add(expected_size, Ordering::Relaxed);
92
93        let result_tx = self.result_tx.clone();
94        let run_state = self.run_state.clone();
95        let progress = self.progress.clone();
96
97        self.task_set.spawn(async move {
98            let result = async {
99                run_state.check_error()?;
100
101                let data = data_future.await?;
102
103                if data.len() as u64 != expected_size {
104                    return Err(FileReconstructionError::InternalWriterError(format!(
105                        "Data size mismatch: expected {} bytes, got {} bytes",
106                        expected_size,
107                        data.len()
108                    )));
109                }
110
111                Ok(CompletedTerm {
112                    byte_range,
113                    data,
114                    permit,
115                })
116            }
117            .await;
118
119            if let Err(ref e) = result {
120                run_state.set_error(e.clone());
121            }
122
123            let completed_bytes = result.as_ref().map(|t| t.data.len() as u64).unwrap_or(0);
124
125            let _ = result_tx.send(result);
126
127            progress.bytes_in_progress.fetch_sub(expected_size, Ordering::Relaxed);
128            progress.terms_in_progress.fetch_sub(1, Ordering::Release);
129
130            if completed_bytes > 0 {
131                Ok(completed_bytes)
132            } else {
133                run_state.check_error()?;
134                Ok(0)
135            }
136        });
137
138        Ok(())
139    }
140
141    async fn finish(mut self: Box<Self>) -> Result<u64> {
142        self.run_state.check_error()?;
143
144        while let Some(result) = self.task_set.join_next().await {
145            self.total_bytes_sent +=
146                result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
147        }
148
149        self.finished = true;
150        Ok(self.total_bytes_sent)
151    }
152}
153
154impl UnorderedWriter {
155    /// Creates an unordered writer for streaming use. Returns the writer (to be
156    /// passed to the reconstruction task as `Box<dyn DataWriter>`), the receiver
157    /// end of the channel, and the shared progress counters for the consumer.
158    ///
159    /// The consumer stream should hold only the `Arc<UnorderedWriterProgress>`,
160    /// **not** the writer itself. This way the channel sender is dropped
161    /// naturally when the reconstruction task finishes (consuming the writer
162    /// via `finish()`), closing the channel without explicit lifetime management.
163    pub(crate) fn new_streaming(
164        run_state: Arc<RunState>,
165    ) -> (Box<dyn DataWriter>, UnboundedReceiver<Result<CompletedTerm>>, Arc<UnorderedWriterProgress>) {
166        let (tx, rx) = unbounded_channel();
167
168        let progress = Arc::new(UnorderedWriterProgress {
169            terms_in_progress: AtomicU64::new(0),
170            bytes_in_progress: AtomicU64::new(0),
171        });
172
173        let writer = Box::new(UnorderedWriter {
174            result_tx: tx,
175            run_state,
176            progress: progress.clone(),
177            task_set: JoinSet::new(),
178            total_bytes_sent: 0,
179            finished: false,
180        });
181
182        (writer, rx, progress)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use std::time::Duration;
189
190    use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
191
192    use super::*;
193
194    fn immediate_future(data: Bytes) -> DataFuture {
195        Box::pin(async move { Ok(data) })
196    }
197
198    fn delayed_future(data: Bytes, delay: Duration) -> DataFuture {
199        Box::pin(async move {
200            tokio::time::sleep(delay).await;
201            Ok(data)
202        })
203    }
204
205    /// Drains all results from the receiver, returning data sorted by offset.
206    /// The writer must have been dropped (after calling `finish()`) so that
207    /// the channel closes naturally when all spawned tasks complete.
208    async fn drain_sorted(rx: &mut UnboundedReceiver<Result<CompletedTerm>>) -> Result<Vec<(u64, Bytes)>> {
209        let mut items = Vec::new();
210        while let Some(result) = rx.recv().await {
211            let term = result?;
212            items.push((term.byte_range.start, term.data));
213            drop(term.permit);
214        }
215        items.sort_by_key(|(offset, _)| *offset);
216        Ok(items)
217    }
218
219    #[tokio::test]
220    async fn test_basic_unordered_writes() {
221        let run_state = RunState::new_for_test();
222        let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
223
224        writer
225            .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
226            .await
227            .unwrap();
228        writer
229            .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
230            .await
231            .unwrap();
232        writer
233            .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
234            .await
235            .unwrap();
236
237        let total = writer.finish().await.unwrap();
238        assert_eq!(total, 11);
239
240        let items = drain_sorted(&mut rx).await.unwrap();
241        let assembled: Vec<u8> = items.into_iter().flat_map(|(_, data)| data.to_vec()).collect();
242        assert_eq!(&assembled, b"Hello World");
243    }
244
245    #[tokio::test]
246    async fn test_delayed_futures_complete_out_of_order() {
247        let run_state = RunState::new_for_test();
248        let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
249
250        writer
251            .set_next_term_data_source(
252                FileRange::new(0, 5),
253                None,
254                delayed_future(Bytes::from("Hello"), Duration::from_millis(80)),
255            )
256            .await
257            .unwrap();
258        writer
259            .set_next_term_data_source(
260                FileRange::new(5, 6),
261                None,
262                delayed_future(Bytes::from(" "), Duration::from_millis(40)),
263            )
264            .await
265            .unwrap();
266        writer
267            .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
268            .await
269            .unwrap();
270
271        let total = writer.finish().await.unwrap();
272        assert_eq!(total, 11);
273
274        let items = drain_sorted(&mut rx).await.unwrap();
275        let assembled: Vec<u8> = items.into_iter().flat_map(|(_, data)| data.to_vec()).collect();
276        assert_eq!(&assembled, b"Hello World");
277    }
278
279    #[tokio::test]
280    async fn test_size_mismatch_error() {
281        let run_state = RunState::new_for_test();
282        let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
283
284        writer
285            .set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hello")))
286            .await
287            .unwrap();
288
289        let result = writer.finish().await;
290        assert!(result.is_err());
291
292        let result = rx.recv().await.unwrap();
293        assert!(result.is_err());
294        assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
295    }
296
297    #[tokio::test]
298    async fn test_future_error_propagates() {
299        let run_state = RunState::new_for_test();
300        let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
301
302        let failing_future: DataFuture =
303            Box::pin(async { Err(FileReconstructionError::InternalError("Simulated error".to_string())) });
304
305        writer
306            .set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
307            .await
308            .unwrap();
309
310        let result = writer.finish().await;
311        assert!(result.is_err());
312
313        let result = rx.recv().await.unwrap();
314        assert!(result.is_err());
315    }
316
317    #[tokio::test]
318    async fn test_semaphore_permit_released_after_consumption() {
319        let run_state = RunState::new_for_test();
320        let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
321        let semaphore = AdjustableSemaphore::new(2, (0, 2));
322
323        let permit1 = semaphore.acquire().await.unwrap();
324        let permit2 = semaphore.acquire().await.unwrap();
325        assert_eq!(semaphore.available_permits(), 0);
326
327        writer
328            .set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
329            .await
330            .unwrap();
331        writer
332            .set_next_term_data_source(FileRange::new(5, 6), Some(permit2), immediate_future(Bytes::from(" ")))
333            .await
334            .unwrap();
335
336        writer.finish().await.unwrap();
337
338        let items = drain_sorted(&mut rx).await.unwrap();
339        drop(items);
340
341        assert_eq!(semaphore.available_permits(), 2);
342    }
343
344    #[tokio::test]
345    async fn test_counter_accuracy() {
346        let run_state = RunState::new_for_test();
347        let (mut writer, mut rx, progress) = UnorderedWriter::new_streaming(run_state);
348
349        writer
350            .set_next_term_data_source(
351                FileRange::new(0, 5),
352                None,
353                delayed_future(Bytes::from("Hello"), Duration::from_millis(50)),
354            )
355            .await
356            .unwrap();
357        writer
358            .set_next_term_data_source(
359                FileRange::new(5, 11),
360                None,
361                delayed_future(Bytes::from(" World"), Duration::from_millis(50)),
362            )
363            .await
364            .unwrap();
365
366        let total = writer.finish().await.unwrap();
367        assert_eq!(total, 11);
368
369        let _items = drain_sorted(&mut rx).await.unwrap();
370
371        assert_eq!(progress.bytes_in_progress(), 0);
372        assert_eq!(progress.terms_in_progress(), 0);
373    }
374
375    #[tokio::test]
376    async fn test_finish_returns_total_bytes() {
377        let run_state = RunState::new_for_test();
378        let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
379
380        writer
381            .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
382            .await
383            .unwrap();
384        writer
385            .set_next_term_data_source(FileRange::new(5, 11), None, immediate_future(Bytes::from(" World")))
386            .await
387            .unwrap();
388
389        let total = writer.finish().await.unwrap();
390        assert_eq!(total, 11);
391
392        let _items = drain_sorted(&mut rx).await.unwrap();
393    }
394
395    #[tokio::test]
396    async fn test_error_propagation_prevents_subsequent_writes() {
397        let run_state = RunState::new_for_test();
398        let (mut writer, mut _rx, _progress) = UnorderedWriter::new_streaming(run_state.clone());
399
400        let failing_future: DataFuture =
401            Box::pin(async { Err(FileReconstructionError::InternalError("fail".to_string())) });
402
403        writer
404            .set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
405            .await
406            .unwrap();
407
408        let wait_for_error = tokio::time::timeout(Duration::from_secs(1), async {
409            loop {
410                if run_state.check_error().is_err() {
411                    break;
412                }
413                tokio::task::yield_now().await;
414            }
415        })
416        .await;
417        assert!(wait_for_error.is_ok());
418
419        let result = writer
420            .set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("World")))
421            .await;
422        assert!(result.is_err());
423    }
424
425    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
426    async fn stress_test_many_concurrent_terms() {
427        let run_state = RunState::new_for_test();
428        let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
429
430        let num_terms: usize = 100;
431        let mut expected: Vec<(u64, Vec<u8>)> = Vec::new();
432        let mut offset = 0u64;
433
434        for i in 0..num_terms {
435            let size = 100 + (i % 50) * 10;
436            let data: Vec<u8> = (0..size).map(|j| ((i * 7 + j * 13) % 256) as u8).collect();
437            let bytes = Bytes::from(data.clone());
438            expected.push((offset, data));
439
440            let delay = Duration::from_micros((i % 10) as u64 * 100);
441            writer
442                .set_next_term_data_source(
443                    FileRange::new(offset, offset + size as u64),
444                    None,
445                    delayed_future(bytes, delay),
446                )
447                .await
448                .unwrap();
449
450            offset += size as u64;
451        }
452
453        let total = writer.finish().await.unwrap();
454        assert_eq!(total, offset);
455
456        let items = drain_sorted(&mut rx).await.unwrap();
457        assert_eq!(items.len(), num_terms);
458
459        for ((exp_offset, exp_data), (act_offset, act_data)) in expected.iter().zip(items.iter()) {
460            assert_eq!(*exp_offset, *act_offset);
461            assert_eq!(exp_data.as_slice(), act_data.as_ref());
462        }
463    }
464
465    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
466    async fn stress_test_rapid_finish_after_writes() {
467        for _ in 0..50 {
468            let run_state = RunState::new_for_test();
469            let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
470
471            for i in 0..10u64 {
472                let data = Bytes::from(vec![i as u8; 100]);
473                writer
474                    .set_next_term_data_source(FileRange::new(i * 100, (i + 1) * 100), None, immediate_future(data))
475                    .await
476                    .unwrap();
477            }
478
479            let total = writer.finish().await.unwrap();
480            assert_eq!(total, 1000);
481
482            let items = drain_sorted(&mut rx).await.unwrap();
483            assert_eq!(items.len(), 10);
484
485            let total_bytes: usize = items.iter().map(|(_, data)| data.len()).sum();
486            assert_eq!(total_bytes, 1000);
487        }
488    }
489
490    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
491    async fn stress_test_mixed_immediate_and_delayed() {
492        for _ in 0..20 {
493            let run_state = RunState::new_for_test();
494            let (mut writer, mut rx, progress) = UnorderedWriter::new_streaming(run_state);
495
496            let mut offset = 0u64;
497            let mut total_size = 0u64;
498            let num_terms = 30usize;
499
500            for i in 0..num_terms {
501                let size = ((i + 1) * 50) as u64;
502                let data = Bytes::from(vec![(i % 256) as u8; size as usize]);
503                total_size += size;
504
505                let future = if i % 3 == 0 {
506                    delayed_future(data, Duration::from_millis((i % 5) as u64))
507                } else {
508                    immediate_future(data)
509                };
510
511                writer
512                    .set_next_term_data_source(FileRange::new(offset, offset + size), None, future)
513                    .await
514                    .unwrap();
515                offset += size;
516            }
517
518            let total = writer.finish().await.unwrap();
519            assert_eq!(total, total_size);
520
521            let items = drain_sorted(&mut rx).await.unwrap();
522            assert_eq!(items.len(), num_terms);
523
524            let received_bytes: u64 = items.iter().map(|(_, data)| data.len() as u64).sum();
525            assert_eq!(received_bytes, total_size);
526            assert_eq!(progress.terms_in_progress(), 0);
527        }
528    }
529}