stream_download/
lib.rs

1#![deny(missing_docs)]
2#![forbid(clippy::unwrap_used)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4#![doc = include_str!("../README.md")]
5
6use std::fmt::Debug;
7use std::future::{self, Future};
8use std::io::{self, Read, Seek, SeekFrom};
9
10use educe::Educe;
11pub use settings::*;
12use source::handle::SourceHandle;
13use source::{DecodeError, Source, SourceStream};
14use storage::StorageProvider;
15use tokio_util::sync::CancellationToken;
16use tracing::{debug, error, instrument, trace};
17
18#[cfg(feature = "async-read")]
19pub mod async_read;
20#[cfg(feature = "http")]
21pub mod http;
22#[cfg(feature = "process")]
23pub mod process;
24#[cfg(feature = "registry")]
25pub mod registry;
26mod settings;
27pub mod source;
28pub mod storage;
29
30/// A handle that can be usd to interact with the stream remotely.
31#[derive(Debug, Clone)]
32pub struct StreamHandle {
33    finished: CancellationToken,
34}
35
36impl StreamHandle {
37    /// Wait for the stream download task to complete.
38    ///
39    /// This method can be useful when using a [`ProcessStream`][process::ProcessStream] if you want
40    /// to ensure the subprocess has exited cleanly before continuing.
41    pub async fn wait_for_completion(self) {
42        self.finished.cancelled().await;
43    }
44}
45
46/// Represents content streamed from a remote source.
47/// This struct implements [read](https://doc.rust-lang.org/stable/std/io/trait.Read.html)
48/// and [seek](https://doc.rust-lang.org/stable/std/io/trait.Seek.html)
49/// so it can be used as a generic source for libraries and applications that operate on these
50/// traits. On creation, an async task is spawned that will immediately start to download the remote
51/// content.
52///
53/// Any read attempts that request part of the stream that hasn't been downloaded yet will block
54/// until the requested portion is reached. Any seek attempts that meet the same criteria will
55/// result in additional request to restart the stream download from the seek point.
56///
57/// If the stream download hasn't completed when this struct is dropped, the task will be cancelled.
58///
59/// If the stream stalls for any reason, the download task will attempt to automatically reconnect.
60/// This reconnect interval can be controlled via [`Settings::retry_timeout`].
61/// Server-side failures are not automatically handled and should be retried by the supplied
62/// [`SourceStream`] implementation if desired.
63#[derive(Debug)]
64pub struct StreamDownload<P: StorageProvider> {
65    output_reader: P::Reader,
66    handle: SourceHandle,
67    download_task_cancellation_token: CancellationToken,
68    cancel_on_drop: bool,
69    content_length: Option<u64>,
70    storage_capacity: Option<usize>,
71}
72
73impl<P: StorageProvider> StreamDownload<P> {
74    #[cfg(feature = "reqwest")]
75    /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
76    ///
77    /// # Example
78    ///
79    /// ```no_run
80    /// use std::error::Error;
81    /// use std::io::{self, Read};
82    /// use std::result::Result;
83    ///
84    /// use stream_download::source::DecodeError;
85    /// use stream_download::storage::temp::TempStorageProvider;
86    /// use stream_download::{Settings, StreamDownload};
87    ///
88    /// #[tokio::main]
89    /// async fn main() -> Result<(), Box<dyn Error>> {
90    ///     let mut reader = match StreamDownload::new_http(
91    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
92    ///         TempStorageProvider::default(),
93    ///         Settings::default(),
94    ///     )
95    ///     .await
96    ///     {
97    ///         Ok(reader) => reader,
98    ///         Err(e) => return Err(e.decode_error().await)?,
99    ///     };
100    ///
101    ///     tokio::task::spawn_blocking(move || {
102    ///         let mut buf = Vec::new();
103    ///         reader.read_to_end(&mut buf)?;
104    ///         Ok::<_, io::Error>(())
105    ///     })
106    ///     .await??;
107    ///     Ok(())
108    /// }
109    /// ```
110    pub async fn new_http(
111        url: ::reqwest::Url,
112        storage_provider: P,
113        settings: Settings<http::HttpStream<::reqwest::Client>>,
114    ) -> Result<Self, StreamInitializationError<http::HttpStream<::reqwest::Client>>> {
115        Self::new(url, storage_provider, settings).await
116    }
117
118    /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
119    /// It uses the [`reqwest_middleware::ClientWithMiddleware`] client instead of the default
120    /// [`reqwest`] client. Any global middleware set by [`Settings::add_default_middleware`] will
121    /// be automatically applied.
122    ///
123    /// # Example
124    ///
125    /// ```no_run
126    /// use std::error::Error;
127    /// use std::io::{self, Read};
128    /// use std::result::Result;
129    ///
130    /// use reqwest_retry::RetryTransientMiddleware;
131    /// use reqwest_retry::policies::ExponentialBackoff;
132    /// use stream_download::source::DecodeError;
133    /// use stream_download::storage::temp::TempStorageProvider;
134    /// use stream_download::{Settings, StreamDownload};
135    ///
136    /// #[tokio::main]
137    /// async fn main() -> Result<(), Box<dyn Error>> {
138    ///     let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
139    ///     Settings::add_default_middleware(RetryTransientMiddleware::new_with_policy(retry_policy));
140    ///
141    ///     let mut reader = match StreamDownload::new_http_with_middleware(
142    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
143    ///         TempStorageProvider::default(),
144    ///         Settings::default(),
145    ///     )
146    ///     .await
147    ///     {
148    ///         Ok(reader) => reader,
149    ///         Err(e) => return Err(e.decode_error().await)?,
150    ///     };
151    ///
152    ///     tokio::task::spawn_blocking(move || {
153    ///         let mut buf = Vec::new();
154    ///         reader.read_to_end(&mut buf)?;
155    ///         Ok::<_, io::Error>(())
156    ///     })
157    ///     .await??;
158    ///     Ok(())
159    /// }
160    /// ```
161    #[cfg(feature = "reqwest-middleware")]
162    pub async fn new_http_with_middleware(
163        url: ::reqwest::Url,
164        storage_provider: P,
165        settings: Settings<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
166    ) -> Result<
167        Self,
168        StreamInitializationError<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
169    > {
170        Self::new(url, storage_provider, settings).await
171    }
172
173    /// Creates a new [`StreamDownload`] that uses an [`AsyncRead`][tokio::io::AsyncRead] resource.
174    ///
175    /// # Example reading from `stdin`
176    ///
177    /// ```no_run
178    /// use std::error::Error;
179    /// use std::io::{self, Read};
180    /// use std::result::Result;
181    ///
182    /// use stream_download::async_read::AsyncReadStreamParams;
183    /// use stream_download::storage::temp::TempStorageProvider;
184    /// use stream_download::{Settings, StreamDownload};
185    ///
186    /// #[tokio::main]
187    /// async fn main() -> Result<(), Box<dyn Error>> {
188    ///     let mut reader = StreamDownload::new_async_read(
189    ///         AsyncReadStreamParams::new(tokio::io::stdin()),
190    ///         TempStorageProvider::new(),
191    ///         Settings::default(),
192    ///     )
193    ///     .await?;
194    ///
195    ///     tokio::task::spawn_blocking(move || {
196    ///         let mut buf = Vec::new();
197    ///         reader.read_to_end(&mut buf)?;
198    ///         Ok::<_, io::Error>(())
199    ///     })
200    ///     .await??;
201    ///     Ok(())
202    /// }
203    /// ```
204    #[cfg(feature = "async-read")]
205    pub async fn new_async_read<T>(
206        params: async_read::AsyncReadStreamParams<T>,
207        storage_provider: P,
208        settings: Settings<async_read::AsyncReadStream<T>>,
209    ) -> Result<Self, StreamInitializationError<async_read::AsyncReadStream<T>>>
210    where
211        T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static,
212    {
213        Self::new(params, storage_provider, settings).await
214    }
215
216    /// Creates a new [`StreamDownload`] that uses a [`Command`][process::Command] as input.
217    ///
218    /// # Example
219    ///
220    /// ```no_run
221    /// use std::error::Error;
222    /// use std::io::{self, Read};
223    /// use std::result::Result;
224    ///
225    /// use stream_download::process::{Command, ProcessStreamParams};
226    /// use stream_download::storage::temp::TempStorageProvider;
227    /// use stream_download::{Settings, StreamDownload};
228    ///
229    /// #[tokio::main]
230    /// async fn main() -> Result<(), Box<dyn Error>> {
231    ///     let mut reader = StreamDownload::new_process(
232    ///         ProcessStreamParams::new(Command::new("cat").args(["./assets/music.mp3"]))?,
233    ///         TempStorageProvider::new(),
234    ///         Settings::default(),
235    ///     )
236    ///     .await?;
237    ///
238    ///     tokio::task::spawn_blocking(move || {
239    ///         let mut buf = Vec::new();
240    ///         reader.read_to_end(&mut buf)?;
241    ///         Ok::<_, io::Error>(())
242    ///     })
243    ///     .await??;
244    ///     Ok(())
245    /// }
246    /// ```
247    #[cfg(feature = "process")]
248    pub async fn new_process(
249        params: process::ProcessStreamParams,
250        storage_provider: P,
251        settings: Settings<process::ProcessStream>,
252    ) -> Result<Self, StreamInitializationError<process::ProcessStream>> {
253        Self::new(params, storage_provider, settings).await
254    }
255
256    /// Creates a new [`StreamDownload`] that accesses a remote resource at the given URL.
257    ///
258    /// # Example
259    ///
260    /// ```no_run
261    /// use std::error::Error;
262    /// use std::io::{self, Read};
263    /// use std::result::Result;
264    ///
265    /// use reqwest::Client;
266    /// use stream_download::http::HttpStream;
267    /// use stream_download::storage::temp::TempStorageProvider;
268    /// use stream_download::{Settings, StreamDownload};
269    ///
270    /// use crate::stream_download::source::DecodeError;
271    ///
272    /// #[tokio::main]
273    /// async fn main() -> Result<(), Box<dyn Error>> {
274    ///     let mut reader = match StreamDownload::new::<HttpStream<Client>>(
275    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
276    ///         TempStorageProvider::default(),
277    ///         Settings::default(),
278    ///     )
279    ///     .await
280    ///     {
281    ///         Ok(reader) => reader,
282    ///         Err(e) => return Err(e.decode_error().await)?,
283    ///     };
284    ///
285    ///     tokio::task::spawn_blocking(move || {
286    ///         let mut buf = Vec::new();
287    ///         reader.read_to_end(&mut buf)?;
288    ///         Ok::<_, io::Error>(())
289    ///     })
290    ///     .await??;
291    ///     Ok(())
292    /// }
293    /// ```
294    pub async fn new<S>(
295        params: S::Params,
296        storage_provider: P,
297        settings: Settings<S>,
298    ) -> Result<Self, StreamInitializationError<S>>
299    where
300        S: SourceStream,
301        S::Error: Debug + Send,
302    {
303        Self::from_create_stream(move || S::create(params), storage_provider, settings).await
304    }
305
306    /// Creates a new [`StreamDownload`] from a [`SourceStream`].
307    ///
308    /// # Example
309    ///
310    /// ```no_run
311    /// use std::error::Error;
312    /// use std::io::Read;
313    /// use std::result::Result;
314    ///
315    /// use reqwest::Client;
316    /// use stream_download::http::HttpStream;
317    /// use stream_download::storage::temp::TempStorageProvider;
318    /// use stream_download::{Settings, StreamDownload};
319    ///
320    /// use crate::stream_download::source::DecodeError;
321    ///
322    /// #[tokio::main]
323    /// async fn main() -> Result<(), Box<dyn Error>> {
324    ///     let stream = HttpStream::new(
325    ///         Client::new(),
326    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
327    ///     )
328    ///     .await?;
329    ///
330    ///     let mut reader = match StreamDownload::from_stream(
331    ///         stream,
332    ///         TempStorageProvider::default(),
333    ///         Settings::default(),
334    ///     )
335    ///     .await
336    ///     {
337    ///         Ok(reader) => reader,
338    ///         Err(e) => Err(e.decode_error().await)?,
339    ///     };
340    ///     Ok(())
341    /// }
342    /// ```
343    pub async fn from_stream<S>(
344        stream: S,
345        storage_provider: P,
346        settings: Settings<S>,
347    ) -> Result<Self, StreamInitializationError<S>>
348    where
349        S: SourceStream,
350        S::Error: Debug + Send,
351    {
352        Self::from_create_stream(
353            move || future::ready(Ok(stream)),
354            storage_provider,
355            settings,
356        )
357        .await
358    }
359
360    /// Cancels the background task that's downloading the stream content.
361    /// This has no effect if the download is already completed.
362    pub fn cancel_download(&self) {
363        self.download_task_cancellation_token.cancel();
364    }
365
366    /// Returns the [`CancellationToken`] for the download task.
367    /// This can be used to cancel the download task before it completes.
368    pub fn cancellation_token(&self) -> CancellationToken {
369        self.download_task_cancellation_token.clone()
370    }
371
372    /// Returns a [`StreamHandle`] that can be used to interact with
373    /// the stream remotely.
374    pub fn handle(&self) -> StreamHandle {
375        StreamHandle {
376            finished: self.download_task_cancellation_token.clone(),
377        }
378    }
379
380    /// Returns the content length of the stream, if available.
381    pub fn content_length(&self) -> Option<u64> {
382        self.content_length
383    }
384
385    async fn from_create_stream<S, F, Fut>(
386        create_stream: F,
387        storage_provider: P,
388        settings: Settings<S>,
389    ) -> Result<Self, StreamInitializationError<S>>
390    where
391        S: SourceStream<Error: Debug + Send>,
392        F: FnOnce() -> Fut + Send + 'static,
393        Fut: Future<Output = Result<S, S::StreamCreationError>> + Send,
394    {
395        let stream = create_stream()
396            .await
397            .map_err(StreamInitializationError::StreamCreationFailure)?;
398        let content_length = stream.content_length();
399        let storage_capacity = storage_provider.max_capacity();
400        let (reader, writer) = storage_provider
401            .into_reader_writer(content_length)
402            .map_err(StreamInitializationError::StorageCreationFailure)?;
403        let cancellation_token = CancellationToken::new();
404        let cancel_on_drop = settings.cancel_on_drop;
405        let mut source = Source::new(writer, content_length, settings, cancellation_token.clone());
406        let handle = source.source_handle();
407
408        tokio::spawn({
409            let cancellation_token = cancellation_token.clone();
410            async move {
411                source.download(stream).await;
412                cancellation_token.cancel();
413                debug!("download task finished");
414            }
415        });
416
417        Ok(Self {
418            output_reader: reader,
419            handle,
420            download_task_cancellation_token: cancellation_token,
421            cancel_on_drop,
422            content_length,
423            storage_capacity,
424        })
425    }
426
427    fn get_absolute_seek_position(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
428        Ok(match relative_position {
429            SeekFrom::Start(position) => {
430                debug!(seek_position = position, "seeking from start");
431                position
432            }
433            SeekFrom::Current(position) => {
434                debug!(seek_position = position, "seeking from current position");
435                (self.output_reader.stream_position()? as i64 + position) as u64
436            }
437            SeekFrom::End(position) => {
438                debug!(seek_position = position, "seeking from end");
439                if let Some(length) = self.handle.content_length() {
440                    (length as i64 + position) as u64
441                } else {
442                    return Err(io::Error::new(
443                        io::ErrorKind::Unsupported,
444                        "cannot seek from end when content length is unknown",
445                    ));
446                }
447            }
448        })
449    }
450
451    fn handle_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
452        let res = self.output_reader.read(buf).inspect(|l| {
453            trace!(read_length = format!("{l:?}"), "returning read");
454        });
455        self.handle.notify_read();
456        res
457    }
458
459    fn normalize_requested_position(&self, requested_position: u64) -> u64 {
460        if let Some(content_length) = self.content_length {
461            // ensure we don't request a position beyond the end of the stream
462            requested_position.min(content_length)
463        } else {
464            requested_position
465        }
466    }
467
468    fn check_for_failure(&self) -> io::Result<()> {
469        if self.handle.is_failed() {
470            Err(io::Error::other("stream failed to download"))
471        } else {
472            Ok(())
473        }
474    }
475
476    fn check_for_excessive_read(&self, buf_len: usize) -> io::Result<()> {
477        // Ensure the buffer fits within the storage capacity.
478        // We could get around this from erroring by breaking this into multiple smaller reads, but
479        // if you're using a bounded storage type, that's probably not what you want.
480        let capacity = self.storage_capacity.unwrap_or(usize::MAX);
481        if buf_len > capacity {
482            Err(io::Error::new(
483                io::ErrorKind::InvalidInput,
484                format!("buffer size {buf_len} exceeds the max capacity of {capacity}",),
485            ))
486        } else {
487            Ok(())
488        }
489    }
490
491    fn check_for_excessive_seek(&mut self, absolute_seek_position: u64) -> io::Result<()> {
492        // Ensure the seek position is within the available storage capacity.
493        // We could get around this by issuing a few read requests until the seek position is within
494        // bounds, but if you're using a bounded storage type, that's probably not what you want.
495        if let Some(max_capacity) = self.storage_capacity {
496            let max_possible_seek_position = self
497                .output_reader
498                .stream_position()?
499                .saturating_add(max_capacity as u64);
500            if absolute_seek_position
501                > self
502                    .output_reader
503                    .stream_position()?
504                    .saturating_add(max_capacity as u64)
505            {
506                return Err(io::Error::new(
507                    io::ErrorKind::InvalidInput,
508                    format!(
509                        "seek position {absolute_seek_position} exceeds maximum of \
510                         {max_possible_seek_position}"
511                    ),
512                ));
513            }
514        }
515        Ok(())
516    }
517}
518
519/// Error returned when initializing a stream.
520#[derive(thiserror::Error, Educe)]
521#[educe(Debug)]
522pub enum StreamInitializationError<S: SourceStream> {
523    /// Storage creation failure.
524    #[error("Storage creation failure: {0}")]
525    StorageCreationFailure(io::Error),
526    /// Stream creation failure.
527    #[error("Stream creation failure: {0}")]
528    StreamCreationFailure(<S as SourceStream>::StreamCreationError),
529}
530
531impl<S: SourceStream> DecodeError for StreamInitializationError<S> {
532    async fn decode_error(self) -> String {
533        match self {
534            this @ Self::StorageCreationFailure(_) => this.to_string(),
535            Self::StreamCreationFailure(e) => e.decode_error().await,
536        }
537    }
538}
539
540impl<P: StorageProvider> Drop for StreamDownload<P> {
541    fn drop(&mut self) {
542        if self.cancel_on_drop {
543            self.cancel_download();
544        }
545    }
546}
547
548impl<P: StorageProvider> Read for StreamDownload<P> {
549    #[instrument(skip_all, fields(len=buf.len()))]
550    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
551        self.check_for_failure()?;
552        self.check_for_excessive_read(buf.len())?;
553
554        trace!(buffer_length = buf.len(), "read requested");
555        let stream_position = self.output_reader.stream_position()?;
556        let requested_position =
557            self.normalize_requested_position(stream_position + buf.len() as u64);
558        trace!(
559            current_position = stream_position,
560            requested_position = requested_position
561        );
562
563        if let Some(closest_set) = self.handle.get_downloaded_at_position(stream_position) {
564            trace!(
565                downloaded_range = format!("{closest_set:?}"),
566                "current position already downloaded"
567            );
568            if closest_set.end >= requested_position {
569                trace!("requested position already downloaded");
570                return self.handle_read(buf);
571            }
572            debug!("requested position not yet downloaded");
573        } else {
574            debug!("stream position not yet downloaded");
575        }
576
577        self.handle.wait_for_position(requested_position);
578        self.check_for_failure()?;
579        debug!(
580            current_position = stream_position,
581            requested_position = requested_position,
582            output_stream_position = self.output_reader.stream_position()?,
583            "reached requested position"
584        );
585
586        self.handle_read(buf)
587    }
588}
589
590impl<P: StorageProvider> Seek for StreamDownload<P> {
591    #[instrument(skip(self))]
592    fn seek(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
593        self.check_for_failure()?;
594
595        let absolute_seek_position = self.get_absolute_seek_position(relative_position)?;
596        let absolute_seek_position = self.normalize_requested_position(absolute_seek_position);
597        self.check_for_excessive_seek(absolute_seek_position)?;
598
599        debug!(absolute_seek_position, "absolute seek position");
600        if let Some(closest_set) = self
601            .handle
602            .get_downloaded_at_position(absolute_seek_position)
603        {
604            debug!(
605                downloaded_range = format!("{closest_set:?}"),
606                "seek position already downloaded"
607            );
608            return self
609                .output_reader
610                .seek(SeekFrom::Start(absolute_seek_position))
611                .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"));
612        }
613
614        self.handle.seek(absolute_seek_position);
615        self.check_for_failure()?;
616        debug!("reached seek position");
617
618        self.output_reader
619            .seek(SeekFrom::Start(absolute_seek_position))
620            .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"))
621    }
622}
623
624pub(crate) trait WrapIoResult {
625    fn wrap_err(self, msg: &str) -> Self;
626}
627
628impl<T> WrapIoResult for io::Result<T> {
629    fn wrap_err(self, msg: &str) -> Self {
630        if let Err(e) = self {
631            Err(io::Error::new(e.kind(), format!("{msg}: {e}")))
632        } else {
633            self
634        }
635    }
636}