stream_download_opendal/
lib.rs

1#![deny(missing_docs)]
2#![forbid(clippy::unwrap_used)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![doc = include_str!("../README.md")]
5
6use std::fmt::Debug;
7use std::future::Future;
8use std::io::{self};
9use std::num::NonZeroUsize;
10use std::task::Poll;
11
12use bytes::{Bytes, BytesMut};
13use futures_util::{Stream, ready};
14use opendal::{FuturesAsyncReader, Operator, Reader};
15use pin_project_lite::pin_project;
16use stream_download::source::{DecodeError, SourceStream};
17use stream_download::storage::StorageProvider;
18use stream_download::{Settings, StreamDownload, StreamInitializationError};
19use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt};
20use tokio_util::io::poll_read_buf;
21use tracing::instrument;
22
23/// Extension trait for adding `OpenDAL` support to a [`StreamDownload`] instance.
24pub trait StreamDownloadExt<P>
25where
26    Self: Sized,
27{
28    /// Creates a new [`StreamDownload`] that uses an `OpenDAL` resource.
29    /// See the [`opendal`] documentation for more details.
30    ///
31    /// # Example
32    ///
33    /// ```no_run
34    /// use std::error::Error;
35    /// use std::io::{self, Read};
36    /// use std::result::Result;
37    ///
38    /// use opendal::{Operator, services};
39    /// use stream_download::storage::temp::TempStorageProvider;
40    /// use stream_download::{Settings, StreamDownload};
41    /// use stream_download_opendal::{OpendalStreamParams, StreamDownloadExt};
42    ///
43    /// #[tokio::main]
44    /// async fn main() -> Result<(), Box<dyn Error>> {
45    ///     let mut builder = services::S3::default()
46    ///         .region("us-east-1")
47    ///         .access_key_id("test")
48    ///         .secret_access_key("test")
49    ///         .bucket("my-bucket");
50    ///     let operator = Operator::new(builder)?.finish();
51    ///
52    ///     let mut reader = StreamDownload::new_opendal(
53    ///         OpendalStreamParams::new(operator, "some-object-key"),
54    ///         TempStorageProvider::default(),
55    ///         Settings::default(),
56    ///     )
57    ///     .await?;
58    ///
59    ///     tokio::task::spawn_blocking(move || {
60    ///         let mut buf = Vec::new();
61    ///         reader.read_to_end(&mut buf)?;
62    ///         Ok::<_, io::Error>(())
63    ///     })
64    ///     .await??;
65    ///     Ok(())
66    /// }
67    /// ```
68    fn new_opendal(
69        params: OpendalStreamParams,
70        storage_provider: P,
71        settings: Settings<OpendalStream>,
72    ) -> impl Future<Output = Result<Self, StreamInitializationError<OpendalStream>>> + Send;
73}
74
75impl<P: StorageProvider> StreamDownloadExt<P> for StreamDownload<P> {
76    async fn new_opendal(
77        params: OpendalStreamParams,
78        storage_provider: P,
79        settings: Settings<OpendalStream>,
80    ) -> Result<Self, StreamInitializationError<OpendalStream>> {
81        Self::new(params, storage_provider, settings).await
82    }
83}
84
85/// Parameters for creating an `OpenDAL` stream.
86#[derive(Debug, Clone)]
87pub struct OpendalStreamParams {
88    operator: Operator,
89    path: String,
90    chunk_size: usize,
91}
92
93impl OpendalStreamParams {
94    /// Creates a new [`OpendalStreamParams`] instance.
95    pub fn new<S>(operator: Operator, path: S) -> Self
96    where
97        S: Into<String>,
98    {
99        Self {
100            operator,
101            path: path.into(),
102            chunk_size: 4096,
103        }
104    }
105
106    /// Sets the chunk size for the [`OpendalStream`].
107    /// The default value is 4096.
108    #[must_use]
109    pub fn chunk_size(mut self, chunk_size: NonZeroUsize) -> Self {
110        self.chunk_size = chunk_size.get();
111        self
112    }
113}
114
115pin_project! {
116    /// An `OpenDAL` implementation of the [`SourceStream`] trait
117    pub struct OpendalStream {
118        #[pin]
119        async_reader: Compat<FuturesAsyncReader>,
120        reader: Reader,
121        buf: BytesMut,
122        capacity: usize,
123        content_length: Option<u64>,
124        content_type: Option<String>,
125    }
126}
127
128// Can't use educe here because of https://github.com/taiki-e/pin-project-lite/issues/3
129impl Debug for OpendalStream {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("OpendalStream")
132            .field("async_reader", &"<async_reader>")
133            .field("reader", &"<reader>")
134            .field("buf", &self.buf)
135            .field("capacity", &self.capacity)
136            .field("content_length", &self.content_length)
137            .field("content_type", &self.content_type)
138            .finish()
139    }
140}
141
142/// Error returned from `OpenDAL`
143#[derive(thiserror::Error, Debug)]
144#[error("{0}")]
145pub struct Error(#[from] opendal::Error);
146
147impl DecodeError for Error {}
148
149impl OpendalStream {
150    /// Creates a new [`OpendalStream`].
151    #[instrument]
152    pub async fn new(params: OpendalStreamParams) -> Result<Self, Error> {
153        let stat = params.operator.stat(&params.path).await?;
154
155        let content_length = stat.content_length();
156        let content_type = stat.content_type().map(ToString::to_string);
157
158        let reader = params.operator.reader(&params.path).await?;
159
160        let async_reader = reader.clone().into_futures_async_read(..).await?.compat();
161
162        Ok(Self {
163            async_reader,
164            reader,
165            buf: BytesMut::with_capacity(params.chunk_size),
166            capacity: params.chunk_size,
167            content_length: if content_length > 0 {
168                Some(content_length)
169            } else {
170                None
171            },
172            content_type,
173        })
174    }
175
176    /// Returns the content type of the stream, if it is known.
177    pub fn content_type(&self) -> Option<&str> {
178        self.content_type.as_deref()
179    }
180}
181
182impl SourceStream for OpendalStream {
183    type Params = OpendalStreamParams;
184
185    type StreamCreationError = Error;
186
187    async fn create(params: Self::Params) -> Result<Self, Self::StreamCreationError> {
188        Self::new(params).await
189    }
190
191    fn content_length(&self) -> Option<u64> {
192        self.content_length
193    }
194
195    async fn seek_range(&mut self, start: u64, end: Option<u64>) -> io::Result<()> {
196        let reader = self.reader.clone();
197        let async_reader = match end {
198            Some(end) => reader.into_futures_async_read(start..end).await,
199            None => reader.into_futures_async_read(start..).await,
200        };
201
202        self.async_reader = async_reader
203            .map_err(Into::into)
204            .wrap_err("error creating async reader")?
205            .compat();
206        Ok(())
207    }
208
209    async fn reconnect(&mut self, current_position: u64) -> io::Result<()> {
210        self.seek_range(current_position, None).await
211    }
212
213    fn supports_seek(&self) -> bool {
214        true
215    }
216}
217
218impl Stream for OpendalStream {
219    type Item = io::Result<Bytes>;
220
221    fn poll_next(
222        mut self: std::pin::Pin<&mut Self>,
223        cx: &mut std::task::Context<'_>,
224    ) -> std::task::Poll<Option<Self::Item>> {
225        let mut this = self.as_mut().project();
226
227        if this.buf.capacity() == 0 {
228            this.buf.reserve(*this.capacity);
229        }
230
231        match ready!(poll_read_buf(this.async_reader, cx, &mut this.buf)) {
232            Err(err) => Poll::Ready(Some(Err(err))),
233            Ok(0) => Poll::Ready(None),
234            Ok(_) => {
235                let chunk = this.buf.split();
236                Poll::Ready(Some(Ok(chunk.freeze())))
237            }
238        }
239    }
240}
241
242pub(crate) trait WrapIoResult {
243    fn wrap_err(self, msg: &str) -> Self;
244}
245
246impl<T> WrapIoResult for io::Result<T> {
247    fn wrap_err(self, msg: &str) -> Self {
248        if let Err(e) = self {
249            Err(io::Error::new(e.kind(), format!("{msg}: {e}")))
250        } else {
251            self
252        }
253    }
254}