Skip to main content

steamroom_client/
download.rs

1use crate::event::DownloadEvent;
2use bytes::Bytes;
3use std::future::Future;
4use std::io::Write;
5use std::path::Path;
6use std::path::PathBuf;
7use std::sync::OnceLock;
8use std::time::Duration;
9use steamroom::cdn::CdnClient;
10use steamroom::cdn::pool::CdnServerPool;
11use steamroom::depot::ChunkId;
12use steamroom::depot::DepotId;
13use steamroom::depot::DepotKey;
14use steamroom::depot::chunk;
15use steamroom::depot::manifest::DepotManifest;
16use steamroom::depot::manifest::ManifestFile;
17use steamroom::enums::DepotFileFlags;
18use steamroom::error::Error as SteamError;
19use tokio::sync::mpsc;
20
21pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
22
23/// Trait for fetching raw encrypted chunk bytes. Implement this to provide
24/// a custom data source (CDN, local cache, LAN peer, etc.).
25pub trait ChunkFetcher: Send + Sync {
26    fn fetch_chunk(
27        &self,
28        depot_id: DepotId,
29        chunk_id: &ChunkId,
30    ) -> impl Future<Output = Result<Bytes, BoxError>> + Send;
31}
32
33/// CDN-backed chunk fetcher with server pool rotation and rate-limit handling.
34#[non_exhaustive]
35pub struct CdnChunkFetcher {
36    pub cdn: CdnClient,
37    pub pool: CdnServerPool,
38    pub cdn_auth_token: Option<String>,
39}
40
41impl CdnChunkFetcher {
42    pub fn new(cdn: CdnClient, pool: CdnServerPool, cdn_auth_token: Option<String>) -> Self {
43        Self {
44            cdn,
45            pool,
46            cdn_auth_token,
47        }
48    }
49}
50
51impl ChunkFetcher for CdnChunkFetcher {
52    async fn fetch_chunk(&self, depot_id: DepotId, chunk_id: &ChunkId) -> Result<Bytes, BoxError> {
53        let (server, wait) = self.pool.pick();
54        if !wait.is_zero() {
55            tracing::warn!(
56                server = %server.host,
57                wait_secs = wait.as_secs_f32(),
58                "all CDN servers in cooldown, waiting"
59            );
60            tokio::time::sleep(wait).await;
61        }
62        match self
63            .cdn
64            .download_chunk(server, depot_id, chunk_id, self.cdn_auth_token.as_deref())
65            .await
66        {
67            Ok(data) => {
68                self.pool.report_success(server);
69                Ok(data)
70            }
71            Err(SteamError::CdnStatus {
72                status,
73                retry_after,
74            }) => {
75                let ra = retry_after.map(Duration::from_secs);
76                if status == reqwest::StatusCode::TOO_MANY_REQUESTS
77                    || status == reqwest::StatusCode::SERVICE_UNAVAILABLE
78                {
79                    tracing::warn!(
80                        server = %server.host,
81                        status = status.as_u16(),
82                        retry_after = retry_after.unwrap_or(0),
83                        "CDN rate limited, backing off"
84                    );
85                } else {
86                    tracing::debug!(
87                        server = %server.host,
88                        status = status.as_u16(),
89                        "CDN error"
90                    );
91                }
92                self.pool.report_failure(server, ra);
93                Err(Box::new(SteamError::CdnStatus {
94                    status,
95                    retry_after,
96                }))
97            }
98            Err(e) => {
99                tracing::debug!(server = %server.host, error = %e, "CDN request failed");
100                self.pool.report_failure(server, None);
101                Err(Box::new(e))
102            }
103        }
104    }
105}
106
107#[derive(Clone, Debug)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109#[non_exhaustive]
110pub struct RetryConfig {
111    pub max_attempts: u32,
112    pub initial_delay: Duration,
113}
114
115impl Default for RetryConfig {
116    fn default() -> Self {
117        Self {
118            max_attempts: 5,
119            initial_delay: Duration::from_millis(500),
120        }
121    }
122}
123
124/// Controls which manifest files are included in a download.
125///
126/// ```
127/// use steamroom_client::download::FileFilter;
128///
129/// // Match only .dll files
130/// let filter = FileFilter::Regex(regex::Regex::new(r"\.dll$").unwrap());
131/// assert!(filter.matches("bin/server.dll"));
132/// assert!(!filter.matches("bin/server.exe"));
133///
134/// // Parse a filelist with mixed literal and regex entries
135/// let filter = FileFilter::from_filelist(&[
136///     "game/bin/server.dll".into(),
137///     "regex:^maps/.*\\.vpk$".into(),
138/// ]).unwrap();
139/// assert!(filter.matches("game/bin/server.dll"));
140/// assert!(filter.matches("maps/de_dust2.vpk"));
141/// ```
142pub enum FileFilter {
143    None,
144    Combined(Vec<FileFilterEntry>),
145    Regex(regex::Regex),
146}
147
148pub enum FileFilterEntry {
149    Literal(String),
150    Regex(regex::Regex),
151}
152
153impl FileFilter {
154    /// Convert the filter back into filelist string format.
155    /// Regex entries are prefixed with `regex:`.
156    pub fn to_filelist(&self) -> Vec<String> {
157        match self {
158            Self::None => vec![],
159            Self::Combined(entries) => entries
160                .iter()
161                .map(|e| match e {
162                    FileFilterEntry::Literal(s) => s.clone(),
163                    FileFilterEntry::Regex(re) => format!("regex:{}", re.as_str()),
164                })
165                .collect(),
166            Self::Regex(re) => vec![format!("regex:{}", re.as_str())],
167        }
168    }
169
170    /// Parse a filelist where lines can be literal paths or `regex:pattern` entries.
171    /// This is compatible with the filelist format used by DepotDownloader.
172    pub fn from_filelist(lines: &[String]) -> Result<Self, regex::Error> {
173        let mut entries = Vec::with_capacity(lines.len());
174        for line in lines {
175            if let Some(pattern) = line.strip_prefix("regex:") {
176                entries.push(FileFilterEntry::Regex(regex::Regex::new(pattern)?));
177            } else {
178                entries.push(FileFilterEntry::Literal(line.clone()));
179            }
180        }
181        Ok(Self::Combined(entries))
182    }
183
184    /// Returns true if `filename` passes the filter.
185    /// Literal comparisons are case-insensitive and normalize path separators.
186    pub fn matches(&self, filename: &str) -> bool {
187        match self {
188            Self::None => true,
189            Self::Combined(entries) => entries.iter().any(|entry| match entry {
190                FileFilterEntry::Literal(lit) => {
191                    filename.eq_ignore_ascii_case(lit)
192                        || filename.replace('\\', "/").eq_ignore_ascii_case(lit)
193                }
194                FileFilterEntry::Regex(re) => re.is_match(filename),
195            }),
196            Self::Regex(re) => re.is_match(filename),
197        }
198    }
199}
200
201#[cfg(feature = "serde")]
202impl serde::Serialize for FileFilter {
203    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
204        self.to_filelist().serialize(serializer)
205    }
206}
207
208#[cfg(feature = "serde")]
209impl<'de> serde::Deserialize<'de> for FileFilter {
210    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
211        let lines: Vec<String> = Vec::deserialize(deserializer)?;
212        if lines.is_empty() {
213            return Ok(Self::None);
214        }
215        Self::from_filelist(&lines).map_err(serde::de::Error::custom)
216    }
217}
218
219/// A download job for a single depot. Handles chunk fetching, decryption,
220/// decompression, file assembly, resume, and delta removal of stale files.
221///
222/// Create via [`DepotJob::builder()`].
223pub struct DepotJob {
224    depot_id: DepotId,
225    depot_key: DepotKey,
226    install_dir: PathBuf,
227    max_downloads: usize,
228    verify: bool,
229    file_filter: FileFilter,
230    retry: RetryConfig,
231    event_tx: Option<mpsc::UnboundedSender<DownloadEvent>>,
232    old_manifest_files: Option<Vec<String>>,
233}
234
235impl DepotJob {
236    pub fn builder() -> DepotJobBuilder {
237        DepotJobBuilder::default()
238    }
239
240    fn emit(&self, event: DownloadEvent) {
241        if let Some(ref tx) = self.event_tx {
242            let _ = tx.send(event);
243        }
244    }
245
246    pub async fn download<F: ChunkFetcher + 'static>(
247        &self,
248        manifest: &DepotManifest,
249        fetcher: std::sync::Arc<F>,
250    ) -> Result<DownloadStats, BoxError> {
251        let (total_bytes, total_files) =
252            manifest
253                .files
254                .iter()
255                .fold((0u64, 0u64), |(bytes, count), f| {
256                    if self.file_filter.matches(&f.filename) {
257                        (bytes + f.size, count + 1)
258                    } else {
259                        (bytes, count)
260                    }
261                });
262        let mut stats = DownloadStats::default();
263
264        self.emit(DownloadEvent::DownloadStarted {
265            total_bytes,
266            total_files,
267        });
268
269        let sem = std::sync::Arc::new(tokio::sync::Semaphore::new(self.max_downloads));
270
271        for file in &manifest.files {
272            let filename = &file.filename;
273
274            if !self.file_filter.matches(filename) {
275                self.emit(DownloadEvent::FileSkipped {
276                    filename: filename.clone(),
277                });
278                stats.files_skipped += 1;
279                continue;
280            }
281
282            let file_path = self.install_dir.join(filename);
283            let flags = DepotFileFlags(file.flags);
284
285            if flags.is_directory() {
286                std::fs::create_dir_all(&file_path)?;
287                continue;
288            }
289
290            if file.size == 0 && file.chunks.is_empty() {
291                if let Some(parent) = file_path.parent() {
292                    std::fs::create_dir_all(parent)?;
293                }
294                std::fs::write(&file_path, [])?;
295                stats.files_completed += 1;
296                continue;
297            }
298
299            if file.link_target.is_some() {
300                // Symlinks — skip for now
301                continue;
302            }
303
304            if let Some(parent) = file_path.parent() {
305                std::fs::create_dir_all(parent)?;
306            }
307
308            // Check if file already matches the manifest (skip if up-to-date)
309            let expected_size = file.size;
310            if self.verify && file_matches(&file_path, expected_size, file.sha_content.as_ref()) {
311                self.emit(DownloadEvent::FileSkipped {
312                    filename: filename.to_string(),
313                });
314                stats.files_skipped += 1;
315                stats.bytes_downloaded += expected_size;
316                continue;
317            }
318
319            self.emit(DownloadEvent::FileStarted {
320                filename: filename.to_string(),
321            });
322
323            // Download to staging, then move to final path
324            let staging_dir = self.install_dir.join(".depotdownloader").join("staging");
325            std::fs::create_dir_all(&staging_dir)?;
326            let staging_path = staging_dir.join(filename.replace(['/', '\\'], "_"));
327
328            let file_size = self
329                .download_file_chunks_with_resume(file, &fetcher, &sem, &staging_path)
330                .await?;
331
332            // Move staging file into place. On Windows, rename fails if the
333            // target exists and is read-only, so remove it first.
334            if file_path.exists() {
335                // Clear read-only attribute on Windows before removing
336                #[cfg(windows)]
337                {
338                    let mut perms = std::fs::metadata(&file_path)?.permissions();
339                    #[allow(clippy::permissions_set_readonly_false)]
340                    if perms.readonly() {
341                        perms.set_readonly(false);
342                        let _ = std::fs::set_permissions(&file_path, perms);
343                    }
344                }
345                std::fs::remove_file(&file_path)?;
346            }
347            std::fs::rename(&staging_path, &file_path)?;
348            stats.bytes_downloaded += file_size;
349            stats.files_completed += 1;
350
351            self.emit(DownloadEvent::FileCompleted {
352                filename: filename.to_string(),
353            });
354            self.emit(DownloadEvent::DepotProgress {
355                completed_bytes: stats.bytes_downloaded,
356                total_bytes,
357            });
358        }
359
360        // Remove files from the old manifest that are absent in the new one
361        if let Some(ref old_files) = self.old_manifest_files {
362            let new_files: std::collections::HashSet<&str> =
363                manifest.files.iter().map(|f| f.filename.as_str()).collect();
364
365            for old_name in old_files {
366                if new_files.contains(old_name.as_str()) {
367                    continue;
368                }
369                let old_path = self.install_dir.join(old_name.replace('\\', "/"));
370                if old_path.exists() {
371                    let is_dir = old_path.is_dir();
372                    let removed = if is_dir {
373                        std::fs::remove_dir(&old_path).is_ok()
374                    } else {
375                        std::fs::remove_file(&old_path).is_ok()
376                    };
377                    if removed {
378                        self.emit(DownloadEvent::FileRemoved {
379                            filename: old_name.clone(),
380                        });
381                        stats.files_removed += 1;
382                    }
383                }
384            }
385        }
386
387        Ok(stats)
388    }
389
390    /// Pipelined chunk download: network fetch and decrypt/decompress overlap.
391    ///
392    /// Stage 1 (async IO, bounded by semaphore): fetch raw bytes from CDN
393    /// Stage 2 (blocking thread pool): decrypt + decompress + checksum verify
394    ///
395    /// Fetchers push raw bytes into a bounded channel. A processor task drains
396    /// the channel and dispatches each chunk to spawn_blocking. Results land in
397    /// ordered slots. The bounded channel provides backpressure: if the CPU pool
398    /// falls behind, fetchers block on send instead of buffering unbounded memory.
399    async fn download_file_chunks<F: ChunkFetcher + 'static>(
400        &self,
401        file: &ManifestFile,
402        fetcher: &std::sync::Arc<F>,
403        sem: &std::sync::Arc<tokio::sync::Semaphore>,
404    ) -> Result<Vec<u8>, BoxError> {
405        let n = file.chunks.len();
406        if n == 0 {
407            return Ok(Vec::new());
408        }
409
410        // Bounded channel: fetch stage → process stage.
411        // Capacity = max_downloads so we buffer at most that many fetched-but-unprocessed chunks.
412        let (fetch_tx, mut fetch_rx) =
413            tokio::sync::mpsc::channel::<(usize, Bytes, u32, u32)>(self.max_downloads);
414
415        let slots: std::sync::Arc<Vec<OnceLock<Vec<u8>>>> =
416            std::sync::Arc::new((0..n).map(|_| OnceLock::new()).collect());
417
418        // Stage 1: spawn fetcher tasks
419        let mut fetch_handles = Vec::with_capacity(n);
420        for (i, chunk_meta) in file.chunks.iter().enumerate() {
421            let chunk_id = chunk_meta.id.clone();
422            let expected_size = chunk_meta.uncompressed_size;
423            let checksum = chunk_meta.checksum;
424            let depot_id = self.depot_id;
425            let retry = self.retry.clone();
426            let event_tx = self.event_tx.clone();
427            let sem = sem.clone();
428            let fetcher = fetcher.clone();
429            let fetch_tx = fetch_tx.clone();
430
431            fetch_handles.push(tokio::spawn(async move {
432                let _permit = sem
433                    .acquire()
434                    .await
435                    .map_err(|e| -> BoxError { Box::new(e) })?;
436
437                let mut delay = retry.initial_delay;
438                let mut result = Err::<Bytes, BoxError>("never attempted".into());
439                for attempt in 0..retry.max_attempts {
440                    match fetcher.fetch_chunk(depot_id, &chunk_id).await {
441                        Ok(data) => {
442                            result = Ok(data);
443                            break;
444                        }
445                        Err(e) if attempt + 1 < retry.max_attempts => {
446                            let wait = retry_delay_for_error(&e, delay);
447                            if let Some(ref tx) = event_tx {
448                                let _ = tx.send(DownloadEvent::ChunkFailed {
449                                    error: e.to_string(),
450                                });
451                            }
452                            tokio::time::sleep(wait).await;
453                            delay = (wait * 2).min(Duration::from_secs(30));
454                        }
455                        Err(e) => {
456                            result = Err(e);
457                            break;
458                        }
459                    }
460                }
461
462                // Backpressure: if process stage is full, this blocks the fetcher
463                // (which releases the semaphore permit, letting other fetchers proceed)
464                fetch_tx
465                    .send((i, result?, expected_size, checksum))
466                    .await
467                    .map_err(|_| -> BoxError { "process channel closed".into() })?;
468                Ok::<(), BoxError>(())
469            }));
470        }
471        drop(fetch_tx); // close so process loop terminates when all fetchers done
472
473        // Stage 2: drain fetch results → spawn_blocking for decrypt+decompress
474        let slots_ref = slots.clone();
475        let depot_key = self.depot_key.clone();
476        let event_tx = self.event_tx.clone();
477
478        let process_handle = tokio::spawn(async move {
479            let mut block_handles = Vec::new();
480
481            while let Some((i, raw, expected_size, checksum)) = fetch_rx.recv().await {
482                let key = depot_key.clone();
483                let slots = slots_ref.clone();
484                let tx = event_tx.clone();
485
486                block_handles.push(tokio::task::spawn_blocking(move || {
487                    let processed = chunk::process_chunk(&raw, &key, expected_size, checksum)?;
488                    if let Some(ref tx) = tx {
489                        let _ = tx.send(DownloadEvent::ChunkCompleted {
490                            bytes: processed.len() as u64,
491                        });
492                    }
493                    let _ = slots[i].set(processed);
494                    Ok::<(), BoxError>(())
495                }));
496            }
497
498            for h in block_handles {
499                h.await??;
500            }
501            Ok::<(), BoxError>(())
502        });
503
504        // Wait for both stages
505        for h in fetch_handles {
506            h.await??;
507        }
508        process_handle.await??;
509
510        // Assemble in order
511        let slots = std::sync::Arc::try_unwrap(slots).map_err(|_| "slots arc still shared")?;
512        // size hint only — Vec grows if absent, no correctness impact
513        let mut file_data = Vec::with_capacity(file.size as usize);
514        for slot in slots {
515            file_data
516                .extend_from_slice(&slot.into_inner().ok_or("chunk slot empty after pipeline")?);
517        }
518        Ok(file_data)
519    }
520
521    /// Downloads remaining chunks to the staging file. Returns total file size in bytes.
522    async fn download_file_chunks_with_resume<F: ChunkFetcher + 'static>(
523        &self,
524        file: &ManifestFile,
525        fetcher: &std::sync::Arc<F>,
526        sem: &std::sync::Arc<tokio::sync::Semaphore>,
527        staging_path: &Path,
528    ) -> Result<u64, BoxError> {
529        let existing_bytes = std::fs::metadata(staging_path)
530            .map(|m| m.len())
531            .unwrap_or(0);
532
533        // Count complete chunks already staged
534        let mut staged_offset: u64 = 0;
535        let mut skip_count = 0;
536        if existing_bytes > 0 {
537            for chunk_meta in &file.chunks {
538                let chunk_size = chunk_meta.uncompressed_size as u64;
539                if staged_offset + chunk_size <= existing_bytes {
540                    staged_offset += chunk_size;
541                    skip_count += 1;
542                } else {
543                    break;
544                }
545            }
546        }
547
548        if skip_count == file.chunks.len() {
549            return Ok(staged_offset);
550        }
551
552        if skip_count > 0 {
553            tracing::debug!(
554                "resuming {}: skipping {skip_count}/{} chunks ({staged_offset} bytes staged)",
555                &file.filename,
556                file.chunks.len(),
557            );
558        } else {
559            let _ = std::fs::remove_file(staging_path);
560        }
561
562        // Build a trimmed file with only remaining chunks, pipeline-download them
563        let mut remaining = ManifestFile::new(file.filename.clone(), file.size - staged_offset);
564        remaining.flags = file.flags;
565        remaining.sha_content = file.sha_content;
566        remaining.chunks = file.chunks[skip_count..].to_vec();
567
568        let new_data = self.download_file_chunks(&remaining, fetcher, sem).await?;
569        let new_len = new_data.len() as u64;
570
571        // Append to staging for crash safety
572        {
573            let mut f = std::fs::OpenOptions::new()
574                .create(true)
575                .append(true)
576                .open(staging_path)?;
577            f.write_all(&new_data)?;
578        }
579
580        Ok(staged_offset + new_len)
581    }
582}
583
584fn file_matches(path: &Path, expected_size: u64, sha_content: Option<&[u8; 20]>) -> bool {
585    let meta = match std::fs::metadata(path) {
586        Ok(m) => m,
587        Err(_) => return false,
588    };
589    if meta.len() != expected_size {
590        return false;
591    }
592    if let Some(expected_sha) = sha_content {
593        if let Ok(data) = std::fs::read(path) {
594            let actual = steamroom::util::checksum::Sha1Hash::compute(&data);
595            return actual.0 == *expected_sha;
596        }
597        return false;
598    }
599    // No SHA to verify — size match is good enough
600    true
601}
602
603#[derive(Default, Debug)]
604#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
605#[non_exhaustive]
606pub struct DownloadStats {
607    pub files_completed: u64,
608    pub files_skipped: u64,
609    pub files_removed: u64,
610    pub bytes_downloaded: u64,
611}
612
613#[derive(Default)]
614pub struct DepotJobBuilder {
615    depot_id: Option<DepotId>,
616    depot_key: Option<DepotKey>,
617    install_dir: Option<PathBuf>,
618    max_downloads: Option<usize>,
619    verify: bool,
620    file_filter: Option<FileFilter>,
621    retry: Option<RetryConfig>,
622    event_tx: Option<mpsc::UnboundedSender<DownloadEvent>>,
623    old_manifest_files: Option<Vec<String>>,
624}
625
626impl DepotJobBuilder {
627    pub fn depot_id(mut self, id: DepotId) -> Self {
628        self.depot_id = Some(id);
629        self
630    }
631
632    pub fn depot_key(mut self, key: DepotKey) -> Self {
633        self.depot_key = Some(key);
634        self
635    }
636
637    pub fn install_dir(mut self, dir: PathBuf) -> Self {
638        self.install_dir = Some(dir);
639        self
640    }
641
642    pub fn max_downloads(mut self, n: usize) -> Self {
643        self.max_downloads = Some(n);
644        self
645    }
646
647    pub fn file_filter(mut self, f: FileFilter) -> Self {
648        self.file_filter = Some(f);
649        self
650    }
651
652    pub fn verify(mut self, v: bool) -> Self {
653        self.verify = v;
654        self
655    }
656
657    pub fn retry(mut self, config: RetryConfig) -> Self {
658        self.retry = Some(config);
659        self
660    }
661
662    pub fn event_sender(mut self, tx: mpsc::UnboundedSender<DownloadEvent>) -> Self {
663        self.event_tx = Some(tx);
664        self
665    }
666
667    pub fn old_manifest_files(mut self, files: Vec<String>) -> Self {
668        self.old_manifest_files = Some(files);
669        self
670    }
671
672    pub fn build(self) -> Result<DepotJob, BoxError> {
673        Ok(DepotJob {
674            depot_id: self.depot_id.ok_or("depot_id required")?,
675            depot_key: self.depot_key.ok_or("depot_key required")?,
676            install_dir: self.install_dir.ok_or("install_dir required")?,
677            max_downloads: self.max_downloads.unwrap_or(16),
678            verify: self.verify,
679            file_filter: self.file_filter.unwrap_or(FileFilter::None),
680            retry: self.retry.unwrap_or_default(),
681            event_tx: self.event_tx,
682            old_manifest_files: self.old_manifest_files,
683        })
684    }
685}
686
687/// Compute retry delay, respecting `Retry-After` from 429/503 responses.
688fn retry_delay_for_error(err: &BoxError, default: Duration) -> Duration {
689    if let Some(SteamError::CdnStatus {
690        status,
691        retry_after,
692    }) = err.downcast_ref::<SteamError>()
693        && (*status == reqwest::StatusCode::TOO_MANY_REQUESTS
694            || *status == reqwest::StatusCode::SERVICE_UNAVAILABLE)
695    {
696        if let Some(secs) = retry_after {
697            return Duration::from_secs((*secs).min(60));
698        }
699        return default.max(Duration::from_secs(5));
700    }
701    default
702}