1use std::convert::Infallible;
5use std::error::Error;
6use std::fmt::Debug;
7use std::future;
8use std::io::{self, SeekFrom};
9use std::time::{Duration, Instant};
10
11use bytes::{BufMut, Bytes, BytesMut};
12use futures_util::{Future, Stream, StreamExt, TryStream};
13use handle::{
14 DownloadStatus, Downloaded, NotifyRead, PositionReached, RequestedPosition, SourceHandle,
15};
16use tokio::sync::mpsc;
17use tokio::task::yield_now;
18use tokio::time::timeout;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, error, instrument, trace, warn};
21
22use crate::storage::StorageWriter;
23use crate::{ProgressFn, ReconnectFn, Settings, StreamPhase, StreamState};
24
25pub(crate) mod handle;
26
27#[derive(Clone, Copy, PartialEq, Eq, Debug)]
29pub enum StreamOutcome {
30 Completed,
32 CancelledByUser,
34}
35
36pub trait SourceStream:
43 TryStream<Ok = Bytes>
44 + Stream<Item = Result<Self::Ok, Self::Error>>
45 + Unpin
46 + Send
47 + Sync
48 + Sized
49 + 'static
50{
51 type Params: Send;
53
54 type StreamCreationError: DecodeError + Send;
56
57 fn create(
59 params: Self::Params,
60 ) -> impl Future<Output = Result<Self, Self::StreamCreationError>> + Send;
61
62 fn content_length(&self) -> Option<u64>;
65
66 fn seek_range(
72 &mut self,
73 start: u64,
74 end: Option<u64>,
75 ) -> impl Future<Output = io::Result<()>> + Send;
76
77 fn reconnect(&mut self, current_position: u64) -> impl Future<Output = io::Result<()>> + Send;
79
80 fn supports_seek(&self) -> bool;
83
84 fn on_finish(
86 &mut self,
87 result: io::Result<()>,
88 #[expect(unused)] outcome: StreamOutcome,
89 ) -> impl Future<Output = io::Result<()>> + Send {
90 future::ready(result)
91 }
92}
93
94pub trait DecodeError: Error + Send + Sized {
96 fn decode_error(self) -> impl Future<Output = String> + Send {
98 future::ready(self.to_string())
99 }
100}
101
102impl DecodeError for Infallible {
103 async fn decode_error(self) -> String {
104 String::new()
106 }
107}
108
109#[derive(PartialEq, Eq)]
110enum DownloadAction {
111 Continue,
112 Complete,
113}
114
115pub(crate) struct Source<S: SourceStream, W: StorageWriter> {
116 writer: W,
117 downloaded: Downloaded,
118 download_status: DownloadStatus,
119 requested_position: RequestedPosition,
120 position_reached: PositionReached,
121 notify_read: NotifyRead,
122 content_length: Option<u64>,
123 seek_tx: mpsc::Sender<u64>,
124 seek_rx: mpsc::Receiver<u64>,
125 prefetch_bytes: u64,
126 batch_write_size: usize,
127 retry_timeout: Duration,
128 on_progress: Option<ProgressFn<S>>,
129 on_reconnect: Option<ReconnectFn<S>>,
130 prefetch_complete: bool,
131 prefetch_start_position: u64,
132 remaining_bytes: Option<Bytes>,
133 cancellation_token: CancellationToken,
134}
135
136impl<S, W> Source<S, W>
137where
138 S: SourceStream<Error: Debug>,
139 W: StorageWriter,
140{
141 pub(crate) fn new(
142 writer: W,
143 content_length: Option<u64>,
144 settings: Settings<S>,
145 cancellation_token: CancellationToken,
146 ) -> Self {
147 let (seek_tx, seek_rx) = mpsc::channel(1);
150 Self {
151 writer,
152 downloaded: Downloaded::default(),
153 download_status: DownloadStatus::default(),
154 requested_position: RequestedPosition::default(),
155 position_reached: PositionReached::default(),
156 notify_read: NotifyRead::default(),
157 seek_tx,
158 seek_rx,
159 content_length,
160 prefetch_complete: settings.prefetch_bytes == 0,
161 prefetch_bytes: settings.prefetch_bytes,
162 batch_write_size: settings.batch_write_size,
163 retry_timeout: settings.retry_timeout,
164 on_progress: settings.on_progress,
165 on_reconnect: settings.on_reconnect,
166 prefetch_start_position: 0,
167 remaining_bytes: None,
168 cancellation_token,
169 }
170 }
171
172 #[instrument(skip_all)]
173 pub(crate) async fn download(&mut self, mut stream: S) {
174 let res = self.download_inner(&mut stream).await;
175 let (res, stream_res) = match res {
176 Ok(StreamOutcome::Completed) => (Ok(()), StreamOutcome::Completed),
177 Ok(StreamOutcome::CancelledByUser) => (
178 Err(io::Error::new(
179 io::ErrorKind::Interrupted,
180 "stream cancelled by user",
181 )),
182 StreamOutcome::CancelledByUser,
183 ),
184 Err(e) => (Err(e), StreamOutcome::Completed),
185 };
186 let res = stream.on_finish(res, stream_res).await;
187 if let Err(e) = res {
188 if stream_res == StreamOutcome::Completed {
189 error!("download failed: {e:?}");
190 }
191 self.download_status.set_failed();
192 }
193 self.signal_download_complete();
194 }
195
196 async fn download_inner(&mut self, stream: &mut S) -> io::Result<StreamOutcome> {
197 debug!("starting file download");
198 let download_start = std::time::Instant::now();
199
200 loop {
201 let next_chunk = timeout(self.retry_timeout, stream.next());
204 tokio::select! {
205 position = self.seek_rx.recv() => {
206 self.handle_seek(stream, position.expect("seek_tx dropped")).await?;
208 },
209 bytes = next_chunk => {
210 let Ok(bytes) = bytes else {
211 self.handle_reconnect(stream).await?;
212 continue;
213 };
214 if self
215 .handle_bytes(stream, bytes, download_start)
216 .await?
217 == DownloadAction::Complete
218 {
219 debug!(
220 download_duration = format!("{:?}", download_start.elapsed()),
221 "stream finished downloading"
222 );
223 break;
224 }
225 }
226 () = self.cancellation_token.cancelled() => {
227 debug!("received cancellation request, stopping download task");
228 return Ok(StreamOutcome::CancelledByUser);
229 }
230 };
231 }
232 self.report_download_complete(stream, download_start)?;
233 Ok(StreamOutcome::Completed)
234 }
235
236 async fn handle_seek(&mut self, stream: &mut S, position: u64) -> io::Result<()> {
237 if self.should_seek(stream, position)? {
238 debug!("seek position not yet downloaded");
239 let current_stream_position = self.writer.stream_position()?;
240 if self.prefetch_complete {
241 debug!("re-starting prefetch");
242 self.prefetch_start_position = position;
243 self.prefetch_complete = false;
244 } else {
245 debug!("seeking during prefetch, ending prefetch early");
246 self.downloaded
247 .add(self.prefetch_start_position..current_stream_position);
248 self.prefetch_complete = true;
249 }
250 if let Some(content_length) = self.content_length {
251 let min_start_position = current_stream_position.min(position);
253 debug!(
254 start = min_start_position,
255 end = content_length,
256 "checking for seek range",
257 );
258 let gap = self
259 .downloaded
260 .next_gap(min_start_position..content_length)
261 .expect("already checked for a gap");
262 let seek_start = gap.start.max(position);
265 debug!(seek_start, seek_end = gap.end, "requesting seek range");
266 self.seek(stream, seek_start, Some(gap.end)).await?;
267 } else {
268 self.seek(stream, position, None).await?;
269 }
270 }
271 Ok(())
272 }
273
274 async fn handle_reconnect(&mut self, stream: &mut S) -> io::Result<()> {
275 warn!("timed out reading next chunk, retrying");
276 let pos = self.writer.stream_position()?;
277 let reconnect_pos = tokio::time::timeout(self.retry_timeout, stream.reconnect(pos)).await;
281 if reconnect_pos
282 .inspect_err(|e| warn!("error attempting to reconnect: {e:?}"))
283 .is_ok()
284 {
285 if let Some(on_reconnect) = &mut self.on_reconnect {
286 on_reconnect(stream, &self.cancellation_token);
287 }
288 }
289 Ok(())
290 }
291
292 async fn handle_prefetch(
293 &mut self,
294 stream: &mut S,
295 bytes: Option<Bytes>,
296 start_position: u64,
297 download_start: Instant,
298 ) -> io::Result<DownloadAction> {
299 let Some(bytes) = bytes else {
300 self.prefetch_complete = true;
301 debug!("file shorter than prefetch length, download finished");
302 self.writer.flush()?;
303 let position = self.writer.stream_position()?;
304 self.downloaded.add(start_position..position);
305
306 return self.finish_or_find_next_gap(stream).await;
307 };
308 let written = self.write_batched(&bytes).await?;
309 self.writer.flush()?;
310
311 let stream_position = self.writer.stream_position()?;
312 let partial_write = written < bytes.len();
313
314 if partial_write {
316 debug!(
317 written,
318 bytes_len = bytes.len(),
319 "failed to write all during prefetch"
320 );
321 self.remaining_bytes = Some(bytes.slice(written..));
322 }
323 if (stream_position >= start_position + self.prefetch_bytes) || partial_write {
324 self.downloaded.add(start_position..stream_position);
325 debug!("prefetch complete");
326 self.prefetch_complete = true;
327 }
328
329 self.report_prefetch_progress(stream, stream_position, download_start, written);
330 Ok(DownloadAction::Continue)
331 }
332
333 async fn finish_or_find_next_gap(&mut self, stream: &mut S) -> io::Result<DownloadAction> {
334 if stream.supports_seek() {
335 if let Some(content_length) = self.content_length {
336 let gap = self.downloaded.next_gap(0..content_length);
337 if let Some(gap) = gap {
338 debug!(
339 missing = format!("{gap:?}"),
340 "downloading missing stream chunk"
341 );
342 self.seek(stream, gap.start, Some(gap.end)).await?;
343 return Ok(DownloadAction::Continue);
344 }
345 }
346 }
347 self.writer.flush()?;
348 self.signal_download_complete();
349 Ok(DownloadAction::Complete)
350 }
351
352 async fn write_batched(&mut self, bytes: &[u8]) -> io::Result<usize> {
353 let mut written = 0;
354 loop {
355 let write_size = self.batch_write_size.min(bytes[written..].len());
356 let batch_written = self.writer.write(&bytes[written..written + write_size])?;
357 if batch_written == 0 {
358 return Ok(written);
359 }
360 written += batch_written;
361 yield_now().await;
364 }
365 }
366
367 async fn handle_bytes(
368 &mut self,
369 stream: &mut S,
370 bytes: Option<Result<Bytes, S::Error>>,
371 download_start: Instant,
372 ) -> io::Result<DownloadAction> {
373 let bytes = match bytes.transpose() {
374 Ok(bytes) => bytes,
375 Err(e) => {
376 error!("Error fetching chunk from stream: {e:?}");
377 return Ok(DownloadAction::Continue);
378 }
379 };
380
381 if !self.prefetch_complete {
382 return self
383 .handle_prefetch(stream, bytes, self.prefetch_start_position, download_start)
384 .await;
385 }
386
387 let bytes = match (self.remaining_bytes.take(), bytes) {
388 (Some(remaining), Some(bytes)) => {
389 let mut combined = BytesMut::new();
390 combined.put(remaining);
391 combined.put(bytes);
392 combined.freeze()
393 }
394 (Some(remaining), None) => remaining,
395 (None, Some(bytes)) => bytes,
396 (None, None) => {
397 return self.finish_or_find_next_gap(stream).await;
398 }
399 };
400 let bytes_len = bytes.len();
401 let new_position = self.write(bytes).await?;
402 self.report_downloading_progress(stream, new_position, download_start, bytes_len)?;
403
404 Ok(DownloadAction::Continue)
405 }
406
407 async fn write(&mut self, bytes: Bytes) -> io::Result<u64> {
408 let mut written = 0;
409 let position = self.writer.stream_position()?;
410 let mut new_position = position;
411 while written < bytes.len() {
414 self.notify_read.request();
415 let new_written = self.write_batched(&bytes[written..]).await?;
416 trace!(written, new_written, len = bytes.len(), "wrote data");
417
418 if new_written > 0 {
419 self.writer.flush()?;
420 written += new_written;
421 }
422 new_position = self.writer.stream_position()?;
423 if new_position > position {
424 self.downloaded.add(position..new_position);
425 }
426
427 if let Some(requested) = self.requested_position.get() {
428 debug!(
429 requested_position = requested,
430 current_position = new_position,
431 "received requested position"
432 );
433
434 if new_position >= requested {
435 debug!("notifying position reached");
436 self.requested_position.clear();
437 self.position_reached.notify_position_reached();
438 }
439 }
440 if new_written == 0 {
441 debug!("waiting for next read");
443 self.notify_read.wait_for_read().await;
444 debug!("read finished");
445 }
446
447 trace!(
448 previous_position = position,
449 new_position,
450 chunk_size = bytes.len(),
451 "received response chunk"
452 );
453 }
454 Ok(new_position)
455 }
456
457 fn should_seek(&mut self, stream: &S, position: u64) -> io::Result<bool> {
458 if !stream.supports_seek() {
459 warn!("Attempting to seek, but it's unsupported. Waiting for stream to catch up.");
460 return Ok(false);
461 }
462 Ok(if let Some(range) = self.downloaded.get(position) {
463 !range.contains(&self.writer.stream_position()?)
464 } else {
465 true
466 })
467 }
468
469 async fn seek(&mut self, stream: &mut S, start: u64, end: Option<u64>) -> io::Result<()> {
470 stream.seek_range(start, end).await?;
471 self.writer.seek(SeekFrom::Start(start))?;
472 Ok(())
473 }
474
475 fn signal_download_complete(&self) {
476 self.position_reached.notify_stream_done();
477 }
478
479 fn report_progress(&mut self, stream: &S, info: StreamState) {
480 if let Some(on_progress) = self.on_progress.as_mut() {
481 on_progress(stream, info, &self.cancellation_token);
482 }
483 }
484
485 fn report_prefetch_progress(
486 &mut self,
487 stream: &S,
488 stream_position: u64,
489 download_start: Instant,
490 chunk_size: usize,
491 ) {
492 self.report_progress(
493 stream,
494 StreamState {
495 current_position: stream_position,
496 current_chunk: (0..stream_position),
497 elapsed: download_start.elapsed(),
498 phase: StreamPhase::Prefetching {
499 target: self.prefetch_bytes,
500 chunk_size,
501 },
502 },
503 );
504 }
505
506 fn report_downloading_progress(
507 &mut self,
508 stream: &S,
509 new_position: u64,
510 download_start: Instant,
511 chunk_size: usize,
512 ) -> io::Result<()> {
513 let pos = self.writer.stream_position()?;
514 self.report_progress(
515 stream,
516 StreamState {
517 current_position: pos,
518 current_chunk: self
519 .downloaded
520 .get(new_position - 1)
521 .expect("position already downloaded"),
522 elapsed: download_start.elapsed(),
523 phase: StreamPhase::Downloading { chunk_size },
524 },
525 );
526 Ok(())
527 }
528
529 fn report_download_complete(&mut self, stream: &S, download_start: Instant) -> io::Result<()> {
530 let pos = self.writer.stream_position()?;
531 self.report_progress(
532 stream,
533 StreamState {
534 current_position: pos,
535 elapsed: download_start.elapsed(),
536 current_chunk: self.downloaded.get(pos.max(1) - 1).unwrap_or_default(),
538 phase: StreamPhase::Complete,
539 },
540 );
541 Ok(())
542 }
543
544 pub(crate) fn source_handle(&self) -> SourceHandle {
545 SourceHandle {
546 downloaded: self.downloaded.clone(),
547 download_status: self.download_status.clone(),
548 requested_position: self.requested_position.clone(),
549 notify_read: self.notify_read.clone(),
550 position_reached: self.position_reached.clone(),
551 seek_tx: self.seek_tx.clone(),
552 content_length: self.content_length,
553 }
554 }
555}