stream_download/lib.rs
1#![deny(missing_docs)]
2#![forbid(clippy::unwrap_used)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4#![doc = include_str!("../README.md")]
5
6use std::fmt::Debug;
7use std::future::{self, Future};
8use std::io::{self, Read, Seek, SeekFrom};
9
10use educe::Educe;
11pub use settings::*;
12use source::handle::SourceHandle;
13use source::{DecodeError, Source, SourceStream};
14use storage::StorageProvider;
15use tokio_util::sync::CancellationToken;
16use tracing::{debug, error, instrument, trace};
17
18#[cfg(feature = "async-read")]
19pub mod async_read;
20#[cfg(feature = "http")]
21pub mod http;
22#[cfg(feature = "process")]
23pub mod process;
24#[cfg(feature = "registry")]
25pub mod registry;
26mod settings;
27pub mod source;
28pub mod storage;
29
30/// A handle that can be usd to interact with the stream remotely.
31#[derive(Debug, Clone)]
32pub struct StreamHandle {
33 finished: CancellationToken,
34}
35
36impl StreamHandle {
37 /// Wait for the stream download task to complete.
38 ///
39 /// This method can be useful when using a [`ProcessStream`][process::ProcessStream] if you want
40 /// to ensure the subprocess has exited cleanly before continuing.
41 pub async fn wait_for_completion(self) {
42 self.finished.cancelled().await;
43 }
44}
45
46/// Represents content streamed from a remote source.
47/// This struct implements [read](https://doc.rust-lang.org/stable/std/io/trait.Read.html)
48/// and [seek](https://doc.rust-lang.org/stable/std/io/trait.Seek.html)
49/// so it can be used as a generic source for libraries and applications that operate on these
50/// traits. On creation, an async task is spawned that will immediately start to download the remote
51/// content.
52///
53/// Any read attempts that request part of the stream that hasn't been downloaded yet will block
54/// until the requested portion is reached. Any seek attempts that meet the same criteria will
55/// result in additional request to restart the stream download from the seek point.
56///
57/// If the stream download hasn't completed when this struct is dropped, the task will be cancelled.
58///
59/// If the stream stalls for any reason, the download task will attempt to automatically reconnect.
60/// This reconnect interval can be controlled via [`Settings::retry_timeout`].
61/// Server-side failures are not automatically handled and should be retried by the supplied
62/// [`SourceStream`] implementation if desired.
63#[derive(Debug)]
64pub struct StreamDownload<P: StorageProvider> {
65 output_reader: P::Reader,
66 handle: SourceHandle,
67 download_task_cancellation_token: CancellationToken,
68 cancel_on_drop: bool,
69 content_length: Option<u64>,
70 storage_capacity: Option<usize>,
71}
72
73impl<P: StorageProvider> StreamDownload<P> {
74 #[cfg(feature = "reqwest")]
75 /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
76 ///
77 /// # Example
78 ///
79 /// ```no_run
80 /// use std::error::Error;
81 /// use std::io::{self, Read};
82 /// use std::result::Result;
83 ///
84 /// use stream_download::source::DecodeError;
85 /// use stream_download::storage::temp::TempStorageProvider;
86 /// use stream_download::{Settings, StreamDownload};
87 ///
88 /// #[tokio::main]
89 /// async fn main() -> Result<(), Box<dyn Error>> {
90 /// let mut reader = match StreamDownload::new_http(
91 /// "https://some-cool-url.com/some-file.mp3".parse()?,
92 /// TempStorageProvider::default(),
93 /// Settings::default(),
94 /// )
95 /// .await
96 /// {
97 /// Ok(reader) => reader,
98 /// Err(e) => return Err(e.decode_error().await)?,
99 /// };
100 ///
101 /// tokio::task::spawn_blocking(move || {
102 /// let mut buf = Vec::new();
103 /// reader.read_to_end(&mut buf)?;
104 /// Ok::<_, io::Error>(())
105 /// })
106 /// .await??;
107 /// Ok(())
108 /// }
109 /// ```
110 pub async fn new_http(
111 url: ::reqwest::Url,
112 storage_provider: P,
113 settings: Settings<http::HttpStream<::reqwest::Client>>,
114 ) -> Result<Self, StreamInitializationError<http::HttpStream<::reqwest::Client>>> {
115 Self::new(url, storage_provider, settings).await
116 }
117
118 /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
119 /// It uses the [`reqwest_middleware::ClientWithMiddleware`] client instead of the default
120 /// [`reqwest`] client. Any global middleware set by [`Settings::add_default_middleware`] will
121 /// be automatically applied.
122 ///
123 /// # Example
124 ///
125 /// ```no_run
126 /// use std::error::Error;
127 /// use std::io::{self, Read};
128 /// use std::result::Result;
129 ///
130 /// use reqwest_retry::RetryTransientMiddleware;
131 /// use reqwest_retry::policies::ExponentialBackoff;
132 /// use stream_download::source::DecodeError;
133 /// use stream_download::storage::temp::TempStorageProvider;
134 /// use stream_download::{Settings, StreamDownload};
135 ///
136 /// #[tokio::main]
137 /// async fn main() -> Result<(), Box<dyn Error>> {
138 /// let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
139 /// Settings::add_default_middleware(RetryTransientMiddleware::new_with_policy(retry_policy));
140 ///
141 /// let mut reader = match StreamDownload::new_http_with_middleware(
142 /// "https://some-cool-url.com/some-file.mp3".parse()?,
143 /// TempStorageProvider::default(),
144 /// Settings::default(),
145 /// )
146 /// .await
147 /// {
148 /// Ok(reader) => reader,
149 /// Err(e) => return Err(e.decode_error().await)?,
150 /// };
151 ///
152 /// tokio::task::spawn_blocking(move || {
153 /// let mut buf = Vec::new();
154 /// reader.read_to_end(&mut buf)?;
155 /// Ok::<_, io::Error>(())
156 /// })
157 /// .await??;
158 /// Ok(())
159 /// }
160 /// ```
161 #[cfg(feature = "reqwest-middleware")]
162 pub async fn new_http_with_middleware(
163 url: ::reqwest::Url,
164 storage_provider: P,
165 settings: Settings<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
166 ) -> Result<
167 Self,
168 StreamInitializationError<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
169 > {
170 Self::new(url, storage_provider, settings).await
171 }
172
173 /// Creates a new [`StreamDownload`] that uses an [`AsyncRead`][tokio::io::AsyncRead] resource.
174 ///
175 /// # Example reading from `stdin`
176 ///
177 /// ```no_run
178 /// use std::error::Error;
179 /// use std::io::{self, Read};
180 /// use std::result::Result;
181 ///
182 /// use stream_download::async_read::AsyncReadStreamParams;
183 /// use stream_download::storage::temp::TempStorageProvider;
184 /// use stream_download::{Settings, StreamDownload};
185 ///
186 /// #[tokio::main]
187 /// async fn main() -> Result<(), Box<dyn Error>> {
188 /// let mut reader = StreamDownload::new_async_read(
189 /// AsyncReadStreamParams::new(tokio::io::stdin()),
190 /// TempStorageProvider::new(),
191 /// Settings::default(),
192 /// )
193 /// .await?;
194 ///
195 /// tokio::task::spawn_blocking(move || {
196 /// let mut buf = Vec::new();
197 /// reader.read_to_end(&mut buf)?;
198 /// Ok::<_, io::Error>(())
199 /// })
200 /// .await??;
201 /// Ok(())
202 /// }
203 /// ```
204 #[cfg(feature = "async-read")]
205 pub async fn new_async_read<T>(
206 params: async_read::AsyncReadStreamParams<T>,
207 storage_provider: P,
208 settings: Settings<async_read::AsyncReadStream<T>>,
209 ) -> Result<Self, StreamInitializationError<async_read::AsyncReadStream<T>>>
210 where
211 T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static,
212 {
213 Self::new(params, storage_provider, settings).await
214 }
215
216 /// Creates a new [`StreamDownload`] that uses a [`Command`][process::Command] as input.
217 ///
218 /// # Example
219 ///
220 /// ```no_run
221 /// use std::error::Error;
222 /// use std::io::{self, Read};
223 /// use std::result::Result;
224 ///
225 /// use stream_download::process::{Command, ProcessStreamParams};
226 /// use stream_download::storage::temp::TempStorageProvider;
227 /// use stream_download::{Settings, StreamDownload};
228 ///
229 /// #[tokio::main]
230 /// async fn main() -> Result<(), Box<dyn Error>> {
231 /// let mut reader = StreamDownload::new_process(
232 /// ProcessStreamParams::new(Command::new("cat").args(["./assets/music.mp3"]))?,
233 /// TempStorageProvider::new(),
234 /// Settings::default(),
235 /// )
236 /// .await?;
237 ///
238 /// tokio::task::spawn_blocking(move || {
239 /// let mut buf = Vec::new();
240 /// reader.read_to_end(&mut buf)?;
241 /// Ok::<_, io::Error>(())
242 /// })
243 /// .await??;
244 /// Ok(())
245 /// }
246 /// ```
247 #[cfg(feature = "process")]
248 pub async fn new_process(
249 params: process::ProcessStreamParams,
250 storage_provider: P,
251 settings: Settings<process::ProcessStream>,
252 ) -> Result<Self, StreamInitializationError<process::ProcessStream>> {
253 Self::new(params, storage_provider, settings).await
254 }
255
256 /// Creates a new [`StreamDownload`] that accesses a remote resource at the given URL.
257 ///
258 /// # Example
259 ///
260 /// ```no_run
261 /// use std::error::Error;
262 /// use std::io::{self, Read};
263 /// use std::result::Result;
264 ///
265 /// use reqwest::Client;
266 /// use stream_download::http::HttpStream;
267 /// use stream_download::storage::temp::TempStorageProvider;
268 /// use stream_download::{Settings, StreamDownload};
269 ///
270 /// use crate::stream_download::source::DecodeError;
271 ///
272 /// #[tokio::main]
273 /// async fn main() -> Result<(), Box<dyn Error>> {
274 /// let mut reader = match StreamDownload::new::<HttpStream<Client>>(
275 /// "https://some-cool-url.com/some-file.mp3".parse()?,
276 /// TempStorageProvider::default(),
277 /// Settings::default(),
278 /// )
279 /// .await
280 /// {
281 /// Ok(reader) => reader,
282 /// Err(e) => return Err(e.decode_error().await)?,
283 /// };
284 ///
285 /// tokio::task::spawn_blocking(move || {
286 /// let mut buf = Vec::new();
287 /// reader.read_to_end(&mut buf)?;
288 /// Ok::<_, io::Error>(())
289 /// })
290 /// .await??;
291 /// Ok(())
292 /// }
293 /// ```
294 pub async fn new<S>(
295 params: S::Params,
296 storage_provider: P,
297 settings: Settings<S>,
298 ) -> Result<Self, StreamInitializationError<S>>
299 where
300 S: SourceStream,
301 S::Error: Debug + Send,
302 {
303 Self::from_create_stream(move || S::create(params), storage_provider, settings).await
304 }
305
306 /// Creates a new [`StreamDownload`] from a [`SourceStream`].
307 ///
308 /// # Example
309 ///
310 /// ```no_run
311 /// use std::error::Error;
312 /// use std::io::Read;
313 /// use std::result::Result;
314 ///
315 /// use reqwest::Client;
316 /// use stream_download::http::HttpStream;
317 /// use stream_download::storage::temp::TempStorageProvider;
318 /// use stream_download::{Settings, StreamDownload};
319 ///
320 /// use crate::stream_download::source::DecodeError;
321 ///
322 /// #[tokio::main]
323 /// async fn main() -> Result<(), Box<dyn Error>> {
324 /// let stream = HttpStream::new(
325 /// Client::new(),
326 /// "https://some-cool-url.com/some-file.mp3".parse()?,
327 /// )
328 /// .await?;
329 ///
330 /// let mut reader = match StreamDownload::from_stream(
331 /// stream,
332 /// TempStorageProvider::default(),
333 /// Settings::default(),
334 /// )
335 /// .await
336 /// {
337 /// Ok(reader) => reader,
338 /// Err(e) => Err(e.decode_error().await)?,
339 /// };
340 /// Ok(())
341 /// }
342 /// ```
343 pub async fn from_stream<S>(
344 stream: S,
345 storage_provider: P,
346 settings: Settings<S>,
347 ) -> Result<Self, StreamInitializationError<S>>
348 where
349 S: SourceStream,
350 S::Error: Debug + Send,
351 {
352 Self::from_create_stream(
353 move || future::ready(Ok(stream)),
354 storage_provider,
355 settings,
356 )
357 .await
358 }
359
360 /// Cancels the background task that's downloading the stream content.
361 /// This has no effect if the download is already completed.
362 pub fn cancel_download(&self) {
363 self.download_task_cancellation_token.cancel();
364 }
365
366 /// Returns the [`CancellationToken`] for the download task.
367 /// This can be used to cancel the download task before it completes.
368 pub fn cancellation_token(&self) -> CancellationToken {
369 self.download_task_cancellation_token.clone()
370 }
371
372 /// Returns a [`StreamHandle`] that can be used to interact with
373 /// the stream remotely.
374 pub fn handle(&self) -> StreamHandle {
375 StreamHandle {
376 finished: self.download_task_cancellation_token.clone(),
377 }
378 }
379
380 /// Returns the content length of the stream, if available.
381 pub fn content_length(&self) -> Option<u64> {
382 self.content_length
383 }
384
385 async fn from_create_stream<S, F, Fut>(
386 create_stream: F,
387 storage_provider: P,
388 settings: Settings<S>,
389 ) -> Result<Self, StreamInitializationError<S>>
390 where
391 S: SourceStream<Error: Debug + Send>,
392 F: FnOnce() -> Fut + Send + 'static,
393 Fut: Future<Output = Result<S, S::StreamCreationError>> + Send,
394 {
395 let stream = create_stream()
396 .await
397 .map_err(StreamInitializationError::StreamCreationFailure)?;
398 let content_length = stream.content_length();
399 let storage_capacity = storage_provider.max_capacity();
400 let (reader, writer) = storage_provider
401 .into_reader_writer(content_length)
402 .map_err(StreamInitializationError::StorageCreationFailure)?;
403 let cancellation_token = CancellationToken::new();
404 let cancel_on_drop = settings.cancel_on_drop;
405 let mut source = Source::new(writer, content_length, settings, cancellation_token.clone());
406 let handle = source.source_handle();
407
408 tokio::spawn({
409 let cancellation_token = cancellation_token.clone();
410 async move {
411 source.download(stream).await;
412 cancellation_token.cancel();
413 debug!("download task finished");
414 }
415 });
416
417 Ok(Self {
418 output_reader: reader,
419 handle,
420 download_task_cancellation_token: cancellation_token,
421 cancel_on_drop,
422 content_length,
423 storage_capacity,
424 })
425 }
426
427 fn get_absolute_seek_position(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
428 Ok(match relative_position {
429 SeekFrom::Start(position) => {
430 debug!(seek_position = position, "seeking from start");
431 position
432 }
433 SeekFrom::Current(position) => {
434 debug!(seek_position = position, "seeking from current position");
435 (self.output_reader.stream_position()? as i64 + position) as u64
436 }
437 SeekFrom::End(position) => {
438 debug!(seek_position = position, "seeking from end");
439 if let Some(length) = self.handle.content_length() {
440 (length as i64 + position) as u64
441 } else {
442 return Err(io::Error::new(
443 io::ErrorKind::Unsupported,
444 "cannot seek from end when content length is unknown",
445 ));
446 }
447 }
448 })
449 }
450
451 fn handle_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
452 let res = self.output_reader.read(buf).inspect(|l| {
453 trace!(read_length = format!("{l:?}"), "returning read");
454 });
455 self.handle.notify_read();
456 res
457 }
458
459 fn normalize_requested_position(&self, requested_position: u64) -> u64 {
460 if let Some(content_length) = self.content_length {
461 // ensure we don't request a position beyond the end of the stream
462 requested_position.min(content_length)
463 } else {
464 requested_position
465 }
466 }
467
468 fn check_for_failure(&self) -> io::Result<()> {
469 if self.handle.is_failed() {
470 Err(io::Error::other("stream failed to download"))
471 } else {
472 Ok(())
473 }
474 }
475
476 fn check_for_excessive_read(&self, buf_len: usize) -> io::Result<()> {
477 // Ensure the buffer fits within the storage capacity.
478 // We could get around this from erroring by breaking this into multiple smaller reads, but
479 // if you're using a bounded storage type, that's probably not what you want.
480 let capacity = self.storage_capacity.unwrap_or(usize::MAX);
481 if buf_len > capacity {
482 Err(io::Error::new(
483 io::ErrorKind::InvalidInput,
484 format!("buffer size {buf_len} exceeds the max capacity of {capacity}",),
485 ))
486 } else {
487 Ok(())
488 }
489 }
490
491 fn check_for_excessive_seek(&mut self, absolute_seek_position: u64) -> io::Result<()> {
492 // Ensure the seek position is within the available storage capacity.
493 // We could get around this by issuing a few read requests until the seek position is within
494 // bounds, but if you're using a bounded storage type, that's probably not what you want.
495 if let Some(max_capacity) = self.storage_capacity {
496 let max_possible_seek_position = self
497 .output_reader
498 .stream_position()?
499 .saturating_add(max_capacity as u64);
500 if absolute_seek_position
501 > self
502 .output_reader
503 .stream_position()?
504 .saturating_add(max_capacity as u64)
505 {
506 return Err(io::Error::new(
507 io::ErrorKind::InvalidInput,
508 format!(
509 "seek position {absolute_seek_position} exceeds maximum of \
510 {max_possible_seek_position}"
511 ),
512 ));
513 }
514 }
515 Ok(())
516 }
517}
518
519/// Error returned when initializing a stream.
520#[derive(thiserror::Error, Educe)]
521#[educe(Debug)]
522pub enum StreamInitializationError<S: SourceStream> {
523 /// Storage creation failure.
524 #[error("Storage creation failure: {0}")]
525 StorageCreationFailure(io::Error),
526 /// Stream creation failure.
527 #[error("Stream creation failure: {0}")]
528 StreamCreationFailure(<S as SourceStream>::StreamCreationError),
529}
530
531impl<S: SourceStream> DecodeError for StreamInitializationError<S> {
532 async fn decode_error(self) -> String {
533 match self {
534 this @ Self::StorageCreationFailure(_) => this.to_string(),
535 Self::StreamCreationFailure(e) => e.decode_error().await,
536 }
537 }
538}
539
540impl<P: StorageProvider> Drop for StreamDownload<P> {
541 fn drop(&mut self) {
542 if self.cancel_on_drop {
543 self.cancel_download();
544 }
545 }
546}
547
548impl<P: StorageProvider> Read for StreamDownload<P> {
549 #[instrument(skip_all, fields(len=buf.len()))]
550 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
551 self.check_for_failure()?;
552 self.check_for_excessive_read(buf.len())?;
553
554 trace!(buffer_length = buf.len(), "read requested");
555 let stream_position = self.output_reader.stream_position()?;
556 let requested_position =
557 self.normalize_requested_position(stream_position + buf.len() as u64);
558 trace!(
559 current_position = stream_position,
560 requested_position = requested_position
561 );
562
563 if let Some(closest_set) = self.handle.get_downloaded_at_position(stream_position) {
564 trace!(
565 downloaded_range = format!("{closest_set:?}"),
566 "current position already downloaded"
567 );
568 if closest_set.end >= requested_position {
569 trace!("requested position already downloaded");
570 return self.handle_read(buf);
571 }
572 debug!("requested position not yet downloaded");
573 } else {
574 debug!("stream position not yet downloaded");
575 }
576
577 self.handle.wait_for_position(requested_position);
578 self.check_for_failure()?;
579 debug!(
580 current_position = stream_position,
581 requested_position = requested_position,
582 output_stream_position = self.output_reader.stream_position()?,
583 "reached requested position"
584 );
585
586 self.handle_read(buf)
587 }
588}
589
590impl<P: StorageProvider> Seek for StreamDownload<P> {
591 #[instrument(skip(self))]
592 fn seek(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
593 self.check_for_failure()?;
594
595 let absolute_seek_position = self.get_absolute_seek_position(relative_position)?;
596 let absolute_seek_position = self.normalize_requested_position(absolute_seek_position);
597 self.check_for_excessive_seek(absolute_seek_position)?;
598
599 debug!(absolute_seek_position, "absolute seek position");
600 if let Some(closest_set) = self
601 .handle
602 .get_downloaded_at_position(absolute_seek_position)
603 {
604 debug!(
605 downloaded_range = format!("{closest_set:?}"),
606 "seek position already downloaded"
607 );
608 return self
609 .output_reader
610 .seek(SeekFrom::Start(absolute_seek_position))
611 .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"));
612 }
613
614 self.handle.seek(absolute_seek_position);
615 self.check_for_failure()?;
616 debug!("reached seek position");
617
618 self.output_reader
619 .seek(SeekFrom::Start(absolute_seek_position))
620 .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"))
621 }
622}
623
624pub(crate) trait WrapIoResult {
625 fn wrap_err(self, msg: &str) -> Self;
626}
627
628impl<T> WrapIoResult for io::Result<T> {
629 fn wrap_err(self, msg: &str) -> Self {
630 if let Err(e) = self {
631 Err(io::Error::new(e.kind(), format!("{msg}: {e}")))
632 } else {
633 self
634 }
635 }
636}