stream_download/
lib.rs

1#![deny(missing_docs)]
2#![forbid(clippy::unwrap_used)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4#![doc = include_str!("../README.md")]
5
6use std::fmt::Debug;
7use std::future::{self, Future};
8use std::io::{self, Read, Seek, SeekFrom};
9
10use educe::Educe;
11pub use settings::*;
12use source::handle::SourceHandle;
13use source::{DecodeError, Source, SourceStream};
14use storage::StorageProvider;
15use tokio_util::sync::CancellationToken;
16use tracing::{debug, error, instrument, trace};
17
18#[cfg(feature = "async-read")]
19pub mod async_read;
20#[cfg(feature = "http")]
21pub mod http;
22#[cfg(feature = "process")]
23pub mod process;
24#[cfg(feature = "registry")]
25pub mod registry;
26mod settings;
27pub mod source;
28pub mod storage;
29
30/// A handle that can be usd to interact with the stream remotely.
31#[derive(Debug, Clone)]
32pub struct StreamHandle {
33    finished: CancellationToken,
34}
35
36impl StreamHandle {
37    /// Wait for the stream download task to complete.
38    ///
39    /// This method can be useful when using a [`ProcessStream`][process::ProcessStream] if you want
40    /// to ensure the subprocess has exited cleanly before continuing.
41    pub async fn wait_for_completion(self) {
42        self.finished.cancelled().await;
43    }
44}
45
46/// Represents content streamed from a remote source.
47/// This struct implements [read](https://doc.rust-lang.org/stable/std/io/trait.Read.html)
48/// and [seek](https://doc.rust-lang.org/stable/std/io/trait.Seek.html)
49/// so it can be used as a generic source for libraries and applications that operate on these
50/// traits. On creation, an async task is spawned that will immediately start to download the remote
51/// content.
52///
53/// Any read attempts that request part of the stream that hasn't been downloaded yet will block
54/// until the requested portion is reached. Any seek attempts that meet the same criteria will
55/// result in additional request to restart the stream download from the seek point.
56///
57/// If the stream download hasn't completed when this struct is dropped, the task will be cancelled.
58///
59/// If the stream stalls for any reason, the download task will attempt to automatically reconnect.
60/// This reconnect interval can be controlled via [`Settings::retry_timeout`].
61/// Server-side failures are not automatically handled and should be retried by the supplied
62/// [`SourceStream`] implementation if desired.
63#[derive(Debug)]
64pub struct StreamDownload<P: StorageProvider> {
65    output_reader: P::Reader,
66    handle: SourceHandle,
67    download_task_cancellation_token: CancellationToken,
68    cancel_on_drop: bool,
69    content_length: Option<u64>,
70}
71
72impl<P: StorageProvider> StreamDownload<P> {
73    #[cfg(feature = "reqwest")]
74    /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
75    ///
76    /// # Example
77    ///
78    /// ```no_run
79    /// use std::error::Error;
80    /// use std::io::{self, Read};
81    /// use std::result::Result;
82    ///
83    /// use stream_download::source::DecodeError;
84    /// use stream_download::storage::temp::TempStorageProvider;
85    /// use stream_download::{Settings, StreamDownload};
86    ///
87    /// #[tokio::main]
88    /// async fn main() -> Result<(), Box<dyn Error>> {
89    ///     let mut reader = match StreamDownload::new_http(
90    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
91    ///         TempStorageProvider::default(),
92    ///         Settings::default(),
93    ///     )
94    ///     .await
95    ///     {
96    ///         Ok(reader) => reader,
97    ///         Err(e) => return Err(e.decode_error().await)?,
98    ///     };
99    ///
100    ///     tokio::task::spawn_blocking(move || {
101    ///         let mut buf = Vec::new();
102    ///         reader.read_to_end(&mut buf)?;
103    ///         Ok::<_, io::Error>(())
104    ///     })
105    ///     .await??;
106    ///     Ok(())
107    /// }
108    /// ```
109    pub async fn new_http(
110        url: ::reqwest::Url,
111        storage_provider: P,
112        settings: Settings<http::HttpStream<::reqwest::Client>>,
113    ) -> Result<Self, StreamInitializationError<http::HttpStream<::reqwest::Client>>> {
114        Self::new(url, storage_provider, settings).await
115    }
116
117    /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
118    /// It uses the [`reqwest_middleware::ClientWithMiddleware`] client instead of the default
119    /// [`reqwest`] client. Any global middleware set by [`Settings::add_default_middleware`] will
120    /// be automatically applied.
121    ///
122    /// # Example
123    ///
124    /// ```no_run
125    /// use std::error::Error;
126    /// use std::io::{self, Read};
127    /// use std::result::Result;
128    ///
129    /// use reqwest_retry::RetryTransientMiddleware;
130    /// use reqwest_retry::policies::ExponentialBackoff;
131    /// use stream_download::source::DecodeError;
132    /// use stream_download::storage::temp::TempStorageProvider;
133    /// use stream_download::{Settings, StreamDownload};
134    ///
135    /// #[tokio::main]
136    /// async fn main() -> Result<(), Box<dyn Error>> {
137    ///     let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
138    ///     Settings::add_default_middleware(RetryTransientMiddleware::new_with_policy(retry_policy));
139    ///
140    ///     let mut reader = match StreamDownload::new_http_with_middleware(
141    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
142    ///         TempStorageProvider::default(),
143    ///         Settings::default(),
144    ///     )
145    ///     .await
146    ///     {
147    ///         Ok(reader) => reader,
148    ///         Err(e) => return Err(e.decode_error().await)?,
149    ///     };
150    ///
151    ///     tokio::task::spawn_blocking(move || {
152    ///         let mut buf = Vec::new();
153    ///         reader.read_to_end(&mut buf)?;
154    ///         Ok::<_, io::Error>(())
155    ///     })
156    ///     .await??;
157    ///     Ok(())
158    /// }
159    /// ```
160    #[cfg(feature = "reqwest-middleware")]
161    pub async fn new_http_with_middleware(
162        url: ::reqwest::Url,
163        storage_provider: P,
164        settings: Settings<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
165    ) -> Result<
166        Self,
167        StreamInitializationError<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
168    > {
169        Self::new(url, storage_provider, settings).await
170    }
171
172    /// Creates a new [`StreamDownload`] that uses an [`AsyncRead`][tokio::io::AsyncRead] resource.
173    ///
174    /// # Example reading from `stdin`
175    ///
176    /// ```no_run
177    /// use std::error::Error;
178    /// use std::io::{self, Read};
179    /// use std::result::Result;
180    ///
181    /// use stream_download::async_read::AsyncReadStreamParams;
182    /// use stream_download::storage::temp::TempStorageProvider;
183    /// use stream_download::{Settings, StreamDownload};
184    ///
185    /// #[tokio::main]
186    /// async fn main() -> Result<(), Box<dyn Error>> {
187    ///     let mut reader = StreamDownload::new_async_read(
188    ///         AsyncReadStreamParams::new(tokio::io::stdin()),
189    ///         TempStorageProvider::new(),
190    ///         Settings::default(),
191    ///     )
192    ///     .await?;
193    ///
194    ///     tokio::task::spawn_blocking(move || {
195    ///         let mut buf = Vec::new();
196    ///         reader.read_to_end(&mut buf)?;
197    ///         Ok::<_, io::Error>(())
198    ///     })
199    ///     .await??;
200    ///     Ok(())
201    /// }
202    /// ```
203    #[cfg(feature = "async-read")]
204    pub async fn new_async_read<T>(
205        params: async_read::AsyncReadStreamParams<T>,
206        storage_provider: P,
207        settings: Settings<async_read::AsyncReadStream<T>>,
208    ) -> Result<Self, StreamInitializationError<async_read::AsyncReadStream<T>>>
209    where
210        T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static,
211    {
212        Self::new(params, storage_provider, settings).await
213    }
214
215    /// Creates a new [`StreamDownload`] that uses a [`Command`][process::Command] as input.
216    ///
217    /// # Example
218    ///
219    /// ```no_run
220    /// use std::error::Error;
221    /// use std::io::{self, Read};
222    /// use std::result::Result;
223    ///
224    /// use stream_download::process::{Command, ProcessStreamParams};
225    /// use stream_download::storage::temp::TempStorageProvider;
226    /// use stream_download::{Settings, StreamDownload};
227    ///
228    /// #[tokio::main]
229    /// async fn main() -> Result<(), Box<dyn Error>> {
230    ///     let mut reader = StreamDownload::new_process(
231    ///         ProcessStreamParams::new(Command::new("cat").args(["./assets/music.mp3"]))?,
232    ///         TempStorageProvider::new(),
233    ///         Settings::default(),
234    ///     )
235    ///     .await?;
236    ///
237    ///     tokio::task::spawn_blocking(move || {
238    ///         let mut buf = Vec::new();
239    ///         reader.read_to_end(&mut buf)?;
240    ///         Ok::<_, io::Error>(())
241    ///     })
242    ///     .await??;
243    ///     Ok(())
244    /// }
245    /// ```
246    #[cfg(feature = "process")]
247    pub async fn new_process(
248        params: process::ProcessStreamParams,
249        storage_provider: P,
250        settings: Settings<process::ProcessStream>,
251    ) -> Result<Self, StreamInitializationError<process::ProcessStream>> {
252        Self::new(params, storage_provider, settings).await
253    }
254
255    /// Creates a new [`StreamDownload`] that accesses a remote resource at the given URL.
256    ///
257    /// # Example
258    ///
259    /// ```no_run
260    /// use std::error::Error;
261    /// use std::io::{self, Read};
262    /// use std::result::Result;
263    ///
264    /// use reqwest::Client;
265    /// use stream_download::http::HttpStream;
266    /// use stream_download::storage::temp::TempStorageProvider;
267    /// use stream_download::{Settings, StreamDownload};
268    ///
269    /// use crate::stream_download::source::DecodeError;
270    ///
271    /// #[tokio::main]
272    /// async fn main() -> Result<(), Box<dyn Error>> {
273    ///     let mut reader = match StreamDownload::new::<HttpStream<Client>>(
274    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
275    ///         TempStorageProvider::default(),
276    ///         Settings::default(),
277    ///     )
278    ///     .await
279    ///     {
280    ///         Ok(reader) => reader,
281    ///         Err(e) => return Err(e.decode_error().await)?,
282    ///     };
283    ///
284    ///     tokio::task::spawn_blocking(move || {
285    ///         let mut buf = Vec::new();
286    ///         reader.read_to_end(&mut buf)?;
287    ///         Ok::<_, io::Error>(())
288    ///     })
289    ///     .await??;
290    ///     Ok(())
291    /// }
292    /// ```
293    pub async fn new<S>(
294        params: S::Params,
295        storage_provider: P,
296        settings: Settings<S>,
297    ) -> Result<Self, StreamInitializationError<S>>
298    where
299        S: SourceStream,
300        S::Error: Debug + Send,
301    {
302        Self::from_create_stream(move || S::create(params), storage_provider, settings).await
303    }
304
305    /// Creates a new [`StreamDownload`] from a [`SourceStream`].
306    ///
307    /// # Example
308    ///
309    /// ```no_run
310    /// use std::error::Error;
311    /// use std::io::Read;
312    /// use std::result::Result;
313    ///
314    /// use reqwest::Client;
315    /// use stream_download::http::HttpStream;
316    /// use stream_download::storage::temp::TempStorageProvider;
317    /// use stream_download::{Settings, StreamDownload};
318    ///
319    /// use crate::stream_download::source::DecodeError;
320    ///
321    /// #[tokio::main]
322    /// async fn main() -> Result<(), Box<dyn Error>> {
323    ///     let stream = HttpStream::new(
324    ///         Client::new(),
325    ///         "https://some-cool-url.com/some-file.mp3".parse()?,
326    ///     )
327    ///     .await?;
328    ///
329    ///     let mut reader = match StreamDownload::from_stream(
330    ///         stream,
331    ///         TempStorageProvider::default(),
332    ///         Settings::default(),
333    ///     )
334    ///     .await
335    ///     {
336    ///         Ok(reader) => reader,
337    ///         Err(e) => Err(e.decode_error().await)?,
338    ///     };
339    ///     Ok(())
340    /// }
341    /// ```
342    pub async fn from_stream<S>(
343        stream: S,
344        storage_provider: P,
345        settings: Settings<S>,
346    ) -> Result<Self, StreamInitializationError<S>>
347    where
348        S: SourceStream,
349        S::Error: Debug + Send,
350    {
351        Self::from_create_stream(
352            move || future::ready(Ok(stream)),
353            storage_provider,
354            settings,
355        )
356        .await
357    }
358
359    /// Cancels the background task that's downloading the stream content.
360    /// This has no effect if the download is already completed.
361    pub fn cancel_download(&self) {
362        self.download_task_cancellation_token.cancel();
363    }
364
365    /// Returns the [`CancellationToken`] for the download task.
366    /// This can be used to cancel the download task before it completes.
367    pub fn cancellation_token(&self) -> CancellationToken {
368        self.download_task_cancellation_token.clone()
369    }
370
371    /// Returns a [`StreamHandle`] that can be used to interact with
372    /// the stream remotely.
373    pub fn handle(&self) -> StreamHandle {
374        StreamHandle {
375            finished: self.download_task_cancellation_token.clone(),
376        }
377    }
378
379    /// Returns the content length of the stream, if available.
380    pub fn content_length(&self) -> Option<u64> {
381        self.content_length
382    }
383
384    async fn from_create_stream<S, F, Fut>(
385        create_stream: F,
386        storage_provider: P,
387        settings: Settings<S>,
388    ) -> Result<Self, StreamInitializationError<S>>
389    where
390        S: SourceStream<Error: Debug + Send>,
391        F: FnOnce() -> Fut + Send + 'static,
392        Fut: Future<Output = Result<S, S::StreamCreationError>> + Send,
393    {
394        let stream = create_stream()
395            .await
396            .map_err(StreamInitializationError::StreamCreationFailure)?;
397        let content_length = stream.content_length();
398        let (reader, writer) = storage_provider
399            .into_reader_writer(content_length)
400            .map_err(StreamInitializationError::StorageCreationFailure)?;
401        let cancellation_token = CancellationToken::new();
402        let cancel_on_drop = settings.cancel_on_drop;
403        let mut source = Source::new(writer, content_length, settings, cancellation_token.clone());
404        let handle = source.source_handle();
405
406        tokio::spawn({
407            let cancellation_token = cancellation_token.clone();
408            async move {
409                source.download(stream).await;
410                cancellation_token.cancel();
411                debug!("download task finished");
412            }
413        });
414
415        Ok(Self {
416            output_reader: reader,
417            handle,
418            download_task_cancellation_token: cancellation_token,
419            cancel_on_drop,
420            content_length,
421        })
422    }
423
424    fn get_absolute_seek_position(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
425        Ok(match relative_position {
426            SeekFrom::Start(position) => {
427                debug!(seek_position = position, "seeking from start");
428                position
429            }
430            SeekFrom::Current(position) => {
431                debug!(seek_position = position, "seeking from current position");
432                (self.output_reader.stream_position()? as i64 + position) as u64
433            }
434            SeekFrom::End(position) => {
435                debug!(seek_position = position, "seeking from end");
436                if let Some(length) = self.handle.content_length() {
437                    (length as i64 + position) as u64
438                } else {
439                    return Err(io::Error::new(
440                        io::ErrorKind::Unsupported,
441                        "cannot seek from end when content length is unknown",
442                    ));
443                }
444            }
445        })
446    }
447
448    fn handle_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
449        let res = self.output_reader.read(buf).inspect(|l| {
450            trace!(read_length = format!("{l:?}"), "returning read");
451        });
452        self.handle.notify_read();
453        res
454    }
455
456    fn normalize_requested_position(&self, requested_position: u64) -> u64 {
457        if let Some(content_length) = self.content_length {
458            // ensure we don't request a position beyond the end of the stream
459            requested_position.min(content_length)
460        } else {
461            requested_position
462        }
463    }
464
465    fn check_for_failure(&self) -> io::Result<()> {
466        if self.handle.is_failed() {
467            Err(io::Error::other("stream failed to download"))
468        } else {
469            Ok(())
470        }
471    }
472}
473
474/// Error returned when initializing a stream.
475#[derive(thiserror::Error, Educe)]
476#[educe(Debug)]
477pub enum StreamInitializationError<S: SourceStream> {
478    /// Storage creation failure.
479    #[error("Storage creation failure: {0}")]
480    StorageCreationFailure(io::Error),
481    /// Stream creation failure.
482    #[error("Stream creation failure: {0}")]
483    StreamCreationFailure(<S as SourceStream>::StreamCreationError),
484}
485
486impl<S: SourceStream> DecodeError for StreamInitializationError<S> {
487    async fn decode_error(self) -> String {
488        match self {
489            this @ Self::StorageCreationFailure(_) => this.to_string(),
490            Self::StreamCreationFailure(e) => e.decode_error().await,
491        }
492    }
493}
494
495impl<P: StorageProvider> Drop for StreamDownload<P> {
496    fn drop(&mut self) {
497        if self.cancel_on_drop {
498            self.cancel_download();
499        }
500    }
501}
502
503impl<P: StorageProvider> Read for StreamDownload<P> {
504    #[instrument(skip_all, fields(len=buf.len()))]
505    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
506        self.check_for_failure()?;
507
508        trace!(buffer_length = buf.len(), "read requested");
509        let stream_position = self.output_reader.stream_position()?;
510        let requested_position =
511            self.normalize_requested_position(stream_position + buf.len() as u64);
512        trace!(
513            current_position = stream_position,
514            requested_position = requested_position
515        );
516
517        if let Some(closest_set) = self.handle.get_downloaded_at_position(stream_position) {
518            trace!(
519                downloaded_range = format!("{closest_set:?}"),
520                "current position already downloaded"
521            );
522            if closest_set.end >= requested_position {
523                debug!("requested position already downloaded");
524                return self.handle_read(buf);
525            }
526            debug!("requested position not yet downloaded");
527        } else {
528            debug!("stream position not yet downloaded");
529        }
530
531        self.handle.wait_for_position(requested_position);
532        self.check_for_failure()?;
533        debug!(
534            current_position = stream_position,
535            requested_position = requested_position,
536            output_stream_position = self.output_reader.stream_position()?,
537            "reached requested position"
538        );
539
540        self.handle_read(buf)
541    }
542}
543
544impl<P: StorageProvider> Seek for StreamDownload<P> {
545    #[instrument(skip(self))]
546    fn seek(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
547        self.check_for_failure()?;
548
549        let absolute_seek_position = self.get_absolute_seek_position(relative_position)?;
550        let absolute_seek_position = self.normalize_requested_position(absolute_seek_position);
551
552        debug!(absolute_seek_position, "absolute seek position");
553        if let Some(closest_set) = self
554            .handle
555            .get_downloaded_at_position(absolute_seek_position)
556        {
557            debug!(
558                downloaded_range = format!("{closest_set:?}"),
559                "seek position already downloaded"
560            );
561            return self
562                .output_reader
563                .seek(SeekFrom::Start(absolute_seek_position))
564                .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"));
565        }
566
567        self.handle.seek(absolute_seek_position);
568        self.check_for_failure()?;
569        debug!("reached seek position");
570
571        self.output_reader
572            .seek(SeekFrom::Start(absolute_seek_position))
573            .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"))
574    }
575}
576
577pub(crate) trait WrapIoResult {
578    fn wrap_err(self, msg: &str) -> Self;
579}
580
581impl<T> WrapIoResult for io::Result<T> {
582    fn wrap_err(self, msg: &str) -> Self {
583        if let Err(e) = self {
584            Err(io::Error::new(e.kind(), format!("{msg}: {e}")))
585        } else {
586            self
587        }
588    }
589}