skua_voice/input/sources/
http.rs

1use crate::input::{
2    AsyncAdapterStream,
3    AsyncMediaSource,
4    AudioStream,
5    AudioStreamError,
6    Compose,
7    Input,
8};
9use async_trait::async_trait;
10use futures::TryStreamExt;
11use pin_project::pin_project;
12use reqwest::{
13    header::{HeaderMap, ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_TYPE, RANGE, RETRY_AFTER},
14    Client,
15};
16use std::{
17    io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult, SeekFrom},
18    pin::Pin,
19    task::{Context, Poll},
20    time::Duration,
21};
22use symphonia_core::{io::MediaSource, probe::Hint};
23use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
24use tokio_util::io::StreamReader;
25
26/// A lazily instantiated HTTP request.
27#[derive(Clone, Debug)]
28pub struct HttpRequest {
29    /// A reqwest client instance used to send the HTTP GET request.
30    pub client: Client,
31    /// The target URL of the required resource.
32    pub request: String,
33    /// HTTP header fields to add to any created requests.
34    pub headers: HeaderMap,
35    /// Content length, used as an upper bound in range requests if known.
36    ///
37    /// This is only needed for certain domains who expect to see a value like
38    /// `range: bytes=0-1023` instead of the simpler `range: bytes=0-` (such as
39    /// Youtube).
40    pub content_length: Option<u64>,
41}
42
43impl HttpRequest {
44    #[must_use]
45    /// Create a lazy HTTP request.
46    pub fn new(client: Client, request: String) -> Self {
47        Self::new_with_headers(client, request, HeaderMap::default())
48    }
49
50    #[must_use]
51    /// Create a lazy HTTP request.
52    pub fn new_with_headers(client: Client, request: String, headers: HeaderMap) -> Self {
53        HttpRequest {
54            client,
55            request,
56            headers,
57            content_length: None,
58        }
59    }
60
61    async fn create_stream(
62        &mut self,
63        offset: Option<u64>,
64    ) -> Result<(HttpStream, Option<Hint>), AudioStreamError> {
65        let mut resp = self.client.get(&self.request).headers(self.headers.clone());
66
67        match (offset, self.content_length) {
68            (Some(offset), None) => {
69                resp = resp.header(RANGE, format!("bytes={offset}-"));
70            },
71            (offset, Some(max)) => {
72                resp = resp.header(
73                    RANGE,
74                    format!("bytes={}-{}", offset.unwrap_or(0), max.saturating_sub(1)),
75                );
76            },
77            _ => {},
78        }
79
80        let resp = resp
81            .send()
82            .await
83            .map_err(|e| AudioStreamError::Fail(Box::new(e)))?;
84
85        if !resp.status().is_success() {
86            let msg: Box<dyn std::error::Error + Send + Sync + 'static> =
87                format!("failed with http status code: {}", resp.status()).into();
88            return Err(AudioStreamError::Fail(msg));
89        }
90
91        if let Some(t) = resp.headers().get(RETRY_AFTER) {
92            t.to_str()
93                .map_err(|_| {
94                    let msg: Box<dyn std::error::Error + Send + Sync + 'static> =
95                        "Retry-after field contained non-ASCII data.".into();
96                    AudioStreamError::Fail(msg)
97                })
98                .and_then(|str_text| {
99                    str_text.parse().map_err(|_| {
100                        let msg: Box<dyn std::error::Error + Send + Sync + 'static> =
101                            "Retry-after field was non-numeric.".into();
102                        AudioStreamError::Fail(msg)
103                    })
104                })
105                .and_then(|t| Err(AudioStreamError::RetryIn(Duration::from_secs(t))))
106        } else {
107            let headers = resp.headers();
108
109            let hint = headers
110                .get(CONTENT_TYPE)
111                .and_then(|val| val.to_str().ok())
112                .map(|val| {
113                    let mut out = Hint::default();
114                    out.mime_type(val);
115                    out
116                });
117
118            let len = headers
119                .get(CONTENT_LENGTH)
120                .and_then(|val| val.to_str().ok())
121                .and_then(|val| val.parse().ok());
122
123            let resume = headers
124                .get(ACCEPT_RANGES)
125                .and_then(|a| a.to_str().ok())
126                .and_then(|a| {
127                    if a == "bytes" {
128                        Some(self.clone())
129                    } else {
130                        None
131                    }
132                });
133
134            let stream = Box::new(StreamReader::new(
135                resp.bytes_stream()
136                    .map_err(|e| IoError::new(IoErrorKind::Other, e)),
137            ));
138
139            let input = HttpStream {
140                stream,
141                len,
142                resume,
143            };
144
145            Ok((input, hint))
146        }
147    }
148}
149
150#[pin_project]
151struct HttpStream {
152    #[pin]
153    stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
154    len: Option<u64>,
155    resume: Option<HttpRequest>,
156}
157
158impl AsyncRead for HttpStream {
159    fn poll_read(
160        self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        buf: &mut ReadBuf<'_>,
163    ) -> Poll<IoResult<()>> {
164        AsyncRead::poll_read(self.project().stream, cx, buf)
165    }
166}
167
168impl AsyncSeek for HttpStream {
169    fn start_seek(self: Pin<&mut Self>, _position: SeekFrom) -> IoResult<()> {
170        Err(IoErrorKind::Unsupported.into())
171    }
172
173    fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<u64>> {
174        unreachable!()
175    }
176}
177
178#[async_trait]
179impl AsyncMediaSource for HttpStream {
180    fn is_seekable(&self) -> bool {
181        false
182    }
183
184    async fn byte_len(&self) -> Option<u64> {
185        self.len
186    }
187
188    async fn try_resume(
189        &mut self,
190        offset: u64,
191    ) -> Result<Box<dyn AsyncMediaSource>, AudioStreamError> {
192        if let Some(resume) = &mut self.resume {
193            resume
194                .create_stream(Some(offset))
195                .await
196                .map(|a| Box::new(a.0) as Box<dyn AsyncMediaSource>)
197        } else {
198            Err(AudioStreamError::Unsupported)
199        }
200    }
201}
202
203#[async_trait]
204impl Compose for HttpRequest {
205    fn create(&mut self) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
206        Err(AudioStreamError::Unsupported)
207    }
208
209    async fn create_async(
210        &mut self,
211    ) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
212        self.create_stream(None).await.map(|(input, hint)| {
213            let stream = AsyncAdapterStream::new(Box::new(input), 64 * 1024);
214
215            AudioStream {
216                input: Box::new(stream) as Box<dyn MediaSource>,
217                hint,
218            }
219        })
220    }
221
222    fn should_create_async(&self) -> bool {
223        true
224    }
225}
226
227impl From<HttpRequest> for Input {
228    fn from(val: HttpRequest) -> Self {
229        Input::Lazy(Box::new(val))
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use reqwest::Client;
236
237    use super::*;
238    use crate::{
239        constants::test_data::{HTTP_OPUS_TARGET, HTTP_TARGET, HTTP_WEBM_TARGET},
240        input::input_tests::*,
241    };
242
243    #[tokio::test]
244    #[ntest::timeout(10_000)]
245    async fn http_track_plays() {
246        track_plays_mixed(|| HttpRequest::new(Client::new(), HTTP_TARGET.into())).await;
247    }
248
249    #[tokio::test]
250    #[ntest::timeout(10_000)]
251    async fn http_forward_seek_correct() {
252        forward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_TARGET.into())).await;
253    }
254
255    #[tokio::test]
256    #[ntest::timeout(10_000)]
257    async fn http_backward_seek_correct() {
258        backward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_TARGET.into())).await;
259    }
260
261    // NOTE: this covers youtube audio in a non-copyright-violating way, since
262    // those depend on an HttpRequest internally anyhow.
263    #[tokio::test]
264    #[ntest::timeout(10_000)]
265    async fn http_opus_track_plays() {
266        track_plays_passthrough(|| HttpRequest::new(Client::new(), HTTP_OPUS_TARGET.into())).await;
267    }
268
269    #[tokio::test]
270    #[ntest::timeout(10_000)]
271    async fn http_opus_forward_seek_correct() {
272        forward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_OPUS_TARGET.into())).await;
273    }
274
275    #[tokio::test]
276    #[ntest::timeout(10_000)]
277    async fn http_opus_backward_seek_correct() {
278        backward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_OPUS_TARGET.into())).await;
279    }
280
281    #[tokio::test]
282    #[ntest::timeout(10_000)]
283    async fn http_webm_track_plays() {
284        track_plays_passthrough(|| HttpRequest::new(Client::new(), HTTP_WEBM_TARGET.into())).await;
285    }
286
287    #[tokio::test]
288    #[ntest::timeout(10_000)]
289    async fn http_webm_forward_seek_correct() {
290        forward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_WEBM_TARGET.into())).await;
291    }
292
293    #[tokio::test]
294    #[ntest::timeout(10_000)]
295    async fn http_webm_backward_seek_correct() {
296        backward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_WEBM_TARGET.into())).await;
297    }
298}