1use core::{fmt, ops::Deref};
4
5use crate::{context::WebContext, error::Error, handler::FromRequest};
6
7#[diagnostic::on_unimplemented(
9 message = "`{T}` can not be borrowed from {Self}",
10 label = "{Self} must impl BorrowState trait for borrowing {T} from app state",
11 note = "consider add `impl BorrowState<{T}> for {Self}`"
12)]
13pub trait BorrowState<T>
14where
15 T: ?Sized,
16{
17 fn borrow(&self) -> &T;
18}
19
20impl<T> BorrowState<T> for T
21where
22 T: ?Sized,
23{
24 fn borrow(&self) -> &T {
25 self
26 }
27}
28
29macro_rules! pointer_impl {
30 ($t: path) => {
31 impl<T> BorrowState<T> for $t
32 where
33 T: ?Sized,
34 {
35 fn borrow(&self) -> &T {
36 &*self
37 }
38 }
39 };
40}
41
42pointer_impl!(std::boxed::Box<T>);
43pointer_impl!(std::rc::Rc<T>);
44pointer_impl!(std::sync::Arc<T>);
45
46pub struct StateRef<'a, S>(pub &'a S)
49where
50 S: ?Sized;
51
52impl<S> fmt::Debug for StateRef<'_, S>
53where
54 S: fmt::Debug,
55{
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(f, "StateRef({:?})", self.0)
58 }
59}
60
61impl<S> fmt::Display for StateRef<'_, S>
62where
63 S: fmt::Display,
64{
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 write!(f, "StateRef({})", self.0)
67 }
68}
69
70impl<S> Deref for StateRef<'_, S> {
71 type Target = S;
72
73 #[inline]
74 fn deref(&self) -> &Self::Target {
75 self.0
76 }
77}
78
79impl<'a, 'r, C, B, T> FromRequest<'a, WebContext<'r, C, B>> for StateRef<'a, T>
80where
81 C: BorrowState<T>,
82 T: ?Sized + 'static,
83{
84 type Type<'b> = StateRef<'b, T>;
85 type Error = Error;
86
87 #[inline]
88 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
89 Ok(StateRef(ctx.state().borrow()))
90 }
91}
92
93pub struct StateOwn<S>(pub S);
96
97impl<S> fmt::Debug for StateOwn<S>
98where
99 S: fmt::Debug,
100{
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 write!(f, "StateOwn({:?})", self.0)
103 }
104}
105
106impl<S> fmt::Display for StateOwn<S>
107where
108 S: fmt::Display,
109{
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 write!(f, "StateOwn({})", self.0)
112 }
113}
114
115impl<S> Deref for StateOwn<S> {
116 type Target = S;
117
118 #[inline]
119 fn deref(&self) -> &Self::Target {
120 &self.0
121 }
122}
123
124impl<'a, 'r, C, B, T> FromRequest<'a, WebContext<'r, C, B>> for StateOwn<T>
125where
126 C: BorrowState<T>,
127 T: Clone,
128{
129 type Type<'b> = StateOwn<T>;
130 type Error = Error;
131
132 #[inline]
133 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
134 Ok(StateOwn(ctx.state().borrow().clone()))
135 }
136}
137
138#[cfg(test)]
139mod test {
140 use std::sync::Arc;
141
142 use xitca_unsafe_collection::futures::NowOrPanic;
143
144 use crate::{App, handler::handler_service, http::WebRequest, route::get, service::Service};
145
146 use super::*;
147
148 #[derive(Clone, Debug)]
149 struct State {
150 field1: String,
151 field2: u32,
152 field3: Arc<dyn std::any::Any + Send + Sync>,
153 }
154
155 impl BorrowState<String> for State {
156 fn borrow(&self) -> &String {
157 &self.field1
158 }
159 }
160
161 impl BorrowState<u32> for State {
162 fn borrow(&self) -> &u32 {
163 &self.field2
164 }
165 }
166
167 impl BorrowState<dyn std::any::Any + Send + Sync> for State {
168 fn borrow(&self) -> &(dyn std::any::Any + Send + Sync) {
169 &*self.field3
170 }
171 }
172
173 async fn handler(
174 StateRef(state): StateRef<'_, String>,
175 StateRef(state2): StateRef<'_, u32>,
176 StateRef(state3): StateRef<'_, State>,
177 StateRef(state4): StateRef<'_, dyn std::any::Any + Send + Sync>,
178 ctx: &WebContext<'_, State>,
179 ) -> String {
180 assert_eq!("state", state);
181 assert_eq!(&996, state2);
182 assert_eq!(state, ctx.state().field1.as_str());
183 assert_eq!(state3.field1, ctx.state().field1);
184 assert!(state4.downcast_ref::<String>().is_some());
185 state.to_string()
186 }
187
188 #[test]
189 fn state_extract() {
190 let state = State {
191 field1: String::from("state"),
192 field2: 996,
193 field3: Arc::new(String::new()),
194 };
195
196 App::new()
197 .with_state(state)
198 .at("/", get(handler_service(handler)))
199 .finish()
200 .call(())
201 .now_or_panic()
202 .ok()
203 .unwrap()
204 .call(WebRequest::default())
205 .now_or_panic()
206 .unwrap();
207 }
208
209 #[test]
210 fn state_extract_deref() {
211 use std::{any::Any, sync::Arc};
212
213 async fn handler(StateRef(state): StateRef<'_, dyn Any + Send + Sync>) -> String {
214 state.downcast_ref::<i32>().unwrap().to_string()
215 }
216
217 async fn handler2(StateRef(state): StateRef<'_, i32>) -> String {
218 state.to_string()
219 }
220
221 let state = Arc::new(996);
222
223 App::new()
224 .with_state(state.clone() as Arc<dyn Any + Send + Sync>)
225 .at("/", get(handler_service(handler)))
226 .at(
227 "/scope",
228 App::new().with_state(state).at("/", handler_service(handler2)),
229 )
230 .finish()
231 .call(())
232 .now_or_panic()
233 .ok()
234 .unwrap()
235 .call(WebRequest::default())
236 .now_or_panic()
237 .unwrap();
238 }
239}