1use crate::client::Client;
2use crate::mediaconn::{MEDIA_AUTH_REFRESH_RETRY_ATTEMPTS, MediaConn, is_media_auth_error};
3use anyhow::{Result, anyhow};
4use std::io::{Seek, SeekFrom, Write};
5
6pub use wacore::download::{
7 DownloadUtils, Downloadable, MediaDecryption, MediaDecryptionError, MediaType,
8};
9
10impl From<&MediaConn> for wacore::download::MediaConnection {
11 fn from(conn: &MediaConn) -> Self {
12 wacore::download::MediaConnection {
13 hosts: conn
14 .hosts
15 .iter()
16 .map(|h| wacore::download::MediaHost {
17 hostname: h.hostname.clone(),
18 })
19 .collect(),
20 auth: conn.auth.clone(),
21 }
22 }
23}
24
25struct DownloadParams {
27 direct_path: String,
28 media_key: Option<Vec<u8>>,
29 file_sha256: Vec<u8>,
30 file_enc_sha256: Option<Vec<u8>>,
31 file_length: u64,
32 media_type: MediaType,
33}
34
35impl Downloadable for DownloadParams {
36 fn direct_path(&self) -> Option<&str> {
37 Some(&self.direct_path)
38 }
39 fn media_key(&self) -> Option<&[u8]> {
40 self.media_key.as_deref()
41 }
42 fn file_enc_sha256(&self) -> Option<&[u8]> {
43 self.file_enc_sha256.as_deref()
44 }
45 fn file_sha256(&self) -> Option<&[u8]> {
46 Some(&self.file_sha256)
47 }
48 fn file_length(&self) -> Option<u64> {
49 Some(self.file_length)
50 }
51 fn app_info(&self) -> MediaType {
52 self.media_type
53 }
54}
55
56#[derive(Debug)]
57enum DownloadRequestError {
58 Auth(anyhow::Error),
59 NotFound(anyhow::Error),
62 Other(anyhow::Error),
63}
64
65impl DownloadRequestError {
66 fn auth(status_code: u16) -> Self {
67 Self::Auth(anyhow!("Download failed with status: {}", status_code))
68 }
69
70 fn not_found(status_code: u16) -> Self {
71 Self::NotFound(anyhow!(
72 "Download media not found/expired with status: {}",
73 status_code
74 ))
75 }
76
77 fn other(err: impl Into<anyhow::Error>) -> Self {
78 Self::Other(err.into())
79 }
80
81 fn is_auth(&self) -> bool {
82 matches!(self, Self::Auth(_))
83 }
84
85 fn is_not_found(&self) -> bool {
87 matches!(self, Self::NotFound(_))
88 }
89
90 fn into_anyhow(self) -> anyhow::Error {
91 match self {
92 Self::Auth(err) | Self::NotFound(err) | Self::Other(err) => err,
93 }
94 }
95}
96
97async fn download_media_with_retry<
98 PrepareRequests,
99 PrepareRequestsFut,
100 InvalidateMediaConn,
101 InvalidateMediaConnFut,
102 ExecuteRequest,
103 ExecuteRequestFut,
104>(
105 mut prepare_requests: PrepareRequests,
106 mut invalidate_media_conn: InvalidateMediaConn,
107 mut execute_request: ExecuteRequest,
108) -> Result<Vec<u8>>
109where
110 PrepareRequests: FnMut(bool) -> PrepareRequestsFut,
111 PrepareRequestsFut:
112 std::future::Future<Output = Result<Vec<wacore::download::DownloadRequest>>>,
113 InvalidateMediaConn: FnMut() -> InvalidateMediaConnFut,
114 InvalidateMediaConnFut: std::future::Future<Output = ()>,
115 ExecuteRequest: FnMut(wacore::download::DownloadRequest) -> ExecuteRequestFut,
116 ExecuteRequestFut:
117 std::future::Future<Output = std::result::Result<Vec<u8>, DownloadRequestError>>,
118{
119 let mut force_refresh = false;
120 let mut last_err: Option<anyhow::Error> = None;
121
122 for attempt in 0..=MEDIA_AUTH_REFRESH_RETRY_ATTEMPTS {
123 let requests = prepare_requests(force_refresh).await?;
124 let mut retry_with_fresh_auth = false;
125
126 for request in requests {
127 match execute_request(request.clone()).await {
128 Ok(data) => return Ok(data),
129 Err(err) if (err.is_auth() || err.is_not_found()) && attempt == 0 => {
130 invalidate_media_conn().await;
133 force_refresh = true;
134 retry_with_fresh_auth = true;
135 break;
136 }
137 Err(err) if err.is_auth() || err.is_not_found() => return Err(err.into_anyhow()),
138 Err(err) => {
139 let err = err.into_anyhow();
140 log::warn!(
141 "Failed to download from URL {}: {:?}. Trying next host.",
142 request.url,
143 err
144 );
145 last_err = Some(err);
146 }
147 }
148 }
149
150 if !retry_with_fresh_auth {
151 break;
152 }
153 }
154
155 match last_err {
156 Some(err) => Err(err),
157 None => Err(anyhow!("Failed to download from all available media hosts")),
158 }
159}
160
161async fn download_to_writer_with_retry<
162 W,
163 PrepareRequests,
164 PrepareRequestsFut,
165 InvalidateMediaConn,
166 InvalidateMediaConnFut,
167 ExecuteRequest,
168 ExecuteRequestFut,
169>(
170 mut writer: W,
171 mut prepare_requests: PrepareRequests,
172 mut invalidate_media_conn: InvalidateMediaConn,
173 mut execute_request: ExecuteRequest,
174) -> Result<W>
175where
176 W: Write + Seek + Send + 'static,
177 PrepareRequests: FnMut(bool) -> PrepareRequestsFut,
178 PrepareRequestsFut:
179 std::future::Future<Output = Result<Vec<wacore::download::DownloadRequest>>>,
180 InvalidateMediaConn: FnMut() -> InvalidateMediaConnFut,
181 InvalidateMediaConnFut: std::future::Future<Output = ()>,
182 ExecuteRequest: FnMut(wacore::download::DownloadRequest, W) -> ExecuteRequestFut,
183 ExecuteRequestFut:
184 std::future::Future<Output = Result<(W, std::result::Result<(), DownloadRequestError>)>>,
185{
186 let mut force_refresh = false;
187 let mut last_err: Option<anyhow::Error> = None;
188
189 for attempt in 0..=MEDIA_AUTH_REFRESH_RETRY_ATTEMPTS {
190 let requests = prepare_requests(force_refresh).await?;
191 let mut retry_with_fresh_auth = false;
192
193 for request in requests {
194 let (next_writer, result) = execute_request(request.clone(), writer).await?;
195 writer = next_writer;
196
197 match result {
198 Ok(()) => return Ok(writer),
199 Err(err) if (err.is_auth() || err.is_not_found()) && attempt == 0 => {
200 invalidate_media_conn().await;
201 force_refresh = true;
202 retry_with_fresh_auth = true;
203 break;
204 }
205 Err(err) if err.is_auth() || err.is_not_found() => return Err(err.into_anyhow()),
206 Err(err) => {
207 let err = err.into_anyhow();
208 log::warn!(
209 "Failed to stream-download from URL {}: {:?}. Trying next host.",
210 request.url,
211 err
212 );
213 last_err = Some(err);
214 }
215 }
216 }
217
218 if !retry_with_fresh_auth {
219 break;
220 }
221 }
222
223 match last_err {
224 Some(err) => Err(err),
225 None => Err(anyhow!("Failed to download from all available media hosts")),
226 }
227}
228
229impl Client {
230 pub async fn download(&self, downloadable: &dyn Downloadable) -> Result<Vec<u8>> {
231 download_media_with_retry(
232 |force| self.prepare_requests(downloadable, force),
233 || async { self.invalidate_media_conn().await },
234 |request| async move { self.download_with_request(&request).await },
235 )
236 .await
237 }
238
239 pub async fn download_to_file<W: Write + Seek + Send + Unpin>(
240 &self,
241 downloadable: &dyn Downloadable,
242 mut writer: W,
243 ) -> Result<()> {
244 let data = self.download(downloadable).await?;
245 writer.seek(SeekFrom::Start(0))?;
246 writer.write_all(&data)?;
247 Ok(())
248 }
249
250 pub async fn download_from_params(
252 &self,
253 direct_path: &str,
254 media_key: &[u8],
255 file_sha256: &[u8],
256 file_enc_sha256: &[u8],
257 file_length: u64,
258 media_type: MediaType,
259 ) -> Result<Vec<u8>> {
260 let params = DownloadParams {
261 direct_path: direct_path.to_string(),
262 media_key: Some(media_key.to_vec()),
263 file_sha256: file_sha256.to_vec(),
264 file_enc_sha256: Some(file_enc_sha256.to_vec()),
265 file_length,
266 media_type,
267 };
268 self.download(¶ms).await
269 }
270
271 async fn prepare_requests(
272 &self,
273 downloadable: &dyn Downloadable,
274 force_refresh: bool,
275 ) -> Result<Vec<wacore::download::DownloadRequest>> {
276 let media_conn = self.refresh_media_conn(force_refresh).await?;
277 let core_media_conn = wacore::download::MediaConnection::from(&media_conn);
278 DownloadUtils::prepare_download_requests(downloadable, &core_media_conn)
279 }
280
281 async fn download_with_request(
282 &self,
283 request: &wacore::download::DownloadRequest,
284 ) -> std::result::Result<Vec<u8>, DownloadRequestError> {
285 let url = request.url.clone();
286 let decryption = request.decryption.clone();
287 let http_request = crate::http::HttpRequest::get(url);
288 let response = self
289 .http_client
290 .execute(http_request)
291 .await
292 .map_err(DownloadRequestError::other)?;
293
294 if response.status_code >= 300 {
295 return Err(if is_media_auth_error(response.status_code) {
296 DownloadRequestError::auth(response.status_code)
297 } else if matches!(response.status_code, 404 | 410) {
298 DownloadRequestError::not_found(response.status_code)
299 } else {
300 DownloadRequestError::other(anyhow!(
301 "Download failed with status: {}",
302 response.status_code
303 ))
304 });
305 }
306
307 match decryption {
308 MediaDecryption::Encrypted {
309 media_key,
310 media_type,
311 } => wacore::runtime::blocking(&*self.runtime, move || {
312 DownloadUtils::decrypt_stream(&response.body[..], &media_key, media_type)
313 })
314 .await
315 .map_err(DownloadRequestError::other),
316 MediaDecryption::Plaintext { file_sha256 } => {
317 let body = response.body;
318 wacore::runtime::blocking(&*self.runtime, move || {
319 DownloadUtils::validate_plaintext_sha256(&body, &file_sha256)?;
320 Ok::<Vec<u8>, anyhow::Error>(body)
321 })
322 .await
323 .map_err(DownloadRequestError::other)
324 }
325 }
326 }
327
328 pub async fn download_to_writer<W: Write + Seek + Send + 'static>(
335 &self,
336 downloadable: &dyn Downloadable,
337 writer: W,
338 ) -> Result<W> {
339 download_to_writer_with_retry(
340 writer,
341 |force| self.prepare_requests(downloadable, force),
342 || async { self.invalidate_media_conn().await },
343 |request, writer| async move { self.streaming_download_and_decrypt(&request, writer).await },
344 )
345 .await
346 }
347
348 #[allow(clippy::too_many_arguments)]
351 pub async fn download_from_params_to_writer<W: Write + Seek + Send + 'static>(
352 &self,
353 direct_path: &str,
354 media_key: &[u8],
355 file_sha256: &[u8],
356 file_enc_sha256: &[u8],
357 file_length: u64,
358 media_type: MediaType,
359 writer: W,
360 ) -> Result<W> {
361 let params = DownloadParams {
362 direct_path: direct_path.to_string(),
363 media_key: Some(media_key.to_vec()),
364 file_sha256: file_sha256.to_vec(),
365 file_enc_sha256: Some(file_enc_sha256.to_vec()),
366 file_length,
367 media_type,
368 };
369 self.download_to_writer(¶ms, writer).await
370 }
371
372 async fn streaming_download_and_decrypt<W: Write + Seek + Send + 'static>(
375 &self,
376 request: &wacore::download::DownloadRequest,
377 writer: W,
378 ) -> Result<(W, std::result::Result<(), DownloadRequestError>)> {
379 let http_client = self.http_client.clone();
380 let url = request.url.clone();
381 let decryption = request.decryption.clone();
382
383 wacore::runtime::blocking(&*self.runtime, move || {
384 let mut writer = writer;
385
386 if let Err(e) = writer.seek(SeekFrom::Start(0)) {
388 return Ok((writer, Err(DownloadRequestError::other(e))));
389 }
390
391 let result = (|| -> std::result::Result<(), DownloadRequestError> {
392 let http_request = crate::http::HttpRequest::get(url);
393 let resp = http_client
394 .execute_streaming(http_request)
395 .map_err(DownloadRequestError::other)?;
396
397 if resp.status_code >= 300 {
398 return Err(if is_media_auth_error(resp.status_code) {
399 DownloadRequestError::auth(resp.status_code)
400 } else if matches!(resp.status_code, 404 | 410) {
401 DownloadRequestError::not_found(resp.status_code)
402 } else {
403 DownloadRequestError::other(anyhow!(
404 "Download failed with status: {}",
405 resp.status_code
406 ))
407 });
408 }
409
410 match &decryption {
411 MediaDecryption::Encrypted {
412 media_key,
413 media_type,
414 } => {
415 DownloadUtils::decrypt_stream_to_writer(
416 resp.body,
417 media_key,
418 *media_type,
419 &mut writer,
420 )
421 .map_err(DownloadRequestError::other)?;
422 }
423 MediaDecryption::Plaintext { file_sha256 } => {
424 DownloadUtils::copy_and_validate_plaintext_to_writer(
425 resp.body,
426 file_sha256,
427 &mut writer,
428 )
429 .map_err(DownloadRequestError::other)?;
430 }
431 }
432 writer
433 .seek(SeekFrom::Start(0))
434 .map_err(DownloadRequestError::other)?;
435 Ok(())
436 })();
437
438 Ok((writer, result))
439 })
440 .await
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::mediaconn::{MediaConn, MediaConnHost};
448 use async_lock::Mutex;
449 use std::io::Cursor;
450 use std::sync::Arc;
451 use wacore::time::Instant;
452
453 struct PlaintextDownloadable {
454 direct_path: String,
455 file_sha256: Vec<u8>,
456 }
457
458 impl Downloadable for PlaintextDownloadable {
459 fn direct_path(&self) -> Option<&str> {
460 Some(&self.direct_path)
461 }
462
463 fn media_key(&self) -> Option<&[u8]> {
464 None
465 }
466
467 fn file_enc_sha256(&self) -> Option<&[u8]> {
468 None
469 }
470
471 fn file_sha256(&self) -> Option<&[u8]> {
472 Some(&self.file_sha256)
473 }
474
475 fn file_length(&self) -> Option<u64> {
476 None
477 }
478
479 fn app_info(&self) -> MediaType {
480 MediaType::Image
481 }
482 }
483
484 fn media_conn(auth: &str, hosts: &[&str]) -> MediaConn {
485 MediaConn {
486 auth: auth.to_string(),
487 ttl: 60,
488 auth_ttl: None,
489 hosts: hosts
490 .iter()
491 .map(|hostname| MediaConnHost::new((*hostname).to_string()))
492 .collect(),
493 fetched_at: Instant::now(),
494 }
495 }
496
497 fn plaintext_sha256(data: &[u8]) -> Vec<u8> {
498 wacore::upload::encrypt_media(data, MediaType::Image)
499 .expect("hash derivation should succeed")
500 .file_sha256
501 .to_vec()
502 }
503
504 #[test]
505 fn process_downloaded_media_ok() {
506 let data = b"Hello media test";
507 let enc = wacore::upload::encrypt_media(data, MediaType::Image)
508 .expect("encryption should succeed");
509 let mut cursor = Cursor::new(Vec::<u8>::new());
510 let plaintext = DownloadUtils::verify_and_decrypt(
511 &enc.data_to_upload,
512 &enc.media_key,
513 MediaType::Image,
514 )
515 .expect("decryption should succeed");
516 cursor.write_all(&plaintext).expect("write should succeed");
517 assert_eq!(cursor.into_inner(), data);
518 }
519
520 #[test]
521 fn process_downloaded_media_bad_mac() {
522 let data = b"Tamper";
523 let mut enc = wacore::upload::encrypt_media(data, MediaType::Image)
524 .expect("encryption should succeed");
525 let last = enc.data_to_upload.len() - 1;
526 enc.data_to_upload[last] ^= 0x01;
527
528 let err = DownloadUtils::verify_and_decrypt(
529 &enc.data_to_upload,
530 &enc.media_key,
531 MediaType::Image,
532 )
533 .unwrap_err();
534
535 assert!(
536 matches!(&err, wacore::download::MediaDecryptionError::InvalidMac),
537 "Expected InvalidMac, got: {}",
538 err
539 );
540 }
541
542 #[tokio::test]
543 async fn download_retries_with_forced_media_conn_refresh_after_auth_error() {
544 let body = b"download me".to_vec();
545 let downloadable = PlaintextDownloadable {
546 direct_path: "/v/t62.7118-24/123".to_string(),
547 file_sha256: plaintext_sha256(&body),
548 };
549 let first_conn = media_conn("stale-auth", &["cdn1.example.com"]);
550 let refreshed_conn = media_conn("fresh-auth", &["cdn2.example.com"]);
551 let refresh_calls = Arc::new(Mutex::new(Vec::new()));
552 let invalidations = Arc::new(Mutex::new(0usize));
553 let seen_urls = Arc::new(Mutex::new(Vec::new()));
554
555 let downloaded = download_media_with_retry(
556 {
557 let refresh_calls = Arc::clone(&refresh_calls);
558 let downloadable = &downloadable;
559 move |force| {
560 let refresh_calls = Arc::clone(&refresh_calls);
561 let first_conn = first_conn.clone();
562 let refreshed_conn = refreshed_conn.clone();
563 async move {
564 refresh_calls.lock().await.push(force);
565 let media_conn = if force { refreshed_conn } else { first_conn };
566 DownloadUtils::prepare_download_requests(
567 downloadable,
568 &wacore::download::MediaConnection::from(&media_conn),
569 )
570 }
571 }
572 },
573 {
574 let invalidations = Arc::clone(&invalidations);
575 move || {
576 let invalidations = Arc::clone(&invalidations);
577 async move {
578 *invalidations.lock().await += 1;
579 }
580 }
581 },
582 {
583 let seen_urls = Arc::clone(&seen_urls);
584 let body = body.clone();
585 move |request| {
586 let seen_urls = Arc::clone(&seen_urls);
587 let body = body.clone();
588 let url = request.url.clone();
589 async move {
590 seen_urls.lock().await.push(url.clone());
591 if url.contains("stale-auth") {
592 Err(DownloadRequestError::auth(401))
593 } else {
594 Ok(body)
595 }
596 }
597 }
598 },
599 )
600 .await
601 .expect("download should succeed after refreshing media auth");
602
603 assert_eq!(downloaded, body);
604 assert_eq!(*refresh_calls.lock().await, vec![false, true]);
605 assert_eq!(*invalidations.lock().await, 1);
606
607 let seen_urls = seen_urls.lock().await.clone();
608 assert_eq!(seen_urls.len(), 2);
609 assert!(seen_urls[0].contains("auth=stale-auth"));
610 assert!(seen_urls[1].contains("auth=fresh-auth"));
611 }
612
613 #[tokio::test]
614 async fn download_to_writer_retries_with_forced_media_conn_refresh_after_auth_error() {
615 let body = b"stream me".to_vec();
616 let downloadable = PlaintextDownloadable {
617 direct_path: "/v/t62.7118-24/stream".to_string(),
618 file_sha256: plaintext_sha256(&body),
619 };
620 let first_conn = media_conn("stale-auth", &["cdn1.example.com"]);
621 let refreshed_conn = media_conn("fresh-auth", &["cdn2.example.com"]);
622 let refresh_calls = Arc::new(Mutex::new(Vec::new()));
623 let invalidations = Arc::new(Mutex::new(0usize));
624 let seen_urls = Arc::new(Mutex::new(Vec::new()));
625
626 let writer = download_to_writer_with_retry(
627 Cursor::new(Vec::<u8>::new()),
628 {
629 let refresh_calls = Arc::clone(&refresh_calls);
630 let downloadable = &downloadable;
631 move |force| {
632 let refresh_calls = Arc::clone(&refresh_calls);
633 let first_conn = first_conn.clone();
634 let refreshed_conn = refreshed_conn.clone();
635 async move {
636 refresh_calls.lock().await.push(force);
637 let media_conn = if force { refreshed_conn } else { first_conn };
638 DownloadUtils::prepare_download_requests(
639 downloadable,
640 &wacore::download::MediaConnection::from(&media_conn),
641 )
642 }
643 }
644 },
645 {
646 let invalidations = Arc::clone(&invalidations);
647 move || {
648 let invalidations = Arc::clone(&invalidations);
649 async move {
650 *invalidations.lock().await += 1;
651 }
652 }
653 },
654 {
655 let seen_urls = Arc::clone(&seen_urls);
656 let body = body.clone();
657 move |request, mut writer| {
658 let seen_urls = Arc::clone(&seen_urls);
659 let body = body.clone();
660 let url = request.url.clone();
661 async move {
662 seen_urls.lock().await.push(url.clone());
663 writer.seek(SeekFrom::Start(0))?;
664 if url.contains("stale-auth") {
665 Ok((writer, Err(DownloadRequestError::auth(403))))
666 } else {
667 writer.write_all(&body)?;
668 writer.seek(SeekFrom::Start(0))?;
669 Ok((writer, Ok(())))
670 }
671 }
672 }
673 },
674 )
675 .await
676 .expect("streaming download should succeed after refreshing media auth");
677
678 assert_eq!(writer.into_inner(), body);
679 assert_eq!(*refresh_calls.lock().await, vec![false, true]);
680 assert_eq!(*invalidations.lock().await, 1);
681
682 let seen_urls = seen_urls.lock().await.clone();
683 assert_eq!(seen_urls.len(), 2);
684 assert!(seen_urls[0].contains("auth=stale-auth"));
685 assert!(seen_urls[1].contains("auth=fresh-auth"));
686 }
687}