1use std::{path::Path, sync::Arc};
2
3use reqwest_cookie_store::CookieStoreMutex;
4use reqwest_middleware::ClientWithMiddleware;
5use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
6use tokio::{fs::File, io::AsyncWriteExt, try_join};
7
8use crate::dash::Representation;
9
10pub struct HttpClient {
11 pub client: ClientWithMiddleware,
12 pub cookies: Arc<CookieStoreMutex>,
13}
14
15#[derive(thiserror::Error, Debug)]
16pub enum DownloadError {
17 #[error("reqwest error: {0}")]
18 ReqwestError(#[from] reqwest::Error),
19 #[error("reqwest middleware error: {0}")]
20 ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
21 #[error("io error: {0}")]
22 IoError(#[from] std::io::Error),
23}
24
25impl HttpClient {
26 pub fn new() -> reqwest::Result<HttpClient> {
27 let cookies = Arc::new(CookieStoreMutex::default());
28 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
29
30 let client = reqwest::Client::builder()
31 .cookie_provider(cookies.clone())
32 .build()?;
33
34 let client = reqwest_middleware::ClientBuilder::new(client)
35 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
36 .build();
37
38 Ok(HttpClient { client, cookies })
39 }
40
41 pub async fn download_file(&self, url: &str, path: &str) -> Result<usize, DownloadError> {
42 let temp_path = format!("{}.tmp", path);
43 let mut file = File::create(&temp_path).await?;
44 let mut resp = self.client.get(url).send().await?;
45 let mut size = 0;
46
47 while let Some(chunk) = resp.chunk().await? {
48 file.write_all(&chunk).await?;
49 size += chunk.len();
50 }
51
52 file.flush().await?;
53 std::fs::rename(temp_path, path)?;
54
55 Ok(size)
56 }
57
58 pub async fn fetch_text(&self, url: &str) -> Result<String, DownloadError> {
59 self.client
60 .get(url)
61 .send()
62 .await?
63 .text()
64 .await
65 .map_err(|e| e.into())
66 }
67}
68
69pub async fn download_av_segment(
70 client: &HttpClient,
71 outdir: &Path,
72 audio: &Representation,
73 video: &Representation,
74 seq: i64,
75) -> Result<(String, String, usize), DownloadError> {
76 let (url_audio, url_video) = (audio.get_url(seq), video.get_url(seq));
77 let (fname_audio, fname_video) = (
78 format!("seq_{:.6}.a{}.mp4", seq, audio.id),
79 format!("seq_{:.6}.v{}.mp4", seq, video.id),
80 );
81
82 let dl_audio = async {
83 let path_audio = outdir.join(&fname_audio);
84 if let Ok(res) = tokio::fs::try_exists(&path_audio).await {
85 if res {
86 return Ok(0);
87 }
88 }
89
90 client
91 .download_file(&url_audio, &path_audio.to_string_lossy())
92 .await
93 };
94 let dl_video = async {
95 let path_video = outdir.join(&fname_video);
96 if let Ok(res) = tokio::fs::try_exists(&path_video).await {
97 if res {
98 return Ok(0);
99 }
100 }
101
102 client
103 .download_file(&url_video, &path_video.to_string_lossy())
104 .await
105 };
106 let (sz_audio, sz_video) = try_join!(dl_audio, dl_video)?;
107
108 Ok((fname_audio, fname_video, sz_audio + sz_video))
109}
110
111pub fn format_bytes(bytes: u64) -> String {
112 let mut bytes = bytes as f64;
113 let mut suffix = "B";
114
115 if bytes > 1024.0 {
116 bytes /= 1024.0;
117 suffix = "KiB";
118 }
119 if bytes > 1024.0 {
120 bytes /= 1024.0;
121 suffix = "MiB";
122 }
123 if bytes > 1024.0 {
124 bytes /= 1024.0;
125 suffix = "GiB";
126 }
127 if bytes > 1024.0 {
128 bytes /= 1024.0;
129 suffix = "TiB";
130 }
131
132 format!("{:.2} {}", bytes, suffix)
133}