rust_web_server/async_state/
mod.rs1#[cfg(test)]
48mod tests;
49
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53
54use crate::app::App;
55use crate::application::Application;
56use crate::core::New;
57use crate::request::Request;
58use crate::response::Response;
59use crate::router::matcher::{self, Segment};
60use crate::router::PathParams;
61use crate::server::ConnectionInfo;
62#[cfg(feature = "openapi")]
63use crate::mime_type::MimeType;
64#[cfg(feature = "openapi")]
65use crate::range::Range;
66#[cfg(feature = "openapi")]
67use crate::response::STATUS_CODE_REASON_PHRASE;
68
69type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
70
71type AsyncHandlerFn<S> = Arc<
72 dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
73>;
74
75#[derive(Clone)]
78struct AsyncRoute<S> {
79 method: String,
80 segments: Vec<Segment>,
81 handler: AsyncHandlerFn<S>,
82}
83
84#[derive(Clone)]
92pub struct AsyncAppWithState<S> {
93 state: Arc<S>,
94 routes: Vec<AsyncRoute<S>>,
95 config: Option<Arc<crate::server_config::ServerConfig>>,
100}
101
102impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
103 pub fn new(state: S) -> Self {
105 AsyncAppWithState { state: Arc::new(state), routes: Vec::new(), config: None }
106 }
107
108 pub fn with_config(mut self, config: crate::server_config::ServerConfig) -> Self {
114 self.config = Some(Arc::new(config));
115 self
116 }
117
118 pub fn state(&self) -> &S {
120 &self.state
121 }
122
123 pub fn route_entries(&self) -> Vec<crate::router::RouteInfo> {
125 self.routes
126 .iter()
127 .map(|r| crate::router::RouteInfo {
128 method: r.method.clone(),
129 pattern: matcher::segments_to_pattern(&r.segments),
130 })
131 .collect()
132 }
133
134 fn fallback_app(&self) -> App {
138 match &self.config {
139 Some(c) => App::with_config((**c).clone()),
140 None => App::new(),
141 }
142 }
143
144 fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
145 where
146 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
147 Fut: Future<Output = Response> + Send + 'static,
148 {
149 self.routes.push(AsyncRoute {
150 method: method.to_string(),
151 segments: matcher::parse_pattern(pattern),
152 handler: Arc::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
153 });
154 self
155 }
156
157 pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
159 where
160 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
161 Fut: Future<Output = Response> + Send + 'static,
162 {
163 self.add("GET", pattern, handler)
164 }
165
166 pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
168 where
169 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
170 Fut: Future<Output = Response> + Send + 'static,
171 {
172 self.add("POST", pattern, handler)
173 }
174
175 pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
177 where
178 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
179 Fut: Future<Output = Response> + Send + 'static,
180 {
181 self.add("PUT", pattern, handler)
182 }
183
184 pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
186 where
187 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
188 Fut: Future<Output = Response> + Send + 'static,
189 {
190 self.add("PATCH", pattern, handler)
191 }
192
193 pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
195 where
196 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
197 Fut: Future<Output = Response> + Send + 'static,
198 {
199 self.add("DELETE", pattern, handler)
200 }
201
202 #[cfg(feature = "openapi")]
213 pub fn openapi(self, config: crate::openapi::OpenApiConfig) -> Self {
214 let spec_json = Arc::new(crate::openapi::build_spec(&config, &self.route_entries()));
215 let html = Arc::new(crate::openapi::swagger_ui_html("/openapi.json"));
216
217 let spec_for_route = Arc::clone(&spec_json);
218 self.get("/openapi.json", move |_req, _params, _conn, _state| {
219 let spec_for_route = Arc::clone(&spec_for_route);
220 async move {
221 let mut r = Response::new();
222 r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
223 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
224 r.content_range_list = vec![Range::get_content_range(
225 spec_for_route.as_bytes().to_vec(),
226 MimeType::APPLICATION_JSON.to_string(),
227 )];
228 r
229 }
230 })
231 .get("/docs", move |_req, _params, _conn, _state| {
232 let html = Arc::clone(&html);
233 async move {
234 let mut r = Response::new();
235 r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
236 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
237 r.content_range_list = vec![Range::get_content_range(
238 html.as_bytes().to_vec(),
239 MimeType::TEXT_HTML.to_string(),
240 )];
241 r
242 }
243 })
244 }
245
246 async fn execute_async(
247 &self,
248 request: &Request,
249 connection: &ConnectionInfo,
250 ) -> Result<Response, String> {
251 let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
252 let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
253
254 for route in &self.routes {
255 if route.method != request.method {
256 continue;
257 }
258 if let Some(params_map) = matcher::try_match(&route.segments, &path_segs) {
259 let params = PathParams::from_map(params_map);
260 let fut = (route.handler)(
261 request.clone(),
262 params,
263 connection.clone(),
264 Arc::clone(&self.state),
265 );
266 return Ok(fut.await);
267 }
268 }
269
270 self.fallback_app().execute(request, connection)
271 }
272}
273
274impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
275 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
276 let request = request.clone();
277 let connection = connection.clone();
278 match tokio::runtime::Handle::try_current() {
279 Ok(_) => {
280 std::thread::scope(|s| {
283 s.spawn(|| {
284 tokio::runtime::Builder::new_current_thread()
285 .enable_all()
286 .build()
287 .unwrap()
288 .block_on(self.execute_async(&request, &connection))
289 })
290 .join()
291 .unwrap()
292 })
293 }
294 Err(_) => {
295 tokio::runtime::Builder::new_current_thread()
297 .enable_all()
298 .build()
299 .unwrap()
300 .block_on(self.execute_async(&request, &connection))
301 }
302 }
303 }
304}