tiny_data/
download.rs

1use futures::future::join_all;
2use indicatif::ProgressBar;
3use reqwest::{IntoUrl, StatusCode};
4// use std::fs::File;
5// use std::io::prelude::*;
6use std::sync::{Arc, Mutex};
7use tokio::fs::File as AsyncFile;
8use tokio::io::AsyncWriteExt;
9
10pub struct Task {
11    pub downloader: Arc<Downloader>,
12    pub url: Option<String>, //so i can move it later
13}
14
15impl Task {
16    pub async fn download(&mut self, filename: String) -> Result<u8, Box<dyn std::error::Error>> {
17        let url = self.url.take().unwrap();
18        let downloader = self.downloader.clone();
19        let res = downloader.download(url, filename).await?;
20        Ok(res)
21    }
22}
23
24//shared state for each download task
25pub struct Downloader {
26    pub cur: Mutex<usize>,
27    pub progress_bar: Mutex<ProgressBar>,
28}
29
30impl Downloader {
31    pub fn new(progress_bar: Mutex<ProgressBar>) -> Self {
32        Self {
33            cur: Mutex::new(0),
34            progress_bar,
35        }
36    }
37
38    pub async fn download(
39        &self,
40        url: impl IntoUrl,
41        filename: String,
42    ) -> Result<u8, Box<dyn std::error::Error>> {
43        let res = reqwest::get(url).await?;
44        match res.status() {
45            StatusCode::OK => {
46                let bytes = res.bytes().await?;
47
48                // ad-hoc
49                if !bytes.starts_with(b"<!DOCTYPE html>") {
50                    let mut file = AsyncFile::create(filename).await?;
51                    file.write_all(&bytes).await?;
52                    // let mut file = File::create(filename)?;
53                    // file.write_all(&bytes)?;
54
55                    //interior mutability + async-safe lock access !
56                    {
57                        *self.cur.lock().unwrap() += 1;
58                        self.progress_bar.lock().unwrap().inc(1);
59                    }
60                }
61            }
62            _ => return Ok(0),
63        }
64
65        Ok(1)
66    }
67}
68
69pub struct DLManager {
70    pub target_size: usize,
71    pub downloader: Arc<Downloader>,
72}
73
74impl DLManager {
75    pub fn new(target_size: usize, progress_bar: ProgressBar) -> Self {
76        let downloader = Arc::new(Downloader::new(Mutex::new(progress_bar)));
77
78        DLManager {
79            target_size,
80            downloader,
81        }
82    }
83
84    //TODO: add upper limit on batch?
85    pub async fn download_batch<'a>(&mut self, batch: Vec<String>, dir: &'a str) -> u8 {
86        let cur = *self.downloader.cur.lock().unwrap();
87
88        // if cur == self.target_size { // used in old streaming-like approach
89        //     return true;
90        // }
91
92        // so we don't overflow on the quota
93        let how_many = usize::min(self.target_size - cur, batch.len());
94        // println!("taking {}", how_many);
95
96        let mut futures = vec![];
97        for (id, url) in batch.into_iter().take(how_many).enumerate() {
98            let id = id + cur;
99
100            let mut task = Task {
101                downloader: self.downloader.clone(),
102                url: Some(url.to_string()),
103            };
104
105            //now spawn the batch and await
106            let filename = format!("{}/{}.jpeg", dir, id);
107
108            futures.push(tokio::spawn(async move {
109                match task.download(filename).await {
110                    Ok(i) => i,
111                    _ => 0,
112                }
113            }));
114        }
115        let did_download = join_all(futures).await;
116        let total: u8 = did_download.into_iter().map(|res| res.unwrap()).sum();
117        total
118    }
119}