stream_download/source/
mod.rs

1//! Provides the [`SourceStream`] trait which abstracts over the transport used to
2//! stream remote content.
3
4use std::convert::Infallible;
5use std::error::Error;
6use std::fmt::Debug;
7use std::future;
8use std::io::{self, SeekFrom};
9use std::time::{Duration, Instant};
10
11use bytes::{BufMut, Bytes, BytesMut};
12use futures_util::{Future, Stream, StreamExt, TryStream};
13use handle::{
14    DownloadStatus, Downloaded, NotifyRead, PositionReached, RequestedPosition, SourceHandle,
15};
16use tokio::sync::mpsc;
17use tokio::task::yield_now;
18use tokio::time::timeout;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, error, instrument, trace, warn};
21
22use crate::storage::StorageWriter;
23use crate::{ProgressFn, ReconnectFn, Settings, StreamPhase, StreamState};
24
25pub(crate) mod handle;
26
27/// Enum representing the final outcome of the stream.
28#[derive(Clone, Copy, PartialEq, Eq, Debug)]
29pub enum StreamOutcome {
30    /// The stream completed naturally.
31    Completed,
32    /// The stream was cancelled by the user.
33    CancelledByUser,
34}
35
36/// Represents a remote resource that can be streamed over the network. Streaming
37/// over http is implemented via the [`HttpStream`](crate::http::HttpStream)
38/// implementation if the `http` feature is enabled.
39///
40/// The implementation must also implement the
41/// [Stream](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) trait.
42pub trait SourceStream:
43    TryStream<Ok = Bytes>
44    + Stream<Item = Result<Self::Ok, Self::Error>>
45    + Unpin
46    + Send
47    + Sync
48    + Sized
49    + 'static
50{
51    /// Parameters used to create the remote resource.
52    type Params: Send;
53
54    /// Error type thrown when creating the stream.
55    type StreamCreationError: DecodeError + Send;
56
57    /// Creates an instance of the stream.
58    fn create(
59        params: Self::Params,
60    ) -> impl Future<Output = Result<Self, Self::StreamCreationError>> + Send;
61
62    /// Returns the size of the remote resource in bytes. The result should be `None`
63    /// if the stream is infinite or doesn't have a known length.
64    fn content_length(&self) -> Option<u64>;
65
66    /// Seeks to a specific position in the stream. This method is only called if the
67    /// requested range has not been downloaded, so this method should jump to the
68    /// requested position in the stream as quickly as possible.
69    ///
70    /// The start value should be inclusive and the end value should be exclusive.
71    fn seek_range(
72        &mut self,
73        start: u64,
74        end: Option<u64>,
75    ) -> impl Future<Output = io::Result<()>> + Send;
76
77    /// Attempts to reconnect to the server when a failure occurs.
78    fn reconnect(&mut self, current_position: u64) -> impl Future<Output = io::Result<()>> + Send;
79
80    /// Returns whether seeking is supported in the stream.
81    /// If this method returns `false`, [`SourceStream::seek_range`] will never be invoked.
82    fn supports_seek(&self) -> bool;
83
84    /// Called when the stream finishes downloading
85    fn on_finish(
86        &mut self,
87        result: io::Result<()>,
88        #[expect(unused)] outcome: StreamOutcome,
89    ) -> impl Future<Output = io::Result<()>> + Send {
90        future::ready(result)
91    }
92}
93
94/// Trait for decoding extra error information asynchronously.
95pub trait DecodeError: Error + Send + Sized {
96    /// Decodes extra error information.
97    fn decode_error(self) -> impl Future<Output = String> + Send {
98        future::ready(self.to_string())
99    }
100}
101
102impl DecodeError for Infallible {
103    async fn decode_error(self) -> String {
104        // This will never get called since it's infallible
105        String::new()
106    }
107}
108
109#[derive(PartialEq, Eq)]
110enum DownloadAction {
111    Continue,
112    Complete,
113}
114
115pub(crate) struct Source<S: SourceStream, W: StorageWriter> {
116    writer: W,
117    downloaded: Downloaded,
118    download_status: DownloadStatus,
119    requested_position: RequestedPosition,
120    position_reached: PositionReached,
121    notify_read: NotifyRead,
122    content_length: Option<u64>,
123    seek_tx: mpsc::Sender<u64>,
124    seek_rx: mpsc::Receiver<u64>,
125    prefetch_bytes: u64,
126    batch_write_size: usize,
127    retry_timeout: Duration,
128    on_progress: Option<ProgressFn<S>>,
129    on_reconnect: Option<ReconnectFn<S>>,
130    prefetch_complete: bool,
131    prefetch_start_position: u64,
132    remaining_bytes: Option<Bytes>,
133    cancellation_token: CancellationToken,
134}
135
136impl<S, W> Source<S, W>
137where
138    S: SourceStream<Error: Debug>,
139    W: StorageWriter,
140{
141    pub(crate) fn new(
142        writer: W,
143        content_length: Option<u64>,
144        settings: Settings<S>,
145        cancellation_token: CancellationToken,
146    ) -> Self {
147        // buffer size of 1 is fine here because we wait for the position to update after we send
148        // each request
149        let (seek_tx, seek_rx) = mpsc::channel(1);
150        Self {
151            writer,
152            downloaded: Downloaded::default(),
153            download_status: DownloadStatus::default(),
154            requested_position: RequestedPosition::default(),
155            position_reached: PositionReached::default(),
156            notify_read: NotifyRead::default(),
157            seek_tx,
158            seek_rx,
159            content_length,
160            prefetch_complete: settings.prefetch_bytes == 0,
161            prefetch_bytes: settings.prefetch_bytes,
162            batch_write_size: settings.batch_write_size,
163            retry_timeout: settings.retry_timeout,
164            on_progress: settings.on_progress,
165            on_reconnect: settings.on_reconnect,
166            prefetch_start_position: 0,
167            remaining_bytes: None,
168            cancellation_token,
169        }
170    }
171
172    #[instrument(skip_all)]
173    pub(crate) async fn download(&mut self, mut stream: S) {
174        let res = self.download_inner(&mut stream).await;
175        let (res, stream_res) = match res {
176            Ok(StreamOutcome::Completed) => (Ok(()), StreamOutcome::Completed),
177            Ok(StreamOutcome::CancelledByUser) => (
178                Err(io::Error::new(
179                    io::ErrorKind::Interrupted,
180                    "stream cancelled by user",
181                )),
182                StreamOutcome::CancelledByUser,
183            ),
184            Err(e) => (Err(e), StreamOutcome::Completed),
185        };
186        let res = stream.on_finish(res, stream_res).await;
187        if let Err(e) = res {
188            if stream_res == StreamOutcome::Completed {
189                error!("download failed: {e:?}");
190            }
191            self.download_status.set_failed();
192        }
193        self.signal_download_complete();
194    }
195
196    async fn download_inner(&mut self, stream: &mut S) -> io::Result<StreamOutcome> {
197        debug!("starting file download");
198        let download_start = std::time::Instant::now();
199
200        loop {
201            // Some streams may get stuck if the connection has a hiccup while waiting for the next
202            // chunk. Forcing the client to abort and retry may help in these cases.
203            let next_chunk = timeout(self.retry_timeout, stream.next());
204            tokio::select! {
205                position = self.seek_rx.recv() => {
206                    // seek_tx can't be dropped here since we keep a reference in this struct
207                    self.handle_seek(stream, position.expect("seek_tx dropped")).await?;
208                },
209                bytes = next_chunk => {
210                    let Ok(bytes) = bytes else {
211                        self.handle_reconnect(stream).await?;
212                        continue;
213                    };
214                    if self
215                        .handle_bytes(stream, bytes, download_start)
216                        .await?
217                        == DownloadAction::Complete
218                    {
219                        debug!(
220                            download_duration = format!("{:?}", download_start.elapsed()),
221                            "stream finished downloading"
222                        );
223                        break;
224                    }
225                }
226                () = self.cancellation_token.cancelled() => {
227                    debug!("received cancellation request, stopping download task");
228                    return Ok(StreamOutcome::CancelledByUser);
229                }
230            };
231        }
232        self.report_download_complete(stream, download_start)?;
233        Ok(StreamOutcome::Completed)
234    }
235
236    async fn handle_seek(&mut self, stream: &mut S, position: u64) -> io::Result<()> {
237        if self.should_seek(stream, position)? {
238            debug!("seek position not yet downloaded");
239            let current_stream_position = self.writer.stream_position()?;
240            if self.prefetch_complete {
241                debug!("re-starting prefetch");
242                self.prefetch_start_position = position;
243                self.prefetch_complete = false;
244            } else {
245                debug!("seeking during prefetch, ending prefetch early");
246                self.downloaded
247                    .add(self.prefetch_start_position..current_stream_position);
248                self.prefetch_complete = true;
249            }
250            if let Some(content_length) = self.content_length {
251                // Get the minimum possible start position to ensure we capture the entire range
252                let min_start_position = current_stream_position.min(position);
253                debug!(
254                    start = min_start_position,
255                    end = content_length,
256                    "checking for seek range",
257                );
258                let gap = self
259                    .downloaded
260                    .next_gap(min_start_position..content_length)
261                    .expect("already checked for a gap");
262                // Gap start may be too low if we're seeking forward, so check it against the
263                // position
264                let seek_start = gap.start.max(position);
265                debug!(seek_start, seek_end = gap.end, "requesting seek range");
266                self.seek(stream, seek_start, Some(gap.end)).await?;
267            } else {
268                self.seek(stream, position, None).await?;
269            }
270        }
271        Ok(())
272    }
273
274    async fn handle_reconnect(&mut self, stream: &mut S) -> io::Result<()> {
275        warn!("timed out reading next chunk, retrying");
276        let pos = self.writer.stream_position()?;
277        // We already know there's a network issue if we're attempting a reconnect.
278        // A retry policy on the client may cause an exponential backoff to be triggered here, so
279        // we'll cap the reconnect time to prevent additional delays between reconnect attempts.
280        let reconnect_pos = tokio::time::timeout(self.retry_timeout, stream.reconnect(pos)).await;
281        if reconnect_pos
282            .inspect_err(|e| warn!("error attempting to reconnect: {e:?}"))
283            .is_ok()
284        {
285            if let Some(on_reconnect) = &mut self.on_reconnect {
286                on_reconnect(stream, &self.cancellation_token);
287            }
288        }
289        Ok(())
290    }
291
292    async fn handle_prefetch(
293        &mut self,
294        stream: &mut S,
295        bytes: Option<Bytes>,
296        start_position: u64,
297        download_start: Instant,
298    ) -> io::Result<DownloadAction> {
299        let Some(bytes) = bytes else {
300            self.prefetch_complete = true;
301            debug!("file shorter than prefetch length, download finished");
302            self.writer.flush()?;
303            let position = self.writer.stream_position()?;
304            self.downloaded.add(start_position..position);
305
306            return self.finish_or_find_next_gap(stream).await;
307        };
308        let written = self.write_batched(&bytes).await?;
309        self.writer.flush()?;
310
311        let stream_position = self.writer.stream_position()?;
312        let partial_write = written < bytes.len();
313
314        // End prefetch early if we weren't able to write the entire contents
315        if partial_write {
316            debug!(
317                written,
318                bytes_len = bytes.len(),
319                "failed to write all during prefetch"
320            );
321            self.remaining_bytes = Some(bytes.slice(written..));
322        }
323        if (stream_position >= start_position + self.prefetch_bytes) || partial_write {
324            self.downloaded.add(start_position..stream_position);
325            debug!("prefetch complete");
326            self.prefetch_complete = true;
327        }
328
329        self.report_prefetch_progress(stream, stream_position, download_start, written);
330        Ok(DownloadAction::Continue)
331    }
332
333    async fn finish_or_find_next_gap(&mut self, stream: &mut S) -> io::Result<DownloadAction> {
334        if stream.supports_seek() {
335            if let Some(content_length) = self.content_length {
336                let gap = self.downloaded.next_gap(0..content_length);
337                if let Some(gap) = gap {
338                    debug!(
339                        missing = format!("{gap:?}"),
340                        "downloading missing stream chunk"
341                    );
342                    self.seek(stream, gap.start, Some(gap.end)).await?;
343                    return Ok(DownloadAction::Continue);
344                }
345            }
346        }
347        self.writer.flush()?;
348        self.signal_download_complete();
349        Ok(DownloadAction::Complete)
350    }
351
352    async fn write_batched(&mut self, bytes: &[u8]) -> io::Result<usize> {
353        let mut written = 0;
354        loop {
355            let write_size = self.batch_write_size.min(bytes[written..].len());
356            let batch_written = self.writer.write(&bytes[written..written + write_size])?;
357            if batch_written == 0 {
358                return Ok(written);
359            }
360            written += batch_written;
361            // yield between writes to ensure we don't spend too long on writes
362            // without an await point
363            yield_now().await;
364        }
365    }
366
367    async fn handle_bytes(
368        &mut self,
369        stream: &mut S,
370        bytes: Option<Result<Bytes, S::Error>>,
371        download_start: Instant,
372    ) -> io::Result<DownloadAction> {
373        let bytes = match bytes.transpose() {
374            Ok(bytes) => bytes,
375            Err(e) => {
376                error!("Error fetching chunk from stream: {e:?}");
377                return Ok(DownloadAction::Continue);
378            }
379        };
380
381        if !self.prefetch_complete {
382            return self
383                .handle_prefetch(stream, bytes, self.prefetch_start_position, download_start)
384                .await;
385        }
386
387        let bytes = match (self.remaining_bytes.take(), bytes) {
388            (Some(remaining), Some(bytes)) => {
389                let mut combined = BytesMut::new();
390                combined.put(remaining);
391                combined.put(bytes);
392                combined.freeze()
393            }
394            (Some(remaining), None) => remaining,
395            (None, Some(bytes)) => bytes,
396            (None, None) => {
397                return self.finish_or_find_next_gap(stream).await;
398            }
399        };
400        let bytes_len = bytes.len();
401        let new_position = self.write(bytes).await?;
402        self.report_downloading_progress(stream, new_position, download_start, bytes_len)?;
403
404        Ok(DownloadAction::Continue)
405    }
406
407    async fn write(&mut self, bytes: Bytes) -> io::Result<u64> {
408        let mut written = 0;
409        let position = self.writer.stream_position()?;
410        let mut new_position = position;
411        // Keep writing until we process the whole buffer.
412        // If the reader is falling behind, this may take several attempts.
413        while written < bytes.len() {
414            self.notify_read.request();
415            let new_written = self.write_batched(&bytes[written..]).await?;
416            trace!(written, new_written, len = bytes.len(), "wrote data");
417
418            if new_written > 0 {
419                self.writer.flush()?;
420                written += new_written;
421            }
422            new_position = self.writer.stream_position()?;
423            if new_position > position {
424                self.downloaded.add(position..new_position);
425            }
426
427            if let Some(requested) = self.requested_position.get() {
428                debug!(
429                    requested_position = requested,
430                    current_position = new_position,
431                    "received requested position"
432                );
433
434                if new_position >= requested {
435                    debug!("notifying position reached");
436                    self.requested_position.clear();
437                    self.position_reached.notify_position_reached();
438                }
439            }
440            if new_written == 0 {
441                // We're not able to write any data, so we need to wait for space to be available
442                debug!("waiting for next read");
443                self.notify_read.wait_for_read().await;
444                debug!("read finished");
445            }
446
447            trace!(
448                previous_position = position,
449                new_position,
450                chunk_size = bytes.len(),
451                "received response chunk"
452            );
453        }
454        Ok(new_position)
455    }
456
457    fn should_seek(&mut self, stream: &S, position: u64) -> io::Result<bool> {
458        if !stream.supports_seek() {
459            warn!("Attempting to seek, but it's unsupported. Waiting for stream to catch up.");
460            return Ok(false);
461        }
462        Ok(if let Some(range) = self.downloaded.get(position) {
463            !range.contains(&self.writer.stream_position()?)
464        } else {
465            true
466        })
467    }
468
469    async fn seek(&mut self, stream: &mut S, start: u64, end: Option<u64>) -> io::Result<()> {
470        stream.seek_range(start, end).await?;
471        self.writer.seek(SeekFrom::Start(start))?;
472        Ok(())
473    }
474
475    fn signal_download_complete(&self) {
476        self.position_reached.notify_stream_done();
477    }
478
479    fn report_progress(&mut self, stream: &S, info: StreamState) {
480        if let Some(on_progress) = self.on_progress.as_mut() {
481            on_progress(stream, info, &self.cancellation_token);
482        }
483    }
484
485    fn report_prefetch_progress(
486        &mut self,
487        stream: &S,
488        stream_position: u64,
489        download_start: Instant,
490        chunk_size: usize,
491    ) {
492        self.report_progress(
493            stream,
494            StreamState {
495                current_position: stream_position,
496                current_chunk: (0..stream_position),
497                elapsed: download_start.elapsed(),
498                phase: StreamPhase::Prefetching {
499                    target: self.prefetch_bytes,
500                    chunk_size,
501                },
502            },
503        );
504    }
505
506    fn report_downloading_progress(
507        &mut self,
508        stream: &S,
509        new_position: u64,
510        download_start: Instant,
511        chunk_size: usize,
512    ) -> io::Result<()> {
513        let pos = self.writer.stream_position()?;
514        self.report_progress(
515            stream,
516            StreamState {
517                current_position: pos,
518                current_chunk: self
519                    .downloaded
520                    .get(new_position - 1)
521                    .expect("position already downloaded"),
522                elapsed: download_start.elapsed(),
523                phase: StreamPhase::Downloading { chunk_size },
524            },
525        );
526        Ok(())
527    }
528
529    fn report_download_complete(&mut self, stream: &S, download_start: Instant) -> io::Result<()> {
530        let pos = self.writer.stream_position()?;
531        self.report_progress(
532            stream,
533            StreamState {
534                current_position: pos,
535                elapsed: download_start.elapsed(),
536                // ensure no subtraction overflow
537                current_chunk: self.downloaded.get(pos.max(1) - 1).unwrap_or_default(),
538                phase: StreamPhase::Complete,
539            },
540        );
541        Ok(())
542    }
543
544    pub(crate) fn source_handle(&self) -> SourceHandle {
545        SourceHandle {
546            downloaded: self.downloaded.clone(),
547            download_status: self.download_status.clone(),
548            requested_position: self.requested_position.clone(),
549            notify_read: self.notify_read.clone(),
550            position_reached: self.position_reached.clone(),
551            seek_tx: self.seek_tx.clone(),
552            content_length: self.content_length,
553        }
554    }
555}