1use std::collections::VecDeque;
2use std::io::{IoSlice, Write};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use bytes::Bytes;
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
8use tokio::sync::oneshot;
9use tokio::task::{JoinHandle, JoinSet};
10use xet_client::cas_types::FileRange;
11use xet_runtime::core::{XetRuntime, check_sigint_shutdown};
12use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphorePermit;
13
14use super::super::data_writer::{DataFuture, DataWriter};
15use super::super::run_state::RunState;
16use super::super::{FileReconstructionError, Result};
17use crate::progress_tracking::ItemProgressUpdater;
18
19const WRITEV_MAX_SLICE: usize = 24;
29
30pub(crate) enum SequentialRetrievalItem {
34 Data {
35 receiver: oneshot::Receiver<Bytes>,
36 permit: Option<AdjustableSemaphorePermit>,
37 },
38 Finish,
39}
40
41type PendingWrite = (Bytes, Option<AdjustableSemaphorePermit>);
43
44struct SyncWriterThread {
47 rx: UnboundedReceiver<SequentialRetrievalItem>,
48 bytes_written: Arc<AtomicU64>,
49 progress_updater: Option<Arc<ItemProgressUpdater>>,
50 run_state: Arc<RunState>,
51 pending: Option<SequentialRetrievalItem>,
52 finished: bool,
53}
54
55impl SyncWriterThread {
56 fn new(
57 rx: UnboundedReceiver<SequentialRetrievalItem>,
58 bytes_written: Arc<AtomicU64>,
59 progress_updater: Option<Arc<ItemProgressUpdater>>,
60 run_state: Arc<RunState>,
61 ) -> Self {
62 Self {
63 rx,
64 bytes_written,
65 progress_updater,
66 run_state,
67 pending: None,
68 finished: false,
69 }
70 }
71
72 #[inline]
79 fn next_write(&mut self, should_block: bool) -> Result<Option<PendingWrite>> {
80 if self.pending.is_none() {
82 self.pending = if should_block {
84 self.rx.blocking_recv()
85 } else {
86 self.rx.try_recv().ok()
87 };
88 }
89
90 match self.pending.take() {
92 Some(SequentialRetrievalItem::Data { mut receiver, permit }) => {
93 if should_block {
94 let data = match receiver.blocking_recv() {
95 Ok(data) => data,
96 Err(_) => {
97 self.run_state.check_error()?;
98 return Err(FileReconstructionError::InternalWriterError(
99 "Data sender was dropped before sending data.".to_string(),
100 ));
101 },
102 };
103 Ok(Some((data, permit)))
104 } else {
105 match receiver.try_recv() {
107 Ok(data) => Ok(Some((data, permit))),
108 Err(oneshot::error::TryRecvError::Empty) => {
109 self.pending = Some(SequentialRetrievalItem::Data { receiver, permit });
111 Ok(None)
112 },
113 Err(oneshot::error::TryRecvError::Closed) => {
114 self.run_state.check_error()?;
115 Err(FileReconstructionError::InternalWriterError(
116 "Data sender was dropped before sending data.".to_string(),
117 ))
118 },
119 }
120 }
121 },
122 Some(SequentialRetrievalItem::Finish) => {
123 self.finished = true;
124 Ok(None)
125 },
126 None => Ok(None),
127 }
128 }
129
130 fn run(mut self, mut writer: impl Write) -> Result<()> {
132 while let Some((data, permit)) = self.next_write(true)? {
133 let len = data.len() as u64;
134 writer.write_all(&data)?;
135 self.bytes_written.fetch_add(len, Ordering::Relaxed);
136 if let Some(ref updater) = self.progress_updater {
137 updater.report_bytes_written(len);
138 }
139 drop(permit);
140
141 if self.finished {
142 break;
143 }
144
145 check_sigint_shutdown()?;
146 }
147
148 debug_assert!(self.finished);
149
150 writer.flush()?;
151 Ok(())
152 }
153
154 fn run_vectorized(mut self, mut writer: impl Write) -> Result<()> {
156 let mut pending_writes: VecDeque<PendingWrite> = VecDeque::new();
157
158 while !self.finished || !pending_writes.is_empty() {
159 check_sigint_shutdown()?;
160
161 if pending_writes.is_empty() {
163 let Some(write) = self.next_write(true)? else {
164 break;
165 };
166
167 pending_writes.push_back(write);
168 }
169
170 while let Some(write) = self.next_write(false)? {
172 pending_writes.push_back(write);
173 }
174
175 let io_slices: Vec<IoSlice<'_>> = pending_writes
177 .iter()
178 .take(WRITEV_MAX_SLICE)
179 .map(|(data, _)| IoSlice::new(data))
180 .collect();
181
182 let written = match writer.write_vectored(&io_slices) {
184 Ok(0) if !io_slices.is_empty() => {
185 return Err(FileReconstructionError::IoError(Arc::new(std::io::Error::new(
186 std::io::ErrorKind::WriteZero,
187 "write_vectored returned 0 with non-empty buffers",
188 ))));
189 },
190 Ok(n) => n,
191 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
192 Err(e) => return Err(FileReconstructionError::IoError(Arc::new(e))),
193 };
194
195 self.bytes_written.fetch_add(written as u64, Ordering::Relaxed);
196 if let Some(ref updater) = self.progress_updater {
197 updater.report_bytes_written(written as u64);
198 }
199
200 let mut remaining = written;
202 while remaining > 0 && !pending_writes.is_empty() {
203 let front_len = pending_writes.front().unwrap().0.len();
204 if remaining >= front_len {
205 remaining -= front_len;
206 pending_writes.pop_front();
207 } else {
208 let front = pending_writes.front_mut().unwrap();
209 front.0 = front.0.slice(remaining..);
210 remaining = 0;
211 }
212 }
213 }
214
215 writer.flush()?;
216 Ok(())
217 }
218}
219
220pub struct SequentialWriter {
224 sender: UnboundedSender<SequentialRetrievalItem>,
225 next_position: u64,
226 background_handle: Option<JoinHandle<()>>,
227 run_state: Arc<RunState>,
228 bytes_written: Arc<AtomicU64>,
229 active_tasks: JoinSet<Result<()>>,
230 finished: bool,
231}
232
233impl Drop for SequentialWriter {
234 fn drop(&mut self) {
235 if !self.finished {
236 self.run_state.cancel();
237 }
238 }
239}
240
241#[async_trait::async_trait]
242impl DataWriter for SequentialWriter {
243 async fn set_next_term_data_source(
247 &mut self,
248 byte_range: FileRange,
249 permit: Option<AdjustableSemaphorePermit>,
250 data_future: DataFuture,
251 ) -> Result<()> {
252 self.run_state.check_error()?;
253
254 while let Some(result) = self.active_tasks.try_join_next() {
255 result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
256 }
257
258 if self.finished {
259 return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
260 }
261
262 if byte_range.start != self.next_position {
263 return Err(FileReconstructionError::InternalWriterError(format!(
264 "Byte range not sequential: expected start at {}, got {}",
265 self.next_position, byte_range.start
266 )));
267 }
268
269 let expected_size = byte_range.end - byte_range.start;
270 self.next_position = byte_range.end;
271
272 let (sender, receiver) = oneshot::channel();
273
274 if self.sender.send(SequentialRetrievalItem::Data { receiver, permit }).is_err() {
275 self.run_state.check_error()?;
276 return Err(FileReconstructionError::InternalWriterError("Background writer channel closed".to_string()));
277 }
278
279 let run_state = self.run_state.clone();
280 let task = async move {
281 let result = async {
282 run_state.check_error()?;
283
284 let data = data_future.await?;
285
286 if data.len() as u64 != expected_size {
287 return Err(FileReconstructionError::InternalWriterError(format!(
288 "Data size mismatch: expected {} bytes, got {} bytes",
289 expected_size,
290 data.len()
291 )));
292 }
293
294 if sender.send(data).is_err() {
295 run_state.check_error()?;
296 return Err(FileReconstructionError::InternalWriterError(
297 "Failed to send data: receiver dropped".to_string(),
298 ));
299 }
300
301 Ok(())
302 }
303 .await;
304
305 if let Err(ref e) = result {
306 run_state.set_error(e.clone());
307 }
308 result
309 };
310
311 self.active_tasks.spawn(task);
312
313 Ok(())
314 }
315
316 async fn finish(mut self: Box<Self>) -> Result<u64> {
319 self.run_state.check_error()?;
320
321 if self.finished {
322 return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
323 }
324
325 self.finished = true;
326
327 if self.sender.send(SequentialRetrievalItem::Finish).is_err() {
328 self.run_state.check_error()?;
329 return Err(FileReconstructionError::InternalWriterError("Background writer channel closed".to_string()));
330 }
331
332 let expected_bytes = self.next_position;
333
334 while let Some(result) = self.active_tasks.join_next().await {
335 result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
336 }
337
338 match self.background_handle.take() {
339 Some(handle) => {
340 handle.await.map_err(|e| {
341 FileReconstructionError::InternalWriterError(format!("Background writer task failed: {e}"))
342 })?;
343
344 self.run_state.check_error()?;
345
346 let actual_bytes = self.bytes_written.load(Ordering::Relaxed);
347 if actual_bytes != expected_bytes {
348 return Err(FileReconstructionError::InternalWriterError(format!(
349 "Bytes written mismatch: expected {} bytes, but wrote {} bytes",
350 expected_bytes, actual_bytes
351 )));
352 }
353
354 Ok(actual_bytes)
355 },
356 None => {
357 Ok(expected_bytes)
360 },
361 }
362 }
363}
364
365impl SequentialWriter {
366 pub(crate) fn new_streaming(
372 run_state: Arc<RunState>,
373 ) -> (Box<dyn DataWriter>, UnboundedReceiver<SequentialRetrievalItem>) {
374 let (tx, rx) = unbounded_channel::<SequentialRetrievalItem>();
375
376 let writer = Self {
377 sender: tx,
378 next_position: 0,
379 background_handle: None,
380 run_state,
381 bytes_written: Arc::new(AtomicU64::new(0)),
382 active_tasks: JoinSet::new(),
383 finished: false,
384 };
385
386 (Box::new(writer), rx)
387 }
388
389 #[allow(clippy::new_ret_no_self)]
395 pub(crate) fn new<W: Write + Send + 'static>(
396 writer: W,
397 use_vectorized: bool,
398 run_state: Arc<RunState>,
399 ) -> Box<dyn DataWriter> {
400 let (tx, rx) = unbounded_channel::<SequentialRetrievalItem>();
401 let bytes_written = Arc::new(AtomicU64::new(0));
402
403 let run_state_clone = run_state.clone();
404 let run_state_thread = run_state.clone();
405 let bytes_written_clone = bytes_written.clone();
406 let progress_updater = run_state.progress_updater().cloned();
407
408 let handle = XetRuntime::current().spawn_blocking(move || {
409 let writer_thread = SyncWriterThread::new(rx, bytes_written_clone, progress_updater, run_state_thread);
410 let result = if use_vectorized {
411 writer_thread.run_vectorized(writer)
412 } else {
413 writer_thread.run(writer)
414 };
415 if let Err(err) = result {
416 run_state_clone.set_error(err);
417 }
418 });
419
420 Box::new(Self {
421 sender: tx,
422 next_position: 0,
423 background_handle: Some(handle),
424 run_state,
425 bytes_written,
426 active_tasks: JoinSet::new(),
427 finished: false,
428 })
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use std::io;
435 use std::time::Duration;
436
437 use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
438
439 use super::*;
440
441 struct SharedBuffer(Arc<std::sync::Mutex<Vec<u8>>>);
442
443 impl Write for SharedBuffer {
444 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
445 self.0.lock().unwrap().extend_from_slice(buf);
446 Ok(buf.len())
447 }
448 fn flush(&mut self) -> io::Result<()> {
449 Ok(())
450 }
451 }
452
453 #[derive(Clone, Default)]
455 struct TestWriterConfig {
456 max_write_size: Option<usize>,
458 max_vectored_write_size: Option<usize>,
460 hard_limit_vectored_write_slice: Option<usize>,
463 simulate_interrupts: bool,
465 interrupt_frequency: usize,
467 }
468
469 impl TestWriterConfig {
470 fn vectorized() -> Self {
471 Self::default()
472 }
473
474 fn vectorized_partial(max_size: usize) -> Self {
475 Self {
476 max_vectored_write_size: Some(max_size),
477 ..Default::default()
478 }
479 }
480
481 fn vectorized_hard_limit(max_slice: usize) -> Self {
482 Self {
483 hard_limit_vectored_write_slice: Some(max_slice),
484 ..Default::default()
485 }
486 }
487
488 fn partial(max_size: usize) -> Self {
489 Self {
490 max_write_size: Some(max_size),
491 ..Default::default()
492 }
493 }
494
495 fn vectorized_with_interrupts() -> Self {
496 Self {
497 simulate_interrupts: true,
498 interrupt_frequency: 2,
499 ..Default::default()
500 }
501 }
502 }
503
504 struct TestWriter {
510 buffer: Arc<std::sync::Mutex<Vec<u8>>>,
511 config: TestWriterConfig,
512 write_count: Arc<AtomicU64>,
513 vectored_write_count: Arc<AtomicU64>,
514 interrupt_counter: Arc<AtomicU64>,
515 }
516
517 impl TestWriter {
518 fn new(config: TestWriterConfig) -> Self {
519 Self {
520 buffer: Arc::new(std::sync::Mutex::new(Vec::new())),
521 config,
522 write_count: Arc::new(AtomicU64::new(0)),
523 vectored_write_count: Arc::new(AtomicU64::new(0)),
524 interrupt_counter: Arc::new(AtomicU64::new(0)),
525 }
526 }
527
528 fn should_interrupt(&self) -> bool {
529 if !self.config.simulate_interrupts {
530 return false;
531 }
532 let count = self.interrupt_counter.fetch_add(1, Ordering::Relaxed);
533 count % self.config.interrupt_frequency as u64 == 0
534 }
535 }
536
537 impl Write for TestWriter {
538 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
539 if self.should_interrupt() {
540 return Err(io::Error::new(io::ErrorKind::Interrupted, "simulated interrupt"));
541 }
542
543 self.write_count.fetch_add(1, Ordering::Relaxed);
544
545 let bytes_to_write = match self.config.max_write_size {
546 Some(max) => buf.len().min(max),
547 None => buf.len(),
548 };
549
550 self.buffer.lock().unwrap().extend_from_slice(&buf[..bytes_to_write]);
551 Ok(bytes_to_write)
552 }
553
554 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
555 if self.should_interrupt() {
556 return Err(io::Error::new(io::ErrorKind::Interrupted, "simulated interrupt"));
557 }
558
559 if let Some(max_slice) = self.config.hard_limit_vectored_write_slice
560 && bufs.len() > max_slice
561 {
562 return Err(io::Error::new(io::ErrorKind::InvalidInput, "simulated iovcnt EINVAL"));
563 }
564
565 self.vectored_write_count.fetch_add(1, Ordering::Relaxed);
566
567 let total_len: usize = bufs.iter().map(|b| b.len()).sum();
568 let max_write = self.config.max_vectored_write_size.unwrap_or(total_len);
569 let bytes_to_write = total_len.min(max_write);
570
571 let mut remaining = bytes_to_write;
572 let mut buffer = self.buffer.lock().unwrap();
573
574 for buf in bufs {
575 if remaining == 0 {
576 break;
577 }
578 let to_write = buf.len().min(remaining);
579 buffer.extend_from_slice(&buf[..to_write]);
580 remaining -= to_write;
581 }
582
583 Ok(bytes_to_write)
584 }
585
586 fn flush(&mut self) -> io::Result<()> {
587 Ok(())
588 }
589 }
590
591 fn immediate_future(data: Bytes) -> DataFuture {
592 Box::pin(async move { Ok(data) })
593 }
594
595 #[tokio::test]
596 async fn test_sequential_writes() {
597 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
598 let buffer_clone = buffer.clone();
599
600 let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
601
602 writer
603 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
604 .await
605 .unwrap();
606 writer
607 .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
608 .await
609 .unwrap();
610 writer
611 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
612 .await
613 .unwrap();
614
615 writer.finish().await.unwrap();
616
617 let result = buffer.lock().unwrap();
618 assert_eq!(&*result, b"Hello World");
619 }
620
621 #[tokio::test]
622 async fn test_delayed_future() {
623 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
624 let buffer_clone = buffer.clone();
625
626 let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
627
628 let f0: DataFuture = Box::pin(async {
630 tokio::time::sleep(Duration::from_millis(50)).await;
631 Ok(Bytes::from("Hello"))
632 });
633 let f1: DataFuture = Box::pin(async {
634 tokio::time::sleep(Duration::from_millis(10)).await;
635 Ok(Bytes::from(" "))
636 });
637 let f2: DataFuture = Box::pin(async { Ok(Bytes::from("World")) });
638
639 writer.set_next_term_data_source(FileRange::new(0, 5), None, f0).await.unwrap();
640 writer.set_next_term_data_source(FileRange::new(5, 6), None, f1).await.unwrap();
641 writer.set_next_term_data_source(FileRange::new(6, 11), None, f2).await.unwrap();
642
643 writer.finish().await.unwrap();
644
645 let result = buffer.lock().unwrap();
646 assert_eq!(&*result, b"Hello World");
647 }
648
649 #[tokio::test]
650 async fn test_size_mismatch_error() {
651 let buffer = std::io::Cursor::new(Vec::new());
652 let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
653
654 writer
655 .set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hello")))
656 .await
657 .unwrap();
658
659 let result = writer.finish().await;
660 assert!(result.is_err());
661 }
662
663 #[tokio::test]
664 async fn test_background_writer_error_propagates() {
665 struct FailingWriter;
666 impl Write for FailingWriter {
667 fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
668 Err(io::Error::new(io::ErrorKind::Other, "Simulated write failure"))
669 }
670 fn flush(&mut self) -> io::Result<()> {
671 Ok(())
672 }
673 }
674
675 let mut writer = SequentialWriter::new(Box::new(FailingWriter), false, RunState::new_for_test());
676
677 writer
678 .set_next_term_data_source(FileRange::new(0, 4), None, immediate_future(Bytes::from("Test")))
679 .await
680 .unwrap();
681
682 tokio::time::sleep(Duration::from_millis(200)).await;
683
684 let result = writer
685 .set_next_term_data_source(FileRange::new(4, 8), None, immediate_future(Bytes::from("More")))
686 .await;
687
688 assert!(result.is_err());
689 assert!(matches!(result, Err(FileReconstructionError::IoError(_))));
690 }
691
692 #[tokio::test]
693 async fn test_flush_error_propagates() {
694 struct FlushFailingWriter;
695 impl Write for FlushFailingWriter {
696 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
697 Ok(buf.len())
698 }
699 fn flush(&mut self) -> io::Result<()> {
700 Err(io::Error::new(io::ErrorKind::Other, "Simulated flush failure"))
701 }
702 }
703
704 let writer = SequentialWriter::new(Box::new(FlushFailingWriter), false, RunState::new_for_test());
705 let result = writer.finish().await;
706 assert!(result.is_err());
707 assert!(matches!(result, Err(FileReconstructionError::IoError(_))));
708 }
709
710 #[tokio::test]
711 async fn test_future_error_propagates() {
712 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
713 let buffer_clone = buffer.clone();
714
715 let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
716
717 let failing_future: DataFuture =
718 Box::pin(async { Err(FileReconstructionError::InternalError("Simulated future error".to_string())) });
719
720 writer
721 .set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
722 .await
723 .unwrap();
724
725 let result = writer.finish().await;
726 assert!(result.is_err());
727 }
728
729 #[tokio::test]
730 async fn test_size_mismatch_too_small() {
731 let buffer = std::io::Cursor::new(Vec::new());
732 let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
733
734 writer
735 .set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hi")))
736 .await
737 .unwrap();
738
739 let result = writer.finish().await;
740 assert!(result.is_err());
741 }
742
743 #[tokio::test]
744 async fn test_size_mismatch_too_large() {
745 let buffer = std::io::Cursor::new(Vec::new());
746 let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
747
748 writer
749 .set_next_term_data_source(FileRange::new(0, 2), None, immediate_future(Bytes::from("Hello World")))
750 .await
751 .unwrap();
752
753 let result = writer.finish().await;
754 assert!(result.is_err());
755 }
756
757 #[tokio::test]
758 async fn test_bytes_written_tracking() {
759 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
760 let buffer_clone = buffer.clone();
761
762 let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
763
764 writer
765 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
766 .await
767 .unwrap();
768 writer
769 .set_next_term_data_source(FileRange::new(5, 11), None, immediate_future(Bytes::from(" World")))
770 .await
771 .unwrap();
772 writer
773 .set_next_term_data_source(FileRange::new(11, 16), None, immediate_future(Bytes::from("!!!!!")))
774 .await
775 .unwrap();
776
777 writer.finish().await.unwrap();
778
779 let result = buffer.lock().unwrap();
780 assert_eq!(&*result, b"Hello World!!!!!");
781 assert_eq!(result.len(), 16);
782 }
783
784 #[tokio::test]
785 async fn test_non_sequential_range_returns_error() {
786 let buffer = std::io::Cursor::new(Vec::new());
787 let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
788
789 writer
790 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
791 .await
792 .unwrap();
793
794 let result = writer
795 .set_next_term_data_source(FileRange::new(10, 15), None, immediate_future(Bytes::from("World")))
796 .await;
797 assert!(result.is_err());
798 assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
799 }
800
801 #[tokio::test]
802 async fn test_first_range_must_start_at_zero() {
803 let buffer = std::io::Cursor::new(Vec::new());
804 let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
805
806 let result = writer
807 .set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("Hello")))
808 .await;
809 assert!(result.is_err());
810 assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
811 }
812
813 #[tokio::test]
814 async fn test_semaphore_permit_released_after_write() {
815 let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
816 let buffer_clone = buffer.clone();
817 let semaphore = AdjustableSemaphore::new(2, (0, 2));
818
819 let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
820
821 let permit1 = semaphore.acquire().await.unwrap();
822 let permit2 = semaphore.acquire().await.unwrap();
823
824 assert_eq!(semaphore.available_permits(), 0);
825
826 writer
827 .set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
828 .await
829 .unwrap();
830
831 tokio::time::sleep(Duration::from_millis(50)).await;
832 assert_eq!(semaphore.available_permits(), 1);
833
834 writer
835 .set_next_term_data_source(FileRange::new(5, 6), Some(permit2), immediate_future(Bytes::from(" ")))
836 .await
837 .unwrap();
838
839 tokio::time::sleep(Duration::from_millis(50)).await;
840 assert_eq!(semaphore.available_permits(), 2);
841
842 writer.finish().await.unwrap();
843
844 let result = buffer.lock().unwrap();
845 assert_eq!(&*result, b"Hello ");
846 }
847
848 #[tokio::test]
851 async fn test_vectorized_basic_writes() {
852 let test_writer = TestWriter::new(TestWriterConfig::vectorized());
853 let buffer = test_writer.buffer.clone();
854 let vectored_count = test_writer.vectored_write_count.clone();
855
856 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
857
858 writer
859 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
860 .await
861 .unwrap();
862 writer
863 .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
864 .await
865 .unwrap();
866 writer
867 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
868 .await
869 .unwrap();
870
871 writer.finish().await.unwrap();
872
873 let result = buffer.lock().unwrap();
874 assert_eq!(&*result, b"Hello World");
875 assert!(vectored_count.load(Ordering::Relaxed) > 0);
876 }
877
878 #[tokio::test]
879 async fn test_vectorized_partial_writes() {
880 let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(3));
881 let buffer = test_writer.buffer.clone();
882
883 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
884
885 writer
886 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
887 .await
888 .unwrap();
889 writer
890 .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
891 .await
892 .unwrap();
893 writer
894 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
895 .await
896 .unwrap();
897 writer
898 .set_next_term_data_source(FileRange::new(11, 12), None, immediate_future(Bytes::from("!")))
899 .await
900 .unwrap();
901
902 writer.finish().await.unwrap();
903
904 let result = buffer.lock().unwrap();
905 assert_eq!(&*result, b"Hello World!");
906 }
907
908 #[tokio::test]
909 async fn test_vectorized_with_delays() {
910 let test_writer = TestWriter::new(TestWriterConfig::vectorized());
911 let buffer = test_writer.buffer.clone();
912
913 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
914
915 let f0: DataFuture = Box::pin(async {
917 tokio::time::sleep(Duration::from_millis(30)).await;
918 Ok(Bytes::from("A"))
919 });
920 let f1: DataFuture = Box::pin(async {
921 tokio::time::sleep(Duration::from_millis(10)).await;
922 Ok(Bytes::from("B"))
923 });
924 let f2: DataFuture = Box::pin(async { Ok(Bytes::from("C")) });
925
926 writer.set_next_term_data_source(FileRange::new(0, 1), None, f0).await.unwrap();
927 writer.set_next_term_data_source(FileRange::new(1, 2), None, f1).await.unwrap();
928 writer.set_next_term_data_source(FileRange::new(2, 3), None, f2).await.unwrap();
929
930 writer.finish().await.unwrap();
931
932 let result = buffer.lock().unwrap();
933 assert_eq!(&*result, b"ABC");
934 }
935
936 #[tokio::test]
937 async fn test_vectorized_many_small_writes() {
938 let expected: Vec<u8> = (0..100u8).collect();
939 let test_writer = TestWriter::new(TestWriterConfig::vectorized());
940 let buffer = test_writer.buffer.clone();
941 let vectored_count = test_writer.vectored_write_count.clone();
942
943 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
944
945 for i in 0..100u8 {
947 writer
948 .set_next_term_data_source(
949 FileRange::new(i as u64, i as u64 + 1),
950 None,
951 immediate_future(Bytes::from(vec![i])),
952 )
953 .await
954 .unwrap();
955 }
956
957 writer.finish().await.unwrap();
958
959 let result = buffer.lock().unwrap();
960 assert_eq!(&*result, &expected);
961
962 let vectored_calls = vectored_count.load(Ordering::Relaxed);
964 assert!(vectored_calls < 100);
965 }
966
967 #[tokio::test]
968 async fn test_vectorized_with_interrupts() {
969 let test_writer = TestWriter::new(TestWriterConfig::vectorized_with_interrupts());
970 let buffer = test_writer.buffer.clone();
971
972 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
973
974 writer
975 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
976 .await
977 .unwrap();
978 writer
979 .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
980 .await
981 .unwrap();
982 writer
983 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
984 .await
985 .unwrap();
986
987 writer.finish().await.unwrap();
988
989 let result = buffer.lock().unwrap();
990 assert_eq!(&*result, b"Hello World");
991 }
992
993 #[tokio::test]
994 async fn test_vectorized_permit_release() {
995 let test_writer = TestWriter::new(TestWriterConfig::vectorized());
996 let buffer = test_writer.buffer.clone();
997 let semaphore = AdjustableSemaphore::new(2, (0, 2));
998
999 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
1000
1001 let permit1 = semaphore.acquire().await.unwrap();
1002 let permit2 = semaphore.acquire().await.unwrap();
1003
1004 assert_eq!(semaphore.available_permits(), 0);
1005
1006 writer
1007 .set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
1008 .await
1009 .unwrap();
1010
1011 tokio::time::sleep(Duration::from_millis(50)).await;
1012 assert_eq!(semaphore.available_permits(), 1);
1013
1014 writer
1015 .set_next_term_data_source(FileRange::new(5, 6), Some(permit2), immediate_future(Bytes::from(" ")))
1016 .await
1017 .unwrap();
1018
1019 tokio::time::sleep(Duration::from_millis(50)).await;
1020 assert_eq!(semaphore.available_permits(), 2);
1021
1022 writer.finish().await.unwrap();
1023
1024 let result = buffer.lock().unwrap();
1025 assert_eq!(&*result, b"Hello ");
1026 }
1027
1028 #[tokio::test]
1029 async fn test_vectorized_partial_permit_release() {
1030 let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(2));
1031 let buffer = test_writer.buffer.clone();
1032 let semaphore = AdjustableSemaphore::new(3, (0, 3));
1033
1034 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
1035
1036 let permit1 = semaphore.acquire().await.unwrap();
1037 let permit2 = semaphore.acquire().await.unwrap();
1038 let permit3 = semaphore.acquire().await.unwrap();
1039
1040 assert_eq!(semaphore.available_permits(), 0);
1041
1042 writer
1043 .set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
1044 .await
1045 .unwrap();
1046 writer
1047 .set_next_term_data_source(FileRange::new(5, 11), Some(permit2), immediate_future(Bytes::from(" World")))
1048 .await
1049 .unwrap();
1050 writer
1051 .set_next_term_data_source(FileRange::new(11, 12), Some(permit3), immediate_future(Bytes::from("!")))
1052 .await
1053 .unwrap();
1054
1055 writer.finish().await.unwrap();
1056
1057 assert_eq!(semaphore.available_permits(), 3);
1058
1059 let result = buffer.lock().unwrap();
1060 assert_eq!(&*result, b"Hello World!");
1061 }
1062
1063 #[tokio::test]
1064 async fn test_non_vectorized_basic_writes() {
1065 let test_writer = TestWriter::new(TestWriterConfig::default());
1066 let buffer = test_writer.buffer.clone();
1067 let write_count = test_writer.write_count.clone();
1068 let vectored_count = test_writer.vectored_write_count.clone();
1069
1070 let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
1071
1072 writer
1073 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
1074 .await
1075 .unwrap();
1076 writer
1077 .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
1078 .await
1079 .unwrap();
1080 writer
1081 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
1082 .await
1083 .unwrap();
1084
1085 writer.finish().await.unwrap();
1086
1087 let result = buffer.lock().unwrap();
1088 assert_eq!(&*result, b"Hello World");
1089 assert!(write_count.load(Ordering::Relaxed) > 0);
1090 assert_eq!(vectored_count.load(Ordering::Relaxed), 0);
1091 }
1092
1093 #[tokio::test]
1094 async fn test_non_vectorized_partial_writes() {
1095 let test_writer = TestWriter::new(TestWriterConfig::partial(3));
1096 let buffer = test_writer.buffer.clone();
1097
1098 let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
1099
1100 writer
1101 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
1102 .await
1103 .unwrap();
1104 writer
1105 .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
1106 .await
1107 .unwrap();
1108 writer
1109 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
1110 .await
1111 .unwrap();
1112 writer
1113 .set_next_term_data_source(FileRange::new(11, 12), None, immediate_future(Bytes::from("!")))
1114 .await
1115 .unwrap();
1116
1117 writer.finish().await.unwrap();
1118
1119 let result = buffer.lock().unwrap();
1120 assert_eq!(&*result, b"Hello World!");
1121 }
1122
1123 #[tokio::test]
1124 async fn test_vectorized_single_byte_partial() {
1125 let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(1));
1126 let buffer = test_writer.buffer.clone();
1127
1128 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
1129
1130 writer
1131 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("ABCDE")))
1132 .await
1133 .unwrap();
1134 writer
1135 .set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("FGHIJ")))
1136 .await
1137 .unwrap();
1138
1139 writer.finish().await.unwrap();
1140
1141 let result = buffer.lock().unwrap();
1142 assert_eq!(&*result, b"ABCDEFGHIJ");
1143 }
1144
1145 #[tokio::test]
1146 async fn test_vectorized_large_data() {
1147 let expected: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
1148 let test_writer = TestWriter::new(TestWriterConfig::vectorized());
1149 let buffer = test_writer.buffer.clone();
1150
1151 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
1152
1153 for i in 0..10 {
1155 let start = i * 1000;
1156 let end = start + 1000;
1157 let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
1158 writer
1159 .set_next_term_data_source(
1160 FileRange::new(start as u64, end as u64),
1161 None,
1162 immediate_future(Bytes::from(chunk)),
1163 )
1164 .await
1165 .unwrap();
1166 }
1167
1168 writer.finish().await.unwrap();
1169
1170 let result = buffer.lock().unwrap();
1171 assert_eq!(&*result, &expected);
1172 }
1173
1174 #[tokio::test]
1175 async fn test_vectorized_large_data_partial() {
1176 let expected: Vec<u8> = (0..5000).map(|i| (i % 256) as u8).collect();
1177 let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(100));
1178 let buffer = test_writer.buffer.clone();
1179
1180 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
1181
1182 for i in 0..10 {
1184 let start = i * 500;
1185 let end = start + 500;
1186 let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
1187 writer
1188 .set_next_term_data_source(
1189 FileRange::new(start as u64, end as u64),
1190 None,
1191 immediate_future(Bytes::from(chunk)),
1192 )
1193 .await
1194 .unwrap();
1195 }
1196
1197 writer.finish().await.unwrap();
1198
1199 let result = buffer.lock().unwrap();
1200 assert_eq!(&*result, &expected);
1201 }
1202
1203 #[tokio::test]
1204 async fn test_vectorized_exceeded_max_slice() {
1205 let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(2)); let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); for i in 0..1000 {
1211 let start = i * 10;
1212 let end = start + 10;
1213 let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
1214 if writer
1215 .set_next_term_data_source(
1216 FileRange::new(start as u64, end as u64),
1217 None,
1218 immediate_future(Bytes::from(chunk)),
1219 )
1220 .await
1221 .is_err()
1222 {
1223 break;
1224 }
1225 }
1226
1227 let ret = writer.finish().await;
1228 assert!(ret.is_err());
1229 if let Err(FileReconstructionError::IoError(inner_err)) = ret {
1230 assert_eq!(inner_err.kind(), std::io::ErrorKind::InvalidInput);
1231 };
1232 }
1233
1234 #[tokio::test]
1235 async fn test_vectorized_controlled_max_slice() {
1236 let expected: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
1237 let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(40)); let buffer = test_writer.buffer.clone();
1239
1240 let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); for i in 0..1000 {
1244 let start = i * 10;
1245 let end = start + 10;
1246 let chunk: Vec<u8> = (start..end).map(|j| (j % 256) as u8).collect();
1247 writer
1248 .set_next_term_data_source(
1249 FileRange::new(start as u64, end as u64),
1250 None,
1251 immediate_future(Bytes::from(chunk)),
1252 )
1253 .await
1254 .unwrap();
1255 }
1256
1257 writer.finish().await.unwrap();
1258
1259 let result = buffer.lock().unwrap();
1260 assert_eq!(&*result, &expected);
1261 }
1262}