1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::io::Write;
4use std::ops::{Bound, Range, RangeBounds};
5use std::path::{Path, PathBuf};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, Mutex};
8
9use tokio::task::JoinHandle;
10use tracing::instrument;
11use xet_client::cas_client::Client;
12use xet_client::cas_types::FileRange;
13use xet_client::chunk_cache::ChunkCache;
14use xet_runtime::core::{XetRuntime, xet_config};
15
16use super::XetFileInfo;
17use super::configurations::TranslatorConfig;
18use super::remote_client_interface::create_remote_client;
19use crate::error::{DataError, Result};
20use crate::file_reconstruction::{DownloadStream, FileReconstructor, UnorderedDownloadStream};
21use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID};
22
23pub struct FileDownloadSession {
28 client: Arc<dyn Client>,
29 chunk_cache: Option<Arc<dyn ChunkCache>>,
30 progress: Arc<GroupProgress>,
31 active_stream_abort_callbacks: Mutex<HashMap<UniqueID, Box<dyn Fn() + Send + Sync>>>,
32 finalized: AtomicBool,
33}
34
35impl FileDownloadSession {
36 pub async fn new(config: Arc<TranslatorConfig>, chunk_cache: Option<Arc<dyn ChunkCache>>) -> Result<Arc<Self>> {
37 let session_id = config
38 .session
39 .session_id
40 .as_ref()
41 .map(Cow::Borrowed)
42 .unwrap_or_else(|| Cow::Owned(UniqueID::new().to_string()));
43
44 let client = create_remote_client(&config, &session_id, false).await?;
45 let progress = GroupProgress::with_speed_config(
46 xet_config().data.progress_update_speed_sampling_window,
47 xet_config().data.progress_update_speed_min_observations,
48 );
49
50 Ok(Arc::new(Self {
51 client,
52 chunk_cache,
53 progress,
54 active_stream_abort_callbacks: Mutex::new(HashMap::new()),
55 finalized: AtomicBool::new(false),
56 }))
57 }
58
59 pub fn from_client(client: Arc<dyn Client>, chunk_cache: Option<Arc<dyn ChunkCache>>) -> Arc<Self> {
64 let progress = GroupProgress::new();
65 Arc::new(Self {
66 client,
67 chunk_cache,
68 progress,
69 active_stream_abort_callbacks: Mutex::new(HashMap::new()),
70 finalized: AtomicBool::new(false),
71 })
72 }
73
74 pub fn report(&self) -> crate::progress_tracking::GroupProgressReport {
75 self.progress.report()
76 }
77
78 pub fn item_report(&self, id: UniqueID) -> Option<crate::progress_tracking::ItemProgressReport> {
79 self.progress.item_report(id)
80 }
81
82 pub fn item_reports(&self) -> HashMap<UniqueID, crate::progress_tracking::ItemProgressReport> {
83 self.progress.item_reports()
84 }
85
86 fn register_stream_abort_callback(&self, id: UniqueID, callback: Box<dyn Fn() + Send + Sync>) {
87 self.active_stream_abort_callbacks.lock().unwrap().insert(id, callback);
88 }
89
90 pub fn unregister_stream_abort_callback(&self, id: UniqueID) {
91 self.active_stream_abort_callbacks.lock().unwrap().remove(&id);
92 }
93
94 pub fn abort_active_streams(&self) {
95 let callbacks = self.active_stream_abort_callbacks.lock().unwrap();
96 for callback in callbacks.values() {
97 callback();
98 }
99 }
100
101 pub async fn download_file_background(
106 self: &Arc<Self>,
107 file_info: XetFileInfo,
108 write_path: PathBuf,
109 ) -> Result<(UniqueID, JoinHandle<Result<u64>>)> {
110 self.check_not_finalized()?;
111 let id = UniqueID::new();
112 let session = self.clone();
113 let rt = XetRuntime::current();
114 let semaphore = rt.common().file_download_semaphore.clone();
115 let handle = rt.spawn(async move {
116 let _permit = semaphore.acquire().await?;
117 session.download_file_with_id(&file_info, &write_path, id).await
118 });
119 Ok((id, handle))
120 }
121
122 #[instrument(skip_all, name = "FileDownloadSession::download_file", fields(hash = file_info.hash()))]
124 pub async fn download_file(&self, file_info: &XetFileInfo, write_path: &Path) -> Result<(UniqueID, u64)> {
125 self.check_not_finalized()?;
126 let id = UniqueID::new();
127 let n_bytes = self.download_file_with_id(file_info, write_path, id).await?;
128 Ok((id, n_bytes))
129 }
130
131 async fn download_file_with_id(&self, file_info: &XetFileInfo, write_path: &Path, id: UniqueID) -> Result<u64> {
132 let name = Arc::from(write_path.to_string_lossy().as_ref());
133 let progress_updater = self.progress.new_item(id, name);
134 let reconstructor = self.setup_reconstructor(file_info, None, Some(progress_updater))?;
135 let n_bytes = reconstructor.reconstruct_to_file(write_path, None, true).await?;
136 if let Some(expected_size) = file_info.file_size()
139 && n_bytes != expected_size
140 {
141 return Err(DataError::SizeMismatch {
142 expected: expected_size,
143 actual: n_bytes,
144 });
145 }
146 Ok(n_bytes)
147 }
148
149 #[instrument(skip_all, name = "FileDownloadSession::download_to_writer",
157 fields(hash = file_info.hash(), range_start = tracing::field::Empty, range_end = tracing::field::Empty))]
158 pub async fn download_to_writer<W: Write + Send + 'static>(
159 &self,
160 file_info: &XetFileInfo,
161 source_range: impl RangeBounds<u64>,
162 writer: W,
163 ) -> Result<(UniqueID, u64)> {
164 self.check_not_finalized()?;
165 let range = range_bounds_to_file_range(&source_range)?;
166 if let Some(ref r) = range {
167 let span = tracing::Span::current();
168 span.record("range_start", r.start);
169 span.record("range_end", r.end);
170 }
171 let id = UniqueID::new();
172 let name = Arc::from("");
173 let progress_updater = self.progress.new_item(id, name);
174 let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
175 let n_bytes = reconstructor.reconstruct_to_writer(writer).await?;
176
177 let expected_size = match range {
178 Some(r) if r.end < u64::MAX => Some(r.end - r.start),
179 None => file_info.file_size(),
180 _ => None,
181 };
182 if let Some(expected) = expected_size
183 && n_bytes != expected
184 {
185 return Err(DataError::SizeMismatch {
186 expected,
187 actual: n_bytes,
188 });
189 }
190
191 Ok((id, n_bytes))
192 }
193
194 #[instrument(skip_all, name = "FileDownloadSession::download_stream", fields(hash = file_info.hash()))]
207 pub async fn download_stream(
208 &self,
209 file_info: &XetFileInfo,
210 source_range: Option<Range<u64>>,
211 ) -> Result<(UniqueID, DownloadStream)> {
212 self.check_not_finalized()?;
213 let id = UniqueID::new();
214 let progress_updater = self.progress.new_item(id, "stream");
215 let range = source_range.map(|r| FileRange::new(r.start, r.end));
216 let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
217 let stream = reconstructor.reconstruct_to_stream();
218 self.register_stream_abort_callback(id, stream.abort_callback());
219 Ok((id, stream))
220 }
221
222 #[instrument(skip_all, name = "FileDownloadSession::download_unordered_stream", fields(hash = file_info.hash()))]
235 pub async fn download_unordered_stream(
236 &self,
237 file_info: &XetFileInfo,
238 source_range: Option<Range<u64>>,
239 ) -> Result<(UniqueID, UnorderedDownloadStream)> {
240 self.check_not_finalized()?;
241 let id = UniqueID::new();
242 let progress_updater = self.progress.new_item(id, "unordered_stream");
243 let range = source_range.map(|r| FileRange::new(r.start, r.end));
244 let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
245 let stream = reconstructor.reconstruct_to_unordered_stream();
246 self.register_stream_abort_callback(id, stream.abort_callback());
247 Ok((id, stream))
248 }
249
250 #[instrument(skip_all, name = "FileDownloadSession::download_stream_range", fields(hash = file_info.hash()))]
256 pub async fn download_stream_range(
257 &self,
258 file_info: &XetFileInfo,
259 range: impl RangeBounds<u64>,
260 ) -> Result<(UniqueID, DownloadStream)> {
261 self.check_not_finalized()?;
262 let file_range = range_bounds_to_file_range(&range)?;
263 let id = UniqueID::new();
264 let progress_updater = self.progress.new_item(id, "stream");
265 let reconstructor = self.setup_reconstructor(file_info, file_range, Some(progress_updater))?;
266 let stream = reconstructor.reconstruct_to_stream();
267 self.register_stream_abort_callback(id, stream.abort_callback());
268 Ok((id, stream))
269 }
270 fn check_not_finalized(&self) -> Result<()> {
271 if self.finalized.load(Ordering::Acquire) {
272 return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string()));
273 }
274 Ok(())
275 }
276
277 pub async fn finalize(&self) -> Result<()> {
279 if self.finalized.swap(true, Ordering::AcqRel) {
280 return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string()));
281 }
282 #[cfg(debug_assertions)]
283 self.progress.assert_complete();
284 Ok(())
285 }
286
287 fn setup_reconstructor(
288 &self,
289 file_info: &XetFileInfo,
290 range: Option<FileRange>,
291 progress_updater: Option<Arc<ItemProgressUpdater>>,
292 ) -> Result<FileReconstructor> {
293 let file_id = file_info.merkle_hash()?;
294
295 let mut reconstructor = FileReconstructor::new(&self.client, file_id);
296
297 match range {
298 Some(range) if range.end < u64::MAX => {
299 let size = range.end - range.start;
301 if let Some(ref updater) = progress_updater {
302 updater.update_item_size(size, true);
303 }
304 reconstructor = reconstructor.with_byte_range(range);
305 },
306 Some(range) => {
307 reconstructor = reconstructor.with_byte_range(range);
311 },
312 None if file_info.file_size().is_some() => {
313 if let Some(ref updater) = progress_updater {
317 updater.update_item_size(file_info.file_size().unwrap(), true);
318 }
319 },
320 None => {
321 },
325 }
326
327 if let Some(updater) = progress_updater {
328 reconstructor = reconstructor.with_progress_updater(updater);
329 }
330
331 if let Some(ref cache) = self.chunk_cache {
332 reconstructor = reconstructor.with_chunk_cache(cache.clone());
333 }
334
335 Ok(reconstructor)
336 }
337}
338
339fn range_bounds_to_file_range(range: &impl RangeBounds<u64>) -> Result<Option<FileRange>> {
347 let start = match range.start_bound() {
348 Bound::Included(&s) => s,
349 Bound::Excluded(&s) => s.saturating_add(1),
350 Bound::Unbounded => 0,
351 };
352 let end = match range.end_bound() {
353 Bound::Included(&e) => e.saturating_add(1),
354 Bound::Excluded(&e) => e,
355 Bound::Unbounded => u64::MAX,
356 };
357 if start > end {
358 return Err(DataError::InvalidOperation(format!("Invalid range: start ({start}) > end ({end})")));
359 }
360 if start == 0 && end == u64::MAX {
361 Ok(None)
362 } else {
363 Ok(Some(FileRange::new(start, end)))
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use std::fs::{read, write};
370 use std::io::{Seek, SeekFrom};
371 use std::sync::{Arc, OnceLock};
372
373 use tempfile::tempdir;
374 use xet_runtime::core::XetRuntime;
375
376 use super::*;
377 use crate::processing::configurations::TranslatorConfig;
378 use crate::processing::file_cleaner::Sha256Policy;
379 use crate::processing::{FileUploadSession, XetFileInfo};
380
381 fn get_threadpool() -> Arc<XetRuntime> {
382 static THREADPOOL: OnceLock<Arc<XetRuntime>> = OnceLock::new();
383 THREADPOOL
384 .get_or_init(|| XetRuntime::new().expect("Error starting multithreaded runtime."))
385 .clone()
386 }
387
388 async fn upload_data(cas_path: &Path, data: &[u8]) -> XetFileInfo {
389 let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into())
390 .await
391 .unwrap();
392
393 let (_id, mut cleaner) = upload_session
394 .start_clean(Some("test".into()), Some(data.len() as u64), Sha256Policy::Compute)
395 .unwrap();
396 cleaner.add_data(data).await.unwrap();
397 let (xfi, _metrics) = cleaner.finish().await.unwrap();
398 upload_session.finalize().await.unwrap();
399 xfi
400 }
401
402 #[test]
403 fn test_download_file() {
404 let runtime = get_threadpool();
405 runtime
406 .clone()
407 .bridge_sync(async {
408 let temp = tempdir().unwrap();
409 let cas_path = temp.path().join("cas");
410 let original_data = b"Hello, download session!";
411
412 let xfi = upload_data(&cas_path, original_data).await;
413
414 let config = TranslatorConfig::local_config(&cas_path).unwrap();
415 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
416
417 let out_path = temp.path().join("output.txt");
418 let (_id, n_bytes) = session.download_file(&xfi, &out_path).await.unwrap();
419
420 assert_eq!(n_bytes, original_data.len() as u64);
421 assert_eq!(read(&out_path).unwrap(), original_data);
422 })
423 .unwrap();
424 }
425
426 #[test]
427 fn test_download_file_creates_parent_dirs() {
428 let runtime = get_threadpool();
429 runtime
430 .clone()
431 .bridge_sync(async {
432 let temp = tempdir().unwrap();
433 let cas_path = temp.path().join("cas");
434 let original_data = b"nested directory test";
435
436 let xfi = upload_data(&cas_path, original_data).await;
437
438 let config = TranslatorConfig::local_config(&cas_path).unwrap();
439 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
440
441 let out_path = temp.path().join("deep").join("nested").join("dir").join("output.txt");
442 assert!(!out_path.parent().unwrap().exists());
443
444 session.download_file(&xfi, &out_path).await.unwrap();
445
446 assert_eq!(read(&out_path).unwrap(), original_data);
447 })
448 .unwrap();
449 }
450
451 #[test]
452 fn test_download_to_writer() {
453 let runtime = get_threadpool();
454 runtime
455 .clone()
456 .bridge_sync(async {
457 let temp = tempdir().unwrap();
458 let cas_path = temp.path().join("cas");
459 let original_data = b"0123456789abcdef";
460
461 let xfi = upload_data(&cas_path, original_data).await;
462
463 let config = TranslatorConfig::local_config(&cas_path).unwrap();
464 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
465
466 let out_path = temp.path().join("partial_writer.txt");
467 write(&out_path, vec![0u8; original_data.len()]).unwrap();
468
469 let mut file = std::fs::OpenOptions::new().write(true).open(&out_path).unwrap();
470 file.seek(SeekFrom::Start(4)).unwrap();
471
472 let (_id, n_bytes) = session.download_to_writer(&xfi, 4..12, file).await.unwrap();
473
474 assert_eq!(n_bytes, 8);
475 let result = read(&out_path).unwrap();
476 assert_eq!(&result[4..12], &original_data[4..12]);
477 })
478 .unwrap();
479 }
480
481 #[test]
482 fn test_download_to_writer_parallel_partitioned_file() {
483 let runtime = get_threadpool();
484 runtime
485 .clone()
486 .bridge_sync(async {
487 let temp = tempdir().unwrap();
488 let cas_path = temp.path().join("cas");
489 let original_data = b"abcdefghijklmnopqrstuvwxyz0123456789";
490
491 let xfi = upload_data(&cas_path, original_data).await;
492 let config = TranslatorConfig::local_config(&cas_path).unwrap();
493 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
494
495 let out_path = temp.path().join("partitioned.txt");
496 write(&out_path, vec![0u8; original_data.len()]).unwrap();
497
498 let n_parts = 5u64;
499 let total = original_data.len() as u64;
500 let mut tasks = Vec::new();
501
502 for idx in 0..n_parts {
503 let start = (idx * total) / n_parts;
504 let end = ((idx + 1) * total) / n_parts;
505 if start == end {
506 continue;
507 }
508
509 let session = session.clone();
510 let xfi = xfi.clone();
511 let out_path = out_path.clone();
512 tasks.push(tokio::spawn(async move {
513 let mut writer = std::fs::OpenOptions::new().write(true).open(out_path).unwrap();
514 writer.seek(SeekFrom::Start(start)).unwrap();
515 session.download_to_writer(&xfi, start..end, writer).await
516 }));
517 }
518
519 for task in tasks {
520 task.await.unwrap().unwrap();
521 }
522
523 let result = read(&out_path).unwrap();
524 assert_eq!(result, original_data);
525 })
526 .unwrap();
527 }
528
529 #[test]
530 fn test_download_multiple_files_concurrent() {
531 let runtime = get_threadpool();
532 runtime
533 .clone()
534 .bridge_sync(async {
535 let temp = tempdir().unwrap();
536 let cas_path = temp.path().join("cas");
537
538 let data_a = b"File A content for concurrent test";
539 let data_b = b"File B content for concurrent test - different";
540
541 let xfi_a = upload_data(&cas_path, data_a).await;
542 let xfi_b = upload_data(&cas_path, data_b).await;
543
544 let config = TranslatorConfig::local_config(&cas_path).unwrap();
545 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
546
547 let out_a = temp.path().join("out_a.txt");
548 let out_b = temp.path().join("out_b.txt");
549
550 let session_a = session.clone();
551 let xfi_a_clone = xfi_a.clone();
552 let out_a_clone = out_a.clone();
553 let task_a = tokio::spawn(async move { session_a.download_file(&xfi_a_clone, &out_a_clone).await });
554
555 let session_b = session.clone();
556 let xfi_b_clone = xfi_b.clone();
557 let out_b_clone = out_b.clone();
558 let task_b = tokio::spawn(async move { session_b.download_file(&xfi_b_clone, &out_b_clone).await });
559
560 task_a.await.unwrap().unwrap();
561 task_b.await.unwrap().unwrap();
562
563 assert_eq!(read(&out_a).unwrap(), data_a);
564 assert_eq!(read(&out_b).unwrap(), data_b);
565 })
566 .unwrap();
567 }
568
569 #[test]
572 fn test_download_stream_async() {
573 let runtime = get_threadpool();
574 runtime
575 .clone()
576 .bridge_sync(async {
577 let temp = tempdir().unwrap();
578 let cas_path = temp.path().join("cas");
579 let original_data = b"Hello, streaming download!";
580
581 let xfi = upload_data(&cas_path, original_data).await;
582
583 let config = TranslatorConfig::local_config(&cas_path).unwrap();
584 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
585
586 let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
587
588 let mut collected = Vec::new();
589 while let Some(chunk) = stream.next().await.unwrap() {
590 collected.extend_from_slice(&chunk);
591 }
592
593 assert_eq!(collected, original_data);
594 })
595 .unwrap();
596 }
597
598 #[test]
599 fn test_download_stream_blocking() {
600 let runtime = get_threadpool();
601 runtime
602 .clone()
603 .bridge_sync(async {
604 let temp = tempdir().unwrap();
605 let cas_path = temp.path().join("cas");
606 let original_data = b"Blocking stream test data";
607
608 let xfi = upload_data(&cas_path, original_data).await;
609
610 let config = TranslatorConfig::local_config(&cas_path).unwrap();
611 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
612
613 let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
614
615 let collected = tokio::task::spawn_blocking(move || {
616 let mut stream = stream;
617 let mut buf = Vec::new();
618 while let Some(chunk) = stream.blocking_next().unwrap() {
619 buf.extend_from_slice(&chunk);
620 }
621 buf
622 })
623 .await
624 .unwrap();
625
626 assert_eq!(collected, original_data);
627 })
628 .unwrap();
629 }
630
631 #[test]
632 fn test_download_stream_returns_none_after_finish() {
633 let runtime = get_threadpool();
634 runtime
635 .clone()
636 .bridge_sync(async {
637 let temp = tempdir().unwrap();
638 let cas_path = temp.path().join("cas");
639 let original_data = b"Extra none calls";
640
641 let xfi = upload_data(&cas_path, original_data).await;
642
643 let config = TranslatorConfig::local_config(&cas_path).unwrap();
644 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
645
646 let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
647
648 while stream.next().await.unwrap().is_some() {}
649
650 assert!(stream.next().await.unwrap().is_none());
652 assert!(stream.next().await.unwrap().is_none());
653 })
654 .unwrap();
655 }
656
657 #[test]
658 fn test_download_stream_multiple_concurrent() {
659 let runtime = get_threadpool();
660 runtime
661 .clone()
662 .bridge_sync(async {
663 let temp = tempdir().unwrap();
664 let cas_path = temp.path().join("cas");
665
666 let data_a = b"Stream A for concurrent download";
667 let data_b = b"Stream B for concurrent download - different";
668
669 let xfi_a = upload_data(&cas_path, data_a).await;
670 let xfi_b = upload_data(&cas_path, data_b).await;
671
672 let config = TranslatorConfig::local_config(&cas_path).unwrap();
673 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
674
675 let (_id_a, mut stream_a) = session.download_stream(&xfi_a, None).await.unwrap();
676 let (_id_b, mut stream_b) = session.download_stream(&xfi_b, None).await.unwrap();
677
678 let task_a = tokio::spawn(async move {
679 let mut buf = Vec::new();
680 while let Some(chunk) = stream_a.next().await.unwrap() {
681 buf.extend_from_slice(&chunk);
682 }
683 buf
684 });
685
686 let task_b = tokio::spawn(async move {
687 let mut buf = Vec::new();
688 while let Some(chunk) = stream_b.next().await.unwrap() {
689 buf.extend_from_slice(&chunk);
690 }
691 buf
692 });
693
694 let result_a = task_a.await.unwrap();
695 let result_b = task_b.await.unwrap();
696
697 assert_eq!(result_a, data_a);
698 assert_eq!(result_b, data_b);
699 })
700 .unwrap();
701 }
702
703 #[test]
704 fn test_drop_stream_without_reading() {
705 let runtime = get_threadpool();
706 runtime
707 .clone()
708 .bridge_sync(async {
709 let temp = tempdir().unwrap();
710 let cas_path = temp.path().join("cas");
711 let original_data = b"Drop-without-reading cleanup test";
712
713 let xfi = upload_data(&cas_path, original_data).await;
714
715 let config = TranslatorConfig::local_config(&cas_path).unwrap();
716 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
717
718 let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
719 drop(stream);
720 tokio::task::yield_now().await;
721
722 let out_path = temp.path().join("after_drop.txt");
723 session.download_file(&xfi, &out_path).await.unwrap();
724 assert_eq!(read(&out_path).unwrap(), original_data);
725 })
726 .unwrap();
727 }
728
729 #[test]
730 fn test_drop_stream_multiple_cycles_then_download() {
731 let runtime = get_threadpool();
732 runtime
733 .clone()
734 .bridge_sync(async {
735 let temp = tempdir().unwrap();
736 let cas_path = temp.path().join("cas");
737 let original_data = b"Multi-cycle drop cleanup test";
738
739 let xfi = upload_data(&cas_path, original_data).await;
740
741 let config = TranslatorConfig::local_config(&cas_path).unwrap();
742 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
743
744 for i in 0..5u32 {
745 let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
746 if i % 3 == 0 {
747 let _ = stream.next().await;
748 }
749 drop(stream);
750 tokio::task::yield_now().await;
751 }
752
753 let out_path = temp.path().join("after_cycles.txt");
754 session.download_file(&xfi, &out_path).await.unwrap();
755 assert_eq!(read(&out_path).unwrap(), original_data);
756 })
757 .unwrap();
758 }
759
760 #[test]
761 fn test_drop_stream_blocking_mid_read_then_download() {
762 let runtime = get_threadpool();
763 runtime
764 .clone()
765 .bridge_sync(async {
766 let temp = tempdir().unwrap();
767 let cas_path = temp.path().join("cas");
768 let original_data = b"Blocking drop cleanup test data";
769
770 let xfi = upload_data(&cas_path, original_data).await;
771
772 let config = TranslatorConfig::local_config(&cas_path).unwrap();
773 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
774
775 let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
776
777 tokio::task::spawn_blocking(move || {
778 let mut stream = stream;
779 let _chunk = stream.blocking_next().unwrap();
780 })
781 .await
782 .unwrap();
783
784 tokio::task::yield_now().await;
785
786 let out_path = temp.path().join("after_blocking_drop.txt");
787 session.download_file(&xfi, &out_path).await.unwrap();
788 assert_eq!(read(&out_path).unwrap(), original_data);
789 })
790 .unwrap();
791 }
792
793 #[test]
794 fn test_cancel_stream_before_start_returns_none() {
795 let runtime = get_threadpool();
796 runtime
797 .clone()
798 .bridge_sync(async {
799 let temp = tempdir().unwrap();
800 let cas_path = temp.path().join("cas");
801 let original_data = b"Cancel-before-start stream test";
802
803 let xfi = upload_data(&cas_path, original_data).await;
804
805 let config = TranslatorConfig::local_config(&cas_path).unwrap();
806 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
807
808 let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
809 stream.cancel();
810 assert!(stream.next().await.unwrap().is_none());
811 assert!(stream.next().await.unwrap().is_none());
812 })
813 .unwrap();
814 }
815
816 #[test]
817 fn test_cancel_stream_after_first_chunk_returns_none() {
818 let runtime = get_threadpool();
819 runtime
820 .clone()
821 .bridge_sync(async {
822 let temp = tempdir().unwrap();
823 let cas_path = temp.path().join("cas");
824 let original_data = b"Cancel-after-first-chunk stream test data";
825
826 let xfi = upload_data(&cas_path, original_data).await;
827
828 let config = TranslatorConfig::local_config(&cas_path).unwrap();
829 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
830
831 let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
832 let _ = stream.next().await.unwrap();
833 stream.cancel();
834 assert!(stream.next().await.unwrap().is_none());
835 assert!(stream.next().await.unwrap().is_none());
836
837 let out_path = temp.path().join("after_cancel.txt");
838 session.download_file(&xfi, &out_path).await.unwrap();
839 assert_eq!(read(&out_path).unwrap(), original_data);
840 })
841 .unwrap();
842 }
843
844 #[test]
847 fn test_download_to_writer_range_from() {
848 let runtime = get_threadpool();
849 runtime
850 .clone()
851 .external_run_async_task(async {
852 let temp = tempdir().unwrap();
853 let cas_path = temp.path().join("cas");
854 let original_data = b"0123456789abcdef";
855
856 let xfi = upload_data(&cas_path, original_data).await;
857
858 let config = TranslatorConfig::local_config(&cas_path).unwrap();
859 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
860
861 let out_path = temp.path().join("range_from.bin");
862 let file = std::fs::File::create(&out_path).unwrap();
863 let (_id, n_bytes) = session.download_to_writer(&xfi, 4.., file).await.unwrap();
864
865 assert_eq!(n_bytes, 12);
866 assert_eq!(read(&out_path).unwrap(), &original_data[4..]);
867 })
868 .unwrap();
869 }
870
871 #[test]
872 fn test_download_to_writer_range_to() {
873 let runtime = get_threadpool();
874 runtime
875 .clone()
876 .external_run_async_task(async {
877 let temp = tempdir().unwrap();
878 let cas_path = temp.path().join("cas");
879 let original_data = b"0123456789abcdef";
880
881 let xfi = upload_data(&cas_path, original_data).await;
882
883 let config = TranslatorConfig::local_config(&cas_path).unwrap();
884 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
885
886 let out_path = temp.path().join("range_to.bin");
887 let file = std::fs::File::create(&out_path).unwrap();
888 let (_id, n_bytes) = session.download_to_writer(&xfi, ..8, file).await.unwrap();
889
890 assert_eq!(n_bytes, 8);
891 assert_eq!(read(&out_path).unwrap(), &original_data[..8]);
892 })
893 .unwrap();
894 }
895
896 #[test]
897 fn test_download_to_writer_full_range() {
898 let runtime = get_threadpool();
899 runtime
900 .clone()
901 .external_run_async_task(async {
902 let temp = tempdir().unwrap();
903 let cas_path = temp.path().join("cas");
904 let original_data = b"0123456789abcdef";
905
906 let xfi = upload_data(&cas_path, original_data).await;
907
908 let config = TranslatorConfig::local_config(&cas_path).unwrap();
909 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
910
911 let out_path = temp.path().join("full_range.bin");
912 let file = std::fs::File::create(&out_path).unwrap();
913 let (_id, n_bytes) = session.download_to_writer(&xfi, .., file).await.unwrap();
914
915 assert_eq!(n_bytes, original_data.len() as u64);
916 assert_eq!(read(&out_path).unwrap(), original_data);
917 })
918 .unwrap();
919 }
920
921 #[test]
922 fn test_download_to_writer_range_inclusive() {
923 let runtime = get_threadpool();
924 runtime
925 .clone()
926 .external_run_async_task(async {
927 let temp = tempdir().unwrap();
928 let cas_path = temp.path().join("cas");
929 let original_data = b"0123456789abcdef";
930
931 let xfi = upload_data(&cas_path, original_data).await;
932
933 let config = TranslatorConfig::local_config(&cas_path).unwrap();
934 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
935
936 let out_path = temp.path().join("range_incl.bin");
937 let file = std::fs::File::create(&out_path).unwrap();
938 let (_id, n_bytes) = session.download_to_writer(&xfi, 2..=5, file).await.unwrap();
939
940 assert_eq!(n_bytes, 4);
941 assert_eq!(read(&out_path).unwrap(), &original_data[2..=5]);
942 })
943 .unwrap();
944 }
945
946 #[test]
949 fn test_download_stream_range_bounded() {
950 let runtime = get_threadpool();
951 runtime
952 .clone()
953 .external_run_async_task(async {
954 let temp = tempdir().unwrap();
955 let cas_path = temp.path().join("cas");
956 let original_data = b"0123456789abcdef";
957
958 let xfi = upload_data(&cas_path, original_data).await;
959
960 let config = TranslatorConfig::local_config(&cas_path).unwrap();
961 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
962
963 let (_id, mut stream) = session.download_stream_range(&xfi, 4..12).await.unwrap();
964
965 let mut collected = Vec::new();
966 while let Some(chunk) = stream.next().await.unwrap() {
967 collected.extend_from_slice(&chunk);
968 }
969
970 assert_eq!(collected, &original_data[4..12]);
971 })
972 .unwrap();
973 }
974
975 #[test]
976 fn test_download_stream_range_from() {
977 let runtime = get_threadpool();
978 runtime
979 .clone()
980 .external_run_async_task(async {
981 let temp = tempdir().unwrap();
982 let cas_path = temp.path().join("cas");
983 let original_data = b"0123456789abcdef";
984
985 let xfi = upload_data(&cas_path, original_data).await;
986
987 let config = TranslatorConfig::local_config(&cas_path).unwrap();
988 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
989
990 let (_id, mut stream) = session.download_stream_range(&xfi, 10..).await.unwrap();
991
992 let mut collected = Vec::new();
993 while let Some(chunk) = stream.next().await.unwrap() {
994 collected.extend_from_slice(&chunk);
995 }
996
997 assert_eq!(collected, &original_data[10..]);
998 })
999 .unwrap();
1000 }
1001
1002 #[test]
1003 fn test_download_stream_range_to() {
1004 let runtime = get_threadpool();
1005 runtime
1006 .clone()
1007 .external_run_async_task(async {
1008 let temp = tempdir().unwrap();
1009 let cas_path = temp.path().join("cas");
1010 let original_data = b"0123456789abcdef";
1011
1012 let xfi = upload_data(&cas_path, original_data).await;
1013
1014 let config = TranslatorConfig::local_config(&cas_path).unwrap();
1015 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1016
1017 let (_id, mut stream) = session.download_stream_range(&xfi, ..6).await.unwrap();
1018
1019 let mut collected = Vec::new();
1020 while let Some(chunk) = stream.next().await.unwrap() {
1021 collected.extend_from_slice(&chunk);
1022 }
1023
1024 assert_eq!(collected, &original_data[..6]);
1025 })
1026 .unwrap();
1027 }
1028
1029 #[test]
1032 fn test_download_file_unknown_size() {
1033 let runtime = get_threadpool();
1034 runtime
1035 .clone()
1036 .external_run_async_task(async {
1037 let temp = tempdir().unwrap();
1038 let cas_path = temp.path().join("cas");
1039 let original_data = b"File with unknown size test";
1040
1041 let xfi = upload_data(&cas_path, original_data).await;
1042 let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string());
1043
1044 let config = TranslatorConfig::local_config(&cas_path).unwrap();
1045 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1046
1047 let out_path = temp.path().join("output_unknown.txt");
1048 let (_id, n_bytes) = session.download_file(&xfi_no_size, &out_path).await.unwrap();
1049
1050 assert_eq!(n_bytes, original_data.len() as u64);
1051 assert_eq!(read(&out_path).unwrap(), original_data);
1052 })
1053 .unwrap();
1054 }
1055
1056 #[test]
1057 fn test_download_stream_unknown_size() {
1058 let runtime = get_threadpool();
1059 runtime
1060 .clone()
1061 .external_run_async_task(async {
1062 let temp = tempdir().unwrap();
1063 let cas_path = temp.path().join("cas");
1064 let original_data = b"Stream with unknown size test";
1065
1066 let xfi = upload_data(&cas_path, original_data).await;
1067 let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string());
1068
1069 let config = TranslatorConfig::local_config(&cas_path).unwrap();
1070 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1071
1072 let (_id, mut stream) = session.download_stream(&xfi_no_size, None).await.unwrap();
1073
1074 let mut collected = Vec::new();
1075 while let Some(chunk) = stream.next().await.unwrap() {
1076 collected.extend_from_slice(&chunk);
1077 }
1078
1079 assert_eq!(collected, original_data);
1080 })
1081 .unwrap();
1082 }
1083
1084 #[cfg(not(debug_assertions))]
1085 #[test]
1086 fn test_download_file_size_mismatch_error() {
1087 let runtime = get_threadpool();
1088 runtime
1089 .clone()
1090 .external_run_async_task(async {
1091 let temp = tempdir().unwrap();
1092 let cas_path = temp.path().join("cas");
1093 let original_data = b"Size mismatch test data";
1094
1095 let xfi = upload_data(&cas_path, original_data).await;
1096 let wrong_size_xfi = XetFileInfo::new(xfi.hash().to_string(), 999);
1097
1098 let config = TranslatorConfig::local_config(&cas_path).unwrap();
1099 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1100
1101 let out_path = temp.path().join("output_mismatch.txt");
1102 let err = session.download_file(&wrong_size_xfi, &out_path).await.unwrap_err();
1103
1104 assert!(
1105 matches!(err, DataError::SizeMismatch { expected: 999, .. }),
1106 "Expected SizeMismatch error, got: {err:?}"
1107 );
1108 })
1109 .unwrap();
1110 }
1111
1112 #[test]
1115 fn test_range_bounds_conversion() {
1116 use super::range_bounds_to_file_range;
1117
1118 assert_eq!(range_bounds_to_file_range(&(..)).unwrap(), None);
1119 assert_eq!(range_bounds_to_file_range(&(0..100)).unwrap(), Some(FileRange::new(0, 100)));
1120 assert_eq!(range_bounds_to_file_range(&(5..)).unwrap(), Some(FileRange::new(5, u64::MAX)));
1121 assert_eq!(range_bounds_to_file_range(&(..50)).unwrap(), Some(FileRange::new(0, 50)));
1122 assert_eq!(range_bounds_to_file_range(&(10..=19)).unwrap(), Some(FileRange::new(10, 20)));
1123 }
1124
1125 #[test]
1126 fn test_range_bounds_inverted_range_errors() {
1127 use super::range_bounds_to_file_range;
1128
1129 let result = range_bounds_to_file_range(&(10..5));
1130 assert!(result.is_err());
1131 }
1132
1133 #[test]
1134 fn test_download_to_writer_empty_range() {
1135 let runtime = get_threadpool();
1136 runtime
1137 .clone()
1138 .external_run_async_task(async {
1139 let temp = tempdir().unwrap();
1140 let cas_path = temp.path().join("cas");
1141 let original_data = b"0123456789abcdef";
1142
1143 let xfi = upload_data(&cas_path, original_data).await;
1144
1145 let config = TranslatorConfig::local_config(&cas_path).unwrap();
1146 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1147
1148 let out_path = temp.path().join("empty_range.bin");
1149 let file = std::fs::File::create(&out_path).unwrap();
1150 let (_id, n_bytes) = session.download_to_writer(&xfi, 5..5, file).await.unwrap();
1151
1152 assert_eq!(n_bytes, 0);
1153 assert_eq!(read(&out_path).unwrap(), &[] as &[u8]);
1154 })
1155 .unwrap();
1156 }
1157
1158 #[test]
1159 fn test_download_to_writer_inverted_range_errors() {
1160 let runtime = get_threadpool();
1161 runtime
1162 .clone()
1163 .external_run_async_task(async {
1164 let temp = tempdir().unwrap();
1165 let cas_path = temp.path().join("cas");
1166 let original_data = b"0123456789abcdef";
1167
1168 let xfi = upload_data(&cas_path, original_data).await;
1169
1170 let config = TranslatorConfig::local_config(&cas_path).unwrap();
1171 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1172
1173 let out_path = temp.path().join("inverted_range.bin");
1174 let file = std::fs::File::create(&out_path).unwrap();
1175 let result = session.download_to_writer(&xfi, 10..5, file).await;
1176
1177 assert!(result.is_err());
1178 })
1179 .unwrap();
1180 }
1181
1182 #[cfg(not(debug_assertions))]
1183 #[test]
1184 fn test_download_to_writer_range_start_beyond_file_size_errors() {
1185 let runtime = get_threadpool();
1186 runtime
1187 .clone()
1188 .external_run_async_task(async {
1189 let temp = tempdir().unwrap();
1190 let cas_path = temp.path().join("cas");
1191 let original_data = b"0123456789abcdef";
1192
1193 let xfi = upload_data(&cas_path, original_data).await;
1194
1195 let config = TranslatorConfig::local_config(&cas_path).unwrap();
1196 let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1197
1198 let out_path = temp.path().join("beyond_size.bin");
1199 let file = std::fs::File::create(&out_path).unwrap();
1200 let result = session.download_to_writer(&xfi, 100000.., file).await;
1201
1202 assert!(result.is_err());
1203 })
1204 .unwrap();
1205 }
1206}