uranium_rs/downloaders/
gen_downloader.rs1use std::fs::create_dir_all;
2use std::sync::Arc;
3use std::{
4 collections::VecDeque,
5 path::{Path, PathBuf},
6};
7
8use futures::{future::join_all, StreamExt};
9use log::{error, info};
10use reqwest::Response;
11use sha1::Digest;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore};
13use tokio::{io::AsyncWriteExt, task::JoinHandle};
14
15use crate::error::Result;
16use crate::{code_functions::N_THREADS, error::UraniumError};
17
18#[allow(async_fn_in_trait)]
26pub trait FileDownloader {
27 fn new(files: Vec<DownloadableObject>) -> Self;
29
30 async fn progress(&mut self) -> Result<DownloadState>;
56
57 async fn complete(&mut self) -> Result<()> {
71 loop {
72 match self.progress().await {
73 Err(e) => return Err(e),
74 Ok(DownloadState::Completed) => return Ok(()),
75 Ok(_) => {}
76 }
77 }
78 }
79
80 fn requests_left(&self) -> usize;
85
86 fn len(&self) -> usize;
88
89 fn add_object(&mut self, obj: DownloadableObject);
95
96 fn add_objects<T>(&mut self, objs: T)
102 where
103 T: IntoIterator<Item = DownloadableObject>,
104 {
105 objs.into_iter()
106 .for_each(|f| self.add_object(f));
107 }
108
109 fn is_empty(&self) -> bool {
111 self.len() == 0
112 }
113}
114
115#[derive(Debug)]
117pub enum DownloadState {
118 MakingRequests,
119 Downloading,
120 Completed,
121}
122
123#[derive(Debug, Clone)]
126pub enum HashType {
127 Sha1(String),
128}
129
130#[derive(Debug, Clone)]
142pub struct DownloadableObject {
143 pub url: String,
144 pub path: PathBuf,
145 pub hash: Option<HashType>,
146}
147
148impl DownloadableObject {
149 pub fn new(url: &str, path: &Path, hash: Option<HashType>) -> Self {
150 Self {
151 url: url.to_owned(),
152 path: path.to_owned(),
153 hash,
154 }
155 }
156
157 pub fn name(&self) -> Option<&str> {
158 self.path
159 .file_name()
160 .and_then(|f| f.to_str())
161 }
162}
163
164pub struct Downloader {
170 files: Vec<DownloadableObject>,
171 requester: reqwest::Client,
172 start: usize,
173 s: Arc<Semaphore>,
174 tasks: VecDeque<JoinHandle<Result<()>>>,
175}
176
177impl FileDownloader for Downloader {
178 fn new(files: Vec<DownloadableObject>) -> Self {
179 let n_files = files.len();
180 info!("{n_files} files to download");
181 info!("{} available permits", N_THREADS());
182
183 let client = reqwest::ClientBuilder::new()
184 .build()
185 .expect("Error while creating the Downloader client, please report this error.");
186
187 Downloader {
188 files,
189 requester: client,
190 start: 0,
191 s: Arc::new(Semaphore::new(N_THREADS())),
192 tasks: VecDeque::with_capacity(n_files),
193 }
194 }
195
196 async fn progress(&mut self) -> Result<DownloadState> {
197 while self.start != self.files.len() && self.s.available_permits() > 0 {
198 self.make_requests().await?;
199 }
200
201 if !self.tasks.is_empty() {
202 let mut guard = true;
203 let mut i = 0;
204 while guard {
205 guard = false;
206 if self
210 .tasks
211 .get(i)
212 .unwrap()
213 .is_finished()
214 {
215 let task = self.tasks.remove(i).unwrap();
216 guard = true;
217 match task.await? {
218 Err(UraniumError::FilesDontMatch(objects)) => {
219 error!("Trying again {} files", objects.len());
220 self.files.extend(objects);
221 }
222 Err(e) => Err(e)?,
223 Ok(_) => {}
224 }
225 break;
226 }
227
228 i = (i + 1) % self.tasks.len();
229 }
230
231 if guard {
232 return Ok(DownloadState::Downloading);
233 }
234
235 if !self.tasks.is_empty() {
237 info!("Waiting the first one...");
238 match self
242 .tasks
243 .pop_front()
244 .unwrap()
245 .await?
246 {
247 Err(UraniumError::FilesDontMatch(objects)) => self.files.extend(objects),
248 Err(e) => Err(e)?,
249 _ => {}
250 };
251 return Ok(DownloadState::Downloading);
252 }
253 }
254 Ok(DownloadState::Completed)
255 }
256
257 fn requests_left(&self) -> usize {
259 self.files.len() - self.start + self.tasks.len()
260 }
261
262 fn len(&self) -> usize {
265 self.files.len()
266 }
267
268 fn add_object(&mut self, obj: DownloadableObject) {
270 self.files.push(obj);
271 }
272
273 fn add_objects<T>(&mut self, objs: T)
274 where
275 T: IntoIterator<Item = DownloadableObject>,
276 {
277 self.files.extend(objs);
278 }
279}
280
281impl Downloader {
282 async fn acquire_semaphore(&self) -> Result<OwnedSemaphorePermit> {
284 self.s
285 .clone()
286 .acquire_owned()
287 .await
288 .map_err(|e| UraniumError::other(&format!("Failed to acquire semaphore: {e}")))
289 }
290
291 async fn get_next_chunk(&mut self) -> Vec<DownloadableObject> {
292 const DEFAULT_CHUNK_SIZE: usize = 16;
293
294 let remaining = self.files.len() - self.start;
295 if remaining == 0 {
296 return vec![];
297 }
298
299 let chunk_size = DEFAULT_CHUNK_SIZE.min(remaining);
300 let end = self.start + chunk_size;
301
302 let mut objects = vec![];
303
304 loop {
305 if objects.len() >= DEFAULT_CHUNK_SIZE {
306 break;
307 }
308
309 if self.start == end {
310 break;
311 }
312
313 let obj = &self.files[self.start];
314
315 if let Ok(true) = verify_file_hash(&obj.path, &obj.hash).await {
317 info!("Skipping {:?}, already exists", obj.path);
318 } else {
319 objects.push(obj.clone());
320 }
321 self.start += 1;
322 }
323 objects
324 }
325
326 async fn make_requests(&mut self) -> Result<DownloadState> {
327 let chunk = self.get_next_chunk().await;
328 if chunk.is_empty() {
329 return Ok(DownloadState::Completed);
330 }
331
332 let sem = self
333 .acquire_semaphore()
334 .await?;
335 let client = self.requester.clone();
336 let task = tokio::spawn(async move { download_and_write(chunk, client, sem).await });
337
338 info!("Pushing new task {}", self.start);
339 self.tasks.push_back(task);
340 Ok(DownloadState::MakingRequests)
341 }
342}
343
344async fn download_and_write(
345 objects: Vec<DownloadableObject>,
346 requester: reqwest::Client,
347 _sem: OwnedSemaphorePermit,
348) -> Result<()> {
349 let x = objects
350 .into_iter()
351 .map(|obj| async {
352 let response = match requester
353 .get(&obj.url)
354 .send()
355 .await
356 {
357 Ok(r) => r,
358 Err(e) => return Err(UraniumError::from(e)),
359 };
360
361 download_single_file(response, obj).await
362 });
363
364 let errors: Vec<DownloadableObject> = join_all(x)
365 .await
366 .into_iter()
367 .flat_map(|e| match e {
368 Err(UraniumError::FileNotMatch(obj)) => Some(obj),
369 Err(error) => {
370 error!("Error with the response: {}", error);
371 None
372 }
373 _ => None,
374 })
375 .collect();
376
377 if !errors.is_empty() {
378 return Err(UraniumError::FilesDontMatch(errors));
379 }
380
381 info!("Chunk wrote successfully!");
382 Ok(())
383}
384
385async fn verify_file_hash(path: &Path, expected_hash: &Option<HashType>) -> Result<bool> {
387 if !path.exists() {
388 return Ok(false);
389 }
390
391 let Some(HashType::Sha1(expected)) = expected_hash else {
392 return Ok(false);
393 };
394
395 let content = tokio::fs::read(path).await?;
396 let mut hasher = sha1::Sha1::new();
397 hasher.update(&content);
398 let actual = hex::encode(hasher.finalize());
399
400 Ok(&actual == expected)
401}
402
403async fn download_single_file(response: Response, obj: DownloadableObject) -> Result<()> {
404 if !response.status().is_success() {
405 return Err(UraniumError::other(&format!(
406 "Error with response, status {}",
407 response.status()
408 )));
409 }
410
411 let content_length = response
412 .content_length()
413 .map(|e| e as usize)
414 .unwrap_or_default();
415
416 let mut bytes_stream = response.bytes_stream();
417
418 if !obj.path.exists() {
419 create_dir_all(
420 obj.path
421 .parent()
422 .ok_or(UraniumError::OtherWithReason(format!(
423 "Cant create {:?} path",
424 obj.path
425 )))?,
426 )?;
427 }
428
429 let mut file = tokio::io::BufWriter::with_capacity(
430 1024 * 512,
431 tokio::fs::OpenOptions::new()
432 .create(true)
433 .write(true)
434 .truncate(true)
435 .open(&obj.path)
436 .await?,
437 );
438
439 let mut total = 0;
440 let mut hasher = sha1::Sha1::new();
441
442 while let Some(item) = bytes_stream.next().await {
443 let chunk = item?;
444 match file.write_all(&chunk).await {
445 Err(e) => {
446 error!("Can not write in {:?}: {}", obj.path, e);
447 return Err(e.into());
448 }
449 Ok(_) => total += chunk.len(),
450 };
451 hasher.update(chunk);
452 }
453 file.flush().await?;
454 let actual = hex::encode(hasher.finalize());
455
456 if total == content_length
457 && obj
458 .hash
459 .as_ref()
460 .is_none_or(|x| match x {
461 HashType::Sha1(h) => h == &actual,
462 })
463 {
464 Ok(())
465 } else {
466 error!("{:?}'s hash doesn't match!", &obj.path);
467 Err(UraniumError::FileNotMatch(obj))
468 }
469}