1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4use bytes::Bytes;
5use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
6use tokio::task::JoinSet;
7use xet_client::cas_types::FileRange;
8use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphorePermit;
9
10use super::super::data_writer::{DataFuture, DataWriter};
11use super::super::run_state::RunState;
12use super::super::{FileReconstructionError, Result};
13
14pub(crate) struct CompletedTerm {
18 pub byte_range: FileRange,
19 pub data: Bytes,
20 pub permit: Option<AdjustableSemaphorePermit>,
21}
22
23pub(crate) struct UnorderedWriterProgress {
27 pub terms_in_progress: AtomicU64,
28 pub bytes_in_progress: AtomicU64,
29}
30
31impl UnorderedWriterProgress {
32 pub fn terms_in_progress(&self) -> u64 {
33 self.terms_in_progress.load(Ordering::Acquire)
34 }
35
36 pub fn bytes_in_progress(&self) -> u64 {
37 self.bytes_in_progress.load(Ordering::Relaxed)
38 }
39}
40
41pub struct UnorderedWriter {
54 result_tx: UnboundedSender<Result<CompletedTerm>>,
55 run_state: Arc<RunState>,
56 progress: Arc<UnorderedWriterProgress>,
57 task_set: JoinSet<Result<u64>>,
58 total_bytes_sent: u64,
59 finished: bool,
60}
61
62impl Drop for UnorderedWriter {
63 fn drop(&mut self) {
64 if !self.finished {
65 self.run_state.cancel();
66 }
67 }
68}
69
70#[async_trait::async_trait]
71impl DataWriter for UnorderedWriter {
72 async fn set_next_term_data_source(
73 &mut self,
74 byte_range: FileRange,
75 permit: Option<AdjustableSemaphorePermit>,
76 data_future: DataFuture,
77 ) -> Result<()> {
78 self.run_state.check_error()?;
79
80 while let Some(result) = self.task_set.try_join_next() {
81 self.total_bytes_sent +=
82 result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
83 }
84
85 if self.finished {
86 return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
87 }
88
89 let expected_size = byte_range.end - byte_range.start;
90 self.progress.terms_in_progress.fetch_add(1, Ordering::Relaxed);
91 self.progress.bytes_in_progress.fetch_add(expected_size, Ordering::Relaxed);
92
93 let result_tx = self.result_tx.clone();
94 let run_state = self.run_state.clone();
95 let progress = self.progress.clone();
96
97 self.task_set.spawn(async move {
98 let result = async {
99 run_state.check_error()?;
100
101 let data = data_future.await?;
102
103 if data.len() as u64 != expected_size {
104 return Err(FileReconstructionError::InternalWriterError(format!(
105 "Data size mismatch: expected {} bytes, got {} bytes",
106 expected_size,
107 data.len()
108 )));
109 }
110
111 Ok(CompletedTerm {
112 byte_range,
113 data,
114 permit,
115 })
116 }
117 .await;
118
119 if let Err(ref e) = result {
120 run_state.set_error(e.clone());
121 }
122
123 let completed_bytes = result.as_ref().map(|t| t.data.len() as u64).unwrap_or(0);
124
125 let _ = result_tx.send(result);
126
127 progress.bytes_in_progress.fetch_sub(expected_size, Ordering::Relaxed);
128 progress.terms_in_progress.fetch_sub(1, Ordering::Release);
129
130 if completed_bytes > 0 {
131 Ok(completed_bytes)
132 } else {
133 run_state.check_error()?;
134 Ok(0)
135 }
136 });
137
138 Ok(())
139 }
140
141 async fn finish(mut self: Box<Self>) -> Result<u64> {
142 self.run_state.check_error()?;
143
144 while let Some(result) = self.task_set.join_next().await {
145 self.total_bytes_sent +=
146 result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
147 }
148
149 self.finished = true;
150 Ok(self.total_bytes_sent)
151 }
152}
153
154impl UnorderedWriter {
155 pub(crate) fn new_streaming(
164 run_state: Arc<RunState>,
165 ) -> (Box<dyn DataWriter>, UnboundedReceiver<Result<CompletedTerm>>, Arc<UnorderedWriterProgress>) {
166 let (tx, rx) = unbounded_channel();
167
168 let progress = Arc::new(UnorderedWriterProgress {
169 terms_in_progress: AtomicU64::new(0),
170 bytes_in_progress: AtomicU64::new(0),
171 });
172
173 let writer = Box::new(UnorderedWriter {
174 result_tx: tx,
175 run_state,
176 progress: progress.clone(),
177 task_set: JoinSet::new(),
178 total_bytes_sent: 0,
179 finished: false,
180 });
181
182 (writer, rx, progress)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use std::time::Duration;
189
190 use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
191
192 use super::*;
193
194 fn immediate_future(data: Bytes) -> DataFuture {
195 Box::pin(async move { Ok(data) })
196 }
197
198 fn delayed_future(data: Bytes, delay: Duration) -> DataFuture {
199 Box::pin(async move {
200 tokio::time::sleep(delay).await;
201 Ok(data)
202 })
203 }
204
205 async fn drain_sorted(rx: &mut UnboundedReceiver<Result<CompletedTerm>>) -> Result<Vec<(u64, Bytes)>> {
209 let mut items = Vec::new();
210 while let Some(result) = rx.recv().await {
211 let term = result?;
212 items.push((term.byte_range.start, term.data));
213 drop(term.permit);
214 }
215 items.sort_by_key(|(offset, _)| *offset);
216 Ok(items)
217 }
218
219 #[tokio::test]
220 async fn test_basic_unordered_writes() {
221 let run_state = RunState::new_for_test();
222 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
223
224 writer
225 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
226 .await
227 .unwrap();
228 writer
229 .set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
230 .await
231 .unwrap();
232 writer
233 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
234 .await
235 .unwrap();
236
237 let total = writer.finish().await.unwrap();
238 assert_eq!(total, 11);
239
240 let items = drain_sorted(&mut rx).await.unwrap();
241 let assembled: Vec<u8> = items.into_iter().flat_map(|(_, data)| data.to_vec()).collect();
242 assert_eq!(&assembled, b"Hello World");
243 }
244
245 #[tokio::test]
246 async fn test_delayed_futures_complete_out_of_order() {
247 let run_state = RunState::new_for_test();
248 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
249
250 writer
251 .set_next_term_data_source(
252 FileRange::new(0, 5),
253 None,
254 delayed_future(Bytes::from("Hello"), Duration::from_millis(80)),
255 )
256 .await
257 .unwrap();
258 writer
259 .set_next_term_data_source(
260 FileRange::new(5, 6),
261 None,
262 delayed_future(Bytes::from(" "), Duration::from_millis(40)),
263 )
264 .await
265 .unwrap();
266 writer
267 .set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
268 .await
269 .unwrap();
270
271 let total = writer.finish().await.unwrap();
272 assert_eq!(total, 11);
273
274 let items = drain_sorted(&mut rx).await.unwrap();
275 let assembled: Vec<u8> = items.into_iter().flat_map(|(_, data)| data.to_vec()).collect();
276 assert_eq!(&assembled, b"Hello World");
277 }
278
279 #[tokio::test]
280 async fn test_size_mismatch_error() {
281 let run_state = RunState::new_for_test();
282 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
283
284 writer
285 .set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hello")))
286 .await
287 .unwrap();
288
289 let result = writer.finish().await;
290 assert!(result.is_err());
291
292 let result = rx.recv().await.unwrap();
293 assert!(result.is_err());
294 assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
295 }
296
297 #[tokio::test]
298 async fn test_future_error_propagates() {
299 let run_state = RunState::new_for_test();
300 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
301
302 let failing_future: DataFuture =
303 Box::pin(async { Err(FileReconstructionError::InternalError("Simulated error".to_string())) });
304
305 writer
306 .set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
307 .await
308 .unwrap();
309
310 let result = writer.finish().await;
311 assert!(result.is_err());
312
313 let result = rx.recv().await.unwrap();
314 assert!(result.is_err());
315 }
316
317 #[tokio::test]
318 async fn test_semaphore_permit_released_after_consumption() {
319 let run_state = RunState::new_for_test();
320 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
321 let semaphore = AdjustableSemaphore::new(2, (0, 2));
322
323 let permit1 = semaphore.acquire().await.unwrap();
324 let permit2 = semaphore.acquire().await.unwrap();
325 assert_eq!(semaphore.available_permits(), 0);
326
327 writer
328 .set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
329 .await
330 .unwrap();
331 writer
332 .set_next_term_data_source(FileRange::new(5, 6), Some(permit2), immediate_future(Bytes::from(" ")))
333 .await
334 .unwrap();
335
336 writer.finish().await.unwrap();
337
338 let items = drain_sorted(&mut rx).await.unwrap();
339 drop(items);
340
341 assert_eq!(semaphore.available_permits(), 2);
342 }
343
344 #[tokio::test]
345 async fn test_counter_accuracy() {
346 let run_state = RunState::new_for_test();
347 let (mut writer, mut rx, progress) = UnorderedWriter::new_streaming(run_state);
348
349 writer
350 .set_next_term_data_source(
351 FileRange::new(0, 5),
352 None,
353 delayed_future(Bytes::from("Hello"), Duration::from_millis(50)),
354 )
355 .await
356 .unwrap();
357 writer
358 .set_next_term_data_source(
359 FileRange::new(5, 11),
360 None,
361 delayed_future(Bytes::from(" World"), Duration::from_millis(50)),
362 )
363 .await
364 .unwrap();
365
366 let total = writer.finish().await.unwrap();
367 assert_eq!(total, 11);
368
369 let _items = drain_sorted(&mut rx).await.unwrap();
370
371 assert_eq!(progress.bytes_in_progress(), 0);
372 assert_eq!(progress.terms_in_progress(), 0);
373 }
374
375 #[tokio::test]
376 async fn test_finish_returns_total_bytes() {
377 let run_state = RunState::new_for_test();
378 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
379
380 writer
381 .set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
382 .await
383 .unwrap();
384 writer
385 .set_next_term_data_source(FileRange::new(5, 11), None, immediate_future(Bytes::from(" World")))
386 .await
387 .unwrap();
388
389 let total = writer.finish().await.unwrap();
390 assert_eq!(total, 11);
391
392 let _items = drain_sorted(&mut rx).await.unwrap();
393 }
394
395 #[tokio::test]
396 async fn test_error_propagation_prevents_subsequent_writes() {
397 let run_state = RunState::new_for_test();
398 let (mut writer, mut _rx, _progress) = UnorderedWriter::new_streaming(run_state.clone());
399
400 let failing_future: DataFuture =
401 Box::pin(async { Err(FileReconstructionError::InternalError("fail".to_string())) });
402
403 writer
404 .set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
405 .await
406 .unwrap();
407
408 let wait_for_error = tokio::time::timeout(Duration::from_secs(1), async {
409 loop {
410 if run_state.check_error().is_err() {
411 break;
412 }
413 tokio::task::yield_now().await;
414 }
415 })
416 .await;
417 assert!(wait_for_error.is_ok());
418
419 let result = writer
420 .set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("World")))
421 .await;
422 assert!(result.is_err());
423 }
424
425 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
426 async fn stress_test_many_concurrent_terms() {
427 let run_state = RunState::new_for_test();
428 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
429
430 let num_terms: usize = 100;
431 let mut expected: Vec<(u64, Vec<u8>)> = Vec::new();
432 let mut offset = 0u64;
433
434 for i in 0..num_terms {
435 let size = 100 + (i % 50) * 10;
436 let data: Vec<u8> = (0..size).map(|j| ((i * 7 + j * 13) % 256) as u8).collect();
437 let bytes = Bytes::from(data.clone());
438 expected.push((offset, data));
439
440 let delay = Duration::from_micros((i % 10) as u64 * 100);
441 writer
442 .set_next_term_data_source(
443 FileRange::new(offset, offset + size as u64),
444 None,
445 delayed_future(bytes, delay),
446 )
447 .await
448 .unwrap();
449
450 offset += size as u64;
451 }
452
453 let total = writer.finish().await.unwrap();
454 assert_eq!(total, offset);
455
456 let items = drain_sorted(&mut rx).await.unwrap();
457 assert_eq!(items.len(), num_terms);
458
459 for ((exp_offset, exp_data), (act_offset, act_data)) in expected.iter().zip(items.iter()) {
460 assert_eq!(*exp_offset, *act_offset);
461 assert_eq!(exp_data.as_slice(), act_data.as_ref());
462 }
463 }
464
465 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
466 async fn stress_test_rapid_finish_after_writes() {
467 for _ in 0..50 {
468 let run_state = RunState::new_for_test();
469 let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
470
471 for i in 0..10u64 {
472 let data = Bytes::from(vec![i as u8; 100]);
473 writer
474 .set_next_term_data_source(FileRange::new(i * 100, (i + 1) * 100), None, immediate_future(data))
475 .await
476 .unwrap();
477 }
478
479 let total = writer.finish().await.unwrap();
480 assert_eq!(total, 1000);
481
482 let items = drain_sorted(&mut rx).await.unwrap();
483 assert_eq!(items.len(), 10);
484
485 let total_bytes: usize = items.iter().map(|(_, data)| data.len()).sum();
486 assert_eq!(total_bytes, 1000);
487 }
488 }
489
490 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
491 async fn stress_test_mixed_immediate_and_delayed() {
492 for _ in 0..20 {
493 let run_state = RunState::new_for_test();
494 let (mut writer, mut rx, progress) = UnorderedWriter::new_streaming(run_state);
495
496 let mut offset = 0u64;
497 let mut total_size = 0u64;
498 let num_terms = 30usize;
499
500 for i in 0..num_terms {
501 let size = ((i + 1) * 50) as u64;
502 let data = Bytes::from(vec![(i % 256) as u8; size as usize]);
503 total_size += size;
504
505 let future = if i % 3 == 0 {
506 delayed_future(data, Duration::from_millis((i % 5) as u64))
507 } else {
508 immediate_future(data)
509 };
510
511 writer
512 .set_next_term_data_source(FileRange::new(offset, offset + size), None, future)
513 .await
514 .unwrap();
515 offset += size;
516 }
517
518 let total = writer.finish().await.unwrap();
519 assert_eq!(total, total_size);
520
521 let items = drain_sorted(&mut rx).await.unwrap();
522 assert_eq!(items.len(), num_terms);
523
524 let received_bytes: u64 = items.iter().map(|(_, data)| data.len() as u64).sum();
525 assert_eq!(received_bytes, total_size);
526 assert_eq!(progress.terms_in_progress(), 0);
527 }
528 }
529}