test_assets_ureq/
lib.rs

1/*!
2Download test assets, managing them outside of git
3
4This library downloads test assets using http(s),
5and ensures integrity by comparing those assets to a hash.
6By managing the download separately, you can keep them
7out of VCS and don't make them bloat your repository.
8
9Usage example:
10
11```rust, no_run
12#[test]
13fn some_awesome_test() {
14    let asset_defs = [
15        TestAssetDef {
16            filepath : format!("file_a.png"),
17            hash : format!("<sha256 here>"),
18            url : format!("https://url/to/a.png"),
19        },
20        TestAssetDef {
21            filepath : format!("subdir/file_b.png"),
22            hash : format!("<sha256 here>"),
23            url : format!("https://url/to/b.png"),
24        },
25    ];
26    test_assets::dl_test_files(&asset_defs, "test-assets").unwrap();
27    // use your files here
28    // with path under test-assets/file_a.png and test-assets/subdir/file_b.png
29}
30```
31
32Optionally, a `toml` can also be used.
33
34```toml, no_run
35[test_assets.test_00]
36filepath = "out.squashfs"
37hash = "976c1638d8c1ba8014de6c64b196cbd70a5acf031be10a8e7f649536193c8e78"
38url = "https://wcampbell.dev/squashfs/testing/test_00/out.squashfs"
39```
40```rust,no_run
41use test_assets_ureq::{TestAsset, dl_test_files_backoff};
42use std::fs;
43use std::time::Duration;
44
45let file_content = fs::read_to_string("test.toml").unwrap();
46let parsed: TestAsset = toml::de::from_str(&file_content).unwrap();
47let assets = parsed.values();
48dl_test_files_backoff(&assets, "test-assets", Duration::from_secs(1)).unwrap();
49```
50
51If you have run the test once, it will re-use the files
52instead of re-downloading them.
53*/
54
55mod hash_list;
56
57use backon::BlockingRetryable;
58use backon::ExponentialBuilder;
59use hash_list::HashList;
60use rayon::prelude::*;
61use serde::Deserialize;
62use sha2::digest::Digest;
63use sha2::Sha256;
64use std::collections::HashSet;
65use std::fs::{create_dir_all, File};
66use std::io::{self, Read, Write};
67use std::sync::{Arc, Mutex};
68use std::time::Duration;
69use ureq::Agent;
70
71#[derive(Debug, Deserialize)]
72pub struct TestAsset {
73    #[serde(rename = "test_assets")]
74    pub assets: std::collections::BTreeMap<String, TestAssetDef>,
75}
76
77impl TestAsset {
78    #[must_use]
79    pub fn values(&self) -> Vec<TestAssetDef> {
80        self.assets.values().cloned().collect()
81    }
82}
83
84/// Definition for a test file
85#[derive(Debug, Deserialize, Clone)]
86pub struct TestAssetDef {
87    /// Path of the file on disk relative to the output directory. Can include subdirectories.
88    pub filepath: String,
89    /// Sha256 hash of the file's data in hexadecimal lowercase representation
90    pub hash: String,
91    /// The url the test file can be obtained from
92    pub url: String,
93}
94
95impl TestAssetDef {
96    /// Get the filename (last component of the filepath)
97    #[must_use]
98    pub fn filename(&self) -> &str {
99        std::path::Path::new(&self.filepath)
100            .file_name()
101            .and_then(|s| s.to_str())
102            .unwrap_or(&self.filepath)
103    }
104}
105
106/// A type for a Sha256 hash value
107///
108/// Provides conversion functionality to hex representation and back
109#[derive(PartialEq, Eq, Hash, Clone)]
110pub struct Sha256Hash([u8; 32]);
111
112impl Sha256Hash {
113    #[must_use]
114    pub fn from_digest(sha: Sha256) -> Self {
115        let sha = sha.finalize();
116        let bytes = sha[..].try_into().unwrap();
117        Self(bytes)
118    }
119
120    /// Converts the hexadecimal string to a hash value
121    fn from_hex(s: &str) -> Result<Self, ()> {
122        let mut res = Self([0; 32]);
123        let mut idx = 0;
124        let mut iter = s.chars();
125        loop {
126            let upper = match iter.next().and_then(|c| c.to_digit(16)) {
127                Some(v) => v as u8,
128                None => return Err(()),
129            };
130            let lower = match iter.next().and_then(|c| c.to_digit(16)) {
131                Some(v) => v as u8,
132                None => return Err(()),
133            };
134            res.0[idx] = (upper << 4) | lower;
135            idx += 1;
136            if idx == 32 {
137                break;
138            }
139        }
140        Ok(res)
141    }
142    /// Converts the hash value to hexadecimal
143    #[must_use]
144    pub fn to_hex(&self) -> String {
145        let mut res = String::with_capacity(64);
146        for v in &self.0 {
147            use std::char::from_digit;
148            res.push(from_digit(u32::from(*v) >> 4, 16).unwrap());
149            res.push(from_digit(u32::from(*v) & 15, 16).unwrap());
150        }
151        res
152    }
153}
154
155#[derive(Debug)]
156pub enum TaError {
157    Io(io::Error),
158    DownloadFailed,
159    HashMismatch(String, String),
160    BadHashFormat,
161}
162
163impl From<io::Error> for TaError {
164    fn from(err: io::Error) -> Self {
165        Self::Io(err)
166    }
167}
168
169enum DownloadOutcome {
170    WithHash(Sha256Hash),
171}
172
173/// Callbacks for download progress and status updates
174pub struct ProgressCallbacks<'a> {
175    pub sha_matched_fn: &'a (dyn Fn(&str) + Send + Sync),
176    pub sha_not_matched_fn: &'a (dyn Fn(&str) + Send + Sync),
177    pub downloaded_fn: &'a (dyn Fn(&str) + Send + Sync),
178    pub downloading_failed_fn: &'a (dyn Fn(&str) + Send + Sync),
179    pub finished_fn: &'a (dyn Fn(&str) + Send + Sync),
180    pub progress_update_fn: &'a (dyn Fn(&str) + Send + Sync),
181    pub download_progress_fn: &'a (dyn Fn(usize, usize) + Send + Sync),
182}
183
184fn format_bytes(bytes: u64) -> String {
185    const KB: u64 = 1024;
186    const MB: u64 = KB * 1024;
187    const GB: u64 = MB * 1024;
188
189    if bytes >= GB {
190        format!("{:.2} GB", bytes as f64 / GB as f64)
191    } else if bytes >= MB {
192        format!("{:.2} MB", bytes as f64 / MB as f64)
193    } else if bytes >= KB {
194        format!("{:.2} KB", bytes as f64 / KB as f64)
195    } else {
196        format!("{bytes} B")
197    }
198}
199
200struct DownloadContext<'a> {
201    bytes_downloaded: &'a Arc<Mutex<u64>>,
202    total_size: u64,
203    downloading: &'a Arc<Mutex<HashSet<String>>>,
204    println_fn: &'a (dyn Fn(&str) + Send + Sync),
205    update_progress_fn: &'a (dyn Fn(&str) + Send + Sync),
206}
207
208fn download_test_file(
209    agent: &mut Agent,
210    tfile: &TestAssetDef,
211    dir: &str,
212    context: &DownloadContext,
213) -> Result<DownloadOutcome, TaError> {
214    let resp = match agent.get(&tfile.url).call() {
215        Ok(resp) => resp,
216        Err(e) => {
217            (context.println_fn)(&format!("{e:?}"));
218            return Err(TaError::DownloadFailed);
219        }
220    };
221
222    let len: usize = resp.header("Content-Length").and_then(|s| s.parse().ok()).unwrap_or(0);
223
224    let mut bytes: Vec<u8> = Vec::with_capacity(len);
225    let mut reader = resp.into_reader().take(10_000_000_000);
226
227    let mut buffer = vec![0; 8192];
228    let mut bytes_since_update = 0u64;
229    loop {
230        let n = reader.read(&mut buffer)?;
231        if n == 0 {
232            break;
233        }
234        bytes.extend_from_slice(&buffer[..n]);
235
236        let mut downloaded = context.bytes_downloaded.lock().unwrap();
237        *downloaded += n as u64;
238        bytes_since_update += n as u64;
239
240        if bytes_since_update >= 262_144 {
241            bytes_since_update = 0;
242            let dl = context.downloading.lock().unwrap();
243            (context.update_progress_fn)(&format!(
244                "{} / {} - {}",
245                format_bytes(*downloaded),
246                format_bytes(context.total_size),
247                dl.iter().cloned().collect::<Vec<_>>().join(", ")
248            ));
249        }
250    }
251
252    let read_len = bytes.len();
253
254    if (bytes.len() != read_len) && (bytes.len() != len) {
255        return Err(TaError::DownloadFailed);
256    }
257
258    let filepath = format!("{}/{}", dir, tfile.filepath);
259    if let Some(parent) = std::path::Path::new(&filepath).parent() {
260        std::fs::create_dir_all(parent)?;
261    }
262    let file = File::create(&filepath)?;
263    let mut writer = io::BufWriter::new(file);
264    writer.write_all(&bytes)?;
265    writer.flush()?;
266
267    let mut hasher = Sha256::new();
268    hasher.update(&bytes);
269
270    Ok(DownloadOutcome::WithHash(Sha256Hash::from_digest(hasher)))
271}
272
273/// Downloads the test files into the passed directory with progress callbacks.
274pub fn dl_test_files_with_progress(
275    defs: &[TestAssetDef],
276    dir: &str,
277    callbacks: &ProgressCallbacks,
278) -> Result<(), TaError> {
279    use std::io::ErrorKind;
280
281    let hash_list_path = format!("{dir}/hash_list");
282    let hash_list = match HashList::from_file(&hash_list_path) {
283        Ok(l) => l,
284        Err(TaError::Io(ref e)) if e.kind() == ErrorKind::NotFound => HashList::new(),
285        e => {
286            e?;
287            unreachable!()
288        }
289    };
290    create_dir_all(dir)?;
291
292    let sha_matched_count = Arc::new(Mutex::new(0u64));
293
294    let files_to_download: Vec<_> = defs
295        .iter()
296        .filter(|tfile| {
297            let tfile_hash = match Sha256Hash::from_hex(&tfile.hash) {
298                Ok(h) => h,
299                Err(_) => {
300                    return true;
301                }
302            };
303
304            let filepath = format!("{}/{}", dir, tfile.filepath);
305
306            if hash_list.get_hash(&tfile.filepath) == Some(&tfile_hash) {
307                match File::open(&filepath) {
308                    Ok(mut file) => {
309                        let mut hasher = Sha256::new();
310                        let mut buffer = vec![0; 8192];
311                        loop {
312                            match file.read(&mut buffer) {
313                                Ok(0) => break,
314                                Ok(n) => hasher.update(&buffer[..n]),
315                                Err(_e) => {
316                                    return true; // Error reading, download it
317                                }
318                            }
319                        }
320                        let file_hash = Sha256Hash::from_digest(hasher);
321                        if file_hash == tfile_hash {
322                            *sha_matched_count.lock().unwrap() += 1;
323                            (callbacks.sha_matched_fn)(&tfile.filepath);
324                            return false;
325                        }
326                        (callbacks.sha_not_matched_fn)(&tfile.filepath);
327                    }
328                    Err(_e) => {}
329                }
330            }
331            true
332        })
333        .collect();
334
335    if files_to_download.is_empty() {
336        (callbacks.finished_fn)("All files SHA matched");
337        return Ok(());
338    }
339
340    let total_size: u64 = files_to_download
341        .iter()
342        .filter_map(|tfile| {
343            let agent = ureq::agent();
344            agent
345                .head(&tfile.url)
346                .call()
347                .ok()
348                .and_then(|resp| resp.header("Content-Length").map(|s| s.to_string()))
349                .and_then(|len| len.parse::<u64>().ok())
350        })
351        .sum();
352
353    let hash_list = Arc::new(Mutex::new(hash_list));
354    let downloading = Arc::new(Mutex::new(HashSet::new()));
355    let bytes_downloaded = Arc::new(Mutex::new(0u64));
356    let downloads_completed = Arc::new(Mutex::new(0usize));
357    let total_to_download = files_to_download.len();
358
359    let results: Vec<_> = files_to_download
360        .par_iter()
361        .map(|tfile| {
362            let mut agent = ureq::agent();
363            let tfile_hash =
364                Sha256Hash::from_hex(&tfile.hash).map_err(|_| TaError::BadHashFormat)?;
365
366            let mut dl = downloading.lock().unwrap();
367            dl.insert(tfile.filepath.clone());
368            drop(dl);
369
370            let println_fn = |msg: &str| {
371                (callbacks.downloading_failed_fn)(msg);
372            };
373
374            let update_progress_fn_local = |msg: &str| {
375                (callbacks.progress_update_fn)(msg);
376            };
377
378            let context = DownloadContext {
379                bytes_downloaded: &bytes_downloaded,
380                total_size,
381                downloading: &downloading,
382                println_fn: &println_fn,
383                update_progress_fn: &update_progress_fn_local,
384            };
385
386            let outcome = download_test_file(&mut agent, tfile, dir, &context);
387
388            let mut dl = downloading.lock().unwrap();
389            dl.remove(&tfile.filepath);
390            drop(dl);
391
392            let outcome = match outcome {
393                Ok(o) => {
394                    (callbacks.downloaded_fn)(&tfile.filepath);
395                    let mut completed = downloads_completed.lock().unwrap();
396                    *completed += 1;
397                    (callbacks.download_progress_fn)(*completed, total_to_download);
398                    Ok(o)
399                }
400                Err(e) => {
401                    (callbacks.downloading_failed_fn)(&tfile.filepath);
402                    let mut completed = downloads_completed.lock().unwrap();
403                    *completed += 1;
404                    (callbacks.download_progress_fn)(*completed, total_to_download);
405                    Err(e)
406                }
407            };
408
409            let outcome = outcome?;
410
411            match outcome {
412                DownloadOutcome::WithHash(ref hash) => {
413                    let mut hash_list = hash_list.lock().unwrap();
414                    hash_list.add_entry(&tfile.filepath, hash);
415                }
416            }
417
418            match outcome {
419                DownloadOutcome::WithHash(ref found_hash) => {
420                    if found_hash == &tfile_hash {
421                        Ok(())
422                    } else {
423                        Err(TaError::HashMismatch(found_hash.to_hex(), tfile.hash.clone()))
424                    }
425                }
426            }
427        })
428        .collect();
429
430    for result in results {
431        result?;
432    }
433
434    let hash_list = match Arc::try_unwrap(hash_list) {
435        Ok(mutex) => match mutex.into_inner() {
436            Ok(list) => list,
437            Err(_) => panic!("Failed to unlock Mutex"),
438        },
439        Err(_) => panic!("Failed to unwrap Arc"),
440    };
441    hash_list.to_file(&hash_list_path)?;
442    Ok(())
443}
444
445/// Download test-assets with backoff retries and progress callbacks
446pub fn dl_test_files_backoff_with_progress(
447    assets_defs: &[TestAssetDef],
448    test_path: &str,
449    max_delay: Duration,
450    callbacks: &ProgressCallbacks,
451) -> Result<(), TaError> {
452    let strategy = ExponentialBuilder::default().with_max_delay(max_delay);
453
454    (|| dl_test_files_with_progress(assets_defs, test_path, callbacks))
455        .retry(strategy)
456        .call()
457        .unwrap();
458
459    Ok(())
460}
461
462/// Download test files
463pub fn dl_test_files(defs: &[TestAssetDef], dir: &str) -> Result<(), TaError> {
464    let callbacks = ProgressCallbacks {
465        sha_matched_fn: &|_| {},
466        sha_not_matched_fn: &|_| {},
467        downloaded_fn: &|_| {},
468        downloading_failed_fn: &|_| {},
469        finished_fn: &|_| {},
470        progress_update_fn: &|_| {},
471        download_progress_fn: &|_, _| {},
472    };
473    dl_test_files_with_progress(defs, dir, &callbacks)
474}
475
476/// Download test files with backoff retries
477pub fn dl_test_files_backoff(
478    defs: &[TestAssetDef],
479    dir: &str,
480    max_delay: Duration,
481) -> Result<(), TaError> {
482    let callbacks = ProgressCallbacks {
483        sha_matched_fn: &|_| {},
484        sha_not_matched_fn: &|_| {},
485        downloaded_fn: &|_| {},
486        downloading_failed_fn: &|_| {},
487        finished_fn: &|_| {},
488        progress_update_fn: &|_| {},
489        download_progress_fn: &|_, _| {},
490    };
491    dl_test_files_backoff_with_progress(defs, dir, max_delay, &callbacks)
492}