streaming_http_range_client/
lib.rs

1mod error;
2mod http_range;
3#[cfg(not(target_arch = "wasm32"))]
4mod test_client;
5
6#[cfg(target_arch = "wasm32")]
7mod wasm_reader;
8
9use futures_util::TryStreamExt;
10use std::ops::{Range, RangeFrom};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13#[cfg(not(target_arch = "wasm32"))]
14use tokio::io::ReadBuf;
15
16#[cfg(target_arch = "wasm32")]
17use futures_util::io as asyncio;
18
19#[cfg(not(target_arch = "wasm32"))]
20use tokio::io as asyncio;
21
22use asyncio::{AsyncRead, AsyncReadExt};
23
24#[macro_use]
25extern crate log;
26
27pub use error::{Error, Result};
28pub use http_range::HttpRange;
29
30use async_trait::async_trait;
31
32/// A stream centric HTTP client.
33///
34/// ```
35/// use streaming_http_range_client::HttpClient;
36/// use tokio::io::AsyncReadExt;
37///
38/// use tokio;
39/// # tokio_test::block_on(async {
40/// let mut new_client = HttpClient::new("https://georust.org");
41/// new_client.set_range(2..14).await.unwrap();
42///
43/// let mut output = String::new();
44/// new_client.read_to_string(&mut output).await.unwrap();
45///
46/// // This `expected_text` may need to be updated someday if someone updates the site.
47/// let expected_text = "DOCTYPE html";
48/// assert_eq!(expected_text, output)
49/// # });
50///
51/// ```
52pub struct HttpClient {
53    client: Box<dyn ReaderSource>,
54    reader: Reader,
55    range: Option<HttpRange>,
56    pos: u64,
57    stats: ReqStats,
58}
59
60impl std::fmt::Debug for HttpClient {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("HttpClient")
63            .field("client", &self.client)
64            // .field("reader", &self.reader)
65            .field("range", &self.range)
66            .field("pos", &self.pos)
67            .field("stats", &self.stats)
68            .finish()
69    }
70}
71
72#[derive(Debug, Default)]
73struct ReqStats {
74    wasted_bytes: u64,
75    used_bytes: u64,
76    req_count: usize,
77}
78
79impl HttpClient {
80    /// Create a new client. To get data from the client, call [`set_range`] and start reading
81    /// from the client.
82    pub fn new(url: &str) -> Self {
83        Self {
84            client: Box::new(ReqwestClient::new(url)),
85            reader: empty(),
86            pos: 0,
87            range: None,
88            stats: ReqStats::default(),
89        }
90    }
91
92    //
93    pub async fn set_range(&mut self, range: Range<u64>) -> Result<()> {
94        assert!(!range.is_empty());
95        self.pos = range.start;
96        self.stats.req_count += 1;
97        trace!(
98            "set_range {range:?}, request #{req_count}",
99            req_count = self.stats.req_count
100        );
101        self.reader = self.client.get_byte_range(range.clone()).await?;
102        self.range = Some(HttpRange::Range(range));
103
104        Ok(())
105    }
106
107    /// Advance client to `to_pos`, discarding any intermediate data, without fetching further data.
108    ///
109    /// `to_pos` must be within the current range and must not occur before the current position.
110    pub async fn fast_forward(&mut self, to_pos: u64) -> Result<()> {
111        assert!(to_pos >= self.pos, "can't rewind");
112
113        let len = to_pos - self.pos;
114        if len == 0 {
115            return Ok(());
116        }
117        self.stats.wasted_bytes += len;
118
119        let mut ff_reader = empty();
120        std::mem::swap(&mut ff_reader, &mut self.reader);
121        let mut ff_reader = ff_reader.take(len);
122        asyncio::copy(&mut ff_reader, &mut asyncio::sink()).await?;
123        let reader = ff_reader.into_inner();
124        self.pos += len;
125        assert_eq!(self.pos, to_pos);
126
127        self.reader = reader;
128        Ok(())
129    }
130
131    /// Fast forwards to the beginning of range, fetching additional data if necessary.
132    pub async fn seek_to_range(&mut self, range: impl Into<HttpRange>) -> Result<()> {
133        let Some(HttpRange::Range(existing_range)) = &mut self.range else {
134            // TODO: Is there a reason *not* to support None or Some(HttpRange::RangeFrom)?
135            panic!("can only fast forward from double ended range");
136        };
137        let range = range.into();
138        trace!("seek_to_range: {range:?}");
139        assert!(range.start() >= self.pos, "can't rewind");
140        match range {
141            HttpRange::Range(range) => {
142                if range.start == self.pos {
143                    if range.end <= existing_range.end {
144                        trace!("Already at requested byte position and already have the requested data. No new request will be made.");
145                        Ok(())
146                    } else {
147                        self.append_contiguous_range(range).await
148                    }
149                } else if range.end <= existing_range.end {
150                    trace!("Fast forwarding to the requested byte position but already have the requested data. No new request will be made.");
151                    self.fast_forward(range.start).await
152                } else if range.start > existing_range.end {
153                    self.set_range(range).await
154                } else {
155                    assert!(range.start > self.pos);
156                    assert!(
157                        range.end > existing_range.end,
158                        "failed: {range_end}, > {existing_range_end}",
159                        range_end = range.end,
160                        existing_range_end = existing_range.end
161                    );
162                    self.fast_forward(range.start).await?;
163                    self.append_contiguous_range(range).await
164                }
165            }
166            HttpRange::RangeFrom(range) => {
167                if range.start == self.pos {
168                    trace!("nothing to do");
169                    Ok(())
170                } else {
171                    // TODO optimize for skipping over a lot of middle content
172                    // e.g. self.set_range(range)
173                    self.extend_to_end().await?;
174                    self.fast_forward(range.start).await
175                }
176            }
177        }
178    }
179
180    /// Fetch all the bytes to the end of the file.
181    ///
182    /// Panics when the current range is:
183    ///  - not set (see [`set_range`])
184    ///  - an open ended (RangeFrom) range, since it can't be extended
185    pub async fn extend_to_end(&mut self) -> Result<()> {
186        debug!("extending to end");
187        let Some(HttpRange::Range(prev_range)) = &self.range else {
188            panic!("must call set_range before you can extendToRange");
189        };
190
191        self.stats.req_count += 1;
192        trace!(
193            "extend_to_end from {prev_range:?}, request #{req_count}",
194            req_count = self.stats.req_count
195        );
196        let reader = self.client.get_byte_range_from(prev_range.end..).await?;
197
198        let mut tmp = empty();
199        std::mem::swap(&mut self.reader, &mut tmp);
200        self.reader = Box::pin(tmp.chain(reader));
201
202        let new_range = prev_range.start..;
203        self.range = Some(HttpRange::RangeFrom(new_range));
204
205        Ok(())
206    }
207
208    /// Append a contiguous extension to the current range.
209    ///
210    /// Panics when the current range is:
211    ///  - not set (see [`set_range`])
212    ///  - an open ended (RangeFrom) range, since it can't be extended
213    ///  - not contiguous with the extension
214    pub async fn append_contiguous_range(&mut self, extension: Range<u64>) -> Result<()> {
215        let Some(range) = &self.range else {
216            panic!("must call set_range before you can extend a range");
217        };
218
219        let HttpRange::Range(prev_range) = range else {
220            panic!("cannot extend an already open-ended range");
221        };
222
223        assert!(
224            prev_range.end >= extension.start,
225            "new range must be contiguous with old range"
226        );
227
228        if prev_range.end >= extension.end {
229            debug!(
230                "skipping extension {extension:?} which is within existing range: {prev_range:?}"
231            );
232            return Ok(());
233        }
234
235        self.stats.req_count += 1;
236        let uncovered_range = prev_range.end..extension.end;
237        trace!("append_contiguous_range {extension:?}, previously uncovered_range: {uncovered_range:?}. request #{req_count}", req_count=self.stats.req_count);
238        let reader = self.client.get_byte_range(uncovered_range.clone()).await?;
239
240        let mut tmp = empty();
241        std::mem::swap(&mut self.reader, &mut tmp);
242        self.reader = Box::pin(tmp.chain(reader));
243        let new_range = prev_range.start..extension.end;
244        self.range = Some(HttpRange::Range(new_range));
245
246        Ok(())
247    }
248
249    /// Move all the unread data from this client into a new instance.
250    pub fn split_off(&mut self) -> Self {
251        let Some(range) = &mut self.range else {
252            panic!("must set_range before splitting off");
253        };
254
255        let after = range.split(self.pos);
256        assert_eq!(range.end(), Some(self.pos));
257
258        let mut old_reader = empty();
259        std::mem::swap(&mut self.reader, &mut old_reader);
260
261        Self {
262            client: self.client.boxed_clone(),
263            reader: old_reader,
264            pos: self.pos,
265            range: Some(after),
266            stats: ReqStats::default(),
267        }
268    }
269
270    /// Does the client already encompass the given range?
271    ///
272    /// Caveat: If the client's current Range has an explicit end, we assume any given open ended
273    /// `(123..)` Range goes beyond the current Range.
274    pub fn contains(&self, range: &HttpRange) -> bool {
275        let Some(current_range) = &self.range else {
276            return false;
277        };
278        if current_range.start() >= range.start() {
279            warn!("rewinding?");
280            return false;
281        }
282        let Some(current_end) = current_range.end() else {
283            return true;
284        };
285
286        let Some(range_end) = range.end() else {
287            return false;
288        };
289
290        current_end >= range_end
291    }
292}
293
294impl Drop for HttpClient {
295    fn drop(&mut self) {
296        debug!("Finished using an HTTP client. used_bytes={used_bytes}, wasted_bytes={wasted_bytes}, req_count={req_count}", used_bytes=self.stats.used_bytes, wasted_bytes=self.stats.wasted_bytes, req_count=self.stats.req_count)
297    }
298}
299
300#[cfg(target_arch = "wasm32")]
301impl AsyncRead for HttpClient {
302    fn poll_read(
303        mut self: Pin<&mut Self>,
304        cx: &mut Context<'_>,
305        buf: &mut [u8],
306    ) -> Poll<std::io::Result<usize>> {
307        assert!(
308            self.range.is_some(),
309            "must call set_range (and await) before attempting read"
310        );
311
312        let result = self.reader.as_mut().poll_read(cx, buf);
313        let mut length = 0;
314        if let Poll::Ready(Ok(successful_read)) = result {
315            length = successful_read;
316            self.pos += length as u64;
317            self.stats.used_bytes += length as u64;
318        }
319        trace!("read {length} bytes. New pos={pos}", pos = self.pos);
320
321        result
322    }
323}
324
325#[cfg(not(target_arch = "wasm32"))]
326impl AsyncRead for HttpClient {
327    fn poll_read(
328        mut self: Pin<&mut Self>,
329        cx: &mut Context<'_>,
330        buf: &mut ReadBuf<'_>,
331    ) -> Poll<std::io::Result<()>> {
332        assert!(
333            self.range.is_some(),
334            "must call set_range (and await) before attempting read"
335        );
336
337        let len_before = buf.filled().len();
338        let result = self.reader.as_mut().poll_read(cx, buf);
339
340        let distance = buf.filled().len() - len_before;
341        self.pos += distance as u64;
342        self.stats.used_bytes += distance as u64;
343        trace!("read {distance} bytes. New pos={pos}", pos = self.pos);
344
345        result
346    }
347}
348
349#[async_trait(?Send)]
350trait ReaderSource: Sync + Send + std::fmt::Debug {
351    async fn get_byte_range(&self, range: Range<u64>) -> Result<Reader>;
352
353    async fn get_byte_range_from(&self, range: RangeFrom<u64>) -> Result<Reader>;
354
355    fn boxed_clone(&self) -> Box<dyn ReaderSource>;
356}
357
358#[derive(Debug, Clone)]
359struct ReqwestClient {
360    client: reqwest::Client,
361    url: String,
362}
363
364impl ReqwestClient {
365    fn new(url: &str) -> Self {
366        Self {
367            client: reqwest::Client::new(),
368            url: url.to_string(),
369        }
370    }
371
372    async fn get_byte_range_with_header(&self, range_header: &str) -> Result<Reader> {
373        debug!("getting range: {range_header}");
374
375        let response = self
376            .client
377            .get(&self.url)
378            .header(reqwest::header::RANGE, range_header)
379            .send()
380            .await
381            .map_err(|e| Error::External(Box::new(e)))?;
382
383        let status = response.status();
384        match response.headers().get("Content-Length") {
385            Some(content_length) => debug!("content length: {content_length:?}"),
386            None => debug!("Response lacks a content length header"),
387        }
388
389        if !status.is_success() {
390            return Err(Error::HttpFailed {
391                status: status.as_u16(),
392            });
393        }
394
395        #[cfg(target_arch = "wasm32")]
396        {
397            let bytes_stream = response
398                .bytes_stream()
399                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
400
401            let reader = wasm_reader::WasmReader::new(Box::new(bytes_stream));
402            Ok(Box::pin(reader))
403        }
404        #[cfg(not(target_arch = "wasm32"))]
405        {
406            use tokio_util::io::StreamReader;
407            let bytes_stream = response
408                .bytes_stream()
409                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
410            Ok(Box::pin(StreamReader::new(bytes_stream)))
411        }
412    }
413}
414
415#[async_trait(?Send)]
416impl ReaderSource for ReqwestClient {
417    async fn get_byte_range(&self, range: Range<u64>) -> Result<Reader> {
418        let range_header = format!("bytes={}-{}", range.start, (range.end - 1));
419        self.get_byte_range_with_header(&range_header).await
420    }
421
422    async fn get_byte_range_from(&self, range: RangeFrom<u64>) -> Result<Reader> {
423        let range_header = format!("bytes={}-", range.start);
424        self.get_byte_range_with_header(&range_header).await
425    }
426
427    fn boxed_clone(&self) -> Box<dyn ReaderSource> {
428        Box::new(self.clone())
429    }
430}
431
432#[cfg(not(target_arch = "wasm32"))]
433type Reader = Pin<Box<dyn AsyncRead + Sync + Send>>;
434#[cfg(target_arch = "wasm32")]
435type Reader = Pin<Box<dyn AsyncRead>>;
436
437pub(crate) fn empty() -> Reader {
438    Box::pin(EmptyReader)
439}
440
441struct EmptyReader;
442#[cfg(not(target_arch = "wasm32"))]
443impl AsyncRead for EmptyReader {
444    fn poll_read(
445        self: Pin<&mut Self>,
446        _cx: &mut Context<'_>,
447        _buf: &mut ReadBuf<'_>,
448    ) -> Poll<std::io::Result<()>> {
449        Poll::Ready(Ok(()))
450    }
451}
452#[cfg(target_arch = "wasm32")]
453impl AsyncRead for EmptyReader {
454    fn poll_read(
455        self: Pin<&mut Self>,
456        _cx: &mut Context<'_>,
457        _buf: &mut [u8],
458    ) -> Poll<std::io::Result<usize>> {
459        Poll::Ready(Ok(0))
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use tokio::io::AsyncReadExt;
467
468    #[tokio::test]
469    async fn single_reader() {
470        ensure_logging();
471
472        let input = (0..4).collect::<Vec<u8>>();
473        let mut reader = HttpClient::test_client(&input);
474        reader.set_range(0..4).await.unwrap();
475
476        let mut output = vec![];
477        reader.read_to_end(&mut output).await.unwrap();
478        assert_eq!(output, input);
479    }
480
481    #[tokio::test]
482    async fn empty_reader() {
483        ensure_logging();
484
485        let input = (0..4).collect::<Vec<u8>>();
486        let mut reader = HttpClient::test_client(&input);
487        reader.set_range(0..4).await.unwrap();
488        let mut output = vec![];
489        reader.read_to_end(&mut output).await.unwrap();
490        assert_eq!(output, input);
491
492        let mut remainder = Vec::<u8>::new();
493        reader.read_to_end(&mut remainder).await.unwrap();
494        assert!(remainder.is_empty());
495    }
496
497    #[tokio::test]
498    async fn extend_range() {
499        ensure_logging();
500
501        let input = (0..7).collect::<Vec<u8>>();
502        let mut reader = HttpClient::test_client(&input);
503        reader.set_range(0..3).await.unwrap();
504
505        let mut output = vec![];
506        reader.read_to_end(&mut output).await.unwrap();
507        assert_eq!(output, vec![0, 1, 2]);
508
509        reader.append_contiguous_range(3..6).await.unwrap();
510        reader.read_to_end(&mut output).await.unwrap();
511        assert_eq!(output, vec![0, 1, 2, 3, 4, 5]);
512    }
513
514    #[tokio::test]
515    async fn read_le_u3() {
516        let input: [u8; 4] = [140, 1, 0, 0];
517        let mut reader = HttpClient::test_client(&input);
518        reader.set_range(0..4).await.unwrap();
519        let result = reader.read_u32_le().await.unwrap();
520        assert_eq!(result, 396);
521    }
522
523    #[tokio::test]
524    async fn split_off() {
525        let input = (0..8).collect::<Vec<u8>>();
526        let mut parent_reader = HttpClient::test_client(&input);
527        parent_reader.set_range(0..7).await.unwrap();
528
529        let mut output = [0; 4];
530        parent_reader.read_exact(&mut output).await.unwrap();
531        assert_eq!(output, [0, 1, 2, 3]);
532
533        let mut child_reader = parent_reader.split_off();
534
535        let mut remainder = vec![];
536        parent_reader.read_to_end(&mut remainder).await.unwrap();
537        assert!(remainder.is_empty());
538
539        let mut output = [0; 4];
540        child_reader.append_contiguous_range(7..8).await.unwrap();
541        child_reader.read_exact(&mut output).await.unwrap();
542        assert_eq!(output, [4, 5, 6, 7]);
543    }
544
545    #[tokio::test]
546    async fn extend_to_end() {
547        let input = (0..8).collect::<Vec<u8>>();
548        let mut reader = HttpClient::test_client(&input);
549
550        reader.set_range(4..5).await.unwrap();
551        reader.extend_to_end().await.unwrap();
552
553        let mut output = vec![];
554        reader.read_to_end(&mut output).await.unwrap();
555
556        assert_eq!(output, [4, 5, 6, 7])
557    }
558
559    #[tokio::test]
560    async fn fast_forward() {
561        let input = (0..8).collect::<Vec<u8>>();
562        let mut reader = HttpClient::test_client(&input);
563
564        reader.set_range(2..7).await.unwrap();
565        reader.fast_forward(3).await.unwrap();
566        let next = reader.read_u8().await.unwrap();
567        assert_eq!(next, 3);
568    }
569
570    #[should_panic]
571    #[tokio::test]
572    async fn fast_forward_too_far() {
573        let input = (0..8).collect::<Vec<u8>>();
574        let mut reader = HttpClient::test_client(&input);
575
576        reader.set_range(2..7).await.unwrap();
577        reader.fast_forward(3).await.unwrap();
578        let next = reader.read_u8().await.unwrap();
579        assert_eq!(next, 3);
580
581        // panics
582        reader.fast_forward(2).await.unwrap();
583        reader.fast_forward(3).await.unwrap();
584    }
585
586    #[cfg(test)]
587    fn ensure_logging() {
588        static ONCE: std::sync::Once = std::sync::Once::new();
589        ONCE.call_once(|| env_logger::builder().format_timestamp_millis().init());
590    }
591}