rust_web_server/async_state/
mod.rs1#[cfg(test)]
48mod tests;
49
50use std::collections::HashMap;
51use std::future::Future;
52use std::pin::Pin;
53use std::sync::Arc;
54
55use crate::app::App;
56use crate::application::Application;
57use crate::core::New;
58use crate::request::Request;
59use crate::response::Response;
60use crate::router::PathParams;
61use crate::server::ConnectionInfo;
62
63type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
64
65type AsyncHandlerFn<S> = Box<
66 dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
67>;
68
69enum Segment {
72 Literal(String),
73 Param(String),
74 Wildcard(String),
75}
76
77fn parse_pattern(pattern: &str) -> Vec<Segment> {
78 if pattern == "/" {
79 return vec![];
80 }
81 pattern
82 .split('/')
83 .filter(|s| !s.is_empty())
84 .map(|seg| {
85 if let Some(name) = seg.strip_prefix(':') {
86 Segment::Param(name.to_string())
87 } else if let Some(name) = seg.strip_prefix('*') {
88 Segment::Wildcard(name.to_string())
89 } else {
90 Segment::Literal(seg.to_string())
91 }
92 })
93 .collect()
94}
95
96fn try_match(pattern: &[Segment], path: &[&str]) -> Option<HashMap<String, String>> {
97 let mut params = HashMap::new();
98 let mut pi = 0;
99
100 for (si, seg) in pattern.iter().enumerate() {
101 match seg {
102 Segment::Literal(lit) => {
103 if pi >= path.len() || path[pi] != lit.as_str() {
104 return None;
105 }
106 pi += 1;
107 }
108 Segment::Param(name) => {
109 if pi >= path.len() {
110 return None;
111 }
112 params.insert(name.clone(), path[pi].to_string());
113 pi += 1;
114 }
115 Segment::Wildcard(name) => {
116 if si != pattern.len() - 1 {
117 return None;
118 }
119 params.insert(name.clone(), path[pi..].join("/"));
120 pi = path.len();
121 }
122 }
123 }
124
125 if pi == path.len() { Some(params) } else { None }
126}
127
128struct AsyncRoute<S> {
131 method: String,
132 segments: Vec<Segment>,
133 handler: AsyncHandlerFn<S>,
134}
135
136pub struct AsyncAppWithState<S> {
144 state: Arc<S>,
145 routes: Vec<AsyncRoute<S>>,
146}
147
148impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
149 pub fn new(state: S) -> Self {
151 AsyncAppWithState { state: Arc::new(state), routes: Vec::new() }
152 }
153
154 pub fn state(&self) -> &S {
156 &self.state
157 }
158
159 fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
160 where
161 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
162 Fut: Future<Output = Response> + Send + 'static,
163 {
164 self.routes.push(AsyncRoute {
165 method: method.to_string(),
166 segments: parse_pattern(pattern),
167 handler: Box::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
168 });
169 self
170 }
171
172 pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
174 where
175 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
176 Fut: Future<Output = Response> + Send + 'static,
177 {
178 self.add("GET", pattern, handler)
179 }
180
181 pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
183 where
184 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
185 Fut: Future<Output = Response> + Send + 'static,
186 {
187 self.add("POST", pattern, handler)
188 }
189
190 pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
192 where
193 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
194 Fut: Future<Output = Response> + Send + 'static,
195 {
196 self.add("PUT", pattern, handler)
197 }
198
199 pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
201 where
202 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
203 Fut: Future<Output = Response> + Send + 'static,
204 {
205 self.add("PATCH", pattern, handler)
206 }
207
208 pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
210 where
211 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
212 Fut: Future<Output = Response> + Send + 'static,
213 {
214 self.add("DELETE", pattern, handler)
215 }
216
217 async fn execute_async(
218 &self,
219 request: &Request,
220 connection: &ConnectionInfo,
221 ) -> Result<Response, String> {
222 let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
223 let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
224
225 for route in &self.routes {
226 if route.method != request.method {
227 continue;
228 }
229 if let Some(params_map) = try_match(&route.segments, &path_segs) {
230 let params = PathParams::from_map(params_map);
231 let fut = (route.handler)(
232 request.clone(),
233 params,
234 connection.clone(),
235 Arc::clone(&self.state),
236 );
237 return Ok(fut.await);
238 }
239 }
240
241 App::new().execute(request, connection)
242 }
243}
244
245impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
246 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
247 let request = request.clone();
248 let connection = connection.clone();
249 match tokio::runtime::Handle::try_current() {
250 Ok(_) => {
251 std::thread::scope(|s| {
254 s.spawn(|| {
255 tokio::runtime::Builder::new_current_thread()
256 .enable_all()
257 .build()
258 .unwrap()
259 .block_on(self.execute_async(&request, &connection))
260 })
261 .join()
262 .unwrap()
263 })
264 }
265 Err(_) => {
266 tokio::runtime::Builder::new_current_thread()
268 .enable_all()
269 .build()
270 .unwrap()
271 .block_on(self.execute_async(&request, &connection))
272 }
273 }
274 }
275}