streaming_http_range_client/
lib.rs1mod error;
2mod http_range;
3#[cfg(not(target_arch = "wasm32"))]
4mod test_client;
5
6#[cfg(target_arch = "wasm32")]
7mod wasm_reader;
8
9use futures_util::TryStreamExt;
10use std::ops::{Range, RangeFrom};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13#[cfg(not(target_arch = "wasm32"))]
14use tokio::io::ReadBuf;
15
16#[cfg(target_arch = "wasm32")]
17use futures_util::io as asyncio;
18
19#[cfg(not(target_arch = "wasm32"))]
20use tokio::io as asyncio;
21
22use asyncio::{AsyncRead, AsyncReadExt};
23
24#[macro_use]
25extern crate log;
26
27pub use error::{Error, Result};
28pub use http_range::HttpRange;
29
30use async_trait::async_trait;
31
32pub struct HttpClient {
53 client: Box<dyn ReaderSource>,
54 reader: Reader,
55 range: Option<HttpRange>,
56 pos: u64,
57 stats: ReqStats,
58}
59
60impl std::fmt::Debug for HttpClient {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("HttpClient")
63 .field("client", &self.client)
64 .field("range", &self.range)
66 .field("pos", &self.pos)
67 .field("stats", &self.stats)
68 .finish()
69 }
70}
71
72#[derive(Debug, Default)]
73struct ReqStats {
74 wasted_bytes: u64,
75 used_bytes: u64,
76 req_count: usize,
77}
78
79impl HttpClient {
80 pub fn new(url: &str) -> Self {
83 Self {
84 client: Box::new(ReqwestClient::new(url)),
85 reader: empty(),
86 pos: 0,
87 range: None,
88 stats: ReqStats::default(),
89 }
90 }
91
92 pub async fn set_range(&mut self, range: Range<u64>) -> Result<()> {
94 assert!(!range.is_empty());
95 self.pos = range.start;
96 self.stats.req_count += 1;
97 trace!(
98 "set_range {range:?}, request #{req_count}",
99 req_count = self.stats.req_count
100 );
101 self.reader = self.client.get_byte_range(range.clone()).await?;
102 self.range = Some(HttpRange::Range(range));
103
104 Ok(())
105 }
106
107 pub async fn fast_forward(&mut self, to_pos: u64) -> Result<()> {
111 assert!(to_pos >= self.pos, "can't rewind");
112
113 let len = to_pos - self.pos;
114 if len == 0 {
115 return Ok(());
116 }
117 self.stats.wasted_bytes += len;
118
119 let mut ff_reader = empty();
120 std::mem::swap(&mut ff_reader, &mut self.reader);
121 let mut ff_reader = ff_reader.take(len);
122 asyncio::copy(&mut ff_reader, &mut asyncio::sink()).await?;
123 let reader = ff_reader.into_inner();
124 self.pos += len;
125 assert_eq!(self.pos, to_pos);
126
127 self.reader = reader;
128 Ok(())
129 }
130
131 pub async fn seek_to_range(&mut self, range: impl Into<HttpRange>) -> Result<()> {
133 let Some(HttpRange::Range(existing_range)) = &mut self.range else {
134 panic!("can only fast forward from double ended range");
136 };
137 let range = range.into();
138 trace!("seek_to_range: {range:?}");
139 assert!(range.start() >= self.pos, "can't rewind");
140 match range {
141 HttpRange::Range(range) => {
142 if range.start == self.pos {
143 if range.end <= existing_range.end {
144 trace!("Already at requested byte position and already have the requested data. No new request will be made.");
145 Ok(())
146 } else {
147 self.append_contiguous_range(range).await
148 }
149 } else if range.end <= existing_range.end {
150 trace!("Fast forwarding to the requested byte position but already have the requested data. No new request will be made.");
151 self.fast_forward(range.start).await
152 } else if range.start > existing_range.end {
153 self.set_range(range).await
154 } else {
155 assert!(range.start > self.pos);
156 assert!(
157 range.end > existing_range.end,
158 "failed: {range_end}, > {existing_range_end}",
159 range_end = range.end,
160 existing_range_end = existing_range.end
161 );
162 self.fast_forward(range.start).await?;
163 self.append_contiguous_range(range).await
164 }
165 }
166 HttpRange::RangeFrom(range) => {
167 if range.start == self.pos {
168 trace!("nothing to do");
169 Ok(())
170 } else {
171 self.extend_to_end().await?;
174 self.fast_forward(range.start).await
175 }
176 }
177 }
178 }
179
180 pub async fn extend_to_end(&mut self) -> Result<()> {
186 debug!("extending to end");
187 let Some(HttpRange::Range(prev_range)) = &self.range else {
188 panic!("must call set_range before you can extendToRange");
189 };
190
191 self.stats.req_count += 1;
192 trace!(
193 "extend_to_end from {prev_range:?}, request #{req_count}",
194 req_count = self.stats.req_count
195 );
196 let reader = self.client.get_byte_range_from(prev_range.end..).await?;
197
198 let mut tmp = empty();
199 std::mem::swap(&mut self.reader, &mut tmp);
200 self.reader = Box::pin(tmp.chain(reader));
201
202 let new_range = prev_range.start..;
203 self.range = Some(HttpRange::RangeFrom(new_range));
204
205 Ok(())
206 }
207
208 pub async fn append_contiguous_range(&mut self, extension: Range<u64>) -> Result<()> {
215 let Some(range) = &self.range else {
216 panic!("must call set_range before you can extend a range");
217 };
218
219 let HttpRange::Range(prev_range) = range else {
220 panic!("cannot extend an already open-ended range");
221 };
222
223 assert!(
224 prev_range.end >= extension.start,
225 "new range must be contiguous with old range"
226 );
227
228 if prev_range.end >= extension.end {
229 debug!(
230 "skipping extension {extension:?} which is within existing range: {prev_range:?}"
231 );
232 return Ok(());
233 }
234
235 self.stats.req_count += 1;
236 let uncovered_range = prev_range.end..extension.end;
237 trace!("append_contiguous_range {extension:?}, previously uncovered_range: {uncovered_range:?}. request #{req_count}", req_count=self.stats.req_count);
238 let reader = self.client.get_byte_range(uncovered_range.clone()).await?;
239
240 let mut tmp = empty();
241 std::mem::swap(&mut self.reader, &mut tmp);
242 self.reader = Box::pin(tmp.chain(reader));
243 let new_range = prev_range.start..extension.end;
244 self.range = Some(HttpRange::Range(new_range));
245
246 Ok(())
247 }
248
249 pub fn split_off(&mut self) -> Self {
251 let Some(range) = &mut self.range else {
252 panic!("must set_range before splitting off");
253 };
254
255 let after = range.split(self.pos);
256 assert_eq!(range.end(), Some(self.pos));
257
258 let mut old_reader = empty();
259 std::mem::swap(&mut self.reader, &mut old_reader);
260
261 Self {
262 client: self.client.boxed_clone(),
263 reader: old_reader,
264 pos: self.pos,
265 range: Some(after),
266 stats: ReqStats::default(),
267 }
268 }
269
270 pub fn contains(&self, range: &HttpRange) -> bool {
275 let Some(current_range) = &self.range else {
276 return false;
277 };
278 if current_range.start() >= range.start() {
279 warn!("rewinding?");
280 return false;
281 }
282 let Some(current_end) = current_range.end() else {
283 return true;
284 };
285
286 let Some(range_end) = range.end() else {
287 return false;
288 };
289
290 current_end >= range_end
291 }
292}
293
294impl Drop for HttpClient {
295 fn drop(&mut self) {
296 debug!("Finished using an HTTP client. used_bytes={used_bytes}, wasted_bytes={wasted_bytes}, req_count={req_count}", used_bytes=self.stats.used_bytes, wasted_bytes=self.stats.wasted_bytes, req_count=self.stats.req_count)
297 }
298}
299
300#[cfg(target_arch = "wasm32")]
301impl AsyncRead for HttpClient {
302 fn poll_read(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context<'_>,
305 buf: &mut [u8],
306 ) -> Poll<std::io::Result<usize>> {
307 assert!(
308 self.range.is_some(),
309 "must call set_range (and await) before attempting read"
310 );
311
312 let result = self.reader.as_mut().poll_read(cx, buf);
313 let mut length = 0;
314 if let Poll::Ready(Ok(successful_read)) = result {
315 length = successful_read;
316 self.pos += length as u64;
317 self.stats.used_bytes += length as u64;
318 }
319 trace!("read {length} bytes. New pos={pos}", pos = self.pos);
320
321 result
322 }
323}
324
325#[cfg(not(target_arch = "wasm32"))]
326impl AsyncRead for HttpClient {
327 fn poll_read(
328 mut self: Pin<&mut Self>,
329 cx: &mut Context<'_>,
330 buf: &mut ReadBuf<'_>,
331 ) -> Poll<std::io::Result<()>> {
332 assert!(
333 self.range.is_some(),
334 "must call set_range (and await) before attempting read"
335 );
336
337 let len_before = buf.filled().len();
338 let result = self.reader.as_mut().poll_read(cx, buf);
339
340 let distance = buf.filled().len() - len_before;
341 self.pos += distance as u64;
342 self.stats.used_bytes += distance as u64;
343 trace!("read {distance} bytes. New pos={pos}", pos = self.pos);
344
345 result
346 }
347}
348
349#[async_trait(?Send)]
350trait ReaderSource: Sync + Send + std::fmt::Debug {
351 async fn get_byte_range(&self, range: Range<u64>) -> Result<Reader>;
352
353 async fn get_byte_range_from(&self, range: RangeFrom<u64>) -> Result<Reader>;
354
355 fn boxed_clone(&self) -> Box<dyn ReaderSource>;
356}
357
358#[derive(Debug, Clone)]
359struct ReqwestClient {
360 client: reqwest::Client,
361 url: String,
362}
363
364impl ReqwestClient {
365 fn new(url: &str) -> Self {
366 Self {
367 client: reqwest::Client::new(),
368 url: url.to_string(),
369 }
370 }
371
372 async fn get_byte_range_with_header(&self, range_header: &str) -> Result<Reader> {
373 debug!("getting range: {range_header}");
374
375 let response = self
376 .client
377 .get(&self.url)
378 .header(reqwest::header::RANGE, range_header)
379 .send()
380 .await
381 .map_err(|e| Error::External(Box::new(e)))?;
382
383 let status = response.status();
384 match response.headers().get("Content-Length") {
385 Some(content_length) => debug!("content length: {content_length:?}"),
386 None => debug!("Response lacks a content length header"),
387 }
388
389 if !status.is_success() {
390 return Err(Error::HttpFailed {
391 status: status.as_u16(),
392 });
393 }
394
395 #[cfg(target_arch = "wasm32")]
396 {
397 let bytes_stream = response
398 .bytes_stream()
399 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
400
401 let reader = wasm_reader::WasmReader::new(Box::new(bytes_stream));
402 Ok(Box::pin(reader))
403 }
404 #[cfg(not(target_arch = "wasm32"))]
405 {
406 use tokio_util::io::StreamReader;
407 let bytes_stream = response
408 .bytes_stream()
409 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
410 Ok(Box::pin(StreamReader::new(bytes_stream)))
411 }
412 }
413}
414
415#[async_trait(?Send)]
416impl ReaderSource for ReqwestClient {
417 async fn get_byte_range(&self, range: Range<u64>) -> Result<Reader> {
418 let range_header = format!("bytes={}-{}", range.start, (range.end - 1));
419 self.get_byte_range_with_header(&range_header).await
420 }
421
422 async fn get_byte_range_from(&self, range: RangeFrom<u64>) -> Result<Reader> {
423 let range_header = format!("bytes={}-", range.start);
424 self.get_byte_range_with_header(&range_header).await
425 }
426
427 fn boxed_clone(&self) -> Box<dyn ReaderSource> {
428 Box::new(self.clone())
429 }
430}
431
432#[cfg(not(target_arch = "wasm32"))]
433type Reader = Pin<Box<dyn AsyncRead + Sync + Send>>;
434#[cfg(target_arch = "wasm32")]
435type Reader = Pin<Box<dyn AsyncRead>>;
436
437pub(crate) fn empty() -> Reader {
438 Box::pin(EmptyReader)
439}
440
441struct EmptyReader;
442#[cfg(not(target_arch = "wasm32"))]
443impl AsyncRead for EmptyReader {
444 fn poll_read(
445 self: Pin<&mut Self>,
446 _cx: &mut Context<'_>,
447 _buf: &mut ReadBuf<'_>,
448 ) -> Poll<std::io::Result<()>> {
449 Poll::Ready(Ok(()))
450 }
451}
452#[cfg(target_arch = "wasm32")]
453impl AsyncRead for EmptyReader {
454 fn poll_read(
455 self: Pin<&mut Self>,
456 _cx: &mut Context<'_>,
457 _buf: &mut [u8],
458 ) -> Poll<std::io::Result<usize>> {
459 Poll::Ready(Ok(0))
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use tokio::io::AsyncReadExt;
467
468 #[tokio::test]
469 async fn single_reader() {
470 ensure_logging();
471
472 let input = (0..4).collect::<Vec<u8>>();
473 let mut reader = HttpClient::test_client(&input);
474 reader.set_range(0..4).await.unwrap();
475
476 let mut output = vec![];
477 reader.read_to_end(&mut output).await.unwrap();
478 assert_eq!(output, input);
479 }
480
481 #[tokio::test]
482 async fn empty_reader() {
483 ensure_logging();
484
485 let input = (0..4).collect::<Vec<u8>>();
486 let mut reader = HttpClient::test_client(&input);
487 reader.set_range(0..4).await.unwrap();
488 let mut output = vec![];
489 reader.read_to_end(&mut output).await.unwrap();
490 assert_eq!(output, input);
491
492 let mut remainder = Vec::<u8>::new();
493 reader.read_to_end(&mut remainder).await.unwrap();
494 assert!(remainder.is_empty());
495 }
496
497 #[tokio::test]
498 async fn extend_range() {
499 ensure_logging();
500
501 let input = (0..7).collect::<Vec<u8>>();
502 let mut reader = HttpClient::test_client(&input);
503 reader.set_range(0..3).await.unwrap();
504
505 let mut output = vec![];
506 reader.read_to_end(&mut output).await.unwrap();
507 assert_eq!(output, vec![0, 1, 2]);
508
509 reader.append_contiguous_range(3..6).await.unwrap();
510 reader.read_to_end(&mut output).await.unwrap();
511 assert_eq!(output, vec![0, 1, 2, 3, 4, 5]);
512 }
513
514 #[tokio::test]
515 async fn read_le_u3() {
516 let input: [u8; 4] = [140, 1, 0, 0];
517 let mut reader = HttpClient::test_client(&input);
518 reader.set_range(0..4).await.unwrap();
519 let result = reader.read_u32_le().await.unwrap();
520 assert_eq!(result, 396);
521 }
522
523 #[tokio::test]
524 async fn split_off() {
525 let input = (0..8).collect::<Vec<u8>>();
526 let mut parent_reader = HttpClient::test_client(&input);
527 parent_reader.set_range(0..7).await.unwrap();
528
529 let mut output = [0; 4];
530 parent_reader.read_exact(&mut output).await.unwrap();
531 assert_eq!(output, [0, 1, 2, 3]);
532
533 let mut child_reader = parent_reader.split_off();
534
535 let mut remainder = vec![];
536 parent_reader.read_to_end(&mut remainder).await.unwrap();
537 assert!(remainder.is_empty());
538
539 let mut output = [0; 4];
540 child_reader.append_contiguous_range(7..8).await.unwrap();
541 child_reader.read_exact(&mut output).await.unwrap();
542 assert_eq!(output, [4, 5, 6, 7]);
543 }
544
545 #[tokio::test]
546 async fn extend_to_end() {
547 let input = (0..8).collect::<Vec<u8>>();
548 let mut reader = HttpClient::test_client(&input);
549
550 reader.set_range(4..5).await.unwrap();
551 reader.extend_to_end().await.unwrap();
552
553 let mut output = vec![];
554 reader.read_to_end(&mut output).await.unwrap();
555
556 assert_eq!(output, [4, 5, 6, 7])
557 }
558
559 #[tokio::test]
560 async fn fast_forward() {
561 let input = (0..8).collect::<Vec<u8>>();
562 let mut reader = HttpClient::test_client(&input);
563
564 reader.set_range(2..7).await.unwrap();
565 reader.fast_forward(3).await.unwrap();
566 let next = reader.read_u8().await.unwrap();
567 assert_eq!(next, 3);
568 }
569
570 #[should_panic]
571 #[tokio::test]
572 async fn fast_forward_too_far() {
573 let input = (0..8).collect::<Vec<u8>>();
574 let mut reader = HttpClient::test_client(&input);
575
576 reader.set_range(2..7).await.unwrap();
577 reader.fast_forward(3).await.unwrap();
578 let next = reader.read_u8().await.unwrap();
579 assert_eq!(next, 3);
580
581 reader.fast_forward(2).await.unwrap();
583 reader.fast_forward(3).await.unwrap();
584 }
585
586 #[cfg(test)]
587 fn ensure_logging() {
588 static ONCE: std::sync::Once = std::sync::Once::new();
589 ONCE.call_once(|| env_logger::builder().format_timestamp_millis().init());
590 }
591}