tidepool_version_manager/
downloader.rs1use indicatif::{ProgressBar, ProgressStyle};
7use log::{debug, info, warn};
8use reqwest::Client;
9use std::path::Path;
10use std::sync::Arc;
11use std::time::Duration;
12use thiserror::Error;
13use tokio::fs::File;
14use tokio::io::AsyncWriteExt;
15use tokio::sync::Mutex;
16
17#[derive(Error, Debug)]
19pub enum DownloadError {
20 #[error("网络错误: {0}")]
21 Network(#[from] reqwest::Error),
22
23 #[error("IO错误: {0}")]
24 Io(#[from] std::io::Error),
25
26 #[error("无法获取文件大小")]
27 FileSizeUnavailable,
28
29 #[error("服务器不支持范围请求")]
30 RangeNotSupported,
31
32 #[error("分片下载失败: {0}")]
33 ChunkDownloadFailed(String),
34
35 #[error("其他错误: {0}")]
36 Other(String),
37}
38
39pub type DownloadResult<T> = Result<T, DownloadError>;
41
42#[derive(Debug, Clone)]
44pub struct ProgressReporter {
45 progress_bar: ProgressBar,
47}
48
49impl ProgressReporter {
50 #[must_use]
56 pub fn new(total_size: u64) -> Self {
57 let progress_bar = ProgressBar::new(total_size);
58 progress_bar.set_style(
59 ProgressStyle::default_bar()
60 .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
61 .unwrap()
62 .progress_chars("#>-"),
63 );
64
65 Self { progress_bar }
66 }
67
68 pub fn start(&self) {
70 self.progress_bar.tick();
71 }
72
73 pub fn update(&self, bytes_downloaded: u64) {
75 self.progress_bar.set_position(bytes_downloaded);
76 }
77
78 pub fn increment(&self, bytes: u64) {
80 self.progress_bar.inc(bytes);
81 }
82
83 pub fn finish(&self) {
85 self.progress_bar.finish_with_message("下载完成");
86 }
87
88 pub fn set_length(&self, length: u64) {
90 self.progress_bar.set_length(length);
91 }
92 #[must_use]
94 pub fn length(&self) -> Option<u64> {
95 Some(self.progress_bar.length().unwrap_or(0))
96 }
97
98 pub fn set_message(&self, message: &str) {
100 self.progress_bar.set_message(message.to_string());
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct DownloadConfig {
107 pub user_agent: Option<String>,
109 pub timeout_seconds: u64,
111 pub connect_timeout_seconds: u64,
113 pub concurrent_connections: usize,
115 pub min_chunk_size: u64,
117 pub max_retries: usize,
119 pub retry_delay_ms: u64,
121 pub enable_chunked_download: bool,
123}
124
125impl Default for DownloadConfig {
126 fn default() -> Self {
127 Self {
128 user_agent: Some("tidepool-version-manager/0.1.0".to_string()),
129 timeout_seconds: 120, connect_timeout_seconds: 30,
131 concurrent_connections: 4, min_chunk_size: 1024 * 1024, max_retries: 3,
134 retry_delay_ms: 1000,
135 enable_chunked_download: true,
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct Downloader {
143 client: Client,
145 #[allow(dead_code)]
147 config: DownloadConfig,
148}
149
150impl Downloader {
151 #[must_use]
153 pub fn new() -> Self {
154 Self::with_config(DownloadConfig::default())
155 }
156 #[must_use]
162 pub fn with_config(config: DownloadConfig) -> Self {
163 let mut client_builder = Client::builder()
164 .timeout(Duration::from_secs(config.timeout_seconds))
165 .connect_timeout(Duration::from_secs(config.connect_timeout_seconds))
166 .tcp_keepalive(Duration::from_secs(60))
167 .pool_idle_timeout(Duration::from_secs(90))
168 .pool_max_idle_per_host(config.concurrent_connections);
169
170 if let Some(user_agent) = &config.user_agent {
171 client_builder = client_builder.user_agent(user_agent);
172 }
173
174 let client = client_builder.build().expect("Failed to create HTTP client");
175
176 Self { client, config }
177 }
178 pub async fn download<P: AsRef<Path>>(
188 &self,
189 url: &str,
190 output_path: P,
191 progress_reporter: Option<ProgressReporter>,
192 ) -> DownloadResult<()> {
193 let output_path = output_path.as_ref();
194
195 debug!("开始下载: {} -> {}", url, output_path.display());
196
197 let (file_size, supports_ranges) = self.get_file_info(url).await?;
199
200 let reporter = if let Some(reporter) = progress_reporter {
202 reporter.set_length(file_size);
203 Some(reporter)
204 } else {
205 Some(ProgressReporter::new(file_size))
206 };
207
208 if let Some(ref reporter) = reporter {
209 reporter.start();
210 }
211
212 let use_chunked = self.config.enable_chunked_download
214 && supports_ranges
215 && file_size > self.config.min_chunk_size
216 && self.config.concurrent_connections > 1;
217
218 if use_chunked {
219 info!("Using chunked download mode, file size: {file_size} bytes");
220 self.download_chunked(url, output_path, file_size, reporter).await
221 } else {
222 info!("Using single-threaded download mode, file size: {file_size} bytes");
223 self.download_single(url, output_path, reporter).await
224 }
225 }
226
227 async fn get_file_info(&self, url: &str) -> DownloadResult<(u64, bool)> {
229 debug!("Getting file info: {url}");
230
231 let response = self.client.head(url).send().await?;
232
233 if !response.status().is_success() {
234 return Err(DownloadError::Other(format!(
235 "Server returned error status: {}",
236 response.status()
237 )));
238 }
239
240 let file_size = response
241 .headers()
242 .get("content-length")
243 .and_then(|v| v.to_str().ok())
244 .and_then(|s| s.parse::<u64>().ok())
245 .ok_or(DownloadError::FileSizeUnavailable)?;
246 let supports_ranges = response
247 .headers()
248 .get("accept-ranges")
249 .is_some_and(|v| v.to_str().unwrap_or("").to_lowercase() == "bytes");
250
251 debug!("File size: {file_size} bytes, supports ranges: {supports_ranges}");
252 Ok((file_size, supports_ranges))
253 }
254
255 async fn download_single<P: AsRef<Path>>(
257 &self,
258 url: &str,
259 output_path: P,
260 progress_reporter: Option<ProgressReporter>,
261 ) -> DownloadResult<()> {
262 for attempt in 1..=self.config.max_retries {
263 match self.try_download_single(url, &output_path, progress_reporter.as_ref()).await {
264 Ok(()) => {
265 if let Some(ref reporter) = progress_reporter {
266 reporter.finish();
267 }
268 info!("单线程下载完成: {}", output_path.as_ref().display());
269 return Ok(());
270 }
271 Err(e) => {
272 warn!("下载尝试 {}/{} 失败: {}", attempt, self.config.max_retries, e);
273 if attempt < self.config.max_retries {
274 tokio::time::sleep(Duration::from_millis(self.config.retry_delay_ms)).await;
275 } else {
276 return Err(e);
277 }
278 }
279 }
280 }
281 unreachable!()
282 }
283 async fn try_download_single<P: AsRef<Path>>(
285 &self,
286 url: &str,
287 output_path: P,
288 progress_reporter: Option<&ProgressReporter>,
289 ) -> DownloadResult<()> {
290 use futures::stream::StreamExt;
291
292 let output_path = output_path.as_ref();
293
294 let temp_path = output_path.with_extension(match output_path.extension() {
296 Some(ext) => format!("{}.tmp", ext.to_string_lossy()),
297 None => "tmp".to_string(),
298 });
299
300 debug!("下载到临时文件: {}", temp_path.display());
301
302 let response = self.client.get(url).send().await?;
303
304 if !response.status().is_success() {
305 return Err(DownloadError::Other(format!(
306 "Server returned error status: {}",
307 response.status()
308 )));
309 }
310
311 if let Some(parent) = temp_path.parent() {
313 tokio::fs::create_dir_all(parent).await?;
314 }
315
316 let mut file = File::create(&temp_path).await?;
317 let mut downloaded: u64 = 0;
318 let mut stream = response.bytes_stream();
319
320 let download_result = async {
322 while let Some(chunk_result) = stream.next().await {
323 let chunk = chunk_result?;
324 file.write_all(&chunk).await?;
325
326 downloaded += chunk.len() as u64;
327 if let Some(reporter) = progress_reporter {
328 reporter.update(downloaded);
329 }
330 }
331
332 file.flush().await?;
333 file.sync_all().await?; Ok::<(), DownloadError>(())
335 }
336 .await;
337
338 match download_result {
340 Ok(()) => {
341 debug!(
343 "下载完成,重命名文件: {} -> {}",
344 temp_path.display(),
345 output_path.display()
346 );
347 tokio::fs::rename(&temp_path, output_path).await?;
348 info!("文件下载并重命名成功: {}", output_path.display());
349 Ok(())
350 }
351 Err(e) => {
352 warn!("下载失败,清理临时文件: {}", temp_path.display());
354 let _ = tokio::fs::remove_file(&temp_path).await; Err(e)
356 }
357 }
358 }
359 #[allow(clippy::too_many_lines)]
361 async fn download_chunked<P: AsRef<Path>>(
362 &self,
363 url: &str,
364 output_path: P,
365 file_size: u64,
366 progress_reporter: Option<ProgressReporter>,
367 ) -> DownloadResult<()> {
368 use std::cmp::min;
369 use std::sync::Arc;
370 use tokio::fs::OpenOptions;
371 use tokio::sync::Mutex;
372
373 let output_path = output_path.as_ref();
374
375 let temp_path = output_path.with_extension(match output_path.extension() {
377 Some(ext) => format!("{}.tmp", ext.to_string_lossy()),
378 None => "tmp".to_string(),
379 });
380
381 debug!("分片下载到临时文件: {}", temp_path.display());
382
383 if let Some(parent) = temp_path.parent() {
385 tokio::fs::create_dir_all(parent).await?;
386 }
387
388 let file = Arc::new(Mutex::new(
390 OpenOptions::new().create(true).write(true).truncate(true).open(&temp_path).await?,
391 ));
392
393 {
395 #[allow(unused_mut)]
396 let mut file_guard = file.lock().await;
397 file_guard.set_len(file_size).await?;
398 }
399
400 let chunk_size = std::cmp::max(
402 file_size / self.config.concurrent_connections as u64,
403 self.config.min_chunk_size,
404 );
405
406 let mut chunks = Vec::new();
407 let mut start = 0;
408
409 while start < file_size {
410 let end = min(start + chunk_size - 1, file_size - 1);
411 chunks.push((start, end));
412 start = end + 1;
413 }
414
415 info!("分片下载: {} 个分片,每片大约 {} 字节", chunks.len(), chunk_size);
416
417 let progress_counter = Arc::new(Mutex::new(0u64));
419
420 let mut handles = Vec::new();
422 let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.concurrent_connections));
423
424 for (chunk_start, chunk_end) in chunks {
425 let client = self.client.clone();
426 let url = url.to_string();
427 let file = Arc::clone(&file);
428 let progress_counter = Arc::clone(&progress_counter);
429 let progress_reporter = progress_reporter.clone();
430 let semaphore = Arc::clone(&semaphore);
431 let max_retries = self.config.max_retries;
432 let retry_delay = self.config.retry_delay_ms;
433
434 let handle = tokio::spawn(async move {
435 let _permit = semaphore.acquire().await.unwrap();
436
437 for attempt in 1..=max_retries {
438 match Self::download_chunk(
439 &client,
440 &url,
441 chunk_start,
442 chunk_end,
443 Arc::clone(&file),
444 Arc::clone(&progress_counter),
445 progress_reporter.as_ref(),
446 )
447 .await
448 {
449 Ok(()) => return Ok(()),
450 Err(e) => {
451 warn!(
452 "分片 {chunk_start}-{chunk_end} 下载尝试 {attempt}/{max_retries} 失败: {e}"
453 );
454 if attempt < max_retries {
455 tokio::time::sleep(Duration::from_millis(retry_delay)).await;
456 } else {
457 return Err(e);
458 }
459 }
460 }
461 }
462 unreachable!()
463 });
464
465 handles.push(handle);
466 }
467
468 let download_result = async {
470 for handle in handles {
471 handle.await.map_err(|e| DownloadError::Other(format!("任务执行错误: {e}")))??;
472 }
473 Ok::<(), DownloadError>(())
474 }
475 .await;
476
477 match download_result {
479 Ok(()) => {
480 {
482 let mut file_guard = file.lock().await;
483 file_guard.flush().await?;
484 file_guard.sync_all().await?;
485 }
486
487 if let Some(ref reporter) = progress_reporter {
488 reporter.finish();
489 }
490
491 debug!(
493 "分片下载完成,重命名文件: {} -> {}",
494 temp_path.display(),
495 output_path.display()
496 );
497 tokio::fs::rename(&temp_path, output_path).await?;
498 info!("分片文件下载并重命名成功: {}", output_path.display());
499 Ok(())
500 }
501 Err(e) => {
502 warn!("分片下载失败,清理临时文件: {}", temp_path.display());
504 let _ = tokio::fs::remove_file(&temp_path).await; Err(e)
506 }
507 }
508 }
509 async fn download_chunk(
511 client: &Client,
512 url: &str,
513 start: u64,
514 end: u64,
515 file: Arc<Mutex<File>>,
516 progress_counter: Arc<Mutex<u64>>,
517 progress_reporter: Option<&ProgressReporter>,
518 ) -> DownloadResult<()> {
519 use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
520
521 debug!("下载分片: {start}-{end}");
522
523 let range_header = format!("bytes={start}-{end}");
524 let response = client.get(url).header("Range", range_header).send().await?;
525
526 if !response.status().is_success() && response.status().as_u16() != 206 {
527 return Err(DownloadError::ChunkDownloadFailed(format!(
528 "分片下载失败,状态码: {}",
529 response.status()
530 )));
531 }
532
533 let chunk_data = response.bytes().await?;
534
535 {
537 let mut file_guard = file.lock().await;
538 file_guard.seek(SeekFrom::Start(start)).await?;
539 file_guard.write_all(&chunk_data).await?;
540 file_guard.flush().await?;
541 }
542
543 {
545 let mut counter = progress_counter.lock().await;
546 *counter += chunk_data.len() as u64;
547 if let Some(reporter) = progress_reporter {
548 reporter.update(*counter);
549 }
550 }
551
552 debug!("分片 {start}-{end} 下载完成");
553 Ok(())
554 }
555
556 pub async fn get_file_size(&self, url: &str) -> DownloadResult<u64> {
566 let (file_size, _) = self.get_file_info(url).await?;
567 Ok(file_size)
568 }
569}
570
571impl Default for Downloader {
572 fn default() -> Self {
573 Self::new()
574 }
575}