stream_download/lib.rs
1#![deny(missing_docs)]
2#![forbid(clippy::unwrap_used)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4#![doc = include_str!("../README.md")]
5
6use std::fmt::Debug;
7use std::future::{self, Future};
8use std::io::{self, Read, Seek, SeekFrom};
9
10use educe::Educe;
11pub use settings::*;
12use source::handle::SourceHandle;
13use source::{DecodeError, Source, SourceStream};
14use storage::StorageProvider;
15use tokio_util::sync::CancellationToken;
16use tracing::{debug, error, instrument, trace};
17
18#[cfg(feature = "async-read")]
19pub mod async_read;
20#[cfg(feature = "http")]
21pub mod http;
22#[cfg(feature = "process")]
23pub mod process;
24#[cfg(feature = "registry")]
25pub mod registry;
26mod settings;
27pub mod source;
28pub mod storage;
29
30/// A handle that can be usd to interact with the stream remotely.
31#[derive(Debug, Clone)]
32pub struct StreamHandle {
33 finished: CancellationToken,
34}
35
36impl StreamHandle {
37 /// Wait for the stream download task to complete.
38 ///
39 /// This method can be useful when using a [`ProcessStream`][process::ProcessStream] if you want
40 /// to ensure the subprocess has exited cleanly before continuing.
41 pub async fn wait_for_completion(self) {
42 self.finished.cancelled().await;
43 }
44}
45
46/// Represents content streamed from a remote source.
47/// This struct implements [read](https://doc.rust-lang.org/stable/std/io/trait.Read.html)
48/// and [seek](https://doc.rust-lang.org/stable/std/io/trait.Seek.html)
49/// so it can be used as a generic source for libraries and applications that operate on these
50/// traits. On creation, an async task is spawned that will immediately start to download the remote
51/// content.
52///
53/// Any read attempts that request part of the stream that hasn't been downloaded yet will block
54/// until the requested portion is reached. Any seek attempts that meet the same criteria will
55/// result in additional request to restart the stream download from the seek point.
56///
57/// If the stream download hasn't completed when this struct is dropped, the task will be cancelled.
58///
59/// If the stream stalls for any reason, the download task will attempt to automatically reconnect.
60/// This reconnect interval can be controlled via [`Settings::retry_timeout`].
61/// Server-side failures are not automatically handled and should be retried by the supplied
62/// [`SourceStream`] implementation if desired.
63#[derive(Debug)]
64pub struct StreamDownload<P: StorageProvider> {
65 output_reader: P::Reader,
66 handle: SourceHandle,
67 download_task_cancellation_token: CancellationToken,
68 cancel_on_drop: bool,
69 content_length: Option<u64>,
70}
71
72impl<P: StorageProvider> StreamDownload<P> {
73 #[cfg(feature = "reqwest")]
74 /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
75 ///
76 /// # Example
77 ///
78 /// ```no_run
79 /// use std::error::Error;
80 /// use std::io::{self, Read};
81 /// use std::result::Result;
82 ///
83 /// use stream_download::source::DecodeError;
84 /// use stream_download::storage::temp::TempStorageProvider;
85 /// use stream_download::{Settings, StreamDownload};
86 ///
87 /// #[tokio::main]
88 /// async fn main() -> Result<(), Box<dyn Error>> {
89 /// let mut reader = match StreamDownload::new_http(
90 /// "https://some-cool-url.com/some-file.mp3".parse()?,
91 /// TempStorageProvider::default(),
92 /// Settings::default(),
93 /// )
94 /// .await
95 /// {
96 /// Ok(reader) => reader,
97 /// Err(e) => return Err(e.decode_error().await)?,
98 /// };
99 ///
100 /// tokio::task::spawn_blocking(move || {
101 /// let mut buf = Vec::new();
102 /// reader.read_to_end(&mut buf)?;
103 /// Ok::<_, io::Error>(())
104 /// })
105 /// .await??;
106 /// Ok(())
107 /// }
108 /// ```
109 pub async fn new_http(
110 url: ::reqwest::Url,
111 storage_provider: P,
112 settings: Settings<http::HttpStream<::reqwest::Client>>,
113 ) -> Result<Self, StreamInitializationError<http::HttpStream<::reqwest::Client>>> {
114 Self::new(url, storage_provider, settings).await
115 }
116
117 /// Creates a new [`StreamDownload`] that accesses an HTTP resource at the given URL.
118 /// It uses the [`reqwest_middleware::ClientWithMiddleware`] client instead of the default
119 /// [`reqwest`] client. Any global middleware set by [`Settings::add_default_middleware`] will
120 /// be automatically applied.
121 ///
122 /// # Example
123 ///
124 /// ```no_run
125 /// use std::error::Error;
126 /// use std::io::{self, Read};
127 /// use std::result::Result;
128 ///
129 /// use reqwest_retry::RetryTransientMiddleware;
130 /// use reqwest_retry::policies::ExponentialBackoff;
131 /// use stream_download::source::DecodeError;
132 /// use stream_download::storage::temp::TempStorageProvider;
133 /// use stream_download::{Settings, StreamDownload};
134 ///
135 /// #[tokio::main]
136 /// async fn main() -> Result<(), Box<dyn Error>> {
137 /// let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
138 /// Settings::add_default_middleware(RetryTransientMiddleware::new_with_policy(retry_policy));
139 ///
140 /// let mut reader = match StreamDownload::new_http_with_middleware(
141 /// "https://some-cool-url.com/some-file.mp3".parse()?,
142 /// TempStorageProvider::default(),
143 /// Settings::default(),
144 /// )
145 /// .await
146 /// {
147 /// Ok(reader) => reader,
148 /// Err(e) => return Err(e.decode_error().await)?,
149 /// };
150 ///
151 /// tokio::task::spawn_blocking(move || {
152 /// let mut buf = Vec::new();
153 /// reader.read_to_end(&mut buf)?;
154 /// Ok::<_, io::Error>(())
155 /// })
156 /// .await??;
157 /// Ok(())
158 /// }
159 /// ```
160 #[cfg(feature = "reqwest-middleware")]
161 pub async fn new_http_with_middleware(
162 url: ::reqwest::Url,
163 storage_provider: P,
164 settings: Settings<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
165 ) -> Result<
166 Self,
167 StreamInitializationError<http::HttpStream<::reqwest_middleware::ClientWithMiddleware>>,
168 > {
169 Self::new(url, storage_provider, settings).await
170 }
171
172 /// Creates a new [`StreamDownload`] that uses an [`AsyncRead`][tokio::io::AsyncRead] resource.
173 ///
174 /// # Example reading from `stdin`
175 ///
176 /// ```no_run
177 /// use std::error::Error;
178 /// use std::io::{self, Read};
179 /// use std::result::Result;
180 ///
181 /// use stream_download::async_read::AsyncReadStreamParams;
182 /// use stream_download::storage::temp::TempStorageProvider;
183 /// use stream_download::{Settings, StreamDownload};
184 ///
185 /// #[tokio::main]
186 /// async fn main() -> Result<(), Box<dyn Error>> {
187 /// let mut reader = StreamDownload::new_async_read(
188 /// AsyncReadStreamParams::new(tokio::io::stdin()),
189 /// TempStorageProvider::new(),
190 /// Settings::default(),
191 /// )
192 /// .await?;
193 ///
194 /// tokio::task::spawn_blocking(move || {
195 /// let mut buf = Vec::new();
196 /// reader.read_to_end(&mut buf)?;
197 /// Ok::<_, io::Error>(())
198 /// })
199 /// .await??;
200 /// Ok(())
201 /// }
202 /// ```
203 #[cfg(feature = "async-read")]
204 pub async fn new_async_read<T>(
205 params: async_read::AsyncReadStreamParams<T>,
206 storage_provider: P,
207 settings: Settings<async_read::AsyncReadStream<T>>,
208 ) -> Result<Self, StreamInitializationError<async_read::AsyncReadStream<T>>>
209 where
210 T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static,
211 {
212 Self::new(params, storage_provider, settings).await
213 }
214
215 /// Creates a new [`StreamDownload`] that uses a [`Command`][process::Command] as input.
216 ///
217 /// # Example
218 ///
219 /// ```no_run
220 /// use std::error::Error;
221 /// use std::io::{self, Read};
222 /// use std::result::Result;
223 ///
224 /// use stream_download::process::{Command, ProcessStreamParams};
225 /// use stream_download::storage::temp::TempStorageProvider;
226 /// use stream_download::{Settings, StreamDownload};
227 ///
228 /// #[tokio::main]
229 /// async fn main() -> Result<(), Box<dyn Error>> {
230 /// let mut reader = StreamDownload::new_process(
231 /// ProcessStreamParams::new(Command::new("cat").args(["./assets/music.mp3"]))?,
232 /// TempStorageProvider::new(),
233 /// Settings::default(),
234 /// )
235 /// .await?;
236 ///
237 /// tokio::task::spawn_blocking(move || {
238 /// let mut buf = Vec::new();
239 /// reader.read_to_end(&mut buf)?;
240 /// Ok::<_, io::Error>(())
241 /// })
242 /// .await??;
243 /// Ok(())
244 /// }
245 /// ```
246 #[cfg(feature = "process")]
247 pub async fn new_process(
248 params: process::ProcessStreamParams,
249 storage_provider: P,
250 settings: Settings<process::ProcessStream>,
251 ) -> Result<Self, StreamInitializationError<process::ProcessStream>> {
252 Self::new(params, storage_provider, settings).await
253 }
254
255 /// Creates a new [`StreamDownload`] that accesses a remote resource at the given URL.
256 ///
257 /// # Example
258 ///
259 /// ```no_run
260 /// use std::error::Error;
261 /// use std::io::{self, Read};
262 /// use std::result::Result;
263 ///
264 /// use reqwest::Client;
265 /// use stream_download::http::HttpStream;
266 /// use stream_download::storage::temp::TempStorageProvider;
267 /// use stream_download::{Settings, StreamDownload};
268 ///
269 /// use crate::stream_download::source::DecodeError;
270 ///
271 /// #[tokio::main]
272 /// async fn main() -> Result<(), Box<dyn Error>> {
273 /// let mut reader = match StreamDownload::new::<HttpStream<Client>>(
274 /// "https://some-cool-url.com/some-file.mp3".parse()?,
275 /// TempStorageProvider::default(),
276 /// Settings::default(),
277 /// )
278 /// .await
279 /// {
280 /// Ok(reader) => reader,
281 /// Err(e) => return Err(e.decode_error().await)?,
282 /// };
283 ///
284 /// tokio::task::spawn_blocking(move || {
285 /// let mut buf = Vec::new();
286 /// reader.read_to_end(&mut buf)?;
287 /// Ok::<_, io::Error>(())
288 /// })
289 /// .await??;
290 /// Ok(())
291 /// }
292 /// ```
293 pub async fn new<S>(
294 params: S::Params,
295 storage_provider: P,
296 settings: Settings<S>,
297 ) -> Result<Self, StreamInitializationError<S>>
298 where
299 S: SourceStream,
300 S::Error: Debug + Send,
301 {
302 Self::from_create_stream(move || S::create(params), storage_provider, settings).await
303 }
304
305 /// Creates a new [`StreamDownload`] from a [`SourceStream`].
306 ///
307 /// # Example
308 ///
309 /// ```no_run
310 /// use std::error::Error;
311 /// use std::io::Read;
312 /// use std::result::Result;
313 ///
314 /// use reqwest::Client;
315 /// use stream_download::http::HttpStream;
316 /// use stream_download::storage::temp::TempStorageProvider;
317 /// use stream_download::{Settings, StreamDownload};
318 ///
319 /// use crate::stream_download::source::DecodeError;
320 ///
321 /// #[tokio::main]
322 /// async fn main() -> Result<(), Box<dyn Error>> {
323 /// let stream = HttpStream::new(
324 /// Client::new(),
325 /// "https://some-cool-url.com/some-file.mp3".parse()?,
326 /// )
327 /// .await?;
328 ///
329 /// let mut reader = match StreamDownload::from_stream(
330 /// stream,
331 /// TempStorageProvider::default(),
332 /// Settings::default(),
333 /// )
334 /// .await
335 /// {
336 /// Ok(reader) => reader,
337 /// Err(e) => Err(e.decode_error().await)?,
338 /// };
339 /// Ok(())
340 /// }
341 /// ```
342 pub async fn from_stream<S>(
343 stream: S,
344 storage_provider: P,
345 settings: Settings<S>,
346 ) -> Result<Self, StreamInitializationError<S>>
347 where
348 S: SourceStream,
349 S::Error: Debug + Send,
350 {
351 Self::from_create_stream(
352 move || future::ready(Ok(stream)),
353 storage_provider,
354 settings,
355 )
356 .await
357 }
358
359 /// Cancels the background task that's downloading the stream content.
360 /// This has no effect if the download is already completed.
361 pub fn cancel_download(&self) {
362 self.download_task_cancellation_token.cancel();
363 }
364
365 /// Returns the [`CancellationToken`] for the download task.
366 /// This can be used to cancel the download task before it completes.
367 pub fn cancellation_token(&self) -> CancellationToken {
368 self.download_task_cancellation_token.clone()
369 }
370
371 /// Returns a [`StreamHandle`] that can be used to interact with
372 /// the stream remotely.
373 pub fn handle(&self) -> StreamHandle {
374 StreamHandle {
375 finished: self.download_task_cancellation_token.clone(),
376 }
377 }
378
379 /// Returns the content length of the stream, if available.
380 pub fn content_length(&self) -> Option<u64> {
381 self.content_length
382 }
383
384 async fn from_create_stream<S, F, Fut>(
385 create_stream: F,
386 storage_provider: P,
387 settings: Settings<S>,
388 ) -> Result<Self, StreamInitializationError<S>>
389 where
390 S: SourceStream<Error: Debug + Send>,
391 F: FnOnce() -> Fut + Send + 'static,
392 Fut: Future<Output = Result<S, S::StreamCreationError>> + Send,
393 {
394 let stream = create_stream()
395 .await
396 .map_err(StreamInitializationError::StreamCreationFailure)?;
397 let content_length = stream.content_length();
398 let (reader, writer) = storage_provider
399 .into_reader_writer(content_length)
400 .map_err(StreamInitializationError::StorageCreationFailure)?;
401 let cancellation_token = CancellationToken::new();
402 let cancel_on_drop = settings.cancel_on_drop;
403 let mut source = Source::new(writer, content_length, settings, cancellation_token.clone());
404 let handle = source.source_handle();
405
406 tokio::spawn({
407 let cancellation_token = cancellation_token.clone();
408 async move {
409 source.download(stream).await;
410 cancellation_token.cancel();
411 debug!("download task finished");
412 }
413 });
414
415 Ok(Self {
416 output_reader: reader,
417 handle,
418 download_task_cancellation_token: cancellation_token,
419 cancel_on_drop,
420 content_length,
421 })
422 }
423
424 fn get_absolute_seek_position(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
425 Ok(match relative_position {
426 SeekFrom::Start(position) => {
427 debug!(seek_position = position, "seeking from start");
428 position
429 }
430 SeekFrom::Current(position) => {
431 debug!(seek_position = position, "seeking from current position");
432 (self.output_reader.stream_position()? as i64 + position) as u64
433 }
434 SeekFrom::End(position) => {
435 debug!(seek_position = position, "seeking from end");
436 if let Some(length) = self.handle.content_length() {
437 (length as i64 + position) as u64
438 } else {
439 return Err(io::Error::new(
440 io::ErrorKind::Unsupported,
441 "cannot seek from end when content length is unknown",
442 ));
443 }
444 }
445 })
446 }
447
448 fn handle_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
449 let res = self.output_reader.read(buf).inspect(|l| {
450 trace!(read_length = format!("{l:?}"), "returning read");
451 });
452 self.handle.notify_read();
453 res
454 }
455
456 fn normalize_requested_position(&self, requested_position: u64) -> u64 {
457 if let Some(content_length) = self.content_length {
458 // ensure we don't request a position beyond the end of the stream
459 requested_position.min(content_length)
460 } else {
461 requested_position
462 }
463 }
464
465 fn check_for_failure(&self) -> io::Result<()> {
466 if self.handle.is_failed() {
467 Err(io::Error::other("stream failed to download"))
468 } else {
469 Ok(())
470 }
471 }
472}
473
474/// Error returned when initializing a stream.
475#[derive(thiserror::Error, Educe)]
476#[educe(Debug)]
477pub enum StreamInitializationError<S: SourceStream> {
478 /// Storage creation failure.
479 #[error("Storage creation failure: {0}")]
480 StorageCreationFailure(io::Error),
481 /// Stream creation failure.
482 #[error("Stream creation failure: {0}")]
483 StreamCreationFailure(<S as SourceStream>::StreamCreationError),
484}
485
486impl<S: SourceStream> DecodeError for StreamInitializationError<S> {
487 async fn decode_error(self) -> String {
488 match self {
489 this @ Self::StorageCreationFailure(_) => this.to_string(),
490 Self::StreamCreationFailure(e) => e.decode_error().await,
491 }
492 }
493}
494
495impl<P: StorageProvider> Drop for StreamDownload<P> {
496 fn drop(&mut self) {
497 if self.cancel_on_drop {
498 self.cancel_download();
499 }
500 }
501}
502
503impl<P: StorageProvider> Read for StreamDownload<P> {
504 #[instrument(skip_all, fields(len=buf.len()))]
505 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
506 self.check_for_failure()?;
507
508 trace!(buffer_length = buf.len(), "read requested");
509 let stream_position = self.output_reader.stream_position()?;
510 let requested_position =
511 self.normalize_requested_position(stream_position + buf.len() as u64);
512 trace!(
513 current_position = stream_position,
514 requested_position = requested_position
515 );
516
517 if let Some(closest_set) = self.handle.get_downloaded_at_position(stream_position) {
518 trace!(
519 downloaded_range = format!("{closest_set:?}"),
520 "current position already downloaded"
521 );
522 if closest_set.end >= requested_position {
523 debug!("requested position already downloaded");
524 return self.handle_read(buf);
525 }
526 debug!("requested position not yet downloaded");
527 } else {
528 debug!("stream position not yet downloaded");
529 }
530
531 self.handle.wait_for_position(requested_position);
532 self.check_for_failure()?;
533 debug!(
534 current_position = stream_position,
535 requested_position = requested_position,
536 output_stream_position = self.output_reader.stream_position()?,
537 "reached requested position"
538 );
539
540 self.handle_read(buf)
541 }
542}
543
544impl<P: StorageProvider> Seek for StreamDownload<P> {
545 #[instrument(skip(self))]
546 fn seek(&mut self, relative_position: SeekFrom) -> io::Result<u64> {
547 self.check_for_failure()?;
548
549 let absolute_seek_position = self.get_absolute_seek_position(relative_position)?;
550 let absolute_seek_position = self.normalize_requested_position(absolute_seek_position);
551
552 debug!(absolute_seek_position, "absolute seek position");
553 if let Some(closest_set) = self
554 .handle
555 .get_downloaded_at_position(absolute_seek_position)
556 {
557 debug!(
558 downloaded_range = format!("{closest_set:?}"),
559 "seek position already downloaded"
560 );
561 return self
562 .output_reader
563 .seek(SeekFrom::Start(absolute_seek_position))
564 .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"));
565 }
566
567 self.handle.seek(absolute_seek_position);
568 self.check_for_failure()?;
569 debug!("reached seek position");
570
571 self.output_reader
572 .seek(SeekFrom::Start(absolute_seek_position))
573 .inspect_err(|p| debug!(position = format!("{p:?}"), "returning seek position"))
574 }
575}
576
577pub(crate) trait WrapIoResult {
578 fn wrap_err(self, msg: &str) -> Self;
579}
580
581impl<T> WrapIoResult for io::Result<T> {
582 fn wrap_err(self, msg: &str) -> Self {
583 if let Err(e) = self {
584 Err(io::Error::new(e.kind(), format!("{msg}: {e}")))
585 } else {
586 self
587 }
588 }
589}