stream_download/
async_read.rs

1//! A [`SourceStream`] adapter for any source that implements [`AsyncRead`].
2
3use std::convert::Infallible;
4use std::io;
5use std::pin::Pin;
6
7use bytes::Bytes;
8use futures_util::Stream;
9use tokio::io::AsyncRead;
10use tokio_util::io::ReaderStream;
11
12use crate::source::SourceStream;
13
14/// Parameters for creating an [`AsyncReadStream`].
15#[derive(Debug)]
16pub struct AsyncReadStreamParams<T> {
17    stream: T,
18    content_length: Option<u64>,
19}
20
21impl<T> AsyncReadStreamParams<T> {
22    /// Creates a new [`AsyncReadStreamParams`] instance.
23    pub fn new(stream: T) -> Self {
24        Self {
25            stream,
26            content_length: None,
27        }
28    }
29
30    /// Sets the content length of the stream.
31    /// A generic [`AsyncRead`] source has no way of knowing the content length automatically, so it
32    /// must be set explicitly or it will default to [`None`].
33    #[must_use]
34    pub fn content_length<L>(self, content_length: L) -> Self
35    where
36        L: Into<Option<u64>>,
37    {
38        Self {
39            content_length: content_length.into(),
40            ..self
41        }
42    }
43}
44
45/// An implementation of the [`SourceStream`] trait for any stream implementing [`AsyncRead`].
46#[derive(Debug)]
47pub struct AsyncReadStream<T> {
48    stream: ReaderStream<T>,
49    content_length: Option<u64>,
50}
51
52impl<T> AsyncReadStream<T>
53where
54    T: AsyncRead + Send + Sync + Unpin + 'static,
55{
56    /// Creates a new [`AsyncReadStream`].
57    pub fn new<L>(stream: T, content_length: L) -> Self
58    where
59        L: Into<Option<u64>>,
60    {
61        Self {
62            stream: ReaderStream::new(stream),
63            content_length: content_length.into(),
64        }
65    }
66}
67
68impl<T> SourceStream for AsyncReadStream<T>
69where
70    T: AsyncRead + Send + Sync + Unpin + 'static,
71{
72    type Params = AsyncReadStreamParams<T>;
73
74    type StreamCreationError = Infallible;
75
76    async fn create(params: Self::Params) -> Result<Self, Self::StreamCreationError> {
77        Ok(Self::new(params.stream, params.content_length))
78    }
79
80    fn content_length(&self) -> Option<u64> {
81        self.content_length
82    }
83
84    fn supports_seek(&self) -> bool {
85        false
86    }
87
88    async fn seek_range(&mut self, _start: u64, _end: Option<u64>) -> io::Result<()> {
89        Err(io::Error::new(
90            io::ErrorKind::Unsupported,
91            "seek unsupported",
92        ))
93    }
94
95    async fn reconnect(&mut self, _current_position: u64) -> io::Result<()> {
96        Err(io::Error::new(
97            io::ErrorKind::Unsupported,
98            "reconnect unsupported",
99        ))
100    }
101}
102
103impl<T> Stream for AsyncReadStream<T>
104where
105    T: AsyncRead + Unpin,
106{
107    type Item = io::Result<Bytes>;
108
109    fn poll_next(
110        mut self: std::pin::Pin<&mut Self>,
111        cx: &mut std::task::Context<'_>,
112    ) -> std::task::Poll<Option<Self::Item>> {
113        Pin::new(&mut self.stream).poll_next(cx)
114    }
115}