1mod hash_list;
56
57use backon::BlockingRetryable;
58use backon::ExponentialBuilder;
59use hash_list::HashList;
60use rayon::prelude::*;
61use serde::Deserialize;
62use sha2::digest::Digest;
63use sha2::Sha256;
64use std::collections::HashSet;
65use std::fs::{create_dir_all, File};
66use std::io::{self, Read, Write};
67use std::sync::{Arc, Mutex};
68use std::time::Duration;
69use ureq::Agent;
70
71#[derive(Debug, Deserialize)]
72pub struct TestAsset {
73 #[serde(rename = "test_assets")]
74 pub assets: std::collections::BTreeMap<String, TestAssetDef>,
75}
76
77impl TestAsset {
78 #[must_use]
79 pub fn values(&self) -> Vec<TestAssetDef> {
80 self.assets.values().cloned().collect()
81 }
82}
83
84#[derive(Debug, Deserialize, Clone)]
86pub struct TestAssetDef {
87 pub filepath: String,
89 pub hash: String,
91 pub url: String,
93}
94
95impl TestAssetDef {
96 #[must_use]
98 pub fn filename(&self) -> &str {
99 std::path::Path::new(&self.filepath)
100 .file_name()
101 .and_then(|s| s.to_str())
102 .unwrap_or(&self.filepath)
103 }
104}
105
106#[derive(PartialEq, Eq, Hash, Clone)]
110pub struct Sha256Hash([u8; 32]);
111
112impl Sha256Hash {
113 #[must_use]
114 pub fn from_digest(sha: Sha256) -> Self {
115 let sha = sha.finalize();
116 let bytes = sha[..].try_into().unwrap();
117 Self(bytes)
118 }
119
120 fn from_hex(s: &str) -> Result<Self, ()> {
122 let mut res = Self([0; 32]);
123 let mut idx = 0;
124 let mut iter = s.chars();
125 loop {
126 let upper = match iter.next().and_then(|c| c.to_digit(16)) {
127 Some(v) => v as u8,
128 None => return Err(()),
129 };
130 let lower = match iter.next().and_then(|c| c.to_digit(16)) {
131 Some(v) => v as u8,
132 None => return Err(()),
133 };
134 res.0[idx] = (upper << 4) | lower;
135 idx += 1;
136 if idx == 32 {
137 break;
138 }
139 }
140 Ok(res)
141 }
142 #[must_use]
144 pub fn to_hex(&self) -> String {
145 let mut res = String::with_capacity(64);
146 for v in &self.0 {
147 use std::char::from_digit;
148 res.push(from_digit(u32::from(*v) >> 4, 16).unwrap());
149 res.push(from_digit(u32::from(*v) & 15, 16).unwrap());
150 }
151 res
152 }
153}
154
155#[derive(Debug)]
156pub enum TaError {
157 Io(io::Error),
158 DownloadFailed,
159 HashMismatch(String, String),
160 BadHashFormat,
161}
162
163impl From<io::Error> for TaError {
164 fn from(err: io::Error) -> Self {
165 Self::Io(err)
166 }
167}
168
169enum DownloadOutcome {
170 WithHash(Sha256Hash),
171}
172
173pub struct ProgressCallbacks<'a> {
175 pub sha_matched_fn: &'a (dyn Fn(&str) + Send + Sync),
176 pub sha_not_matched_fn: &'a (dyn Fn(&str) + Send + Sync),
177 pub downloaded_fn: &'a (dyn Fn(&str) + Send + Sync),
178 pub downloading_failed_fn: &'a (dyn Fn(&str) + Send + Sync),
179 pub finished_fn: &'a (dyn Fn(&str) + Send + Sync),
180 pub progress_update_fn: &'a (dyn Fn(&str) + Send + Sync),
181 pub download_progress_fn: &'a (dyn Fn(usize, usize) + Send + Sync),
182}
183
184fn format_bytes(bytes: u64) -> String {
185 const KB: u64 = 1024;
186 const MB: u64 = KB * 1024;
187 const GB: u64 = MB * 1024;
188
189 if bytes >= GB {
190 format!("{:.2} GB", bytes as f64 / GB as f64)
191 } else if bytes >= MB {
192 format!("{:.2} MB", bytes as f64 / MB as f64)
193 } else if bytes >= KB {
194 format!("{:.2} KB", bytes as f64 / KB as f64)
195 } else {
196 format!("{bytes} B")
197 }
198}
199
200struct DownloadContext<'a> {
201 bytes_downloaded: &'a Arc<Mutex<u64>>,
202 total_size: u64,
203 downloading: &'a Arc<Mutex<HashSet<String>>>,
204 println_fn: &'a (dyn Fn(&str) + Send + Sync),
205 update_progress_fn: &'a (dyn Fn(&str) + Send + Sync),
206}
207
208fn download_test_file(
209 agent: &mut Agent,
210 tfile: &TestAssetDef,
211 dir: &str,
212 context: &DownloadContext,
213) -> Result<DownloadOutcome, TaError> {
214 let resp = match agent.get(&tfile.url).call() {
215 Ok(resp) => resp,
216 Err(e) => {
217 (context.println_fn)(&format!("{e:?}"));
218 return Err(TaError::DownloadFailed);
219 }
220 };
221
222 let len: usize = resp.header("Content-Length").and_then(|s| s.parse().ok()).unwrap_or(0);
223
224 let mut bytes: Vec<u8> = Vec::with_capacity(len);
225 let mut reader = resp.into_reader().take(10_000_000_000);
226
227 let mut buffer = vec![0; 8192];
228 let mut bytes_since_update = 0u64;
229 loop {
230 let n = reader.read(&mut buffer)?;
231 if n == 0 {
232 break;
233 }
234 bytes.extend_from_slice(&buffer[..n]);
235
236 let mut downloaded = context.bytes_downloaded.lock().unwrap();
237 *downloaded += n as u64;
238 bytes_since_update += n as u64;
239
240 if bytes_since_update >= 262_144 {
241 bytes_since_update = 0;
242 let dl = context.downloading.lock().unwrap();
243 (context.update_progress_fn)(&format!(
244 "{} / {} - {}",
245 format_bytes(*downloaded),
246 format_bytes(context.total_size),
247 dl.iter().cloned().collect::<Vec<_>>().join(", ")
248 ));
249 }
250 }
251
252 let read_len = bytes.len();
253
254 if (bytes.len() != read_len) && (bytes.len() != len) {
255 return Err(TaError::DownloadFailed);
256 }
257
258 let filepath = format!("{}/{}", dir, tfile.filepath);
259 if let Some(parent) = std::path::Path::new(&filepath).parent() {
260 std::fs::create_dir_all(parent)?;
261 }
262 let file = File::create(&filepath)?;
263 let mut writer = io::BufWriter::new(file);
264 writer.write_all(&bytes)?;
265 writer.flush()?;
266
267 let mut hasher = Sha256::new();
268 hasher.update(&bytes);
269
270 Ok(DownloadOutcome::WithHash(Sha256Hash::from_digest(hasher)))
271}
272
273pub fn dl_test_files_with_progress(
275 defs: &[TestAssetDef],
276 dir: &str,
277 callbacks: &ProgressCallbacks,
278) -> Result<(), TaError> {
279 use std::io::ErrorKind;
280
281 let hash_list_path = format!("{dir}/hash_list");
282 let hash_list = match HashList::from_file(&hash_list_path) {
283 Ok(l) => l,
284 Err(TaError::Io(ref e)) if e.kind() == ErrorKind::NotFound => HashList::new(),
285 e => {
286 e?;
287 unreachable!()
288 }
289 };
290 create_dir_all(dir)?;
291
292 let sha_matched_count = Arc::new(Mutex::new(0u64));
293
294 let files_to_download: Vec<_> = defs
295 .iter()
296 .filter(|tfile| {
297 let tfile_hash = match Sha256Hash::from_hex(&tfile.hash) {
298 Ok(h) => h,
299 Err(_) => {
300 return true;
301 }
302 };
303
304 let filepath = format!("{}/{}", dir, tfile.filepath);
305
306 if hash_list.get_hash(&tfile.filepath) == Some(&tfile_hash) {
307 match File::open(&filepath) {
308 Ok(mut file) => {
309 let mut hasher = Sha256::new();
310 let mut buffer = vec![0; 8192];
311 loop {
312 match file.read(&mut buffer) {
313 Ok(0) => break,
314 Ok(n) => hasher.update(&buffer[..n]),
315 Err(_e) => {
316 return true; }
318 }
319 }
320 let file_hash = Sha256Hash::from_digest(hasher);
321 if file_hash == tfile_hash {
322 *sha_matched_count.lock().unwrap() += 1;
323 (callbacks.sha_matched_fn)(&tfile.filepath);
324 return false;
325 }
326 (callbacks.sha_not_matched_fn)(&tfile.filepath);
327 }
328 Err(_e) => {}
329 }
330 }
331 true
332 })
333 .collect();
334
335 if files_to_download.is_empty() {
336 (callbacks.finished_fn)("All files SHA matched");
337 return Ok(());
338 }
339
340 let total_size: u64 = files_to_download
341 .iter()
342 .filter_map(|tfile| {
343 let agent = ureq::agent();
344 agent
345 .head(&tfile.url)
346 .call()
347 .ok()
348 .and_then(|resp| resp.header("Content-Length").map(|s| s.to_string()))
349 .and_then(|len| len.parse::<u64>().ok())
350 })
351 .sum();
352
353 let hash_list = Arc::new(Mutex::new(hash_list));
354 let downloading = Arc::new(Mutex::new(HashSet::new()));
355 let bytes_downloaded = Arc::new(Mutex::new(0u64));
356 let downloads_completed = Arc::new(Mutex::new(0usize));
357 let total_to_download = files_to_download.len();
358
359 let results: Vec<_> = files_to_download
360 .par_iter()
361 .map(|tfile| {
362 let mut agent = ureq::agent();
363 let tfile_hash =
364 Sha256Hash::from_hex(&tfile.hash).map_err(|_| TaError::BadHashFormat)?;
365
366 let mut dl = downloading.lock().unwrap();
367 dl.insert(tfile.filepath.clone());
368 drop(dl);
369
370 let println_fn = |msg: &str| {
371 (callbacks.downloading_failed_fn)(msg);
372 };
373
374 let update_progress_fn_local = |msg: &str| {
375 (callbacks.progress_update_fn)(msg);
376 };
377
378 let context = DownloadContext {
379 bytes_downloaded: &bytes_downloaded,
380 total_size,
381 downloading: &downloading,
382 println_fn: &println_fn,
383 update_progress_fn: &update_progress_fn_local,
384 };
385
386 let outcome = download_test_file(&mut agent, tfile, dir, &context);
387
388 let mut dl = downloading.lock().unwrap();
389 dl.remove(&tfile.filepath);
390 drop(dl);
391
392 let outcome = match outcome {
393 Ok(o) => {
394 (callbacks.downloaded_fn)(&tfile.filepath);
395 let mut completed = downloads_completed.lock().unwrap();
396 *completed += 1;
397 (callbacks.download_progress_fn)(*completed, total_to_download);
398 Ok(o)
399 }
400 Err(e) => {
401 (callbacks.downloading_failed_fn)(&tfile.filepath);
402 let mut completed = downloads_completed.lock().unwrap();
403 *completed += 1;
404 (callbacks.download_progress_fn)(*completed, total_to_download);
405 Err(e)
406 }
407 };
408
409 let outcome = outcome?;
410
411 match outcome {
412 DownloadOutcome::WithHash(ref hash) => {
413 let mut hash_list = hash_list.lock().unwrap();
414 hash_list.add_entry(&tfile.filepath, hash);
415 }
416 }
417
418 match outcome {
419 DownloadOutcome::WithHash(ref found_hash) => {
420 if found_hash == &tfile_hash {
421 Ok(())
422 } else {
423 Err(TaError::HashMismatch(found_hash.to_hex(), tfile.hash.clone()))
424 }
425 }
426 }
427 })
428 .collect();
429
430 for result in results {
431 result?;
432 }
433
434 let hash_list = match Arc::try_unwrap(hash_list) {
435 Ok(mutex) => match mutex.into_inner() {
436 Ok(list) => list,
437 Err(_) => panic!("Failed to unlock Mutex"),
438 },
439 Err(_) => panic!("Failed to unwrap Arc"),
440 };
441 hash_list.to_file(&hash_list_path)?;
442 Ok(())
443}
444
445pub fn dl_test_files_backoff_with_progress(
447 assets_defs: &[TestAssetDef],
448 test_path: &str,
449 max_delay: Duration,
450 callbacks: &ProgressCallbacks,
451) -> Result<(), TaError> {
452 let strategy = ExponentialBuilder::default().with_max_delay(max_delay);
453
454 (|| dl_test_files_with_progress(assets_defs, test_path, callbacks))
455 .retry(strategy)
456 .call()
457 .unwrap();
458
459 Ok(())
460}
461
462pub fn dl_test_files(defs: &[TestAssetDef], dir: &str) -> Result<(), TaError> {
464 let callbacks = ProgressCallbacks {
465 sha_matched_fn: &|_| {},
466 sha_not_matched_fn: &|_| {},
467 downloaded_fn: &|_| {},
468 downloading_failed_fn: &|_| {},
469 finished_fn: &|_| {},
470 progress_update_fn: &|_| {},
471 download_progress_fn: &|_, _| {},
472 };
473 dl_test_files_with_progress(defs, dir, &callbacks)
474}
475
476pub fn dl_test_files_backoff(
478 defs: &[TestAssetDef],
479 dir: &str,
480 max_delay: Duration,
481) -> Result<(), TaError> {
482 let callbacks = ProgressCallbacks {
483 sha_matched_fn: &|_| {},
484 sha_not_matched_fn: &|_| {},
485 downloaded_fn: &|_| {},
486 downloading_failed_fn: &|_| {},
487 finished_fn: &|_| {},
488 progress_update_fn: &|_| {},
489 download_progress_fn: &|_, _| {},
490 };
491 dl_test_files_backoff_with_progress(defs, dir, max_delay, &callbacks)
492}