1use crate::factory::{AsyncWriterFactory, CursorWriterFactory, FileWriterFactory};
2use crate::geturl::{get_url, FetchMetadata, RetriableResult};
3use crate::hashing::HasherAsyncWriterFactory;
4use crate::service::ObjectService;
5use anyhow::{anyhow, bail, Context, Result};
6use serde_json::json;
7use std::collections::HashMap;
8use taskcluster::chrono::{DateTime, Utc};
9use taskcluster::retry::{Backoff, Retry};
10use taskcluster::Object;
11use tokio::fs::File;
12
13const ACCEPTABLE_HASHES: &'static [&'static str] = &["sha256", "sha512"];
16
17pub async fn download_to_vec(
20 name: &str,
21 retry: &Retry,
22 object_service: &Object,
23) -> Result<(Vec<u8>, String)> {
24 let mut factory = CursorWriterFactory::new();
25 let content_type = download_impl(name, retry, object_service, &mut factory).await?;
26 Ok((factory.into_inner(), content_type))
27}
28
29pub async fn download_to_buf<'a>(
34 name: &str,
35 retry: &Retry,
36 object_service: &Object,
37 buf: &'a mut [u8],
38) -> Result<(&'a [u8], String)> {
39 let mut factory = CursorWriterFactory::for_buf(buf);
40 let content_type = download_impl(name, retry, object_service, &mut factory).await?;
41 let size = factory.size();
42 Ok((&buf[..size], content_type))
43}
44
45pub async fn download_to_file(
50 name: &str,
51 retry: &Retry,
52 object_service: &Object,
53 file: File,
54) -> Result<(File, String)> {
55 let mut factory = FileWriterFactory::new(file);
56 let content_type = download_impl(name, retry, object_service, &mut factory).await?;
57 Ok((factory.into_inner().await?, content_type))
58}
59
60pub async fn download_with_factory<AWF: AsyncWriterFactory>(
63 name: &str,
64 retry: &Retry,
65 object_service: &Object,
66 writer_factory: &mut AWF,
67) -> Result<String> {
68 let content_type = download_impl(name, retry, object_service, writer_factory).await?;
69 Ok(content_type)
70}
71
72pub(crate) async fn download_impl<O: ObjectService, AWF: AsyncWriterFactory>(
75 name: &str,
76 retry: &Retry,
77 object_service: &O,
78 writer_factory: &mut AWF,
79) -> Result<String> {
80 let response = object_service
81 .startDownload(
82 name,
83 &json!({
84 "acceptDownloadMethods": {
85 "getUrl": true,
86 },
87 }),
88 )
89 .await?;
90
91 let method = response
92 .get("method")
93 .map(|o| o.as_str())
94 .flatten()
95 .ok_or_else(|| anyhow!("invalid response from startDownload"))?;
96
97 match method {
98 "getUrl" => {
99 Ok(geturl_download(response, name, object_service, retry, writer_factory).await?)
100 }
101 _ => bail!("unknown method {} in response from startDownload", method),
102 }
103}
104
105async fn geturl_download<O: ObjectService, AWF: AsyncWriterFactory>(
106 mut response_json: serde_json::Value,
107 name: &str,
108 object_service: &O,
109 retry: &Retry,
110 writer_factory: &mut AWF,
111) -> Result<String> {
112 let mut response_used = false;
115 #[derive(serde::Deserialize)]
116 struct GetUrlStartDownloadResponse {
117 url: String,
118 hashes: HashMap<String, String>,
119 expires: DateTime<Utc>,
120 }
121
122 let mut start_download_response: GetUrlStartDownloadResponse =
123 serde_json::from_value(response_json)?;
124
125 let mut writer_factory = HasherAsyncWriterFactory::new(writer_factory);
127
128 let mut backoff = Backoff::new(retry);
129 let mut attempts = 0;
130
131 let fetchmeta = loop {
132 if response_used && start_download_response.expires <= Utc::now() {
133 response_json = object_service
134 .startDownload(
135 name,
136 &json!({
137 "acceptDownloadMethods": {
138 "getUrl": true,
139 },
140 }),
141 )
142 .await?;
143 start_download_response = serde_json::from_value(response_json)?;
144 }
145
146 response_used = true;
147 attempts += 1;
148 let mut writer = writer_factory.get_writer().await?;
149 match get_url(start_download_response.url.as_ref(), writer.as_mut()).await {
150 RetriableResult::Ok(fetchmeta) => break Ok::<FetchMetadata, anyhow::Error>(fetchmeta),
151 RetriableResult::Retriable(err) => match backoff.next_backoff() {
152 Some(duration) => {
153 tokio::time::sleep(duration).await;
154 continue;
155 }
156 None => {
157 return Err(err).context(format!("Download failed after {} attempts", attempts))
158 }
159 },
160 RetriableResult::Permanent(err) => {
161 return Err(err);
162 }
163 }
164 }?;
165
166 verify_hashes(start_download_response.hashes, writer_factory.hashes())?;
169
170 Ok(fetchmeta.content_type)
171}
172
173fn verify_hashes(
178 exp_hashes: HashMap<String, String>,
179 observed_hashes: HashMap<String, String>,
180) -> Result<()> {
181 let mut some_valid_acceptable_hash = false;
182
183 for (alg, ov) in &observed_hashes {
184 if let Some(ev) = exp_hashes.get(alg) {
185 if ov != ev {
186 bail!("Object hashes for {} differ", alg);
187 }
188 if ACCEPTABLE_HASHES.iter().any(|acc_alg| alg == acc_alg) {
189 some_valid_acceptable_hash = true;
190 }
191 }
192 }
193
194 if !some_valid_acceptable_hash {
195 bail!("No acceptable hashes found in object metadata");
196 }
197 Ok(())
198}
199
200#[cfg(test)]
201mod test {
202 use super::*;
203 use crate::test_helpers::{FakeDataServer, FakeObjectService, Logger};
204 use serde_json::json;
205 use std::io::SeekFrom;
206 use taskcluster::chrono::{Duration, Utc};
207 use tempfile::tempfile;
208 use tokio::io::{AsyncReadExt, AsyncSeekExt};
209
210 #[tokio::test]
211 async fn download_success() -> Result<()> {
212 let server = FakeDataServer::new(false, &[200]);
213 let logger = Logger::default();
214 let object_service = FakeObjectService {
215 logger: logger.clone(),
216 response: json!({
217 "method": "getUrl",
218 "url": server.data_url(),
219 "hashes": {
220 "sha256":"09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b",
221 },
222 "expires": Utc::now() + Duration::hours(2),
223 }),
224 };
225
226 let mut factory = CursorWriterFactory::new();
227 let content_type = download_impl(
228 "some/object",
229 &Retry::default(),
230 &object_service,
231 &mut factory,
232 )
233 .await?;
234
235 logger.assert(vec![format!(
236 "startDownload some/object {}",
237 json!({"getUrl": true})
238 )]);
239
240 assert_eq!(&content_type, "text/plain");
241
242 let data = factory.into_inner();
243 assert_eq!(&data, b"hello, world");
244
245 drop(object_service); Ok(())
248 }
249
250 #[tokio::test]
251 async fn download_with_retries_for_500s_success() -> Result<()> {
252 let server = FakeDataServer::new(false, &[500, 500, 200]);
253 let logger = Logger::default();
254 let object_service = FakeObjectService {
255 logger: logger.clone(),
256 response: json!({
257 "method": "getUrl",
258 "url": server.data_url(),
259 "hashes": {
260 "sha256":"09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b",
261 },
262 "expires": Utc::now() + Duration::hours(2),
263 }),
264 };
265 let retry = Retry {
266 retries: 2,
267 ..Retry::default()
268 };
269
270 let mut factory = CursorWriterFactory::new();
271 download_impl("some/object", &retry, &object_service, &mut factory).await?;
272
273 logger.assert(vec![format!(
274 "startDownload some/object {}",
275 json!({"getUrl": true})
276 )]);
277
278 let data = factory.into_inner();
279 assert_eq!(&data, b"hello, world");
280
281 drop(object_service); Ok(())
284 }
285
286 #[tokio::test]
287 async fn download_with_failure_for_400s() -> Result<()> {
288 let server = FakeDataServer::new(false, &[400, 200]);
289 let logger = Logger::default();
290 let object_service = FakeObjectService {
291 logger: logger.clone(),
292 response: json!({
293 "method": "getUrl",
294 "url": server.data_url(),
295 "hashes": {},
296 "expires": Utc::now() + Duration::hours(2),
297 }),
298 };
299 let retry = Retry::default();
300
301 let mut factory = CursorWriterFactory::new();
302 assert!(
303 download_impl("some/object", &retry, &object_service, &mut factory)
304 .await
305 .is_err()
306 );
307
308 logger.assert(vec![format!(
309 "startDownload some/object {}",
310 json!({"getUrl": true})
311 )]);
312
313 let data = factory.into_inner();
314 assert_eq!(&data, b"");
315
316 drop(object_service); Ok(())
319 }
320
321 #[tokio::test]
322 async fn download_with_retries_for_500s_failure() -> Result<()> {
323 let server = FakeDataServer::new(false, &[500, 500, 500, 200]);
324 let logger = Logger::default();
325 let object_service = FakeObjectService {
326 logger: logger.clone(),
327 response: json!({
328 "method": "getUrl",
329 "url": server.data_url(),
330 "hashes": {},
331 "expires": Utc::now() + Duration::hours(2),
332 }),
333 };
334 let retry = Retry {
335 retries: 2, ..Retry::default()
337 };
338
339 let mut factory = CursorWriterFactory::new();
340 assert!(
341 download_impl("some/object", &retry, &object_service, &mut factory)
342 .await
343 .is_err()
344 );
345
346 logger.assert(vec![format!(
347 "startDownload some/object {}",
348 json!({"getUrl": true})
349 )]);
350
351 let data = factory.into_inner();
352 assert_eq!(&data, b"");
353
354 drop(object_service);
355
356 Ok(())
357 }
358
359 #[tokio::test]
360 async fn download_calls_start_download_when_expired() -> Result<()> {
361 let server = FakeDataServer::new(false, &[500, 200]);
362 let logger = Logger::default();
363 let object_service = FakeObjectService {
364 logger: logger.clone(),
365 response: json!({
366 "method": "getUrl",
367 "url": server.data_url(),
368 "hashes": {},
369 "expires": Utc::now(), }),
371 };
372 let retry = Retry {
373 retries: 2, ..Retry::default()
375 };
376
377 let mut factory = CursorWriterFactory::new();
378 assert!(
379 download_impl("some/object", &retry, &object_service, &mut factory)
380 .await
381 .is_err()
382 );
383
384 logger.assert(vec![
385 format!("startDownload some/object {}", json!({"getUrl": true})),
387 format!("startDownload some/object {}", json!({"getUrl": true})),
388 ]);
389
390 let data = factory.into_inner();
391 assert_eq!(&data, b"hello, world");
392
393 drop(object_service);
394
395 Ok(())
396 }
397
398 #[tokio::test]
399 async fn download_to_file() -> Result<()> {
400 let server = FakeDataServer::new(false, &[200]);
401 let logger = Logger::default();
402 let object_service = FakeObjectService {
403 logger: logger.clone(),
404 response: json!({
405 "method": "getUrl",
406 "url": server.data_url(),
407 "hashes": {
408 "sha256":"09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b",
409 },
410 "expires": Utc::now() + Duration::hours(2),
411 }),
412 };
413
414 let mut factory = FileWriterFactory::new(tempfile()?.into());
415 download_impl(
416 "some/object",
417 &Retry::default(),
418 &object_service,
419 &mut factory,
420 )
421 .await?;
422
423 logger.assert(vec![format!(
424 "startDownload some/object {}",
425 json!({"getUrl": true})
426 )]);
427
428 let mut file = factory.into_inner().await?;
429 let mut res = Vec::new();
430 file.seek(SeekFrom::Start(0)).await?;
431 file.read_to_end(&mut res).await?;
432 assert_eq!(&res, b"hello, world");
433
434 drop(object_service); Ok(())
437 }
438
439 macro_rules! strmap {
440 ($( $key:literal : $val:expr ),*) => {
441 {
442 let mut m: HashMap::<String, String> = HashMap::new();
443 $(
444 m.insert($key.into(), $val.into());
445 )*
446 m
447 }
448 };
449 ($( $key:literal : $val:expr ),* ,) => {
450 strmap!($( $key : $val ,)*)
451 };
452 }
453
454 #[test]
455 fn verify_hashes_valid() {
456 assert!(verify_hashes(
457 strmap!("sha256": "abc", "sha512": "def", "md5": "ignored"),
458 strmap!("sha256": "abc", "sha512": "def", "sha1024": "ignored")
459 )
460 .is_ok());
461 }
462
463 #[test]
464 fn verify_hashes_not_acceptable() {
465 assert!(verify_hashes(strmap!("md5": "abc"), strmap!("md5": "abc")).is_err());
466 }
467
468 #[test]
469 fn verify_hashes_not_matching() {
470 assert!(verify_hashes(strmap!("sha512": "abc"), strmap!("sha512": "def")).is_err());
471 }
472
473 #[test]
474 fn verify_hashes_not_acceptable_not_matching() {
475 assert!(verify_hashes(
476 strmap!("md5": "good", "sha512": "abc"),
477 strmap!("md5": "good", "sha512": "def")
478 )
479 .is_err());
480 }
481}