taskcluster_download/
object.rs

1use crate::factory::{AsyncWriterFactory, CursorWriterFactory, FileWriterFactory};
2use crate::geturl::{get_url, FetchMetadata, RetriableResult};
3use crate::hashing::HasherAsyncWriterFactory;
4use crate::service::ObjectService;
5use anyhow::{anyhow, bail, Context, Result};
6use serde_json::json;
7use std::collections::HashMap;
8use taskcluster::chrono::{DateTime, Utc};
9use taskcluster::retry::{Backoff, Retry};
10use taskcluster::Object;
11use tokio::fs::File;
12
13/// The subset of hashes supported by hashing{read,write}stream which are
14/// "accepted" as per the object service's schemas.
15const ACCEPTABLE_HASHES: &'static [&'static str] = &["sha256", "sha512"];
16
17/// Download an object to a [Vec<u8>] and return that.  If the object is unexpectedly
18/// large, this may exhaust system memory and panic.  Returns (data, content_type)
19pub async fn download_to_vec(
20    name: &str,
21    retry: &Retry,
22    object_service: &Object,
23) -> Result<(Vec<u8>, String)> {
24    let mut factory = CursorWriterFactory::new();
25    let content_type = download_impl(name, retry, object_service, &mut factory).await?;
26    Ok((factory.into_inner(), content_type))
27}
28
29/// Download an object into the given buffer and return the slice of that buffer containing the
30/// object.  If the object is larger than the buffer, then resulting error can be downcast to
31/// [std::io::Error] with kind `WriteZero` and the somewhat cryptic message "write zero byte into
32/// writer".  Returns (slice, content_type)
33pub async fn download_to_buf<'a>(
34    name: &str,
35    retry: &Retry,
36    object_service: &Object,
37    buf: &'a mut [u8],
38) -> Result<(&'a [u8], String)> {
39    let mut factory = CursorWriterFactory::for_buf(buf);
40    let content_type = download_impl(name, retry, object_service, &mut factory).await?;
41    let size = factory.size();
42    Ok((&buf[..size], content_type))
43}
44
45/// Download an object into the given File.  The file must be open in write mode and must be
46/// clone-able (that is, [File::try_clone()] must succeed) in order to support retried downloads.
47/// The File is returned with all write operations complete but with unspecified position.
48/// Returns (file, content_type).
49pub async fn download_to_file(
50    name: &str,
51    retry: &Retry,
52    object_service: &Object,
53    file: File,
54) -> Result<(File, String)> {
55    let mut factory = FileWriterFactory::new(file);
56    let content_type = download_impl(name, retry, object_service, &mut factory).await?;
57    Ok((factory.into_inner().await?, content_type))
58}
59
60/// Download an object using an [AsyncWriterFactory].  This is useful for advanced cases where one
61/// of the convenience functions is not adequate.  Returns the object's content type.
62pub async fn download_with_factory<AWF: AsyncWriterFactory>(
63    name: &str,
64    retry: &Retry,
65    object_service: &Object,
66    writer_factory: &mut AWF,
67) -> Result<String> {
68    let content_type = download_impl(name, retry, object_service, writer_factory).await?;
69    Ok(content_type)
70}
71
72/// Internal implementation of downloads, using the ObjectService trait to allow
73/// injecting a fake dependency.  Returns the object's content-type.
74pub(crate) async fn download_impl<O: ObjectService, AWF: AsyncWriterFactory>(
75    name: &str,
76    retry: &Retry,
77    object_service: &O,
78    writer_factory: &mut AWF,
79) -> Result<String> {
80    let response = object_service
81        .startDownload(
82            name,
83            &json!({
84                "acceptDownloadMethods": {
85                    "getUrl": true,
86                },
87            }),
88        )
89        .await?;
90
91    let method = response
92        .get("method")
93        .map(|o| o.as_str())
94        .flatten()
95        .ok_or_else(|| anyhow!("invalid response from startDownload"))?;
96
97    match method {
98        "getUrl" => {
99            Ok(geturl_download(response, name, object_service, retry, writer_factory).await?)
100        }
101        _ => bail!("unknown method {} in response from startDownload", method),
102    }
103}
104
105async fn geturl_download<O: ObjectService, AWF: AsyncWriterFactory>(
106    mut response_json: serde_json::Value,
107    name: &str,
108    object_service: &O,
109    retry: &Retry,
110    writer_factory: &mut AWF,
111) -> Result<String> {
112    // tracking for whether the URL in start_download_response has been used
113    // at least once, to avoid looping infinitely.
114    let mut response_used = false;
115    #[derive(serde::Deserialize)]
116    struct GetUrlStartDownloadResponse {
117        url: String,
118        hashes: HashMap<String, String>,
119        expires: DateTime<Utc>,
120    }
121
122    let mut start_download_response: GetUrlStartDownloadResponse =
123        serde_json::from_value(response_json)?;
124
125    // wrap the given writer factory with one that will hash
126    let mut writer_factory = HasherAsyncWriterFactory::new(writer_factory);
127
128    let mut backoff = Backoff::new(retry);
129    let mut attempts = 0;
130
131    let fetchmeta = loop {
132        if response_used && start_download_response.expires <= Utc::now() {
133            response_json = object_service
134                .startDownload(
135                    name,
136                    &json!({
137                        "acceptDownloadMethods": {
138                            "getUrl": true,
139                        },
140                    }),
141                )
142                .await?;
143            start_download_response = serde_json::from_value(response_json)?;
144        }
145
146        response_used = true;
147        attempts += 1;
148        let mut writer = writer_factory.get_writer().await?;
149        match get_url(start_download_response.url.as_ref(), writer.as_mut()).await {
150            RetriableResult::Ok(fetchmeta) => break Ok::<FetchMetadata, anyhow::Error>(fetchmeta),
151            RetriableResult::Retriable(err) => match backoff.next_backoff() {
152                Some(duration) => {
153                    tokio::time::sleep(duration).await;
154                    continue;
155                }
156                None => {
157                    return Err(err).context(format!("Download failed after {} attempts", attempts))
158                }
159            },
160            RetriableResult::Permanent(err) => {
161                return Err(err);
162            }
163        }
164    }?;
165
166    // Verify the hashes after a successful download.  Note that verification failure will
167    // not result in a retry.
168    verify_hashes(start_download_response.hashes, writer_factory.hashes())?;
169
170    Ok(fetchmeta.content_type)
171}
172
173/// Validate that the observed hashes match the expected hashes: all hashes with known algorithms
174/// present in both maps are valid, and at least one "accepted" hash algorithm is present.
175///
176/// If the validation fails, this returns an appropriate error.
177fn verify_hashes(
178    exp_hashes: HashMap<String, String>,
179    observed_hashes: HashMap<String, String>,
180) -> Result<()> {
181    let mut some_valid_acceptable_hash = false;
182
183    for (alg, ov) in &observed_hashes {
184        if let Some(ev) = exp_hashes.get(alg) {
185            if ov != ev {
186                bail!("Object hashes for {} differ", alg);
187            }
188            if ACCEPTABLE_HASHES.iter().any(|acc_alg| alg == acc_alg) {
189                some_valid_acceptable_hash = true;
190            }
191        }
192    }
193
194    if !some_valid_acceptable_hash {
195        bail!("No acceptable hashes found in object metadata");
196    }
197    Ok(())
198}
199
200#[cfg(test)]
201mod test {
202    use super::*;
203    use crate::test_helpers::{FakeDataServer, FakeObjectService, Logger};
204    use serde_json::json;
205    use std::io::SeekFrom;
206    use taskcluster::chrono::{Duration, Utc};
207    use tempfile::tempfile;
208    use tokio::io::{AsyncReadExt, AsyncSeekExt};
209
210    #[tokio::test]
211    async fn download_success() -> Result<()> {
212        let server = FakeDataServer::new(false, &[200]);
213        let logger = Logger::default();
214        let object_service = FakeObjectService {
215            logger: logger.clone(),
216            response: json!({
217                "method": "getUrl",
218                "url": server.data_url(),
219                "hashes": {
220                    "sha256":"09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b",
221                },
222                "expires": Utc::now() + Duration::hours(2),
223            }),
224        };
225
226        let mut factory = CursorWriterFactory::new();
227        let content_type = download_impl(
228            "some/object",
229            &Retry::default(),
230            &object_service,
231            &mut factory,
232        )
233        .await?;
234
235        logger.assert(vec![format!(
236            "startDownload some/object {}",
237            json!({"getUrl": true})
238        )]);
239
240        assert_eq!(&content_type, "text/plain");
241
242        let data = factory.into_inner();
243        assert_eq!(&data, b"hello, world");
244
245        drop(object_service); // ..and with it, server, which refs data
246
247        Ok(())
248    }
249
250    #[tokio::test]
251    async fn download_with_retries_for_500s_success() -> Result<()> {
252        let server = FakeDataServer::new(false, &[500, 500, 200]);
253        let logger = Logger::default();
254        let object_service = FakeObjectService {
255            logger: logger.clone(),
256            response: json!({
257                "method": "getUrl",
258                "url": server.data_url(),
259                "hashes": {
260                    "sha256":"09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b",
261                },
262                "expires": Utc::now() + Duration::hours(2),
263            }),
264        };
265        let retry = Retry {
266            retries: 2,
267            ..Retry::default()
268        };
269
270        let mut factory = CursorWriterFactory::new();
271        download_impl("some/object", &retry, &object_service, &mut factory).await?;
272
273        logger.assert(vec![format!(
274            "startDownload some/object {}",
275            json!({"getUrl": true})
276        )]);
277
278        let data = factory.into_inner();
279        assert_eq!(&data, b"hello, world");
280
281        drop(object_service); // ..and with it, server, which refs data
282
283        Ok(())
284    }
285
286    #[tokio::test]
287    async fn download_with_failure_for_400s() -> Result<()> {
288        let server = FakeDataServer::new(false, &[400, 200]);
289        let logger = Logger::default();
290        let object_service = FakeObjectService {
291            logger: logger.clone(),
292            response: json!({
293                "method": "getUrl",
294                "url": server.data_url(),
295                "hashes": {},
296                "expires": Utc::now() + Duration::hours(2),
297            }),
298        };
299        let retry = Retry::default();
300
301        let mut factory = CursorWriterFactory::new();
302        assert!(
303            download_impl("some/object", &retry, &object_service, &mut factory)
304                .await
305                .is_err()
306        );
307
308        logger.assert(vec![format!(
309            "startDownload some/object {}",
310            json!({"getUrl": true})
311        )]);
312
313        let data = factory.into_inner();
314        assert_eq!(&data, b"");
315
316        drop(object_service); // ..and with it, server, which refs data
317
318        Ok(())
319    }
320
321    #[tokio::test]
322    async fn download_with_retries_for_500s_failure() -> Result<()> {
323        let server = FakeDataServer::new(false, &[500, 500, 500, 200]);
324        let logger = Logger::default();
325        let object_service = FakeObjectService {
326            logger: logger.clone(),
327            response: json!({
328                "method": "getUrl",
329                "url": server.data_url(),
330                "hashes": {},
331                "expires": Utc::now() + Duration::hours(2),
332            }),
333        };
334        let retry = Retry {
335            retries: 2, // but, need 3 to succeed!
336            ..Retry::default()
337        };
338
339        let mut factory = CursorWriterFactory::new();
340        assert!(
341            download_impl("some/object", &retry, &object_service, &mut factory)
342                .await
343                .is_err()
344        );
345
346        logger.assert(vec![format!(
347            "startDownload some/object {}",
348            json!({"getUrl": true})
349        )]);
350
351        let data = factory.into_inner();
352        assert_eq!(&data, b"");
353
354        drop(object_service);
355
356        Ok(())
357    }
358
359    #[tokio::test]
360    async fn download_calls_start_download_when_expired() -> Result<()> {
361        let server = FakeDataServer::new(false, &[500, 200]);
362        let logger = Logger::default();
363        let object_service = FakeObjectService {
364            logger: logger.clone(),
365            response: json!({
366                "method": "getUrl",
367                "url": server.data_url(),
368                "hashes": {},
369                "expires": Utc::now(), // download_impl will try the url once
370            }),
371        };
372        let retry = Retry {
373            retries: 2, // but, need 3 to succeed!
374            ..Retry::default()
375        };
376
377        let mut factory = CursorWriterFactory::new();
378        assert!(
379            download_impl("some/object", &retry, &object_service, &mut factory)
380                .await
381                .is_err()
382        );
383
384        logger.assert(vec![
385            // calls startDownload twice
386            format!("startDownload some/object {}", json!({"getUrl": true})),
387            format!("startDownload some/object {}", json!({"getUrl": true})),
388        ]);
389
390        let data = factory.into_inner();
391        assert_eq!(&data, b"hello, world");
392
393        drop(object_service);
394
395        Ok(())
396    }
397
398    #[tokio::test]
399    async fn download_to_file() -> Result<()> {
400        let server = FakeDataServer::new(false, &[200]);
401        let logger = Logger::default();
402        let object_service = FakeObjectService {
403            logger: logger.clone(),
404            response: json!({
405                "method": "getUrl",
406                "url": server.data_url(),
407                "hashes": {
408                    "sha256":"09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b",
409                },
410                "expires": Utc::now() + Duration::hours(2),
411            }),
412        };
413
414        let mut factory = FileWriterFactory::new(tempfile()?.into());
415        download_impl(
416            "some/object",
417            &Retry::default(),
418            &object_service,
419            &mut factory,
420        )
421        .await?;
422
423        logger.assert(vec![format!(
424            "startDownload some/object {}",
425            json!({"getUrl": true})
426        )]);
427
428        let mut file = factory.into_inner().await?;
429        let mut res = Vec::new();
430        file.seek(SeekFrom::Start(0)).await?;
431        file.read_to_end(&mut res).await?;
432        assert_eq!(&res, b"hello, world");
433
434        drop(object_service); // ..and with it, server, which refs data
435
436        Ok(())
437    }
438
439    macro_rules! strmap {
440		($( $key:literal : $val:expr ),*) => {
441			{
442				let mut m: HashMap::<String, String> = HashMap::new();
443				$(
444				m.insert($key.into(), $val.into());
445				)*
446				m
447			}
448		};
449		($( $key:literal : $val:expr ),* ,) => {
450            strmap!($( $key : $val ,)*)
451        };
452	}
453
454    #[test]
455    fn verify_hashes_valid() {
456        assert!(verify_hashes(
457            strmap!("sha256": "abc", "sha512": "def", "md5": "ignored"),
458            strmap!("sha256": "abc", "sha512": "def", "sha1024": "ignored")
459        )
460        .is_ok());
461    }
462
463    #[test]
464    fn verify_hashes_not_acceptable() {
465        assert!(verify_hashes(strmap!("md5": "abc"), strmap!("md5": "abc")).is_err());
466    }
467
468    #[test]
469    fn verify_hashes_not_matching() {
470        assert!(verify_hashes(strmap!("sha512": "abc"), strmap!("sha512": "def")).is_err());
471    }
472
473    #[test]
474    fn verify_hashes_not_acceptable_not_matching() {
475        assert!(verify_hashes(
476            strmap!("md5": "good", "sha512": "abc"),
477            strmap!("md5": "good", "sha512": "def")
478        )
479        .is_err());
480    }
481}