scrappy_service/
and_then_apply_fn.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::rc::Rc;
5use std::task::{Context, Poll};
6
7use crate::cell::Cell;
8use crate::{Service, ServiceFactory};
9
10/// `Apply` service combinator
11pub(crate) struct AndThenApplyFn<A, B, F, Fut, Res, Err>
12where
13    A: Service,
14    B: Service,
15    F: FnMut(A::Response, &mut B) -> Fut,
16    Fut: Future<Output = Result<Res, Err>>,
17    Err: From<A::Error> + From<B::Error>,
18{
19    srv: Cell<(A, B, F)>,
20    r: PhantomData<(Fut, Res, Err)>,
21}
22
23impl<A, B, F, Fut, Res, Err> AndThenApplyFn<A, B, F, Fut, Res, Err>
24where
25    A: Service,
26    B: Service,
27    F: FnMut(A::Response, &mut B) -> Fut,
28    Fut: Future<Output = Result<Res, Err>>,
29    Err: From<A::Error> + From<B::Error>,
30{
31    /// Create new `Apply` combinator
32    pub(crate) fn new(a: A, b: B, f: F) -> Self {
33        Self {
34            srv: Cell::new((a, b, f)),
35            r: PhantomData,
36        }
37    }
38}
39
40impl<A, B, F, Fut, Res, Err> Clone for AndThenApplyFn<A, B, F, Fut, Res, Err>
41where
42    A: Service,
43    B: Service,
44    F: FnMut(A::Response, &mut B) -> Fut,
45    Fut: Future<Output = Result<Res, Err>>,
46    Err: From<A::Error> + From<B::Error>,
47{
48    fn clone(&self) -> Self {
49        AndThenApplyFn {
50            srv: self.srv.clone(),
51            r: PhantomData,
52        }
53    }
54}
55
56impl<A, B, F, Fut, Res, Err> Service for AndThenApplyFn<A, B, F, Fut, Res, Err>
57where
58    A: Service,
59    B: Service,
60    F: FnMut(A::Response, &mut B) -> Fut,
61    Fut: Future<Output = Result<Res, Err>>,
62    Err: From<A::Error> + From<B::Error>,
63{
64    type Request = A::Request;
65    type Response = Res;
66    type Error = Err;
67    type Future = AndThenApplyFnFuture<A, B, F, Fut, Res, Err>;
68
69    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70        let inner = self.srv.get_mut();
71        let not_ready = inner.0.poll_ready(cx)?.is_pending();
72        if inner.1.poll_ready(cx)?.is_pending() || not_ready {
73            Poll::Pending
74        } else {
75            Poll::Ready(Ok(()))
76        }
77    }
78
79    fn call(&mut self, req: A::Request) -> Self::Future {
80        let fut = self.srv.get_mut().0.call(req);
81        AndThenApplyFnFuture {
82            state: State::A(fut, Some(self.srv.clone())),
83        }
84    }
85}
86
87#[pin_project::pin_project]
88pub(crate) struct AndThenApplyFnFuture<A, B, F, Fut, Res, Err>
89where
90    A: Service,
91    B: Service,
92    F: FnMut(A::Response, &mut B) -> Fut,
93    Fut: Future<Output = Result<Res, Err>>,
94    Err: From<A::Error>,
95    Err: From<B::Error>,
96{
97    #[pin]
98    state: State<A, B, F, Fut, Res, Err>,
99}
100
101#[pin_project::pin_project]
102enum State<A, B, F, Fut, Res, Err>
103where
104    A: Service,
105    B: Service,
106    F: FnMut(A::Response, &mut B) -> Fut,
107    Fut: Future<Output = Result<Res, Err>>,
108    Err: From<A::Error>,
109    Err: From<B::Error>,
110{
111    A(#[pin] A::Future, Option<Cell<(A, B, F)>>),
112    B(#[pin] Fut),
113    Empty,
114}
115
116impl<A, B, F, Fut, Res, Err> Future for AndThenApplyFnFuture<A, B, F, Fut, Res, Err>
117where
118    A: Service,
119    B: Service,
120    F: FnMut(A::Response, &mut B) -> Fut,
121    Fut: Future<Output = Result<Res, Err>>,
122    Err: From<A::Error> + From<B::Error>,
123{
124    type Output = Result<Res, Err>;
125
126    #[pin_project::project]
127    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
128        let mut this = self.as_mut().project();
129
130        #[project]
131        match this.state.as_mut().project() {
132            State::A(fut, b) => match fut.poll(cx)? {
133                Poll::Ready(res) => {
134                    let mut b = b.take().unwrap();
135                    this.state.set(State::Empty);
136                    let b = b.get_mut();
137                    let fut = (&mut b.2)(res, &mut b.1);
138                    this.state.set(State::B(fut));
139                    self.poll(cx)
140                }
141                Poll::Pending => Poll::Pending,
142            },
143            State::B(fut) => fut.poll(cx).map(|r| {
144                this.state.set(State::Empty);
145                r
146            }),
147            State::Empty => panic!("future must not be polled after it returned `Poll::Ready`"),
148        }
149    }
150}
151
152/// `AndThenApplyFn` service factory
153pub(crate) struct AndThenApplyFnFactory<A, B, F, Fut, Res, Err> {
154    srv: Rc<(A, B, F)>,
155    r: PhantomData<(Fut, Res, Err)>,
156}
157
158impl<A, B, F, Fut, Res, Err> AndThenApplyFnFactory<A, B, F, Fut, Res, Err>
159where
160    A: ServiceFactory,
161    B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
162    F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
163    Fut: Future<Output = Result<Res, Err>>,
164    Err: From<A::Error> + From<B::Error>,
165{
166    /// Create new `ApplyNewService` new service instance
167    pub(crate) fn new(a: A, b: B, f: F) -> Self {
168        Self {
169            srv: Rc::new((a, b, f)),
170            r: PhantomData,
171        }
172    }
173}
174
175impl<A, B, F, Fut, Res, Err> Clone for AndThenApplyFnFactory<A, B, F, Fut, Res, Err> {
176    fn clone(&self) -> Self {
177        Self {
178            srv: self.srv.clone(),
179            r: PhantomData,
180        }
181    }
182}
183
184impl<A, B, F, Fut, Res, Err> ServiceFactory for AndThenApplyFnFactory<A, B, F, Fut, Res, Err>
185where
186    A: ServiceFactory,
187    A::Config: Clone,
188    B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
189    F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
190    Fut: Future<Output = Result<Res, Err>>,
191    Err: From<A::Error> + From<B::Error>,
192{
193    type Request = A::Request;
194    type Response = Res;
195    type Error = Err;
196    type Service = AndThenApplyFn<A::Service, B::Service, F, Fut, Res, Err>;
197    type Config = A::Config;
198    type InitError = A::InitError;
199    type Future = AndThenApplyFnFactoryResponse<A, B, F, Fut, Res, Err>;
200
201    fn new_service(&self, cfg: A::Config) -> Self::Future {
202        let srv = &*self.srv;
203        AndThenApplyFnFactoryResponse {
204            a: None,
205            b: None,
206            f: srv.2.clone(),
207            fut_a: srv.0.new_service(cfg.clone()),
208            fut_b: srv.1.new_service(cfg),
209        }
210    }
211}
212
213#[pin_project::pin_project]
214pub(crate) struct AndThenApplyFnFactoryResponse<A, B, F, Fut, Res, Err>
215where
216    A: ServiceFactory,
217    B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
218    F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
219    Fut: Future<Output = Result<Res, Err>>,
220    Err: From<A::Error>,
221    Err: From<B::Error>,
222{
223    #[pin]
224    fut_b: B::Future,
225    #[pin]
226    fut_a: A::Future,
227    f: F,
228    a: Option<A::Service>,
229    b: Option<B::Service>,
230}
231
232impl<A, B, F, Fut, Res, Err> Future for AndThenApplyFnFactoryResponse<A, B, F, Fut, Res, Err>
233where
234    A: ServiceFactory,
235    B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
236    F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
237    Fut: Future<Output = Result<Res, Err>>,
238    Err: From<A::Error> + From<B::Error>,
239{
240    type Output =
241        Result<AndThenApplyFn<A::Service, B::Service, F, Fut, Res, Err>, A::InitError>;
242
243    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244        let this = self.project();
245
246        if this.a.is_none() {
247            if let Poll::Ready(service) = this.fut_a.poll(cx)? {
248                *this.a = Some(service);
249            }
250        }
251
252        if this.b.is_none() {
253            if let Poll::Ready(service) = this.fut_b.poll(cx)? {
254                *this.b = Some(service);
255            }
256        }
257
258        if this.a.is_some() && this.b.is_some() {
259            Poll::Ready(Ok(AndThenApplyFn {
260                srv: Cell::new((
261                    this.a.take().unwrap(),
262                    this.b.take().unwrap(),
263                    this.f.clone(),
264                )),
265                r: PhantomData,
266            }))
267        } else {
268            Poll::Pending
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    use futures_util::future::{lazy, ok, Ready, TryFutureExt};
278
279    use crate::{fn_service, pipeline, pipeline_factory, Service, ServiceFactory};
280
281    #[derive(Clone)]
282    struct Srv;
283    impl Service for Srv {
284        type Request = ();
285        type Response = ();
286        type Error = ();
287        type Future = Ready<Result<(), ()>>;
288
289        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
290            Poll::Ready(Ok(()))
291        }
292
293        fn call(&mut self, req: Self::Request) -> Self::Future {
294            ok(req)
295        }
296    }
297
298    #[scrappy_rt::test]
299    async fn test_service() {
300        let mut srv = pipeline(|r: &'static str| ok(r))
301            .and_then_apply_fn(Srv, |req: &'static str, s| {
302                s.call(()).map_ok(move |res| (req, res))
303            });
304        let res = lazy(|cx| srv.poll_ready(cx)).await;
305        assert_eq!(res, Poll::Ready(Ok(())));
306
307        let res = srv.call("srv").await;
308        assert!(res.is_ok());
309        assert_eq!(res.unwrap(), ("srv", ()));
310    }
311
312    #[scrappy_rt::test]
313    async fn test_service_factory() {
314        let new_srv = pipeline_factory(|| ok::<_, ()>(fn_service(|r: &'static str| ok(r))))
315            .and_then_apply_fn(
316                || ok(Srv),
317                |req: &'static str, s| s.call(()).map_ok(move |res| (req, res)),
318            );
319        let mut srv = new_srv.new_service(()).await.unwrap();
320        let res = lazy(|cx| srv.poll_ready(cx)).await;
321        assert_eq!(res, Poll::Ready(Ok(())));
322
323        let res = srv.call("srv").await;
324        assert!(res.is_ok());
325        assert_eq!(res.unwrap(), ("srv", ()));
326    }
327}