rust_web_server/async_state/
mod.rs1#[cfg(test)]
47mod tests;
48
49use std::collections::HashMap;
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53
54use crate::app::App;
55use crate::application::Application;
56use crate::core::New;
57use crate::request::Request;
58use crate::response::Response;
59use crate::router::PathParams;
60use crate::server::ConnectionInfo;
61
62type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
63
64type AsyncHandlerFn<S> = Box<
65 dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
66>;
67
68enum Segment {
71 Literal(String),
72 Param(String),
73 Wildcard(String),
74}
75
76fn parse_pattern(pattern: &str) -> Vec<Segment> {
77 if pattern == "/" {
78 return vec![];
79 }
80 pattern
81 .split('/')
82 .filter(|s| !s.is_empty())
83 .map(|seg| {
84 if let Some(name) = seg.strip_prefix(':') {
85 Segment::Param(name.to_string())
86 } else if let Some(name) = seg.strip_prefix('*') {
87 Segment::Wildcard(name.to_string())
88 } else {
89 Segment::Literal(seg.to_string())
90 }
91 })
92 .collect()
93}
94
95fn try_match(pattern: &[Segment], path: &[&str]) -> Option<HashMap<String, String>> {
96 let mut params = HashMap::new();
97 let mut pi = 0;
98
99 for (si, seg) in pattern.iter().enumerate() {
100 match seg {
101 Segment::Literal(lit) => {
102 if pi >= path.len() || path[pi] != lit.as_str() {
103 return None;
104 }
105 pi += 1;
106 }
107 Segment::Param(name) => {
108 if pi >= path.len() {
109 return None;
110 }
111 params.insert(name.clone(), path[pi].to_string());
112 pi += 1;
113 }
114 Segment::Wildcard(name) => {
115 if si != pattern.len() - 1 {
116 return None;
117 }
118 params.insert(name.clone(), path[pi..].join("/"));
119 pi = path.len();
120 }
121 }
122 }
123
124 if pi == path.len() { Some(params) } else { None }
125}
126
127struct AsyncRoute<S> {
130 method: String,
131 segments: Vec<Segment>,
132 handler: AsyncHandlerFn<S>,
133}
134
135pub struct AsyncAppWithState<S> {
143 state: Arc<S>,
144 routes: Vec<AsyncRoute<S>>,
145}
146
147impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
148 pub fn new(state: S) -> Self {
150 AsyncAppWithState { state: Arc::new(state), routes: Vec::new() }
151 }
152
153 pub fn state(&self) -> &S {
155 &self.state
156 }
157
158 fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
159 where
160 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
161 Fut: Future<Output = Response> + Send + 'static,
162 {
163 self.routes.push(AsyncRoute {
164 method: method.to_string(),
165 segments: parse_pattern(pattern),
166 handler: Box::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
167 });
168 self
169 }
170
171 pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
173 where
174 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
175 Fut: Future<Output = Response> + Send + 'static,
176 {
177 self.add("GET", pattern, handler)
178 }
179
180 pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
182 where
183 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
184 Fut: Future<Output = Response> + Send + 'static,
185 {
186 self.add("POST", pattern, handler)
187 }
188
189 pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
191 where
192 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
193 Fut: Future<Output = Response> + Send + 'static,
194 {
195 self.add("PUT", pattern, handler)
196 }
197
198 pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
200 where
201 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
202 Fut: Future<Output = Response> + Send + 'static,
203 {
204 self.add("PATCH", pattern, handler)
205 }
206
207 pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
209 where
210 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
211 Fut: Future<Output = Response> + Send + 'static,
212 {
213 self.add("DELETE", pattern, handler)
214 }
215
216 async fn execute_async(
217 &self,
218 request: &Request,
219 connection: &ConnectionInfo,
220 ) -> Result<Response, String> {
221 let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
222 let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
223
224 for route in &self.routes {
225 if route.method != request.method {
226 continue;
227 }
228 if let Some(params_map) = try_match(&route.segments, &path_segs) {
229 let params = PathParams::from_map(params_map);
230 let fut = (route.handler)(
231 request.clone(),
232 params,
233 connection.clone(),
234 Arc::clone(&self.state),
235 );
236 return Ok(fut.await);
237 }
238 }
239
240 App::new().execute(request, connection)
241 }
242}
243
244impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
245 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
246 let request = request.clone();
247 let connection = connection.clone();
248 match tokio::runtime::Handle::try_current() {
249 Ok(_) => {
250 std::thread::scope(|s| {
253 s.spawn(|| {
254 tokio::runtime::Builder::new_current_thread()
255 .enable_all()
256 .build()
257 .unwrap()
258 .block_on(self.execute_async(&request, &connection))
259 })
260 .join()
261 .unwrap()
262 })
263 }
264 Err(_) => {
265 tokio::runtime::Builder::new_current_thread()
267 .enable_all()
268 .build()
269 .unwrap()
270 .block_on(self.execute_async(&request, &connection))
271 }
272 }
273 }
274}