skua_voice/input/sources/
http.rs1use crate::input::{
2 AsyncAdapterStream,
3 AsyncMediaSource,
4 AudioStream,
5 AudioStreamError,
6 Compose,
7 Input,
8};
9use async_trait::async_trait;
10use futures::TryStreamExt;
11use pin_project::pin_project;
12use reqwest::{
13 header::{HeaderMap, ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_TYPE, RANGE, RETRY_AFTER},
14 Client,
15};
16use std::{
17 io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult, SeekFrom},
18 pin::Pin,
19 task::{Context, Poll},
20 time::Duration,
21};
22use symphonia_core::{io::MediaSource, probe::Hint};
23use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
24use tokio_util::io::StreamReader;
25
26#[derive(Clone, Debug)]
28pub struct HttpRequest {
29 pub client: Client,
31 pub request: String,
33 pub headers: HeaderMap,
35 pub content_length: Option<u64>,
41}
42
43impl HttpRequest {
44 #[must_use]
45 pub fn new(client: Client, request: String) -> Self {
47 Self::new_with_headers(client, request, HeaderMap::default())
48 }
49
50 #[must_use]
51 pub fn new_with_headers(client: Client, request: String, headers: HeaderMap) -> Self {
53 HttpRequest {
54 client,
55 request,
56 headers,
57 content_length: None,
58 }
59 }
60
61 async fn create_stream(
62 &mut self,
63 offset: Option<u64>,
64 ) -> Result<(HttpStream, Option<Hint>), AudioStreamError> {
65 let mut resp = self.client.get(&self.request).headers(self.headers.clone());
66
67 match (offset, self.content_length) {
68 (Some(offset), None) => {
69 resp = resp.header(RANGE, format!("bytes={offset}-"));
70 },
71 (offset, Some(max)) => {
72 resp = resp.header(
73 RANGE,
74 format!("bytes={}-{}", offset.unwrap_or(0), max.saturating_sub(1)),
75 );
76 },
77 _ => {},
78 }
79
80 let resp = resp
81 .send()
82 .await
83 .map_err(|e| AudioStreamError::Fail(Box::new(e)))?;
84
85 if !resp.status().is_success() {
86 let msg: Box<dyn std::error::Error + Send + Sync + 'static> =
87 format!("failed with http status code: {}", resp.status()).into();
88 return Err(AudioStreamError::Fail(msg));
89 }
90
91 if let Some(t) = resp.headers().get(RETRY_AFTER) {
92 t.to_str()
93 .map_err(|_| {
94 let msg: Box<dyn std::error::Error + Send + Sync + 'static> =
95 "Retry-after field contained non-ASCII data.".into();
96 AudioStreamError::Fail(msg)
97 })
98 .and_then(|str_text| {
99 str_text.parse().map_err(|_| {
100 let msg: Box<dyn std::error::Error + Send + Sync + 'static> =
101 "Retry-after field was non-numeric.".into();
102 AudioStreamError::Fail(msg)
103 })
104 })
105 .and_then(|t| Err(AudioStreamError::RetryIn(Duration::from_secs(t))))
106 } else {
107 let headers = resp.headers();
108
109 let hint = headers
110 .get(CONTENT_TYPE)
111 .and_then(|val| val.to_str().ok())
112 .map(|val| {
113 let mut out = Hint::default();
114 out.mime_type(val);
115 out
116 });
117
118 let len = headers
119 .get(CONTENT_LENGTH)
120 .and_then(|val| val.to_str().ok())
121 .and_then(|val| val.parse().ok());
122
123 let resume = headers
124 .get(ACCEPT_RANGES)
125 .and_then(|a| a.to_str().ok())
126 .and_then(|a| {
127 if a == "bytes" {
128 Some(self.clone())
129 } else {
130 None
131 }
132 });
133
134 let stream = Box::new(StreamReader::new(
135 resp.bytes_stream()
136 .map_err(|e| IoError::new(IoErrorKind::Other, e)),
137 ));
138
139 let input = HttpStream {
140 stream,
141 len,
142 resume,
143 };
144
145 Ok((input, hint))
146 }
147 }
148}
149
150#[pin_project]
151struct HttpStream {
152 #[pin]
153 stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
154 len: Option<u64>,
155 resume: Option<HttpRequest>,
156}
157
158impl AsyncRead for HttpStream {
159 fn poll_read(
160 self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 buf: &mut ReadBuf<'_>,
163 ) -> Poll<IoResult<()>> {
164 AsyncRead::poll_read(self.project().stream, cx, buf)
165 }
166}
167
168impl AsyncSeek for HttpStream {
169 fn start_seek(self: Pin<&mut Self>, _position: SeekFrom) -> IoResult<()> {
170 Err(IoErrorKind::Unsupported.into())
171 }
172
173 fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<u64>> {
174 unreachable!()
175 }
176}
177
178#[async_trait]
179impl AsyncMediaSource for HttpStream {
180 fn is_seekable(&self) -> bool {
181 false
182 }
183
184 async fn byte_len(&self) -> Option<u64> {
185 self.len
186 }
187
188 async fn try_resume(
189 &mut self,
190 offset: u64,
191 ) -> Result<Box<dyn AsyncMediaSource>, AudioStreamError> {
192 if let Some(resume) = &mut self.resume {
193 resume
194 .create_stream(Some(offset))
195 .await
196 .map(|a| Box::new(a.0) as Box<dyn AsyncMediaSource>)
197 } else {
198 Err(AudioStreamError::Unsupported)
199 }
200 }
201}
202
203#[async_trait]
204impl Compose for HttpRequest {
205 fn create(&mut self) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
206 Err(AudioStreamError::Unsupported)
207 }
208
209 async fn create_async(
210 &mut self,
211 ) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
212 self.create_stream(None).await.map(|(input, hint)| {
213 let stream = AsyncAdapterStream::new(Box::new(input), 64 * 1024);
214
215 AudioStream {
216 input: Box::new(stream) as Box<dyn MediaSource>,
217 hint,
218 }
219 })
220 }
221
222 fn should_create_async(&self) -> bool {
223 true
224 }
225}
226
227impl From<HttpRequest> for Input {
228 fn from(val: HttpRequest) -> Self {
229 Input::Lazy(Box::new(val))
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use reqwest::Client;
236
237 use super::*;
238 use crate::{
239 constants::test_data::{HTTP_OPUS_TARGET, HTTP_TARGET, HTTP_WEBM_TARGET},
240 input::input_tests::*,
241 };
242
243 #[tokio::test]
244 #[ntest::timeout(10_000)]
245 async fn http_track_plays() {
246 track_plays_mixed(|| HttpRequest::new(Client::new(), HTTP_TARGET.into())).await;
247 }
248
249 #[tokio::test]
250 #[ntest::timeout(10_000)]
251 async fn http_forward_seek_correct() {
252 forward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_TARGET.into())).await;
253 }
254
255 #[tokio::test]
256 #[ntest::timeout(10_000)]
257 async fn http_backward_seek_correct() {
258 backward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_TARGET.into())).await;
259 }
260
261 #[tokio::test]
264 #[ntest::timeout(10_000)]
265 async fn http_opus_track_plays() {
266 track_plays_passthrough(|| HttpRequest::new(Client::new(), HTTP_OPUS_TARGET.into())).await;
267 }
268
269 #[tokio::test]
270 #[ntest::timeout(10_000)]
271 async fn http_opus_forward_seek_correct() {
272 forward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_OPUS_TARGET.into())).await;
273 }
274
275 #[tokio::test]
276 #[ntest::timeout(10_000)]
277 async fn http_opus_backward_seek_correct() {
278 backward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_OPUS_TARGET.into())).await;
279 }
280
281 #[tokio::test]
282 #[ntest::timeout(10_000)]
283 async fn http_webm_track_plays() {
284 track_plays_passthrough(|| HttpRequest::new(Client::new(), HTTP_WEBM_TARGET.into())).await;
285 }
286
287 #[tokio::test]
288 #[ntest::timeout(10_000)]
289 async fn http_webm_forward_seek_correct() {
290 forward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_WEBM_TARGET.into())).await;
291 }
292
293 #[tokio::test]
294 #[ntest::timeout(10_000)]
295 async fn http_webm_backward_seek_correct() {
296 backward_seek_correct(|| HttpRequest::new(Client::new(), HTTP_WEBM_TARGET.into())).await;
297 }
298}