uranium_rs/downloaders/
gen_downloader.rs

1use std::fs::create_dir_all;
2use std::sync::Arc;
3use std::{
4    collections::VecDeque,
5    path::{Path, PathBuf},
6};
7
8use futures::{future::join_all, StreamExt};
9use log::{error, info};
10use reqwest::Response;
11use sha1::Digest;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore};
13use tokio::{io::AsyncWriteExt, task::JoinHandle};
14
15use crate::error::Result;
16use crate::{code_functions::N_THREADS, error::UraniumError};
17
18/// Download files asynchronously.
19///
20/// This trait allows the user to make their own `FileDownloader` and use it
21/// with the different downloader such us:
22/// - `MinecraftDownloader`
23/// - `RinthDownloader`
24/// - `CurseDownloader`
25#[allow(async_fn_in_trait)]
26pub trait FileDownloader {
27    /// Builds a new struct from a vec of `DownlodableObject`s.
28    fn new(files: Vec<DownloadableObject>) -> Self;
29
30    /// This method is responsible for managing the progress of downloads and
31    /// tasks in the Uranium library.
32    ///
33    /// It returns the current `DownloadState`, which represents the state of
34    /// the download process.
35    ///
36    /// If there are pending `DownlodableObject` and the number of active tasks
37    /// is less than the maximum allowed threads, this method will make
38    /// additional requests to fetch data.
39    ///
40    /// If there are active tasks, it will check their status and handle
41    /// completed tasks accordingly.
42    ///
43    /// # Errors
44    ///
45    /// This method can return an error of type `UraniumError` in the following
46    /// cases:
47    ///
48    /// - If there is an error while making requests or processing tasks.
49    ///
50    /// # Returns
51    ///
52    /// This method returns a `Result<DownloadState, UraniumError>`, where
53    /// `DownloadState` represents the current state of the download process,
54    /// and `UraniumError` is the error type that occurs in case of failure.
55    async fn progress(&mut self) -> Result<DownloadState>;
56
57    /// This method calls `Self::progress()` repeatedly until it returns
58    /// `DownloadState::Completed`.
59    ///
60    /// # Errors
61    ///
62    /// This method can return an error of type `UraniumError` in the following
63    /// cases:
64    ///
65    /// - If there is an error while making requests or processing tasks.
66    ///
67    /// # Returns
68    ///
69    /// This method returns a `Result<(), UraniumError>`.
70    async fn complete(&mut self) -> Result<()> {
71        loop {
72            match self.progress().await {
73                Err(e) => return Err(e),
74                Ok(DownloadState::Completed) => return Ok(()),
75                Ok(_) => {}
76            }
77        }
78    }
79
80    /// Return how many requests are left.
81    ///
82    /// This method is important when it comes to know the % of the
83    /// already downloaded files.
84    fn requests_left(&self) -> usize;
85
86    /// Return how many requests the downloader has.
87    fn len(&self) -> usize;
88
89    /// Adds a single `DownloadableObject` to the downloader's queue.
90    ///
91    /// This method allows you to dynamically add new download tasks to an
92    /// existing downloader instance. The object will be queued for download
93    /// and processed according to the downloader's scheduling logic.
94    fn add_object(&mut self, obj: DownloadableObject);
95
96    /// Adds multiple `DownloadableObject`s to the downloader's queue.
97    ///
98    /// This is a convenience method that accepts any iterator of
99    /// `DownloadableObject`s and adds them all to the download queue.
100    /// Internally, it calls `add_object` for each item in the iterator.
101    fn add_objects<T>(&mut self, objs: T)
102    where
103        T: IntoIterator<Item = DownloadableObject>,
104    {
105        objs.into_iter()
106            .for_each(|f| self.add_object(f));
107    }
108
109    /// Returns `true` if the downloader has no downloadable objects.
110    fn is_empty(&self) -> bool {
111        self.len() == 0
112    }
113}
114
115/// Indicates the state of the downloader
116#[derive(Debug)]
117pub enum DownloadState {
118    MakingRequests,
119    Downloading,
120    Completed,
121}
122
123// TODO! : Add Sha5
124/// Indicates which hash the file uses for verification.
125#[derive(Debug, Clone)]
126pub enum HashType {
127    Sha1(String),
128}
129
130/// Simple struct with the necessary data to download a file
131///
132/// Fields:
133/// - url : http://somerandomurl.com
134/// - name: my_filename.whatever
135/// - path: /path/to/something/mods
136///
137/// The join between path and name MUST result in the final path e.g:
138///
139/// `name`: MyMinecraftMod.jar <br>
140/// `path`: /home/sergio/.minecraft/Fabric1.18/mods/
141#[derive(Debug, Clone)]
142pub struct DownloadableObject {
143    pub url: String,
144    pub path: PathBuf,
145    pub hash: Option<HashType>,
146}
147
148impl DownloadableObject {
149    pub fn new(url: &str, path: &Path, hash: Option<HashType>) -> Self {
150        Self {
151            url: url.to_owned(),
152            path: path.to_owned(),
153            hash,
154        }
155    }
156
157    pub fn name(&self) -> Option<&str> {
158        self.path
159            .file_name()
160            .and_then(|f| f.to_str())
161    }
162}
163
164/// Basic downloader
165///
166/// `Downloader` is a basic implementation of `FileDownloader` trait.
167///
168/// It uses `reqwest::Client` for the HTTP requests.
169pub struct Downloader {
170    files: Vec<DownloadableObject>,
171    requester: reqwest::Client,
172    start: usize,
173    s: Arc<Semaphore>,
174    tasks: VecDeque<JoinHandle<Result<()>>>,
175}
176
177impl FileDownloader for Downloader {
178    fn new(files: Vec<DownloadableObject>) -> Self {
179        let n_files = files.len();
180        info!("{n_files} files to download");
181        info!("{} available permits", N_THREADS());
182
183        let client = reqwest::ClientBuilder::new()
184            .build()
185            .expect("Error while creating the Downloader client, please report this error.");
186
187        Downloader {
188            files,
189            requester: client,
190            start: 0,
191            s: Arc::new(Semaphore::new(N_THREADS())),
192            tasks: VecDeque::with_capacity(n_files),
193        }
194    }
195
196    async fn progress(&mut self) -> Result<DownloadState> {
197        while self.start != self.files.len() && self.s.available_permits() > 0 {
198            self.make_requests().await?;
199        }
200
201        if !self.tasks.is_empty() {
202            let mut guard = true;
203            let mut i = 0;
204            while guard {
205                guard = false;
206                // SAFETY: There is no way this unwraps fails since we are
207                // iterating over the len of the queue and no other thread
208                // is modifying the queue, also the queue is not empty.
209                if self
210                    .tasks
211                    .get(i)
212                    .unwrap()
213                    .is_finished()
214                {
215                    let task = self.tasks.remove(i).unwrap();
216                    guard = true;
217                    match task.await? {
218                        Err(UraniumError::FilesDontMatch(objects)) => {
219                            error!("Trying again {} files", objects.len());
220                            self.files.extend(objects);
221                        }
222                        Err(e) => Err(e)?,
223                        Ok(_) => {}
224                    }
225                    break;
226                }
227
228                i = (i + 1) % self.tasks.len();
229            }
230
231            if guard {
232                return Ok(DownloadState::Downloading);
233            }
234
235            // In case no task is finished yet, we wait for the first one
236            if !self.tasks.is_empty() {
237                info!("Waiting the first one...");
238                // let _ = join_all(&mut self.tasks).await;
239                // self.tasks.clear();
240                // UNWRAP SAFETY: Can't be empty since we are checking.
241                match self
242                    .tasks
243                    .pop_front()
244                    .unwrap()
245                    .await?
246                {
247                    Err(UraniumError::FilesDontMatch(objects)) => self.files.extend(objects),
248                    Err(e) => Err(e)?,
249                    _ => {}
250                };
251                return Ok(DownloadState::Downloading);
252            }
253        }
254        Ok(DownloadState::Completed)
255    }
256
257    /// Returns how many requests are left.
258    fn requests_left(&self) -> usize {
259        self.files.len() - self.start + self.tasks.len()
260    }
261
262    /// Returns how many files are in the files vector. The already downloaded
263    /// files are also taking into account.
264    fn len(&self) -> usize {
265        self.files.len()
266    }
267
268    /// Add an object to the files vector.
269    fn add_object(&mut self, obj: DownloadableObject) {
270        self.files.push(obj);
271    }
272
273    fn add_objects<T>(&mut self, objs: T)
274    where
275        T: IntoIterator<Item = DownloadableObject>,
276    {
277        self.files.extend(objs);
278    }
279}
280
281impl Downloader {
282    /// Improved semaphore acquisition with proper error handling
283    async fn acquire_semaphore(&self) -> Result<OwnedSemaphorePermit> {
284        self.s
285            .clone()
286            .acquire_owned()
287            .await
288            .map_err(|e| UraniumError::other(&format!("Failed to acquire semaphore: {e}")))
289    }
290
291    async fn get_next_chunk(&mut self) -> Vec<DownloadableObject> {
292        const DEFAULT_CHUNK_SIZE: usize = 16;
293
294        let remaining = self.files.len() - self.start;
295        if remaining == 0 {
296            return vec![];
297        }
298
299        let chunk_size = DEFAULT_CHUNK_SIZE.min(remaining);
300        let end = self.start + chunk_size;
301
302        let mut objects = vec![];
303
304        loop {
305            if objects.len() >= DEFAULT_CHUNK_SIZE {
306                break;
307            }
308
309            if self.start == end {
310                break;
311            }
312
313            let obj = &self.files[self.start];
314
315            // Check if the file already exists so we can skit it.
316            if let Ok(true) = verify_file_hash(&obj.path, &obj.hash).await {
317                info!("Skipping {:?}, already exists", obj.path);
318            } else {
319                objects.push(obj.clone());
320            }
321            self.start += 1;
322        }
323        objects
324    }
325
326    async fn make_requests(&mut self) -> Result<DownloadState> {
327        let chunk = self.get_next_chunk().await;
328        if chunk.is_empty() {
329            return Ok(DownloadState::Completed);
330        }
331
332        let sem = self
333            .acquire_semaphore()
334            .await?;
335        let client = self.requester.clone();
336        let task = tokio::spawn(async move { download_and_write(chunk, client, sem).await });
337
338        info!("Pushing new task {}", self.start);
339        self.tasks.push_back(task);
340        Ok(DownloadState::MakingRequests)
341    }
342}
343
344async fn download_and_write(
345    objects: Vec<DownloadableObject>,
346    requester: reqwest::Client,
347    _sem: OwnedSemaphorePermit,
348) -> Result<()> {
349    let x = objects
350        .into_iter()
351        .map(|obj| async {
352            let response = match requester
353                .get(&obj.url)
354                .send()
355                .await
356            {
357                Ok(r) => r,
358                Err(e) => return Err(UraniumError::from(e)),
359            };
360
361            download_single_file(response, obj).await
362        });
363
364    let errors: Vec<DownloadableObject> = join_all(x)
365        .await
366        .into_iter()
367        .flat_map(|e| match e {
368            Err(UraniumError::FileNotMatch(obj)) => Some(obj),
369            Err(error) => {
370                error!("Error with the response: {}", error);
371                None
372            }
373            _ => None,
374        })
375        .collect();
376
377    if !errors.is_empty() {
378        return Err(UraniumError::FilesDontMatch(errors));
379    }
380
381    info!("Chunk wrote successfully!");
382    Ok(())
383}
384
385/// Verifies if a file matches its expected hash
386async fn verify_file_hash(path: &Path, expected_hash: &Option<HashType>) -> Result<bool> {
387    if !path.exists() {
388        return Ok(false);
389    }
390
391    let Some(HashType::Sha1(expected)) = expected_hash else {
392        return Ok(false);
393    };
394
395    let content = tokio::fs::read(path).await?;
396    let mut hasher = sha1::Sha1::new();
397    hasher.update(&content);
398    let actual = hex::encode(hasher.finalize());
399
400    Ok(&actual == expected)
401}
402
403async fn download_single_file(response: Response, obj: DownloadableObject) -> Result<()> {
404    if !response.status().is_success() {
405        return Err(UraniumError::other(&format!(
406            "Error with response, status {}",
407            response.status()
408        )));
409    }
410
411    let content_length = response
412        .content_length()
413        .map(|e| e as usize)
414        .unwrap_or_default();
415
416    let mut bytes_stream = response.bytes_stream();
417
418    if !obj.path.exists() {
419        create_dir_all(
420            obj.path
421                .parent()
422                .ok_or(UraniumError::OtherWithReason(format!(
423                    "Cant create {:?} path",
424                    obj.path
425                )))?,
426        )?;
427    }
428
429    let mut file = tokio::io::BufWriter::with_capacity(
430        1024 * 512,
431        tokio::fs::OpenOptions::new()
432            .create(true)
433            .write(true)
434            .truncate(true)
435            .open(&obj.path)
436            .await?,
437    );
438
439    let mut total = 0;
440    let mut hasher = sha1::Sha1::new();
441
442    while let Some(item) = bytes_stream.next().await {
443        let chunk = item?;
444        match file.write_all(&chunk).await {
445            Err(e) => {
446                error!("Can not write in {:?}: {}", obj.path, e);
447                return Err(e.into());
448            }
449            Ok(_) => total += chunk.len(),
450        };
451        hasher.update(chunk);
452    }
453    file.flush().await?;
454    let actual = hex::encode(hasher.finalize());
455
456    if total == content_length
457        && obj
458            .hash
459            .as_ref()
460            .is_none_or(|x| match x {
461                HashType::Sha1(h) => h == &actual,
462            })
463    {
464        Ok(())
465    } else {
466        error!("{:?}'s hash doesn't match!", &obj.path);
467        Err(UraniumError::FileNotMatch(obj))
468    }
469}