1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::rc::Rc;
5use std::task::{Context, Poll};
6
7use crate::cell::Cell;
8use crate::{Service, ServiceFactory};
9
10pub(crate) struct AndThenApplyFn<A, B, F, Fut, Res, Err>
12where
13 A: Service,
14 B: Service,
15 F: FnMut(A::Response, &mut B) -> Fut,
16 Fut: Future<Output = Result<Res, Err>>,
17 Err: From<A::Error> + From<B::Error>,
18{
19 srv: Cell<(A, B, F)>,
20 r: PhantomData<(Fut, Res, Err)>,
21}
22
23impl<A, B, F, Fut, Res, Err> AndThenApplyFn<A, B, F, Fut, Res, Err>
24where
25 A: Service,
26 B: Service,
27 F: FnMut(A::Response, &mut B) -> Fut,
28 Fut: Future<Output = Result<Res, Err>>,
29 Err: From<A::Error> + From<B::Error>,
30{
31 pub(crate) fn new(a: A, b: B, f: F) -> Self {
33 Self {
34 srv: Cell::new((a, b, f)),
35 r: PhantomData,
36 }
37 }
38}
39
40impl<A, B, F, Fut, Res, Err> Clone for AndThenApplyFn<A, B, F, Fut, Res, Err>
41where
42 A: Service,
43 B: Service,
44 F: FnMut(A::Response, &mut B) -> Fut,
45 Fut: Future<Output = Result<Res, Err>>,
46 Err: From<A::Error> + From<B::Error>,
47{
48 fn clone(&self) -> Self {
49 AndThenApplyFn {
50 srv: self.srv.clone(),
51 r: PhantomData,
52 }
53 }
54}
55
56impl<A, B, F, Fut, Res, Err> Service for AndThenApplyFn<A, B, F, Fut, Res, Err>
57where
58 A: Service,
59 B: Service,
60 F: FnMut(A::Response, &mut B) -> Fut,
61 Fut: Future<Output = Result<Res, Err>>,
62 Err: From<A::Error> + From<B::Error>,
63{
64 type Request = A::Request;
65 type Response = Res;
66 type Error = Err;
67 type Future = AndThenApplyFnFuture<A, B, F, Fut, Res, Err>;
68
69 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70 let inner = self.srv.get_mut();
71 let not_ready = inner.0.poll_ready(cx)?.is_pending();
72 if inner.1.poll_ready(cx)?.is_pending() || not_ready {
73 Poll::Pending
74 } else {
75 Poll::Ready(Ok(()))
76 }
77 }
78
79 fn call(&mut self, req: A::Request) -> Self::Future {
80 let fut = self.srv.get_mut().0.call(req);
81 AndThenApplyFnFuture {
82 state: State::A(fut, Some(self.srv.clone())),
83 }
84 }
85}
86
87#[pin_project::pin_project]
88pub(crate) struct AndThenApplyFnFuture<A, B, F, Fut, Res, Err>
89where
90 A: Service,
91 B: Service,
92 F: FnMut(A::Response, &mut B) -> Fut,
93 Fut: Future<Output = Result<Res, Err>>,
94 Err: From<A::Error>,
95 Err: From<B::Error>,
96{
97 #[pin]
98 state: State<A, B, F, Fut, Res, Err>,
99}
100
101#[pin_project::pin_project]
102enum State<A, B, F, Fut, Res, Err>
103where
104 A: Service,
105 B: Service,
106 F: FnMut(A::Response, &mut B) -> Fut,
107 Fut: Future<Output = Result<Res, Err>>,
108 Err: From<A::Error>,
109 Err: From<B::Error>,
110{
111 A(#[pin] A::Future, Option<Cell<(A, B, F)>>),
112 B(#[pin] Fut),
113 Empty,
114}
115
116impl<A, B, F, Fut, Res, Err> Future for AndThenApplyFnFuture<A, B, F, Fut, Res, Err>
117where
118 A: Service,
119 B: Service,
120 F: FnMut(A::Response, &mut B) -> Fut,
121 Fut: Future<Output = Result<Res, Err>>,
122 Err: From<A::Error> + From<B::Error>,
123{
124 type Output = Result<Res, Err>;
125
126 #[pin_project::project]
127 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
128 let mut this = self.as_mut().project();
129
130 #[project]
131 match this.state.as_mut().project() {
132 State::A(fut, b) => match fut.poll(cx)? {
133 Poll::Ready(res) => {
134 let mut b = b.take().unwrap();
135 this.state.set(State::Empty);
136 let b = b.get_mut();
137 let fut = (&mut b.2)(res, &mut b.1);
138 this.state.set(State::B(fut));
139 self.poll(cx)
140 }
141 Poll::Pending => Poll::Pending,
142 },
143 State::B(fut) => fut.poll(cx).map(|r| {
144 this.state.set(State::Empty);
145 r
146 }),
147 State::Empty => panic!("future must not be polled after it returned `Poll::Ready`"),
148 }
149 }
150}
151
152pub(crate) struct AndThenApplyFnFactory<A, B, F, Fut, Res, Err> {
154 srv: Rc<(A, B, F)>,
155 r: PhantomData<(Fut, Res, Err)>,
156}
157
158impl<A, B, F, Fut, Res, Err> AndThenApplyFnFactory<A, B, F, Fut, Res, Err>
159where
160 A: ServiceFactory,
161 B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
162 F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
163 Fut: Future<Output = Result<Res, Err>>,
164 Err: From<A::Error> + From<B::Error>,
165{
166 pub(crate) fn new(a: A, b: B, f: F) -> Self {
168 Self {
169 srv: Rc::new((a, b, f)),
170 r: PhantomData,
171 }
172 }
173}
174
175impl<A, B, F, Fut, Res, Err> Clone for AndThenApplyFnFactory<A, B, F, Fut, Res, Err> {
176 fn clone(&self) -> Self {
177 Self {
178 srv: self.srv.clone(),
179 r: PhantomData,
180 }
181 }
182}
183
184impl<A, B, F, Fut, Res, Err> ServiceFactory for AndThenApplyFnFactory<A, B, F, Fut, Res, Err>
185where
186 A: ServiceFactory,
187 A::Config: Clone,
188 B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
189 F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
190 Fut: Future<Output = Result<Res, Err>>,
191 Err: From<A::Error> + From<B::Error>,
192{
193 type Request = A::Request;
194 type Response = Res;
195 type Error = Err;
196 type Service = AndThenApplyFn<A::Service, B::Service, F, Fut, Res, Err>;
197 type Config = A::Config;
198 type InitError = A::InitError;
199 type Future = AndThenApplyFnFactoryResponse<A, B, F, Fut, Res, Err>;
200
201 fn new_service(&self, cfg: A::Config) -> Self::Future {
202 let srv = &*self.srv;
203 AndThenApplyFnFactoryResponse {
204 a: None,
205 b: None,
206 f: srv.2.clone(),
207 fut_a: srv.0.new_service(cfg.clone()),
208 fut_b: srv.1.new_service(cfg),
209 }
210 }
211}
212
213#[pin_project::pin_project]
214pub(crate) struct AndThenApplyFnFactoryResponse<A, B, F, Fut, Res, Err>
215where
216 A: ServiceFactory,
217 B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
218 F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
219 Fut: Future<Output = Result<Res, Err>>,
220 Err: From<A::Error>,
221 Err: From<B::Error>,
222{
223 #[pin]
224 fut_b: B::Future,
225 #[pin]
226 fut_a: A::Future,
227 f: F,
228 a: Option<A::Service>,
229 b: Option<B::Service>,
230}
231
232impl<A, B, F, Fut, Res, Err> Future for AndThenApplyFnFactoryResponse<A, B, F, Fut, Res, Err>
233where
234 A: ServiceFactory,
235 B: ServiceFactory<Config = A::Config, InitError = A::InitError>,
236 F: FnMut(A::Response, &mut B::Service) -> Fut + Clone,
237 Fut: Future<Output = Result<Res, Err>>,
238 Err: From<A::Error> + From<B::Error>,
239{
240 type Output =
241 Result<AndThenApplyFn<A::Service, B::Service, F, Fut, Res, Err>, A::InitError>;
242
243 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244 let this = self.project();
245
246 if this.a.is_none() {
247 if let Poll::Ready(service) = this.fut_a.poll(cx)? {
248 *this.a = Some(service);
249 }
250 }
251
252 if this.b.is_none() {
253 if let Poll::Ready(service) = this.fut_b.poll(cx)? {
254 *this.b = Some(service);
255 }
256 }
257
258 if this.a.is_some() && this.b.is_some() {
259 Poll::Ready(Ok(AndThenApplyFn {
260 srv: Cell::new((
261 this.a.take().unwrap(),
262 this.b.take().unwrap(),
263 this.f.clone(),
264 )),
265 r: PhantomData,
266 }))
267 } else {
268 Poll::Pending
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 use futures_util::future::{lazy, ok, Ready, TryFutureExt};
278
279 use crate::{fn_service, pipeline, pipeline_factory, Service, ServiceFactory};
280
281 #[derive(Clone)]
282 struct Srv;
283 impl Service for Srv {
284 type Request = ();
285 type Response = ();
286 type Error = ();
287 type Future = Ready<Result<(), ()>>;
288
289 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
290 Poll::Ready(Ok(()))
291 }
292
293 fn call(&mut self, req: Self::Request) -> Self::Future {
294 ok(req)
295 }
296 }
297
298 #[scrappy_rt::test]
299 async fn test_service() {
300 let mut srv = pipeline(|r: &'static str| ok(r))
301 .and_then_apply_fn(Srv, |req: &'static str, s| {
302 s.call(()).map_ok(move |res| (req, res))
303 });
304 let res = lazy(|cx| srv.poll_ready(cx)).await;
305 assert_eq!(res, Poll::Ready(Ok(())));
306
307 let res = srv.call("srv").await;
308 assert!(res.is_ok());
309 assert_eq!(res.unwrap(), ("srv", ()));
310 }
311
312 #[scrappy_rt::test]
313 async fn test_service_factory() {
314 let new_srv = pipeline_factory(|| ok::<_, ()>(fn_service(|r: &'static str| ok(r))))
315 .and_then_apply_fn(
316 || ok(Srv),
317 |req: &'static str, s| s.call(()).map_ok(move |res| (req, res)),
318 );
319 let mut srv = new_srv.new_service(()).await.unwrap();
320 let res = lazy(|cx| srv.poll_ready(cx)).await;
321 assert_eq!(res, Poll::Ready(Ok(())));
322
323 let res = srv.call("srv").await;
324 assert!(res.is_ok());
325 assert_eq!(res.unwrap(), ("srv", ()));
326 }
327}