rust_web_server/async_state/
mod.rs1#[cfg(test)]
48mod tests;
49
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53
54use crate::app::App;
55use crate::application::Application;
56use crate::core::New;
57use crate::request::Request;
58use crate::response::Response;
59use crate::router::matcher::{self, Segment};
60use crate::router::PathParams;
61use crate::server::ConnectionInfo;
62
63type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
64
65type AsyncHandlerFn<S> = Arc<
66 dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
67>;
68
69#[derive(Clone)]
72struct AsyncRoute<S> {
73 method: String,
74 segments: Vec<Segment>,
75 handler: AsyncHandlerFn<S>,
76}
77
78#[derive(Clone)]
86pub struct AsyncAppWithState<S> {
87 state: Arc<S>,
88 routes: Vec<AsyncRoute<S>>,
89}
90
91impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
92 pub fn new(state: S) -> Self {
94 AsyncAppWithState { state: Arc::new(state), routes: Vec::new() }
95 }
96
97 pub fn state(&self) -> &S {
99 &self.state
100 }
101
102 fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
103 where
104 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
105 Fut: Future<Output = Response> + Send + 'static,
106 {
107 self.routes.push(AsyncRoute {
108 method: method.to_string(),
109 segments: matcher::parse_pattern(pattern),
110 handler: Arc::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
111 });
112 self
113 }
114
115 pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
117 where
118 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
119 Fut: Future<Output = Response> + Send + 'static,
120 {
121 self.add("GET", pattern, handler)
122 }
123
124 pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
126 where
127 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
128 Fut: Future<Output = Response> + Send + 'static,
129 {
130 self.add("POST", pattern, handler)
131 }
132
133 pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
135 where
136 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
137 Fut: Future<Output = Response> + Send + 'static,
138 {
139 self.add("PUT", pattern, handler)
140 }
141
142 pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
144 where
145 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
146 Fut: Future<Output = Response> + Send + 'static,
147 {
148 self.add("PATCH", pattern, handler)
149 }
150
151 pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
153 where
154 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
155 Fut: Future<Output = Response> + Send + 'static,
156 {
157 self.add("DELETE", pattern, handler)
158 }
159
160 async fn execute_async(
161 &self,
162 request: &Request,
163 connection: &ConnectionInfo,
164 ) -> Result<Response, String> {
165 let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
166 let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
167
168 for route in &self.routes {
169 if route.method != request.method {
170 continue;
171 }
172 if let Some(params_map) = matcher::try_match(&route.segments, &path_segs) {
173 let params = PathParams::from_map(params_map);
174 let fut = (route.handler)(
175 request.clone(),
176 params,
177 connection.clone(),
178 Arc::clone(&self.state),
179 );
180 return Ok(fut.await);
181 }
182 }
183
184 App::new().execute(request, connection)
185 }
186}
187
188impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
189 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
190 let request = request.clone();
191 let connection = connection.clone();
192 match tokio::runtime::Handle::try_current() {
193 Ok(_) => {
194 std::thread::scope(|s| {
197 s.spawn(|| {
198 tokio::runtime::Builder::new_current_thread()
199 .enable_all()
200 .build()
201 .unwrap()
202 .block_on(self.execute_async(&request, &connection))
203 })
204 .join()
205 .unwrap()
206 })
207 }
208 Err(_) => {
209 tokio::runtime::Builder::new_current_thread()
211 .enable_all()
212 .build()
213 .unwrap()
214 .block_on(self.execute_async(&request, &connection))
215 }
216 }
217 }
218}