1use core::fmt;
4
5use xitca_service::Service;
6
7use crate::http::{BorrowReq, BorrowReqMut};
8
9pub struct ContextBuilder<CF> {
47 builder: CF,
48}
49
50impl<CF, Fut, C, CErr> ContextBuilder<CF>
51where
52 CF: Fn() -> Fut,
53 Fut: Future<Output = Result<C, CErr>>,
54{
55 pub fn new(builder: CF) -> Self {
57 Self { builder }
58 }
59}
60
61pub struct Context<'a, Req, C> {
65 req: Req,
66 state: &'a C,
67}
68
69impl<'a, Req, C> Context<'a, Req, C> {
70 #[inline]
72 pub fn into_parts(self) -> (Req, &'a C) {
73 (self.req, self.state)
74 }
75}
76
77impl<T, Req, C> BorrowReq<T> for Context<'_, Req, C>
80where
81 Req: BorrowReq<T>,
82{
83 #[inline]
84 fn borrow(&self) -> &T {
85 self.req.borrow()
86 }
87}
88
89impl<T, Req, C> BorrowReqMut<T> for Context<'_, Req, C>
90where
91 Req: BorrowReqMut<T>,
92{
93 #[inline]
94 fn borrow_mut(&mut self) -> &mut T {
95 self.req.borrow_mut()
96 }
97}
98
99type Error = Box<dyn fmt::Debug>;
100
101impl<CF, Fut, C, CErr, S, E> Service<Result<S, E>> for ContextBuilder<CF>
102where
103 CF: Fn() -> Fut,
104 Fut: Future<Output = Result<C, CErr>>,
105 C: 'static,
106 CErr: fmt::Debug + 'static,
107 E: fmt::Debug + 'static,
108{
109 type Response = service::ContextService<C, S>;
110 type Error = Error;
111
112 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
113 let service = res.map_err(|e| Box::new(e) as Error)?;
114 let state = (self.builder)().await.map_err(|e| Box::new(e) as Error)?;
115 Ok(service::ContextService { service, state })
116 }
117}
118
119mod service {
120 use xitca_service::ready::ReadyService;
121
122 use super::*;
123
124 pub struct ContextService<C, S> {
125 pub(super) state: C,
126 pub(super) service: S,
127 }
128
129 impl<Req, C, S, Res, Err> Service<Req> for ContextService<C, S>
130 where
131 S: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err>,
132 {
133 type Response = Res;
134 type Error = Err;
135
136 #[inline]
137 async fn call(&self, req: Req) -> Result<Self::Response, Self::Error> {
138 self.service
139 .call(Context {
140 req,
141 state: &self.state,
142 })
143 .await
144 }
145 }
146
147 impl<C, S> ReadyService for ContextService<C, S>
148 where
149 S: ReadyService,
150 {
151 type Ready = S::Ready;
152
153 #[inline]
154 async fn ready(&self) -> Self::Ready {
155 self.service.ready().await
156 }
157 }
158}
159
160#[cfg(feature = "router")]
161mod router_impl {
162 use xitca_service::object::ServiceObject;
163
164 use crate::util::service::router::{IntoObject, PathGen, RouteGen, RouteObject};
165
166 use super::*;
167
168 pub type ContextObject<Req, C, Res, Err> =
169 Box<dyn for<'c> ServiceObject<Context<'c, Req, C>, Response = Res, Error = Err>>;
170
171 impl<C, I, Arg, Req, Res, Err> IntoObject<I, Arg> for Context<'_, Req, C>
172 where
173 C: 'static,
174 Req: 'static,
175 I: Service<Arg> + RouteGen + Send + Sync + 'static,
176 I::Response: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err> + 'static,
177 {
178 type Object = RouteObject<Arg, ContextObject<Req, C, Res, Err>, I::Error>;
179
180 fn into_object(inner: I) -> Self::Object {
181 struct Builder<I, Req, C>(I, core::marker::PhantomData<fn(Req, C)>);
182
183 impl<I, Req, C> PathGen for Builder<I, Req, C>
184 where
185 I: PathGen,
186 {
187 fn path_gen(&mut self, prefix: &str) -> String {
188 self.0.path_gen(prefix)
189 }
190 }
191
192 impl<I, Req, C> RouteGen for Builder<I, Req, C>
193 where
194 I: RouteGen,
195 {
196 type Route<R> = I::Route<R>;
197
198 fn route_gen<R>(route: R) -> Self::Route<R> {
199 I::route_gen(route)
200 }
201 }
202
203 impl<C, I, Arg, Req, Res, Err> Service<Arg> for Builder<I, Req, C>
204 where
205 I: Service<Arg> + RouteGen,
206 I::Response: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err> + 'static,
207 {
208 type Response = ContextObject<Req, C, Res, Err>;
209 type Error = I::Error;
210
211 async fn call(&self, arg: Arg) -> Result<Self::Response, Self::Error> {
212 self.0.call(arg).await.map(|s| Box::new(s) as _)
213 }
214 }
215
216 RouteObject(Box::new(Builder(inner, core::marker::PhantomData)))
217 }
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use std::convert::Infallible;
224
225 use xitca_service::{ServiceExt, fn_service};
226 use xitca_unsafe_collection::futures::NowOrPanic;
227
228 use crate::http::{Request, Response};
229
230 use super::*;
231
232 struct Context2<'a, ST> {
233 req: Request<()>,
234 state: &'a ST,
235 }
236
237 async fn into_context(req: Context<'_, Request<()>, String>) -> Result<Context2<'_, String>, Infallible> {
238 let (req, state) = req.into_parts();
239 assert_eq!(state, "string_state");
240 Ok(Context2 { req, state })
241 }
242
243 async fn ctx_handler(ctx: Context2<'_, String>) -> Result<Response<()>, Infallible> {
244 assert_eq!(ctx.state, "string_state");
245 assert_eq!(ctx.req.method().as_str(), "GET");
246 Ok(Response::new(()))
247 }
248
249 #[test]
250 fn test_state_and_then() {
251 let res = fn_service(into_context)
252 .and_then(fn_service(ctx_handler))
253 .enclosed(ContextBuilder::new(|| async {
254 Ok::<_, Infallible>(String::from("string_state"))
255 }))
256 .call(())
257 .now_or_panic()
258 .ok()
259 .unwrap()
260 .call(Request::default())
261 .now_or_panic()
262 .unwrap();
263
264 assert_eq!(res.status().as_u16(), 200);
265 }
266
267 #[cfg(feature = "router")]
268 #[test]
269 fn test_state_in_router() {
270 use crate::{
271 http::RequestExt,
272 util::service::{Router, route::get},
273 };
274
275 async fn enclosed<S, Req, C, Res, Err>(service: &S, req: Context<'_, Req, C>) -> Result<Res, Err>
276 where
277 S: for<'c> Service<Context<'c, Req, C>, Response = Res, Error = Err>,
278 {
279 service.call(req).await
280 }
281
282 let router = || {
283 Router::new().insert(
284 "/",
285 get(fn_service(async |req: Context<'_, Request<RequestExt<()>>, String>| {
286 let (_, state) = req.into_parts();
287 assert_eq!(state, "string_state");
288 Ok::<_, Infallible>(Response::new(()))
289 })),
290 )
291 };
292
293 let router_with_ctx = || {
294 router().enclosed_fn(enclosed).enclosed(ContextBuilder::new(|| async {
295 Ok::<_, Infallible>(String::from("string_state"))
296 }))
297 };
298
299 fn bound_check<T: Send + Sync>(_: T) {}
300
301 bound_check(router_with_ctx());
302
303 let res = router_with_ctx()
304 .call(())
305 .now_or_panic()
306 .ok()
307 .unwrap()
308 .call(Request::default())
309 .now_or_panic()
310 .unwrap();
311
312 assert_eq!(res.status().as_u16(), 200);
313
314 let res = router()
315 .insert("/nest", router())
316 .enclosed(ContextBuilder::new(|| async {
317 Ok::<_, Infallible>(String::from("string_state"))
318 }))
319 .call(())
320 .now_or_panic()
321 .ok()
322 .unwrap()
323 .call(Request::default())
324 .now_or_panic()
325 .unwrap();
326
327 assert_eq!(res.status().as_u16(), 200);
328 }
329}