yta_rs/
util.rs

1use std::{path::Path, sync::Arc};
2
3use reqwest_cookie_store::CookieStoreMutex;
4use reqwest_middleware::ClientWithMiddleware;
5use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
6use tokio::{fs::File, io::AsyncWriteExt, try_join};
7
8use crate::dash::Representation;
9
10pub struct HttpClient {
11    pub client: ClientWithMiddleware,
12    pub cookies: Arc<CookieStoreMutex>,
13}
14
15#[derive(thiserror::Error, Debug)]
16pub enum DownloadError {
17    #[error("reqwest error: {0}")]
18    ReqwestError(#[from] reqwest::Error),
19    #[error("reqwest middleware error: {0}")]
20    ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
21    #[error("io error: {0}")]
22    IoError(#[from] std::io::Error),
23}
24
25impl HttpClient {
26    pub fn new() -> reqwest::Result<HttpClient> {
27        let cookies = Arc::new(CookieStoreMutex::default());
28        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
29
30        let client = reqwest::Client::builder()
31            .cookie_provider(cookies.clone())
32            .build()?;
33
34        let client = reqwest_middleware::ClientBuilder::new(client)
35            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
36            .build();
37
38        Ok(HttpClient { client, cookies })
39    }
40
41    pub async fn download_file(&self, url: &str, path: &str) -> Result<usize, DownloadError> {
42        let temp_path = format!("{}.tmp", path);
43        let mut file = File::create(&temp_path).await?;
44        let mut resp = self.client.get(url).send().await?;
45        let mut size = 0;
46
47        while let Some(chunk) = resp.chunk().await? {
48            file.write_all(&chunk).await?;
49            size += chunk.len();
50        }
51
52        file.flush().await?;
53        std::fs::rename(temp_path, path)?;
54
55        Ok(size)
56    }
57
58    pub async fn fetch_text(&self, url: &str) -> Result<String, DownloadError> {
59        self.client
60            .get(url)
61            .send()
62            .await?
63            .text()
64            .await
65            .map_err(|e| e.into())
66    }
67}
68
69pub async fn download_av_segment(
70    client: &HttpClient,
71    outdir: &Path,
72    audio: &Representation,
73    video: &Representation,
74    seq: i64,
75) -> Result<(String, String, usize), DownloadError> {
76    let (url_audio, url_video) = (audio.get_url(seq), video.get_url(seq));
77    let (fname_audio, fname_video) = (
78        format!("seq_{:.6}.a{}.mp4", seq, audio.id),
79        format!("seq_{:.6}.v{}.mp4", seq, video.id),
80    );
81
82    let dl_audio = async {
83        let path_audio = outdir.join(&fname_audio);
84        if let Ok(res) = tokio::fs::try_exists(&path_audio).await {
85            if res {
86                return Ok(0);
87            }
88        }
89
90        client
91            .download_file(&url_audio, &path_audio.to_string_lossy())
92            .await
93    };
94    let dl_video = async {
95        let path_video = outdir.join(&fname_video);
96        if let Ok(res) = tokio::fs::try_exists(&path_video).await {
97            if res {
98                return Ok(0);
99            }
100        }
101
102        client
103            .download_file(&url_video, &path_video.to_string_lossy())
104            .await
105    };
106    let (sz_audio, sz_video) = try_join!(dl_audio, dl_video)?;
107
108    Ok((fname_audio, fname_video, sz_audio + sz_video))
109}
110
111pub fn format_bytes(bytes: u64) -> String {
112    let mut bytes = bytes as f64;
113    let mut suffix = "B";
114
115    if bytes > 1024.0 {
116        bytes /= 1024.0;
117        suffix = "KiB";
118    }
119    if bytes > 1024.0 {
120        bytes /= 1024.0;
121        suffix = "MiB";
122    }
123    if bytes > 1024.0 {
124        bytes /= 1024.0;
125        suffix = "GiB";
126    }
127    if bytes > 1024.0 {
128        bytes /= 1024.0;
129        suffix = "TiB";
130    }
131
132    format!("{:.2} {}", bytes, suffix)
133}