xitca_web/handler/types/
state.rs

1//! type extractor or application state.
2
3use core::{fmt, ops::Deref};
4
5use crate::{context::WebContext, error::Error, handler::FromRequest};
6
7/// borrow trait for extracting typed field from application state
8#[diagnostic::on_unimplemented(
9    message = "`{T}` can not be borrowed from {Self}",
10    label = "{Self} must impl BorrowState trait for borrowing {T} from app state",
11    note = "consider add `impl BorrowState<{T}> for {Self}`"
12)]
13pub trait BorrowState<T>
14where
15    T: ?Sized,
16{
17    fn borrow(&self) -> &T;
18}
19
20impl<T> BorrowState<T> for T
21where
22    T: ?Sized,
23{
24    fn borrow(&self) -> &T {
25        self
26    }
27}
28
29macro_rules! pointer_impl {
30    ($t: path) => {
31        impl<T> BorrowState<T> for $t
32        where
33            T: ?Sized,
34        {
35            fn borrow(&self) -> &T {
36                &*self
37            }
38        }
39    };
40}
41
42pointer_impl!(std::boxed::Box<T>);
43pointer_impl!(std::rc::Rc<T>);
44pointer_impl!(std::sync::Arc<T>);
45
46/// App state extractor.
47/// S type must be the same with the type passed to App::with_xxx_state(S).
48pub struct StateRef<'a, S>(pub &'a S)
49where
50    S: ?Sized;
51
52impl<S> fmt::Debug for StateRef<'_, S>
53where
54    S: fmt::Debug,
55{
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(f, "StateRef({:?})", self.0)
58    }
59}
60
61impl<S> fmt::Display for StateRef<'_, S>
62where
63    S: fmt::Display,
64{
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        write!(f, "StateRef({})", self.0)
67    }
68}
69
70impl<S> Deref for StateRef<'_, S> {
71    type Target = S;
72
73    #[inline]
74    fn deref(&self) -> &Self::Target {
75        self.0
76    }
77}
78
79impl<'a, 'r, C, B, T> FromRequest<'a, WebContext<'r, C, B>> for StateRef<'a, T>
80where
81    C: BorrowState<T>,
82    T: ?Sized + 'static,
83{
84    type Type<'b> = StateRef<'b, T>;
85    type Error = Error;
86
87    #[inline]
88    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
89        Ok(StateRef(ctx.state().borrow()))
90    }
91}
92
93/// App state extractor for owned value.
94/// S type must be the same with the type passed to App::with_xxx_state(S).
95pub struct StateOwn<S>(pub S);
96
97impl<S> fmt::Debug for StateOwn<S>
98where
99    S: fmt::Debug,
100{
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        write!(f, "StateOwn({:?})", self.0)
103    }
104}
105
106impl<S> fmt::Display for StateOwn<S>
107where
108    S: fmt::Display,
109{
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        write!(f, "StateOwn({})", self.0)
112    }
113}
114
115impl<S> Deref for StateOwn<S> {
116    type Target = S;
117
118    #[inline]
119    fn deref(&self) -> &Self::Target {
120        &self.0
121    }
122}
123
124impl<'a, 'r, C, B, T> FromRequest<'a, WebContext<'r, C, B>> for StateOwn<T>
125where
126    C: BorrowState<T>,
127    T: Clone,
128{
129    type Type<'b> = StateOwn<T>;
130    type Error = Error;
131
132    #[inline]
133    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
134        Ok(StateOwn(ctx.state().borrow().clone()))
135    }
136}
137
138#[cfg(test)]
139mod test {
140    use std::sync::Arc;
141
142    use xitca_unsafe_collection::futures::NowOrPanic;
143
144    use crate::{App, handler::handler_service, http::WebRequest, route::get, service::Service};
145
146    use super::*;
147
148    #[derive(Clone, Debug)]
149    struct State {
150        field1: String,
151        field2: u32,
152        field3: Arc<dyn std::any::Any + Send + Sync>,
153    }
154
155    impl BorrowState<String> for State {
156        fn borrow(&self) -> &String {
157            &self.field1
158        }
159    }
160
161    impl BorrowState<u32> for State {
162        fn borrow(&self) -> &u32 {
163            &self.field2
164        }
165    }
166
167    impl BorrowState<dyn std::any::Any + Send + Sync> for State {
168        fn borrow(&self) -> &(dyn std::any::Any + Send + Sync) {
169            &*self.field3
170        }
171    }
172
173    async fn handler(
174        StateRef(state): StateRef<'_, String>,
175        StateRef(state2): StateRef<'_, u32>,
176        StateRef(state3): StateRef<'_, State>,
177        StateRef(state4): StateRef<'_, dyn std::any::Any + Send + Sync>,
178        ctx: &WebContext<'_, State>,
179    ) -> String {
180        assert_eq!("state", state);
181        assert_eq!(&996, state2);
182        assert_eq!(state, ctx.state().field1.as_str());
183        assert_eq!(state3.field1, ctx.state().field1);
184        assert!(state4.downcast_ref::<String>().is_some());
185        state.to_string()
186    }
187
188    #[test]
189    fn state_extract() {
190        let state = State {
191            field1: String::from("state"),
192            field2: 996,
193            field3: Arc::new(String::new()),
194        };
195
196        App::new()
197            .with_state(state)
198            .at("/", get(handler_service(handler)))
199            .finish()
200            .call(())
201            .now_or_panic()
202            .ok()
203            .unwrap()
204            .call(WebRequest::default())
205            .now_or_panic()
206            .unwrap();
207    }
208
209    #[test]
210    fn state_extract_deref() {
211        use std::{any::Any, sync::Arc};
212
213        async fn handler(StateRef(state): StateRef<'_, dyn Any + Send + Sync>) -> String {
214            state.downcast_ref::<i32>().unwrap().to_string()
215        }
216
217        async fn handler2(StateRef(state): StateRef<'_, i32>) -> String {
218            state.to_string()
219        }
220
221        let state = Arc::new(996);
222
223        App::new()
224            .with_state(state.clone() as Arc<dyn Any + Send + Sync>)
225            .at("/", get(handler_service(handler)))
226            .at(
227                "/scope",
228                App::new().with_state(state).at("/", handler_service(handler2)),
229            )
230            .finish()
231            .call(())
232            .now_or_panic()
233            .ok()
234            .unwrap()
235            .call(WebRequest::default())
236            .now_or_panic()
237            .unwrap();
238    }
239}