tower_api_client/
pagination.rs

1//! Constructs for wrapping a paginated API.
2use crate::request::Request;
3use futures::{ready, Stream};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tower::Service;
9
10pub trait PaginatedRequest: Request + Clone {
11    type PaginationData;
12    fn get_page(&self) -> Option<Self::PaginationData>;
13    fn next_page(
14        &self,
15        prev_page: Option<&Self::PaginationData>,
16        response: &Self::Response,
17    ) -> Option<Self::PaginationData>;
18    fn update_request(&mut self, page: &Self::PaginationData);
19}
20
21pin_project! {
22    pub struct PaginationStream<Svc: Service<R>, T, R> {
23        state: State<T>,
24        svc: Svc,
25        future: Option<Pin<Box<Svc::Future>>>,
26        request: R,
27    }
28}
29
30impl<Svc: Service<R>, T, R: PaginatedRequest<PaginationData = T>> PaginationStream<Svc, T, R> {
31    pub(crate) fn new(svc: Svc, request: R) -> Self {
32        let page = request.get_page();
33        Self {
34            state: State::Start(page),
35            svc,
36            future: None,
37            request,
38        }
39    }
40}
41
42impl<Svc, T, R> Stream for PaginationStream<Svc, T, R>
43where
44    T: Clone + std::fmt::Debug,
45    Svc: Service<R, Response = R::Response>,
46    Svc::Future: Unpin,
47    R: PaginatedRequest<PaginationData = T>,
48{
49    type Item = Result<Svc::Response, Svc::Error>;
50
51    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52        let this = self.project();
53        let mut page = match this.state {
54            State::Start(None) => None,
55            State::Start(Some(state)) | State::Next(state) => Some(state.clone()),
56            State::End => {
57                return Poll::Ready(None);
58            }
59        };
60
61        loop {
62            match this.future {
63                Some(fut) => {
64                    let response = ready!(fut.as_mut().poll(cx));
65                    // The future has completed, so we replace it with none to make sure it doesn't
66                    // get polled again
67                    *this.future = None;
68                    let response = response?;
69                    page = this.request.next_page(page.as_ref(), &response);
70                    if let Some(page) = page {
71                        *this.state = State::Next(page)
72                    } else {
73                        *this.state = State::End
74                    }
75
76                    return Poll::Ready(Some(Ok(response)));
77                }
78                None => {
79                    if let Err(e) = ready!(this.svc.poll_ready(cx)) {
80                        return Poll::Ready(Some(Err(e)));
81                    }
82
83                    if let Some(page) = page.as_ref() {
84                        this.request.update_request(page);
85                    }
86
87                    *this.future = Some(Box::pin(this.svc.call(this.request.clone())));
88                }
89            }
90        }
91    }
92}
93
94#[derive(Clone, Debug)]
95/// The current pagination state.
96pub enum State<T> {
97    /// State associated with the initial request.
98    Start(Option<T>),
99    /// State associated with continuing pagination.
100    Next(T),
101    /// State denoting that the last page has been reached.
102    End,
103}
104
105impl<T> Default for State<T> {
106    fn default() -> State<T> {
107        State::Start(None)
108    }
109}