Skip to main content

xet_data/processing/
file_download_session.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::io::Write;
4use std::ops::{Bound, Range, RangeBounds};
5use std::path::{Path, PathBuf};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, Mutex};
8
9use tokio::task::JoinHandle;
10use tracing::instrument;
11use xet_client::cas_client::Client;
12use xet_client::cas_types::FileRange;
13use xet_client::chunk_cache::ChunkCache;
14use xet_runtime::core::{XetRuntime, xet_config};
15
16use super::XetFileInfo;
17use super::configurations::TranslatorConfig;
18use super::remote_client_interface::create_remote_client;
19use crate::error::{DataError, Result};
20use crate::file_reconstruction::{DownloadStream, FileReconstructor, UnorderedDownloadStream};
21use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID};
22
23/// Manages the downloading of files from CAS storage.
24///
25/// This struct parallels `FileUploadSession` for the download path. It holds the
26/// CAS client and a shared progress group for all downloads in the session.
27pub struct FileDownloadSession {
28    client: Arc<dyn Client>,
29    chunk_cache: Option<Arc<dyn ChunkCache>>,
30    progress: Arc<GroupProgress>,
31    active_stream_abort_callbacks: Mutex<HashMap<UniqueID, Box<dyn Fn() + Send + Sync>>>,
32    finalized: AtomicBool,
33}
34
35impl FileDownloadSession {
36    pub async fn new(config: Arc<TranslatorConfig>, chunk_cache: Option<Arc<dyn ChunkCache>>) -> Result<Arc<Self>> {
37        let session_id = config
38            .session
39            .session_id
40            .as_ref()
41            .map(Cow::Borrowed)
42            .unwrap_or_else(|| Cow::Owned(UniqueID::new().to_string()));
43
44        let client = create_remote_client(&config, &session_id, false).await?;
45        let progress = GroupProgress::with_speed_config(
46            xet_config().data.progress_update_speed_sampling_window,
47            xet_config().data.progress_update_speed_min_observations,
48        );
49
50        Ok(Arc::new(Self {
51            client,
52            chunk_cache,
53            progress,
54            active_stream_abort_callbacks: Mutex::new(HashMap::new()),
55            finalized: AtomicBool::new(false),
56        }))
57    }
58
59    /// Construct a download session from an existing CAS client.
60    ///
61    /// This path uses default progress speed settings. Use [`Self::new`] when the
62    /// session should inherit the configured speed parameters from `xet_config`.
63    pub fn from_client(client: Arc<dyn Client>, chunk_cache: Option<Arc<dyn ChunkCache>>) -> Arc<Self> {
64        let progress = GroupProgress::new();
65        Arc::new(Self {
66            client,
67            chunk_cache,
68            progress,
69            active_stream_abort_callbacks: Mutex::new(HashMap::new()),
70            finalized: AtomicBool::new(false),
71        })
72    }
73
74    pub fn report(&self) -> crate::progress_tracking::GroupProgressReport {
75        self.progress.report()
76    }
77
78    pub fn item_report(&self, id: UniqueID) -> Option<crate::progress_tracking::ItemProgressReport> {
79        self.progress.item_report(id)
80    }
81
82    pub fn item_reports(&self) -> HashMap<UniqueID, crate::progress_tracking::ItemProgressReport> {
83        self.progress.item_reports()
84    }
85
86    fn register_stream_abort_callback(&self, id: UniqueID, callback: Box<dyn Fn() + Send + Sync>) {
87        self.active_stream_abort_callbacks.lock().unwrap().insert(id, callback);
88    }
89
90    pub fn unregister_stream_abort_callback(&self, id: UniqueID) {
91        self.active_stream_abort_callbacks.lock().unwrap().remove(&id);
92    }
93
94    pub fn abort_active_streams(&self) {
95        let callbacks = self.active_stream_abort_callbacks.lock().unwrap();
96        for callback in callbacks.values() {
97            callback();
98        }
99    }
100
101    /// Spawns a download task that writes `file_info` to `write_path`.
102    ///
103    /// Acquires a permit from the global download semaphore before starting.
104    /// Returns the tracking ID and the join handle for the spawned task.
105    pub async fn download_file_background(
106        self: &Arc<Self>,
107        file_info: XetFileInfo,
108        write_path: PathBuf,
109    ) -> Result<(UniqueID, JoinHandle<Result<u64>>)> {
110        self.check_not_finalized()?;
111        let id = UniqueID::new();
112        let session = self.clone();
113        let rt = XetRuntime::current();
114        let semaphore = rt.common().file_download_semaphore.clone();
115        let handle = rt.spawn(async move {
116            let _permit = semaphore.acquire().await?;
117            session.download_file_with_id(&file_info, &write_path, id).await
118        });
119        Ok((id, handle))
120    }
121
122    /// Downloads a complete file to the given path.
123    #[instrument(skip_all, name = "FileDownloadSession::download_file", fields(hash = file_info.hash()))]
124    pub async fn download_file(&self, file_info: &XetFileInfo, write_path: &Path) -> Result<(UniqueID, u64)> {
125        self.check_not_finalized()?;
126        let id = UniqueID::new();
127        let n_bytes = self.download_file_with_id(file_info, write_path, id).await?;
128        Ok((id, n_bytes))
129    }
130
131    async fn download_file_with_id(&self, file_info: &XetFileInfo, write_path: &Path, id: UniqueID) -> Result<u64> {
132        let name = Arc::from(write_path.to_string_lossy().as_ref());
133        let progress_updater = self.progress.new_item(id, name);
134        let reconstructor = self.setup_reconstructor(file_info, None, Some(progress_updater))?;
135        let n_bytes = reconstructor.reconstruct_to_file(write_path, None, true).await?;
136        // Caller is responsible for cleaning up the file on error (consistent
137        // with other error paths); see download_group.rs error handling.
138        if let Some(expected_size) = file_info.file_size()
139            && n_bytes != expected_size
140        {
141            return Err(DataError::SizeMismatch {
142                expected: expected_size,
143                actual: n_bytes,
144            });
145        }
146        Ok(n_bytes)
147    }
148
149    /// Downloads a byte range of a file and writes it to the provided writer.
150    ///
151    /// The provided `source_range` is interpreted against the original file; output
152    /// starts at the writer's current position. Accepts any `RangeBounds<u64>`:
153    /// `4..12`, `5..`, `..100`, or `..` (full file).
154    ///
155    /// This path does not acquire the session-level file download semaphore.
156    #[instrument(skip_all, name = "FileDownloadSession::download_to_writer",
157        fields(hash = file_info.hash(), range_start = tracing::field::Empty, range_end = tracing::field::Empty))]
158    pub async fn download_to_writer<W: Write + Send + 'static>(
159        &self,
160        file_info: &XetFileInfo,
161        source_range: impl RangeBounds<u64>,
162        writer: W,
163    ) -> Result<(UniqueID, u64)> {
164        self.check_not_finalized()?;
165        let range = range_bounds_to_file_range(&source_range)?;
166        if let Some(ref r) = range {
167            let span = tracing::Span::current();
168            span.record("range_start", r.start);
169            span.record("range_end", r.end);
170        }
171        let id = UniqueID::new();
172        let name = Arc::from("");
173        let progress_updater = self.progress.new_item(id, name);
174        let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
175        let n_bytes = reconstructor.reconstruct_to_writer(writer).await?;
176
177        let expected_size = match range {
178            Some(r) if r.end < u64::MAX => Some(r.end - r.start),
179            None => file_info.file_size(),
180            _ => None,
181        };
182        if let Some(expected) = expected_size
183            && n_bytes != expected
184        {
185            return Err(DataError::SizeMismatch {
186                expected,
187                actual: n_bytes,
188            });
189        }
190
191        Ok((id, n_bytes))
192    }
193
194    /// Creates a streaming download of a file, optionally restricted to a
195    /// byte range.
196    ///
197    /// Returns a [`DownloadStream`] that yields data chunks as the file is
198    /// reconstructed. Reconstruction starts lazily on first
199    /// [`DownloadStream::next`] / [`DownloadStream::blocking_next`] call
200    /// (or when `start()` is called explicitly).
201    ///
202    /// If `source_range` is `Some`, only the specified byte range of the
203    /// file is reconstructed.
204    ///
205    /// This path does not acquire the session-level file download semaphore.
206    #[instrument(skip_all, name = "FileDownloadSession::download_stream", fields(hash = file_info.hash()))]
207    pub async fn download_stream(
208        &self,
209        file_info: &XetFileInfo,
210        source_range: Option<Range<u64>>,
211    ) -> Result<(UniqueID, DownloadStream)> {
212        self.check_not_finalized()?;
213        let id = UniqueID::new();
214        let progress_updater = self.progress.new_item(id, "stream");
215        let range = source_range.map(|r| FileRange::new(r.start, r.end));
216        let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
217        let stream = reconstructor.reconstruct_to_stream();
218        self.register_stream_abort_callback(id, stream.abort_callback());
219        Ok((id, stream))
220    }
221
222    /// Creates an unordered streaming download of a file, optionally
223    /// restricted to a byte range.
224    ///
225    /// Returns an [`UnorderedDownloadStream`] that yields `(offset, Bytes)`
226    /// chunks in whatever order they complete. The total expected size is
227    /// set from the range length (or `file_info.file_size()` when no range
228    /// is given).
229    ///
230    /// If `source_range` is `Some`, only the specified byte range of the
231    /// file is reconstructed.
232    ///
233    /// This path does not acquire the session-level file download semaphore.
234    #[instrument(skip_all, name = "FileDownloadSession::download_unordered_stream", fields(hash = file_info.hash()))]
235    pub async fn download_unordered_stream(
236        &self,
237        file_info: &XetFileInfo,
238        source_range: Option<Range<u64>>,
239    ) -> Result<(UniqueID, UnorderedDownloadStream)> {
240        self.check_not_finalized()?;
241        let id = UniqueID::new();
242        let progress_updater = self.progress.new_item(id, "unordered_stream");
243        let range = source_range.map(|r| FileRange::new(r.start, r.end));
244        let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
245        let stream = reconstructor.reconstruct_to_unordered_stream();
246        self.register_stream_abort_callback(id, stream.abort_callback());
247        Ok((id, stream))
248    }
249
250    /// Creates a streaming download of a byte range of a file.
251    ///
252    /// Accepts any `RangeBounds<u64>`: `4..12`, `5..`, `..100`, or `..` (full file).
253    ///
254    /// This path does not acquire the session-level file download semaphore.
255    #[instrument(skip_all, name = "FileDownloadSession::download_stream_range", fields(hash = file_info.hash()))]
256    pub async fn download_stream_range(
257        &self,
258        file_info: &XetFileInfo,
259        range: impl RangeBounds<u64>,
260    ) -> Result<(UniqueID, DownloadStream)> {
261        self.check_not_finalized()?;
262        let file_range = range_bounds_to_file_range(&range)?;
263        let id = UniqueID::new();
264        let progress_updater = self.progress.new_item(id, "stream");
265        let reconstructor = self.setup_reconstructor(file_info, file_range, Some(progress_updater))?;
266        let stream = reconstructor.reconstruct_to_stream();
267        self.register_stream_abort_callback(id, stream.abort_callback());
268        Ok((id, stream))
269    }
270    fn check_not_finalized(&self) -> Result<()> {
271        if self.finalized.load(Ordering::Acquire) {
272            return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string()));
273        }
274        Ok(())
275    }
276
277    /// Finalizes the session; in debug builds, asserts all items are complete.
278    pub async fn finalize(&self) -> Result<()> {
279        if self.finalized.swap(true, Ordering::AcqRel) {
280            return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string()));
281        }
282        #[cfg(debug_assertions)]
283        self.progress.assert_complete();
284        Ok(())
285    }
286
287    fn setup_reconstructor(
288        &self,
289        file_info: &XetFileInfo,
290        range: Option<FileRange>,
291        progress_updater: Option<Arc<ItemProgressUpdater>>,
292    ) -> Result<FileReconstructor> {
293        let file_id = file_info.merkle_hash()?;
294
295        let mut reconstructor = FileReconstructor::new(&self.client, file_id);
296
297        match range {
298            Some(range) if range.end < u64::MAX => {
299                // Fully bounded range: we know the exact download size upfront.
300                let size = range.end - range.start;
301                if let Some(ref updater) = progress_updater {
302                    updater.update_item_size(size, true);
303                }
304                reconstructor = reconstructor.with_byte_range(range);
305            },
306            Some(range) => {
307                // Open-ended range (end == u64::MAX): pass the range to set the
308                // start position, but let ReconstructionTermManager discover
309                // the actual end and finalize progress incrementally.
310                reconstructor = reconstructor.with_byte_range(range);
311            },
312            None if file_info.file_size().is_some() => {
313                // Full file with caller-provided size. Set progress upfront so
314                // UI consumers get percentage-based progress. SizeMismatch is
315                // validated after reconstruction in download_file_with_id.
316                if let Some(ref updater) = progress_updater {
317                    updater.update_item_size(file_info.file_size().unwrap(), true);
318                }
319            },
320            None => {
321                // Full file with unknown size: the reconstructor uses
322                // FileRange::full() internally and ReconstructionTermManager
323                // discovers the size incrementally.
324            },
325        }
326
327        if let Some(updater) = progress_updater {
328            reconstructor = reconstructor.with_progress_updater(updater);
329        }
330
331        if let Some(ref cache) = self.chunk_cache {
332            reconstructor = reconstructor.with_chunk_cache(cache.clone());
333        }
334
335        Ok(reconstructor)
336    }
337}
338
339/// Converts any `RangeBounds<u64>` into an `Option<FileRange>`.
340///
341/// Returns `None` for the unbounded range `..` (equivalent to full file),
342/// and `Some(FileRange)` otherwise. Open-ended ranges use `u64::MAX` as
343/// the end sentinel (matching `FileRange::full()`).
344///
345/// Returns an error for inverted ranges where `start > end`.
346fn range_bounds_to_file_range(range: &impl RangeBounds<u64>) -> Result<Option<FileRange>> {
347    let start = match range.start_bound() {
348        Bound::Included(&s) => s,
349        Bound::Excluded(&s) => s.saturating_add(1),
350        Bound::Unbounded => 0,
351    };
352    let end = match range.end_bound() {
353        Bound::Included(&e) => e.saturating_add(1),
354        Bound::Excluded(&e) => e,
355        Bound::Unbounded => u64::MAX,
356    };
357    if start > end {
358        return Err(DataError::InvalidOperation(format!("Invalid range: start ({start}) > end ({end})")));
359    }
360    if start == 0 && end == u64::MAX {
361        Ok(None)
362    } else {
363        Ok(Some(FileRange::new(start, end)))
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use std::fs::{read, write};
370    use std::io::{Seek, SeekFrom};
371    use std::sync::{Arc, OnceLock};
372
373    use tempfile::tempdir;
374    use xet_runtime::core::XetRuntime;
375
376    use super::*;
377    use crate::processing::configurations::TranslatorConfig;
378    use crate::processing::file_cleaner::Sha256Policy;
379    use crate::processing::{FileUploadSession, XetFileInfo};
380
381    fn get_threadpool() -> Arc<XetRuntime> {
382        static THREADPOOL: OnceLock<Arc<XetRuntime>> = OnceLock::new();
383        THREADPOOL
384            .get_or_init(|| XetRuntime::new().expect("Error starting multithreaded runtime."))
385            .clone()
386    }
387
388    async fn upload_data(cas_path: &Path, data: &[u8]) -> XetFileInfo {
389        let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into())
390            .await
391            .unwrap();
392
393        let (_id, mut cleaner) = upload_session
394            .start_clean(Some("test".into()), Some(data.len() as u64), Sha256Policy::Compute)
395            .unwrap();
396        cleaner.add_data(data).await.unwrap();
397        let (xfi, _metrics) = cleaner.finish().await.unwrap();
398        upload_session.finalize().await.unwrap();
399        xfi
400    }
401
402    #[test]
403    fn test_download_file() {
404        let runtime = get_threadpool();
405        runtime
406            .clone()
407            .bridge_sync(async {
408                let temp = tempdir().unwrap();
409                let cas_path = temp.path().join("cas");
410                let original_data = b"Hello, download session!";
411
412                let xfi = upload_data(&cas_path, original_data).await;
413
414                let config = TranslatorConfig::local_config(&cas_path).unwrap();
415                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
416
417                let out_path = temp.path().join("output.txt");
418                let (_id, n_bytes) = session.download_file(&xfi, &out_path).await.unwrap();
419
420                assert_eq!(n_bytes, original_data.len() as u64);
421                assert_eq!(read(&out_path).unwrap(), original_data);
422            })
423            .unwrap();
424    }
425
426    #[test]
427    fn test_download_file_creates_parent_dirs() {
428        let runtime = get_threadpool();
429        runtime
430            .clone()
431            .bridge_sync(async {
432                let temp = tempdir().unwrap();
433                let cas_path = temp.path().join("cas");
434                let original_data = b"nested directory test";
435
436                let xfi = upload_data(&cas_path, original_data).await;
437
438                let config = TranslatorConfig::local_config(&cas_path).unwrap();
439                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
440
441                let out_path = temp.path().join("deep").join("nested").join("dir").join("output.txt");
442                assert!(!out_path.parent().unwrap().exists());
443
444                session.download_file(&xfi, &out_path).await.unwrap();
445
446                assert_eq!(read(&out_path).unwrap(), original_data);
447            })
448            .unwrap();
449    }
450
451    #[test]
452    fn test_download_to_writer() {
453        let runtime = get_threadpool();
454        runtime
455            .clone()
456            .bridge_sync(async {
457                let temp = tempdir().unwrap();
458                let cas_path = temp.path().join("cas");
459                let original_data = b"0123456789abcdef";
460
461                let xfi = upload_data(&cas_path, original_data).await;
462
463                let config = TranslatorConfig::local_config(&cas_path).unwrap();
464                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
465
466                let out_path = temp.path().join("partial_writer.txt");
467                write(&out_path, vec![0u8; original_data.len()]).unwrap();
468
469                let mut file = std::fs::OpenOptions::new().write(true).open(&out_path).unwrap();
470                file.seek(SeekFrom::Start(4)).unwrap();
471
472                let (_id, n_bytes) = session.download_to_writer(&xfi, 4..12, file).await.unwrap();
473
474                assert_eq!(n_bytes, 8);
475                let result = read(&out_path).unwrap();
476                assert_eq!(&result[4..12], &original_data[4..12]);
477            })
478            .unwrap();
479    }
480
481    #[test]
482    fn test_download_to_writer_parallel_partitioned_file() {
483        let runtime = get_threadpool();
484        runtime
485            .clone()
486            .bridge_sync(async {
487                let temp = tempdir().unwrap();
488                let cas_path = temp.path().join("cas");
489                let original_data = b"abcdefghijklmnopqrstuvwxyz0123456789";
490
491                let xfi = upload_data(&cas_path, original_data).await;
492                let config = TranslatorConfig::local_config(&cas_path).unwrap();
493                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
494
495                let out_path = temp.path().join("partitioned.txt");
496                write(&out_path, vec![0u8; original_data.len()]).unwrap();
497
498                let n_parts = 5u64;
499                let total = original_data.len() as u64;
500                let mut tasks = Vec::new();
501
502                for idx in 0..n_parts {
503                    let start = (idx * total) / n_parts;
504                    let end = ((idx + 1) * total) / n_parts;
505                    if start == end {
506                        continue;
507                    }
508
509                    let session = session.clone();
510                    let xfi = xfi.clone();
511                    let out_path = out_path.clone();
512                    tasks.push(tokio::spawn(async move {
513                        let mut writer = std::fs::OpenOptions::new().write(true).open(out_path).unwrap();
514                        writer.seek(SeekFrom::Start(start)).unwrap();
515                        session.download_to_writer(&xfi, start..end, writer).await
516                    }));
517                }
518
519                for task in tasks {
520                    task.await.unwrap().unwrap();
521                }
522
523                let result = read(&out_path).unwrap();
524                assert_eq!(result, original_data);
525            })
526            .unwrap();
527    }
528
529    #[test]
530    fn test_download_multiple_files_concurrent() {
531        let runtime = get_threadpool();
532        runtime
533            .clone()
534            .bridge_sync(async {
535                let temp = tempdir().unwrap();
536                let cas_path = temp.path().join("cas");
537
538                let data_a = b"File A content for concurrent test";
539                let data_b = b"File B content for concurrent test - different";
540
541                let xfi_a = upload_data(&cas_path, data_a).await;
542                let xfi_b = upload_data(&cas_path, data_b).await;
543
544                let config = TranslatorConfig::local_config(&cas_path).unwrap();
545                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
546
547                let out_a = temp.path().join("out_a.txt");
548                let out_b = temp.path().join("out_b.txt");
549
550                let session_a = session.clone();
551                let xfi_a_clone = xfi_a.clone();
552                let out_a_clone = out_a.clone();
553                let task_a = tokio::spawn(async move { session_a.download_file(&xfi_a_clone, &out_a_clone).await });
554
555                let session_b = session.clone();
556                let xfi_b_clone = xfi_b.clone();
557                let out_b_clone = out_b.clone();
558                let task_b = tokio::spawn(async move { session_b.download_file(&xfi_b_clone, &out_b_clone).await });
559
560                task_a.await.unwrap().unwrap();
561                task_b.await.unwrap().unwrap();
562
563                assert_eq!(read(&out_a).unwrap(), data_a);
564                assert_eq!(read(&out_b).unwrap(), data_b);
565            })
566            .unwrap();
567    }
568
569    // ==================== Download Stream Tests ====================
570
571    #[test]
572    fn test_download_stream_async() {
573        let runtime = get_threadpool();
574        runtime
575            .clone()
576            .bridge_sync(async {
577                let temp = tempdir().unwrap();
578                let cas_path = temp.path().join("cas");
579                let original_data = b"Hello, streaming download!";
580
581                let xfi = upload_data(&cas_path, original_data).await;
582
583                let config = TranslatorConfig::local_config(&cas_path).unwrap();
584                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
585
586                let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
587
588                let mut collected = Vec::new();
589                while let Some(chunk) = stream.next().await.unwrap() {
590                    collected.extend_from_slice(&chunk);
591                }
592
593                assert_eq!(collected, original_data);
594            })
595            .unwrap();
596    }
597
598    #[test]
599    fn test_download_stream_blocking() {
600        let runtime = get_threadpool();
601        runtime
602            .clone()
603            .bridge_sync(async {
604                let temp = tempdir().unwrap();
605                let cas_path = temp.path().join("cas");
606                let original_data = b"Blocking stream test data";
607
608                let xfi = upload_data(&cas_path, original_data).await;
609
610                let config = TranslatorConfig::local_config(&cas_path).unwrap();
611                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
612
613                let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
614
615                let collected = tokio::task::spawn_blocking(move || {
616                    let mut stream = stream;
617                    let mut buf = Vec::new();
618                    while let Some(chunk) = stream.blocking_next().unwrap() {
619                        buf.extend_from_slice(&chunk);
620                    }
621                    buf
622                })
623                .await
624                .unwrap();
625
626                assert_eq!(collected, original_data);
627            })
628            .unwrap();
629    }
630
631    #[test]
632    fn test_download_stream_returns_none_after_finish() {
633        let runtime = get_threadpool();
634        runtime
635            .clone()
636            .bridge_sync(async {
637                let temp = tempdir().unwrap();
638                let cas_path = temp.path().join("cas");
639                let original_data = b"Extra none calls";
640
641                let xfi = upload_data(&cas_path, original_data).await;
642
643                let config = TranslatorConfig::local_config(&cas_path).unwrap();
644                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
645
646                let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
647
648                while stream.next().await.unwrap().is_some() {}
649
650                // Subsequent calls should return Ok(None)
651                assert!(stream.next().await.unwrap().is_none());
652                assert!(stream.next().await.unwrap().is_none());
653            })
654            .unwrap();
655    }
656
657    #[test]
658    fn test_download_stream_multiple_concurrent() {
659        let runtime = get_threadpool();
660        runtime
661            .clone()
662            .bridge_sync(async {
663                let temp = tempdir().unwrap();
664                let cas_path = temp.path().join("cas");
665
666                let data_a = b"Stream A for concurrent download";
667                let data_b = b"Stream B for concurrent download - different";
668
669                let xfi_a = upload_data(&cas_path, data_a).await;
670                let xfi_b = upload_data(&cas_path, data_b).await;
671
672                let config = TranslatorConfig::local_config(&cas_path).unwrap();
673                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
674
675                let (_id_a, mut stream_a) = session.download_stream(&xfi_a, None).await.unwrap();
676                let (_id_b, mut stream_b) = session.download_stream(&xfi_b, None).await.unwrap();
677
678                let task_a = tokio::spawn(async move {
679                    let mut buf = Vec::new();
680                    while let Some(chunk) = stream_a.next().await.unwrap() {
681                        buf.extend_from_slice(&chunk);
682                    }
683                    buf
684                });
685
686                let task_b = tokio::spawn(async move {
687                    let mut buf = Vec::new();
688                    while let Some(chunk) = stream_b.next().await.unwrap() {
689                        buf.extend_from_slice(&chunk);
690                    }
691                    buf
692                });
693
694                let result_a = task_a.await.unwrap();
695                let result_b = task_b.await.unwrap();
696
697                assert_eq!(result_a, data_a);
698                assert_eq!(result_b, data_b);
699            })
700            .unwrap();
701    }
702
703    #[test]
704    fn test_drop_stream_without_reading() {
705        let runtime = get_threadpool();
706        runtime
707            .clone()
708            .bridge_sync(async {
709                let temp = tempdir().unwrap();
710                let cas_path = temp.path().join("cas");
711                let original_data = b"Drop-without-reading cleanup test";
712
713                let xfi = upload_data(&cas_path, original_data).await;
714
715                let config = TranslatorConfig::local_config(&cas_path).unwrap();
716                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
717
718                let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
719                drop(stream);
720                tokio::task::yield_now().await;
721
722                let out_path = temp.path().join("after_drop.txt");
723                session.download_file(&xfi, &out_path).await.unwrap();
724                assert_eq!(read(&out_path).unwrap(), original_data);
725            })
726            .unwrap();
727    }
728
729    #[test]
730    fn test_drop_stream_multiple_cycles_then_download() {
731        let runtime = get_threadpool();
732        runtime
733            .clone()
734            .bridge_sync(async {
735                let temp = tempdir().unwrap();
736                let cas_path = temp.path().join("cas");
737                let original_data = b"Multi-cycle drop cleanup test";
738
739                let xfi = upload_data(&cas_path, original_data).await;
740
741                let config = TranslatorConfig::local_config(&cas_path).unwrap();
742                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
743
744                for i in 0..5u32 {
745                    let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
746                    if i % 3 == 0 {
747                        let _ = stream.next().await;
748                    }
749                    drop(stream);
750                    tokio::task::yield_now().await;
751                }
752
753                let out_path = temp.path().join("after_cycles.txt");
754                session.download_file(&xfi, &out_path).await.unwrap();
755                assert_eq!(read(&out_path).unwrap(), original_data);
756            })
757            .unwrap();
758    }
759
760    #[test]
761    fn test_drop_stream_blocking_mid_read_then_download() {
762        let runtime = get_threadpool();
763        runtime
764            .clone()
765            .bridge_sync(async {
766                let temp = tempdir().unwrap();
767                let cas_path = temp.path().join("cas");
768                let original_data = b"Blocking drop cleanup test data";
769
770                let xfi = upload_data(&cas_path, original_data).await;
771
772                let config = TranslatorConfig::local_config(&cas_path).unwrap();
773                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
774
775                let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
776
777                tokio::task::spawn_blocking(move || {
778                    let mut stream = stream;
779                    let _chunk = stream.blocking_next().unwrap();
780                })
781                .await
782                .unwrap();
783
784                tokio::task::yield_now().await;
785
786                let out_path = temp.path().join("after_blocking_drop.txt");
787                session.download_file(&xfi, &out_path).await.unwrap();
788                assert_eq!(read(&out_path).unwrap(), original_data);
789            })
790            .unwrap();
791    }
792
793    #[test]
794    fn test_cancel_stream_before_start_returns_none() {
795        let runtime = get_threadpool();
796        runtime
797            .clone()
798            .bridge_sync(async {
799                let temp = tempdir().unwrap();
800                let cas_path = temp.path().join("cas");
801                let original_data = b"Cancel-before-start stream test";
802
803                let xfi = upload_data(&cas_path, original_data).await;
804
805                let config = TranslatorConfig::local_config(&cas_path).unwrap();
806                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
807
808                let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
809                stream.cancel();
810                assert!(stream.next().await.unwrap().is_none());
811                assert!(stream.next().await.unwrap().is_none());
812            })
813            .unwrap();
814    }
815
816    #[test]
817    fn test_cancel_stream_after_first_chunk_returns_none() {
818        let runtime = get_threadpool();
819        runtime
820            .clone()
821            .bridge_sync(async {
822                let temp = tempdir().unwrap();
823                let cas_path = temp.path().join("cas");
824                let original_data = b"Cancel-after-first-chunk stream test data";
825
826                let xfi = upload_data(&cas_path, original_data).await;
827
828                let config = TranslatorConfig::local_config(&cas_path).unwrap();
829                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
830
831                let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
832                let _ = stream.next().await.unwrap();
833                stream.cancel();
834                assert!(stream.next().await.unwrap().is_none());
835                assert!(stream.next().await.unwrap().is_none());
836
837                let out_path = temp.path().join("after_cancel.txt");
838                session.download_file(&xfi, &out_path).await.unwrap();
839                assert_eq!(read(&out_path).unwrap(), original_data);
840            })
841            .unwrap();
842    }
843
844    // ==================== Range Download Tests ====================
845
846    #[test]
847    fn test_download_to_writer_range_from() {
848        let runtime = get_threadpool();
849        runtime
850            .clone()
851            .external_run_async_task(async {
852                let temp = tempdir().unwrap();
853                let cas_path = temp.path().join("cas");
854                let original_data = b"0123456789abcdef";
855
856                let xfi = upload_data(&cas_path, original_data).await;
857
858                let config = TranslatorConfig::local_config(&cas_path).unwrap();
859                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
860
861                let out_path = temp.path().join("range_from.bin");
862                let file = std::fs::File::create(&out_path).unwrap();
863                let (_id, n_bytes) = session.download_to_writer(&xfi, 4.., file).await.unwrap();
864
865                assert_eq!(n_bytes, 12);
866                assert_eq!(read(&out_path).unwrap(), &original_data[4..]);
867            })
868            .unwrap();
869    }
870
871    #[test]
872    fn test_download_to_writer_range_to() {
873        let runtime = get_threadpool();
874        runtime
875            .clone()
876            .external_run_async_task(async {
877                let temp = tempdir().unwrap();
878                let cas_path = temp.path().join("cas");
879                let original_data = b"0123456789abcdef";
880
881                let xfi = upload_data(&cas_path, original_data).await;
882
883                let config = TranslatorConfig::local_config(&cas_path).unwrap();
884                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
885
886                let out_path = temp.path().join("range_to.bin");
887                let file = std::fs::File::create(&out_path).unwrap();
888                let (_id, n_bytes) = session.download_to_writer(&xfi, ..8, file).await.unwrap();
889
890                assert_eq!(n_bytes, 8);
891                assert_eq!(read(&out_path).unwrap(), &original_data[..8]);
892            })
893            .unwrap();
894    }
895
896    #[test]
897    fn test_download_to_writer_full_range() {
898        let runtime = get_threadpool();
899        runtime
900            .clone()
901            .external_run_async_task(async {
902                let temp = tempdir().unwrap();
903                let cas_path = temp.path().join("cas");
904                let original_data = b"0123456789abcdef";
905
906                let xfi = upload_data(&cas_path, original_data).await;
907
908                let config = TranslatorConfig::local_config(&cas_path).unwrap();
909                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
910
911                let out_path = temp.path().join("full_range.bin");
912                let file = std::fs::File::create(&out_path).unwrap();
913                let (_id, n_bytes) = session.download_to_writer(&xfi, .., file).await.unwrap();
914
915                assert_eq!(n_bytes, original_data.len() as u64);
916                assert_eq!(read(&out_path).unwrap(), original_data);
917            })
918            .unwrap();
919    }
920
921    #[test]
922    fn test_download_to_writer_range_inclusive() {
923        let runtime = get_threadpool();
924        runtime
925            .clone()
926            .external_run_async_task(async {
927                let temp = tempdir().unwrap();
928                let cas_path = temp.path().join("cas");
929                let original_data = b"0123456789abcdef";
930
931                let xfi = upload_data(&cas_path, original_data).await;
932
933                let config = TranslatorConfig::local_config(&cas_path).unwrap();
934                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
935
936                let out_path = temp.path().join("range_incl.bin");
937                let file = std::fs::File::create(&out_path).unwrap();
938                let (_id, n_bytes) = session.download_to_writer(&xfi, 2..=5, file).await.unwrap();
939
940                assert_eq!(n_bytes, 4);
941                assert_eq!(read(&out_path).unwrap(), &original_data[2..=5]);
942            })
943            .unwrap();
944    }
945
946    // ==================== Range Stream Tests ====================
947
948    #[test]
949    fn test_download_stream_range_bounded() {
950        let runtime = get_threadpool();
951        runtime
952            .clone()
953            .external_run_async_task(async {
954                let temp = tempdir().unwrap();
955                let cas_path = temp.path().join("cas");
956                let original_data = b"0123456789abcdef";
957
958                let xfi = upload_data(&cas_path, original_data).await;
959
960                let config = TranslatorConfig::local_config(&cas_path).unwrap();
961                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
962
963                let (_id, mut stream) = session.download_stream_range(&xfi, 4..12).await.unwrap();
964
965                let mut collected = Vec::new();
966                while let Some(chunk) = stream.next().await.unwrap() {
967                    collected.extend_from_slice(&chunk);
968                }
969
970                assert_eq!(collected, &original_data[4..12]);
971            })
972            .unwrap();
973    }
974
975    #[test]
976    fn test_download_stream_range_from() {
977        let runtime = get_threadpool();
978        runtime
979            .clone()
980            .external_run_async_task(async {
981                let temp = tempdir().unwrap();
982                let cas_path = temp.path().join("cas");
983                let original_data = b"0123456789abcdef";
984
985                let xfi = upload_data(&cas_path, original_data).await;
986
987                let config = TranslatorConfig::local_config(&cas_path).unwrap();
988                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
989
990                let (_id, mut stream) = session.download_stream_range(&xfi, 10..).await.unwrap();
991
992                let mut collected = Vec::new();
993                while let Some(chunk) = stream.next().await.unwrap() {
994                    collected.extend_from_slice(&chunk);
995                }
996
997                assert_eq!(collected, &original_data[10..]);
998            })
999            .unwrap();
1000    }
1001
1002    #[test]
1003    fn test_download_stream_range_to() {
1004        let runtime = get_threadpool();
1005        runtime
1006            .clone()
1007            .external_run_async_task(async {
1008                let temp = tempdir().unwrap();
1009                let cas_path = temp.path().join("cas");
1010                let original_data = b"0123456789abcdef";
1011
1012                let xfi = upload_data(&cas_path, original_data).await;
1013
1014                let config = TranslatorConfig::local_config(&cas_path).unwrap();
1015                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1016
1017                let (_id, mut stream) = session.download_stream_range(&xfi, ..6).await.unwrap();
1018
1019                let mut collected = Vec::new();
1020                while let Some(chunk) = stream.next().await.unwrap() {
1021                    collected.extend_from_slice(&chunk);
1022                }
1023
1024                assert_eq!(collected, &original_data[..6]);
1025            })
1026            .unwrap();
1027    }
1028
1029    // ==================== Download with unknown file size ====================
1030
1031    #[test]
1032    fn test_download_file_unknown_size() {
1033        let runtime = get_threadpool();
1034        runtime
1035            .clone()
1036            .external_run_async_task(async {
1037                let temp = tempdir().unwrap();
1038                let cas_path = temp.path().join("cas");
1039                let original_data = b"File with unknown size test";
1040
1041                let xfi = upload_data(&cas_path, original_data).await;
1042                let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string());
1043
1044                let config = TranslatorConfig::local_config(&cas_path).unwrap();
1045                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1046
1047                let out_path = temp.path().join("output_unknown.txt");
1048                let (_id, n_bytes) = session.download_file(&xfi_no_size, &out_path).await.unwrap();
1049
1050                assert_eq!(n_bytes, original_data.len() as u64);
1051                assert_eq!(read(&out_path).unwrap(), original_data);
1052            })
1053            .unwrap();
1054    }
1055
1056    #[test]
1057    fn test_download_stream_unknown_size() {
1058        let runtime = get_threadpool();
1059        runtime
1060            .clone()
1061            .external_run_async_task(async {
1062                let temp = tempdir().unwrap();
1063                let cas_path = temp.path().join("cas");
1064                let original_data = b"Stream with unknown size test";
1065
1066                let xfi = upload_data(&cas_path, original_data).await;
1067                let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string());
1068
1069                let config = TranslatorConfig::local_config(&cas_path).unwrap();
1070                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1071
1072                let (_id, mut stream) = session.download_stream(&xfi_no_size, None).await.unwrap();
1073
1074                let mut collected = Vec::new();
1075                while let Some(chunk) = stream.next().await.unwrap() {
1076                    collected.extend_from_slice(&chunk);
1077                }
1078
1079                assert_eq!(collected, original_data);
1080            })
1081            .unwrap();
1082    }
1083
1084    #[cfg(not(debug_assertions))]
1085    #[test]
1086    fn test_download_file_size_mismatch_error() {
1087        let runtime = get_threadpool();
1088        runtime
1089            .clone()
1090            .external_run_async_task(async {
1091                let temp = tempdir().unwrap();
1092                let cas_path = temp.path().join("cas");
1093                let original_data = b"Size mismatch test data";
1094
1095                let xfi = upload_data(&cas_path, original_data).await;
1096                let wrong_size_xfi = XetFileInfo::new(xfi.hash().to_string(), 999);
1097
1098                let config = TranslatorConfig::local_config(&cas_path).unwrap();
1099                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1100
1101                let out_path = temp.path().join("output_mismatch.txt");
1102                let err = session.download_file(&wrong_size_xfi, &out_path).await.unwrap_err();
1103
1104                assert!(
1105                    matches!(err, DataError::SizeMismatch { expected: 999, .. }),
1106                    "Expected SizeMismatch error, got: {err:?}"
1107                );
1108            })
1109            .unwrap();
1110    }
1111
1112    // ==================== range_bounds_to_file_range unit tests ====================
1113
1114    #[test]
1115    fn test_range_bounds_conversion() {
1116        use super::range_bounds_to_file_range;
1117
1118        assert_eq!(range_bounds_to_file_range(&(..)).unwrap(), None);
1119        assert_eq!(range_bounds_to_file_range(&(0..100)).unwrap(), Some(FileRange::new(0, 100)));
1120        assert_eq!(range_bounds_to_file_range(&(5..)).unwrap(), Some(FileRange::new(5, u64::MAX)));
1121        assert_eq!(range_bounds_to_file_range(&(..50)).unwrap(), Some(FileRange::new(0, 50)));
1122        assert_eq!(range_bounds_to_file_range(&(10..=19)).unwrap(), Some(FileRange::new(10, 20)));
1123    }
1124
1125    #[test]
1126    fn test_range_bounds_inverted_range_errors() {
1127        use super::range_bounds_to_file_range;
1128
1129        let result = range_bounds_to_file_range(&(10..5));
1130        assert!(result.is_err());
1131    }
1132
1133    #[test]
1134    fn test_download_to_writer_empty_range() {
1135        let runtime = get_threadpool();
1136        runtime
1137            .clone()
1138            .external_run_async_task(async {
1139                let temp = tempdir().unwrap();
1140                let cas_path = temp.path().join("cas");
1141                let original_data = b"0123456789abcdef";
1142
1143                let xfi = upload_data(&cas_path, original_data).await;
1144
1145                let config = TranslatorConfig::local_config(&cas_path).unwrap();
1146                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1147
1148                let out_path = temp.path().join("empty_range.bin");
1149                let file = std::fs::File::create(&out_path).unwrap();
1150                let (_id, n_bytes) = session.download_to_writer(&xfi, 5..5, file).await.unwrap();
1151
1152                assert_eq!(n_bytes, 0);
1153                assert_eq!(read(&out_path).unwrap(), &[] as &[u8]);
1154            })
1155            .unwrap();
1156    }
1157
1158    #[test]
1159    fn test_download_to_writer_inverted_range_errors() {
1160        let runtime = get_threadpool();
1161        runtime
1162            .clone()
1163            .external_run_async_task(async {
1164                let temp = tempdir().unwrap();
1165                let cas_path = temp.path().join("cas");
1166                let original_data = b"0123456789abcdef";
1167
1168                let xfi = upload_data(&cas_path, original_data).await;
1169
1170                let config = TranslatorConfig::local_config(&cas_path).unwrap();
1171                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1172
1173                let out_path = temp.path().join("inverted_range.bin");
1174                let file = std::fs::File::create(&out_path).unwrap();
1175                let result = session.download_to_writer(&xfi, 10..5, file).await;
1176
1177                assert!(result.is_err());
1178            })
1179            .unwrap();
1180    }
1181
1182    #[cfg(not(debug_assertions))]
1183    #[test]
1184    fn test_download_to_writer_range_start_beyond_file_size_errors() {
1185        let runtime = get_threadpool();
1186        runtime
1187            .clone()
1188            .external_run_async_task(async {
1189                let temp = tempdir().unwrap();
1190                let cas_path = temp.path().join("cas");
1191                let original_data = b"0123456789abcdef";
1192
1193                let xfi = upload_data(&cas_path, original_data).await;
1194
1195                let config = TranslatorConfig::local_config(&cas_path).unwrap();
1196                let session = FileDownloadSession::new(config.into(), None).await.unwrap();
1197
1198                let out_path = temp.path().join("beyond_size.bin");
1199                let file = std::fs::File::create(&out_path).unwrap();
1200                let result = session.download_to_writer(&xfi, 100000.., file).await;
1201
1202                assert!(result.is_err());
1203            })
1204            .unwrap();
1205    }
1206}