tidepool_version_manager/
downloader.rs

1//! 内置下载器模块
2//!
3//! 集成到 version-manager 中的下载器,提供文件下载功能
4//! 支持分片下载、多线程下载和断点续传
5
6use indicatif::{ProgressBar, ProgressStyle};
7use log::{debug, info, warn};
8use reqwest::Client;
9use std::path::Path;
10use std::sync::Arc;
11use std::time::Duration;
12use thiserror::Error;
13use tokio::fs::File;
14use tokio::io::AsyncWriteExt;
15use tokio::sync::Mutex;
16
17/// 下载器错误类型
18#[derive(Error, Debug)]
19pub enum DownloadError {
20    #[error("网络错误: {0}")]
21    Network(#[from] reqwest::Error),
22
23    #[error("IO错误: {0}")]
24    Io(#[from] std::io::Error),
25
26    #[error("无法获取文件大小")]
27    FileSizeUnavailable,
28
29    #[error("服务器不支持范围请求")]
30    RangeNotSupported,
31
32    #[error("分片下载失败: {0}")]
33    ChunkDownloadFailed(String),
34
35    #[error("其他错误: {0}")]
36    Other(String),
37}
38
39/// 下载结果类型
40pub type DownloadResult<T> = Result<T, DownloadError>;
41
42/// 进度报告器
43#[derive(Debug, Clone)]
44pub struct ProgressReporter {
45    /// 主进度条
46    progress_bar: ProgressBar,
47}
48
49impl ProgressReporter {
50    /// 创建新的进度报告器
51    ///
52    /// # Panics
53    ///
54    /// Panics if the progress bar template is invalid (which should not happen with the predefined template)
55    #[must_use]
56    pub fn new(total_size: u64) -> Self {
57        let progress_bar = ProgressBar::new(total_size);
58        progress_bar.set_style(
59            ProgressStyle::default_bar()
60                .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
61                .unwrap()
62                .progress_chars("#>-"),
63        );
64
65        Self { progress_bar }
66    }
67
68    /// 开始下载
69    pub fn start(&self) {
70        self.progress_bar.tick();
71    }
72
73    /// 更新进度
74    pub fn update(&self, bytes_downloaded: u64) {
75        self.progress_bar.set_position(bytes_downloaded);
76    }
77
78    /// 增加进度
79    pub fn increment(&self, bytes: u64) {
80        self.progress_bar.inc(bytes);
81    }
82
83    /// 完成下载
84    pub fn finish(&self) {
85        self.progress_bar.finish_with_message("下载完成");
86    }
87
88    /// 设置长度
89    pub fn set_length(&self, length: u64) {
90        self.progress_bar.set_length(length);
91    }
92    /// 获取长度
93    #[must_use]
94    pub fn length(&self) -> Option<u64> {
95        Some(self.progress_bar.length().unwrap_or(0))
96    }
97
98    /// 设置消息
99    pub fn set_message(&self, message: &str) {
100        self.progress_bar.set_message(message.to_string());
101    }
102}
103
104/// 下载配置
105#[derive(Debug, Clone)]
106pub struct DownloadConfig {
107    /// 用户代理
108    pub user_agent: Option<String>,
109    /// 请求超时时间(秒)
110    pub timeout_seconds: u64,
111    /// 连接超时时间(秒)
112    pub connect_timeout_seconds: u64,
113    /// 并发下载的线程数
114    pub concurrent_connections: usize,
115    /// 每个分片的最小大小(字节)
116    pub min_chunk_size: u64,
117    /// 重试次数
118    pub max_retries: usize,
119    /// 重试间隔(毫秒)
120    pub retry_delay_ms: u64,
121    /// 是否启用分片下载
122    pub enable_chunked_download: bool,
123}
124
125impl Default for DownloadConfig {
126    fn default() -> Self {
127        Self {
128            user_agent: Some("tidepool-version-manager/0.1.0".to_string()),
129            timeout_seconds: 120, // 增加超时时间
130            connect_timeout_seconds: 30,
131            concurrent_connections: 4,   // 默认4个并发连接
132            min_chunk_size: 1024 * 1024, // 1MB 最小分片
133            max_retries: 3,
134            retry_delay_ms: 1000,
135            enable_chunked_download: true,
136        }
137    }
138}
139
140/// 简化的下载器
141#[derive(Debug, Clone)]
142pub struct Downloader {
143    /// HTTP客户端
144    client: Client,
145    /// 下载配置(保留以供将来扩展)
146    #[allow(dead_code)]
147    config: DownloadConfig,
148}
149
150impl Downloader {
151    /// 创建新的下载器
152    #[must_use]
153    pub fn new() -> Self {
154        Self::with_config(DownloadConfig::default())
155    }
156    /// 使用指定配置创建下载器
157    ///
158    /// # Panics
159    ///
160    /// Panics if the HTTP client cannot be created (should not happen with valid configuration)
161    #[must_use]
162    pub fn with_config(config: DownloadConfig) -> Self {
163        let mut client_builder = Client::builder()
164            .timeout(Duration::from_secs(config.timeout_seconds))
165            .connect_timeout(Duration::from_secs(config.connect_timeout_seconds))
166            .tcp_keepalive(Duration::from_secs(60))
167            .pool_idle_timeout(Duration::from_secs(90))
168            .pool_max_idle_per_host(config.concurrent_connections);
169
170        if let Some(user_agent) = &config.user_agent {
171            client_builder = client_builder.user_agent(user_agent);
172        }
173
174        let client = client_builder.build().expect("Failed to create HTTP client");
175
176        Self { client, config }
177    }
178    /// 下载文件(支持分片下载和多线程)
179    ///
180    /// # Errors
181    ///
182    /// Returns an error if:
183    /// - Network request fails
184    /// - File I/O operations fail
185    /// - Download validation fails
186    /// - Server does not support range requests when chunked download is attempted
187    pub async fn download<P: AsRef<Path>>(
188        &self,
189        url: &str,
190        output_path: P,
191        progress_reporter: Option<ProgressReporter>,
192    ) -> DownloadResult<()> {
193        let output_path = output_path.as_ref();
194
195        debug!("开始下载: {} -> {}", url, output_path.display());
196
197        // 获取文件大小和检查是否支持范围请求
198        let (file_size, supports_ranges) = self.get_file_info(url).await?;
199
200        // 创建或更新进度报告器
201        let reporter = if let Some(reporter) = progress_reporter {
202            reporter.set_length(file_size);
203            Some(reporter)
204        } else {
205            Some(ProgressReporter::new(file_size))
206        };
207
208        if let Some(ref reporter) = reporter {
209            reporter.start();
210        }
211
212        // 决定使用分片下载还是单线程下载
213        let use_chunked = self.config.enable_chunked_download
214            && supports_ranges
215            && file_size > self.config.min_chunk_size
216            && self.config.concurrent_connections > 1;
217
218        if use_chunked {
219            info!("Using chunked download mode, file size: {file_size} bytes");
220            self.download_chunked(url, output_path, file_size, reporter).await
221        } else {
222            info!("Using single-threaded download mode, file size: {file_size} bytes");
223            self.download_single(url, output_path, reporter).await
224        }
225    }
226
227    /// 获取文件信息(大小和是否支持范围请求)
228    async fn get_file_info(&self, url: &str) -> DownloadResult<(u64, bool)> {
229        debug!("Getting file info: {url}");
230
231        let response = self.client.head(url).send().await?;
232
233        if !response.status().is_success() {
234            return Err(DownloadError::Other(format!(
235                "Server returned error status: {}",
236                response.status()
237            )));
238        }
239
240        let file_size = response
241            .headers()
242            .get("content-length")
243            .and_then(|v| v.to_str().ok())
244            .and_then(|s| s.parse::<u64>().ok())
245            .ok_or(DownloadError::FileSizeUnavailable)?;
246        let supports_ranges = response
247            .headers()
248            .get("accept-ranges")
249            .is_some_and(|v| v.to_str().unwrap_or("").to_lowercase() == "bytes");
250
251        debug!("File size: {file_size} bytes, supports ranges: {supports_ranges}");
252        Ok((file_size, supports_ranges))
253    }
254
255    /// 单线程下载
256    async fn download_single<P: AsRef<Path>>(
257        &self,
258        url: &str,
259        output_path: P,
260        progress_reporter: Option<ProgressReporter>,
261    ) -> DownloadResult<()> {
262        for attempt in 1..=self.config.max_retries {
263            match self.try_download_single(url, &output_path, progress_reporter.as_ref()).await {
264                Ok(()) => {
265                    if let Some(ref reporter) = progress_reporter {
266                        reporter.finish();
267                    }
268                    info!("单线程下载完成: {}", output_path.as_ref().display());
269                    return Ok(());
270                }
271                Err(e) => {
272                    warn!("下载尝试 {}/{} 失败: {}", attempt, self.config.max_retries, e);
273                    if attempt < self.config.max_retries {
274                        tokio::time::sleep(Duration::from_millis(self.config.retry_delay_ms)).await;
275                    } else {
276                        return Err(e);
277                    }
278                }
279            }
280        }
281        unreachable!()
282    }
283    /// 单次下载尝试
284    async fn try_download_single<P: AsRef<Path>>(
285        &self,
286        url: &str,
287        output_path: P,
288        progress_reporter: Option<&ProgressReporter>,
289    ) -> DownloadResult<()> {
290        use futures::stream::StreamExt;
291
292        let output_path = output_path.as_ref();
293
294        // 创建临时文件路径,添加 .tmp 后缀
295        let temp_path = output_path.with_extension(match output_path.extension() {
296            Some(ext) => format!("{}.tmp", ext.to_string_lossy()),
297            None => "tmp".to_string(),
298        });
299
300        debug!("下载到临时文件: {}", temp_path.display());
301
302        let response = self.client.get(url).send().await?;
303
304        if !response.status().is_success() {
305            return Err(DownloadError::Other(format!(
306                "Server returned error status: {}",
307                response.status()
308            )));
309        }
310
311        // 确保父目录存在
312        if let Some(parent) = temp_path.parent() {
313            tokio::fs::create_dir_all(parent).await?;
314        }
315
316        let mut file = File::create(&temp_path).await?;
317        let mut downloaded: u64 = 0;
318        let mut stream = response.bytes_stream();
319
320        // 下载过程中如果出错,确保清理临时文件
321        let download_result = async {
322            while let Some(chunk_result) = stream.next().await {
323                let chunk = chunk_result?;
324                file.write_all(&chunk).await?;
325
326                downloaded += chunk.len() as u64;
327                if let Some(reporter) = progress_reporter {
328                    reporter.update(downloaded);
329                }
330            }
331
332            file.flush().await?;
333            file.sync_all().await?; // 确保数据写入磁盘
334            Ok::<(), DownloadError>(())
335        }
336        .await;
337
338        // 处理下载结果
339        match download_result {
340            Ok(()) => {
341                // 下载成功,将临时文件重命名为目标文件
342                debug!(
343                    "下载完成,重命名文件: {} -> {}",
344                    temp_path.display(),
345                    output_path.display()
346                );
347                tokio::fs::rename(&temp_path, output_path).await?;
348                info!("文件下载并重命名成功: {}", output_path.display());
349                Ok(())
350            }
351            Err(e) => {
352                // 下载失败,清理临时文件
353                warn!("下载失败,清理临时文件: {}", temp_path.display());
354                let _ = tokio::fs::remove_file(&temp_path).await; // 忽略删除错误
355                Err(e)
356            }
357        }
358    }
359    /// 分片下载
360    #[allow(clippy::too_many_lines)]
361    async fn download_chunked<P: AsRef<Path>>(
362        &self,
363        url: &str,
364        output_path: P,
365        file_size: u64,
366        progress_reporter: Option<ProgressReporter>,
367    ) -> DownloadResult<()> {
368        use std::cmp::min;
369        use std::sync::Arc;
370        use tokio::fs::OpenOptions;
371        use tokio::sync::Mutex;
372
373        let output_path = output_path.as_ref();
374
375        // 创建临时文件路径,添加 .tmp 后缀
376        let temp_path = output_path.with_extension(match output_path.extension() {
377            Some(ext) => format!("{}.tmp", ext.to_string_lossy()),
378            None => "tmp".to_string(),
379        });
380
381        debug!("分片下载到临时文件: {}", temp_path.display());
382
383        // 确保父目录存在
384        if let Some(parent) = temp_path.parent() {
385            tokio::fs::create_dir_all(parent).await?;
386        }
387
388        // 创建临时文件
389        let file = Arc::new(Mutex::new(
390            OpenOptions::new().create(true).write(true).truncate(true).open(&temp_path).await?,
391        ));
392
393        // 预分配文件空间
394        {
395            #[allow(unused_mut)]
396            let mut file_guard = file.lock().await;
397            file_guard.set_len(file_size).await?;
398        }
399
400        // 计算分片大小和数量
401        let chunk_size = std::cmp::max(
402            file_size / self.config.concurrent_connections as u64,
403            self.config.min_chunk_size,
404        );
405
406        let mut chunks = Vec::new();
407        let mut start = 0;
408
409        while start < file_size {
410            let end = min(start + chunk_size - 1, file_size - 1);
411            chunks.push((start, end));
412            start = end + 1;
413        }
414
415        info!("分片下载: {} 个分片,每片大约 {} 字节", chunks.len(), chunk_size);
416
417        // 共享进度计数器
418        let progress_counter = Arc::new(Mutex::new(0u64));
419
420        // 启动下载任务
421        let mut handles = Vec::new();
422        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.concurrent_connections));
423
424        for (chunk_start, chunk_end) in chunks {
425            let client = self.client.clone();
426            let url = url.to_string();
427            let file = Arc::clone(&file);
428            let progress_counter = Arc::clone(&progress_counter);
429            let progress_reporter = progress_reporter.clone();
430            let semaphore = Arc::clone(&semaphore);
431            let max_retries = self.config.max_retries;
432            let retry_delay = self.config.retry_delay_ms;
433
434            let handle = tokio::spawn(async move {
435                let _permit = semaphore.acquire().await.unwrap();
436
437                for attempt in 1..=max_retries {
438                    match Self::download_chunk(
439                        &client,
440                        &url,
441                        chunk_start,
442                        chunk_end,
443                        Arc::clone(&file),
444                        Arc::clone(&progress_counter),
445                        progress_reporter.as_ref(),
446                    )
447                    .await
448                    {
449                        Ok(()) => return Ok(()),
450                        Err(e) => {
451                            warn!(
452                                "分片 {chunk_start}-{chunk_end} 下载尝试 {attempt}/{max_retries} 失败: {e}"
453                            );
454                            if attempt < max_retries {
455                                tokio::time::sleep(Duration::from_millis(retry_delay)).await;
456                            } else {
457                                return Err(e);
458                            }
459                        }
460                    }
461                }
462                unreachable!()
463            });
464
465            handles.push(handle);
466        }
467
468        // 等待所有分片下载完成
469        let download_result = async {
470            for handle in handles {
471                handle.await.map_err(|e| DownloadError::Other(format!("任务执行错误: {e}")))??;
472            }
473            Ok::<(), DownloadError>(())
474        }
475        .await;
476
477        // 处理下载结果
478        match download_result {
479            Ok(()) => {
480                // 确保文件数据写入磁盘
481                {
482                    let mut file_guard = file.lock().await;
483                    file_guard.flush().await?;
484                    file_guard.sync_all().await?;
485                }
486
487                if let Some(ref reporter) = progress_reporter {
488                    reporter.finish();
489                }
490
491                // 下载成功,将临时文件重命名为目标文件
492                debug!(
493                    "分片下载完成,重命名文件: {} -> {}",
494                    temp_path.display(),
495                    output_path.display()
496                );
497                tokio::fs::rename(&temp_path, output_path).await?;
498                info!("分片文件下载并重命名成功: {}", output_path.display());
499                Ok(())
500            }
501            Err(e) => {
502                // 下载失败,清理临时文件
503                warn!("分片下载失败,清理临时文件: {}", temp_path.display());
504                let _ = tokio::fs::remove_file(&temp_path).await; // 忽略删除错误
505                Err(e)
506            }
507        }
508    }
509    /// 下载单个分片
510    async fn download_chunk(
511        client: &Client,
512        url: &str,
513        start: u64,
514        end: u64,
515        file: Arc<Mutex<File>>,
516        progress_counter: Arc<Mutex<u64>>,
517        progress_reporter: Option<&ProgressReporter>,
518    ) -> DownloadResult<()> {
519        use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
520
521        debug!("下载分片: {start}-{end}");
522
523        let range_header = format!("bytes={start}-{end}");
524        let response = client.get(url).header("Range", range_header).send().await?;
525
526        if !response.status().is_success() && response.status().as_u16() != 206 {
527            return Err(DownloadError::ChunkDownloadFailed(format!(
528                "分片下载失败,状态码: {}",
529                response.status()
530            )));
531        }
532
533        let chunk_data = response.bytes().await?;
534
535        // 写入文件
536        {
537            let mut file_guard = file.lock().await;
538            file_guard.seek(SeekFrom::Start(start)).await?;
539            file_guard.write_all(&chunk_data).await?;
540            file_guard.flush().await?;
541        }
542
543        // 更新进度
544        {
545            let mut counter = progress_counter.lock().await;
546            *counter += chunk_data.len() as u64;
547            if let Some(reporter) = progress_reporter {
548                reporter.update(*counter);
549            }
550        }
551
552        debug!("分片 {start}-{end} 下载完成");
553        Ok(())
554    }
555
556    /// 获取文件大小
557    /// 获取文件大小
558    ///
559    /// # Errors
560    ///
561    /// Returns an error if:
562    /// - Network request fails
563    /// - Server does not provide content-length header
564    /// - Content-length value is invalid
565    pub async fn get_file_size(&self, url: &str) -> DownloadResult<u64> {
566        let (file_size, _) = self.get_file_info(url).await?;
567        Ok(file_size)
568    }
569}
570
571impl Default for Downloader {
572    fn default() -> Self {
573        Self::new()
574    }
575}