1use crate::{Error, HttpClient, Result};
8use reqwest::Response;
9use serde::{Deserialize, Serialize};
10use std::path::{Path, PathBuf};
11use tokio::fs::{File, OpenOptions};
12use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
13use tracing::{debug, info, warn};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DownloadProgress {
18 pub total_size: Option<u64>,
20 pub bytes_downloaded: u64,
22 pub file_hash: String,
24 pub cdn_host: String,
26 pub cdn_path: String,
28 pub target_file: PathBuf,
30 pub progress_file: PathBuf,
32 pub is_complete: bool,
34 pub last_updated: u64,
36}
37
38#[derive(Debug)]
40pub struct ResumableDownload {
41 client: HttpClient,
42 progress: DownloadProgress,
43}
44
45impl DownloadProgress {
46 pub fn new(
48 file_hash: String,
49 cdn_host: String,
50 cdn_path: String,
51 target_file: PathBuf,
52 ) -> Self {
53 let progress_file = target_file.with_extension("download");
54
55 Self {
56 total_size: None,
57 bytes_downloaded: 0,
58 file_hash,
59 cdn_host,
60 cdn_path,
61 target_file,
62 progress_file,
63 is_complete: false,
64 last_updated: current_timestamp(),
65 }
66 }
67
68 pub async fn load_from_file(progress_file: &Path) -> Result<Self> {
70 let content = tokio::fs::read_to_string(progress_file).await?;
71 let mut progress: DownloadProgress = serde_json::from_str(&content)?;
72 progress.last_updated = current_timestamp();
73 Ok(progress)
74 }
75
76 pub async fn save_to_file(&self) -> Result<()> {
78 let content = serde_json::to_string_pretty(self)?;
79 tokio::fs::write(&self.progress_file, content).await?;
80 debug!("Saved download progress to {:?}", self.progress_file);
81 Ok(())
82 }
83
84 pub async fn verify_existing_file(&self) -> Result<bool> {
86 if let Ok(metadata) = tokio::fs::metadata(&self.target_file).await {
87 let file_size = metadata.len();
88
89 if let Some(total) = self.total_size {
91 return Ok(file_size == total);
92 }
93
94 Ok(file_size >= self.bytes_downloaded)
96 } else {
97 Ok(false)
98 }
99 }
100
101 pub fn completion_percentage(&self) -> Option<f64> {
103 self.total_size.map(|total| {
104 if total == 0 {
105 100.0
106 } else {
107 (self.bytes_downloaded as f64 / total as f64) * 100.0
108 }
109 })
110 }
111
112 pub fn progress_string(&self) -> String {
114 match (self.total_size, self.completion_percentage()) {
115 (Some(total), Some(percent)) => {
116 format!(
117 "{}/{} bytes ({:.1}%)",
118 format_bytes(self.bytes_downloaded),
119 format_bytes(total),
120 percent
121 )
122 }
123 (Some(total), None) => {
124 format!(
125 "{}/{} bytes",
126 format_bytes(self.bytes_downloaded),
127 format_bytes(total)
128 )
129 }
130 (None, _) => {
131 format!("{} bytes", format_bytes(self.bytes_downloaded))
132 }
133 }
134 }
135}
136
137impl ResumableDownload {
138 pub fn new(client: HttpClient, progress: DownloadProgress) -> Self {
140 Self { client, progress }
141 }
142
143 pub async fn start_or_resume(&mut self) -> Result<()> {
145 let can_resume = if self.progress.bytes_downloaded > 0 {
147 self.progress.verify_existing_file().await.unwrap_or(false)
148 } else {
149 false
150 };
151
152 if can_resume {
153 info!(
154 "Resuming download from {} bytes for {}",
155 self.progress.bytes_downloaded, self.progress.file_hash
156 );
157 } else {
158 info!("Starting new download for {}", self.progress.file_hash);
159 self.progress.bytes_downloaded = 0;
160 }
161
162 self.progress.save_to_file().await?;
164
165 self.download_with_resume().await
167 }
168
169 async fn download_with_resume(&mut self) -> Result<()> {
171 let mut file = OpenOptions::new()
173 .create(true)
174 .write(true)
175 .read(true)
176 .truncate(false)
177 .open(&self.progress.target_file)
178 .await?;
179
180 if self.progress.bytes_downloaded > 0 {
182 file.seek(SeekFrom::Start(self.progress.bytes_downloaded))
183 .await?;
184 }
185
186 let range = (self.progress.bytes_downloaded, None);
188 let response = self
189 .client
190 .download_file_range(
191 &self.progress.cdn_host,
192 &self.progress.cdn_path,
193 &self.progress.file_hash,
194 range,
195 )
196 .await?;
197
198 if self.progress.total_size.is_none() {
200 self.progress.total_size =
201 extract_total_size(&response, self.progress.bytes_downloaded);
202 }
203
204 match response.status() {
206 reqwest::StatusCode::PARTIAL_CONTENT => {
207 debug!(
208 "Server supports range requests, resuming from byte {}",
209 self.progress.bytes_downloaded
210 );
211 }
212 reqwest::StatusCode::OK => {
213 if self.progress.bytes_downloaded > 0 {
214 warn!(
215 "Server doesn't support range requests, restarting download from beginning"
216 );
217 file.seek(SeekFrom::Start(0)).await?;
218 file.set_len(0).await?;
219 self.progress.bytes_downloaded = 0;
220 }
221 }
222 _status => {
223 return Err(Error::InvalidResponse);
224 }
225 }
226
227 self.stream_response_to_file(response, &mut file).await?;
229
230 self.progress.is_complete = true;
232 self.progress.save_to_file().await?;
233
234 info!("Download completed: {}", self.progress.progress_string());
235 Ok(())
236 }
237
238 async fn stream_response_to_file(&mut self, response: Response, file: &mut File) -> Result<()> {
240 let mut stream = response.bytes_stream();
241 let mut bytes_written_since_save = 0u64;
242 const SAVE_INTERVAL: u64 = 1024 * 1024; use futures_util::StreamExt;
245
246 while let Some(chunk_result) = stream.next().await {
247 let chunk = chunk_result.map_err(Error::Http)?;
248
249 file.write_all(&chunk).await?;
251
252 let chunk_size = chunk.len() as u64;
254 self.progress.bytes_downloaded += chunk_size;
255 bytes_written_since_save += chunk_size;
256
257 if bytes_written_since_save >= SAVE_INTERVAL {
259 file.flush().await?;
260 self.progress.last_updated = current_timestamp();
261 self.progress.save_to_file().await?;
262 bytes_written_since_save = 0;
263
264 debug!("Progress: {}", self.progress.progress_string());
265 }
266 }
267
268 file.flush().await?;
270 self.progress.last_updated = current_timestamp();
271
272 Ok(())
273 }
274
275 pub fn progress(&self) -> &DownloadProgress {
277 &self.progress
278 }
279
280 pub async fn cancel(&self) -> Result<()> {
282 if self.progress.progress_file.exists() {
283 tokio::fs::remove_file(&self.progress.progress_file).await?;
284 debug!("Removed progress file {:?}", self.progress.progress_file);
285 }
286 Ok(())
287 }
288
289 pub async fn cleanup_completed(&self) -> Result<()> {
291 if self.progress.is_complete && self.progress.progress_file.exists() {
292 tokio::fs::remove_file(&self.progress.progress_file).await?;
293 debug!("Cleaned up progress file for completed download");
294 }
295 Ok(())
296 }
297}
298
299fn extract_total_size(response: &Response, bytes_already_downloaded: u64) -> Option<u64> {
301 if let Some(content_range) = response.headers().get("content-range") {
303 if let Ok(range_str) = content_range.to_str() {
304 if let Some(total_str) = range_str.split('/').nth(1) {
306 if let Ok(total) = total_str.parse::<u64>() {
307 return Some(total);
308 }
309 }
310 }
311 }
312
313 if let Some(content_length) = response.headers().get("content-length") {
315 if let Ok(length_str) = content_length.to_str() {
316 if let Ok(length) = length_str.parse::<u64>() {
317 return Some(length + bytes_already_downloaded);
319 }
320 }
321 }
322
323 None
324}
325
326fn format_bytes(bytes: u64) -> String {
328 const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
329 let mut size = bytes as f64;
330 let mut unit_index = 0;
331
332 while size >= 1024.0 && unit_index < UNITS.len() - 1 {
333 size /= 1024.0;
334 unit_index += 1;
335 }
336
337 if unit_index == 0 {
338 format!("{} {}", bytes, UNITS[unit_index])
339 } else {
340 format!("{:.2} {}", size, UNITS[unit_index])
341 }
342}
343
344fn current_timestamp() -> u64 {
346 std::time::SystemTime::now()
347 .duration_since(std::time::UNIX_EPOCH)
348 .unwrap_or_default()
349 .as_secs()
350}
351
352pub async fn find_resumable_downloads(dir: &Path) -> Result<Vec<DownloadProgress>> {
354 let mut downloads = Vec::new();
355
356 if !dir.exists() {
357 return Ok(downloads);
358 }
359
360 let mut entries = tokio::fs::read_dir(dir).await?;
361
362 while let Some(entry) = entries.next_entry().await? {
363 let path = entry.path();
364
365 if path.extension().and_then(|s| s.to_str()) == Some("download") {
366 match DownloadProgress::load_from_file(&path).await {
367 Ok(progress) => {
368 if !progress.is_complete {
369 downloads.push(progress);
370 }
371 }
372 Err(e) => {
373 warn!("Failed to load download progress from {:?}: {}", path, e);
374 }
375 }
376 }
377 }
378
379 Ok(downloads)
380}
381
382pub async fn cleanup_old_progress_files(dir: &Path, max_age_hours: u64) -> Result<usize> {
384 let max_age_secs = max_age_hours * 3600;
385 let current_time = current_timestamp();
386 let mut cleaned_count = 0;
387
388 if !dir.exists() {
389 return Ok(0);
390 }
391
392 let mut entries = tokio::fs::read_dir(dir).await?;
393
394 while let Some(entry) = entries.next_entry().await? {
395 let path = entry.path();
396
397 if path.extension().and_then(|s| s.to_str()) == Some("download") {
398 match DownloadProgress::load_from_file(&path).await {
399 Ok(progress) => {
400 let age = current_time.saturating_sub(progress.last_updated);
401
402 if progress.is_complete
403 && age > max_age_secs
404 && tokio::fs::remove_file(&path).await.is_ok()
405 {
406 cleaned_count += 1;
407 debug!("Cleaned up old progress file: {:?}", path);
408 }
409 }
410 Err(_) => {
411 if let Ok(metadata) = tokio::fs::metadata(&path).await {
414 if let Ok(modified) = metadata.modified() {
415 let file_age = std::time::SystemTime::now()
416 .duration_since(modified)
417 .unwrap_or_default()
418 .as_secs();
419
420 if file_age > max_age_secs
421 && tokio::fs::remove_file(&path).await.is_ok()
422 {
423 cleaned_count += 1;
424 debug!("Cleaned up corrupted progress file: {:?}", path);
425 }
426 }
427 }
428 }
429 }
430 }
431 }
432
433 Ok(cleaned_count)
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use tempfile::TempDir;
440
441 #[test]
442 fn test_format_bytes() {
443 assert_eq!(format_bytes(0), "0 B");
444 assert_eq!(format_bytes(512), "512 B");
445 assert_eq!(format_bytes(1024), "1.00 KB");
446 assert_eq!(format_bytes(1536), "1.50 KB");
447 assert_eq!(format_bytes(1048576), "1.00 MB");
448 assert_eq!(format_bytes(1073741824), "1.00 GB");
449 }
450
451 #[test]
452 fn test_completion_percentage() {
453 let mut progress = DownloadProgress::new(
454 "testhash".to_string(),
455 "cdn.test.com".to_string(),
456 "/data".to_string(),
457 PathBuf::from("/tmp/test.dat"),
458 );
459
460 assert!(progress.completion_percentage().is_none());
462
463 progress.total_size = Some(1000);
465 progress.bytes_downloaded = 250;
466 assert_eq!(progress.completion_percentage(), Some(25.0));
467
468 progress.bytes_downloaded = 1000;
470 assert_eq!(progress.completion_percentage(), Some(100.0));
471
472 progress.total_size = Some(0);
474 progress.bytes_downloaded = 0;
475 assert_eq!(progress.completion_percentage(), Some(100.0));
476 }
477
478 #[tokio::test]
479 async fn test_progress_persistence() {
480 let temp_dir = TempDir::new().unwrap();
481 let target_file = temp_dir.path().join("test.dat");
482
483 let mut progress = DownloadProgress::new(
484 "testhash123".to_string(),
485 "cdn.example.com".to_string(),
486 "/data".to_string(),
487 target_file,
488 );
489
490 progress.total_size = Some(2048);
491 progress.bytes_downloaded = 1024;
492
493 progress.save_to_file().await.unwrap();
495 assert!(progress.progress_file.exists());
496
497 let loaded_progress = DownloadProgress::load_from_file(&progress.progress_file)
499 .await
500 .unwrap();
501 assert_eq!(loaded_progress.file_hash, "testhash123");
502 assert_eq!(loaded_progress.total_size, Some(2048));
503 assert_eq!(loaded_progress.bytes_downloaded, 1024);
504 assert_eq!(loaded_progress.cdn_host, "cdn.example.com");
505 }
506
507 #[test]
508 fn test_extract_total_size_from_content_range() {
509 use reqwest::header::{HeaderMap, HeaderValue};
510
511 let client = reqwest::Client::new();
512 let _response = client.get("http://example.com").build().unwrap();
513
514 let mut headers = HeaderMap::new();
516 headers.insert(
517 "content-range",
518 HeaderValue::from_static("bytes 200-1023/2048"),
519 );
520
521 let content_range = "bytes 200-1023/2048";
523 let total: Option<u64> = content_range.split('/').nth(1).and_then(|s| s.parse().ok());
524 assert_eq!(total, Some(2048));
525
526 let content_length = "1024";
528 let length: Option<u64> = content_length.parse().ok();
529 assert_eq!(length, Some(1024));
530 }
531}