rust_web_server/async_state/
mod.rs1#[cfg(test)]
48mod tests;
49
50use std::collections::HashMap;
51use std::future::Future;
52use std::pin::Pin;
53use std::sync::Arc;
54
55use crate::app::App;
56use crate::application::Application;
57use crate::core::New;
58use crate::request::Request;
59use crate::response::Response;
60use crate::router::PathParams;
61use crate::server::ConnectionInfo;
62
63type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
64
65type AsyncHandlerFn<S> = Arc<
66 dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
67>;
68
69#[derive(Clone)]
72enum Segment {
73 Literal(String),
74 Param(String),
75 Wildcard(String),
76}
77
78fn parse_pattern(pattern: &str) -> Vec<Segment> {
79 if pattern == "/" {
80 return vec![];
81 }
82 pattern
83 .split('/')
84 .filter(|s| !s.is_empty())
85 .map(|seg| {
86 if let Some(name) = seg.strip_prefix(':') {
87 Segment::Param(name.to_string())
88 } else if let Some(name) = seg.strip_prefix('*') {
89 Segment::Wildcard(name.to_string())
90 } else {
91 Segment::Literal(seg.to_string())
92 }
93 })
94 .collect()
95}
96
97fn try_match(pattern: &[Segment], path: &[&str]) -> Option<HashMap<String, String>> {
98 let mut params = HashMap::new();
99 let mut pi = 0;
100
101 for (si, seg) in pattern.iter().enumerate() {
102 match seg {
103 Segment::Literal(lit) => {
104 if pi >= path.len() || path[pi] != lit.as_str() {
105 return None;
106 }
107 pi += 1;
108 }
109 Segment::Param(name) => {
110 if pi >= path.len() {
111 return None;
112 }
113 params.insert(name.clone(), path[pi].to_string());
114 pi += 1;
115 }
116 Segment::Wildcard(name) => {
117 if si != pattern.len() - 1 {
118 return None;
119 }
120 params.insert(name.clone(), path[pi..].join("/"));
121 pi = path.len();
122 }
123 }
124 }
125
126 if pi == path.len() { Some(params) } else { None }
127}
128
129#[derive(Clone)]
132struct AsyncRoute<S> {
133 method: String,
134 segments: Vec<Segment>,
135 handler: AsyncHandlerFn<S>,
136}
137
138#[derive(Clone)]
146pub struct AsyncAppWithState<S> {
147 state: Arc<S>,
148 routes: Vec<AsyncRoute<S>>,
149}
150
151impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
152 pub fn new(state: S) -> Self {
154 AsyncAppWithState { state: Arc::new(state), routes: Vec::new() }
155 }
156
157 pub fn state(&self) -> &S {
159 &self.state
160 }
161
162 fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
163 where
164 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
165 Fut: Future<Output = Response> + Send + 'static,
166 {
167 self.routes.push(AsyncRoute {
168 method: method.to_string(),
169 segments: parse_pattern(pattern),
170 handler: Arc::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
171 });
172 self
173 }
174
175 pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
177 where
178 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
179 Fut: Future<Output = Response> + Send + 'static,
180 {
181 self.add("GET", pattern, handler)
182 }
183
184 pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
186 where
187 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
188 Fut: Future<Output = Response> + Send + 'static,
189 {
190 self.add("POST", pattern, handler)
191 }
192
193 pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
195 where
196 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
197 Fut: Future<Output = Response> + Send + 'static,
198 {
199 self.add("PUT", pattern, handler)
200 }
201
202 pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
204 where
205 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
206 Fut: Future<Output = Response> + Send + 'static,
207 {
208 self.add("PATCH", pattern, handler)
209 }
210
211 pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
213 where
214 F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
215 Fut: Future<Output = Response> + Send + 'static,
216 {
217 self.add("DELETE", pattern, handler)
218 }
219
220 async fn execute_async(
221 &self,
222 request: &Request,
223 connection: &ConnectionInfo,
224 ) -> Result<Response, String> {
225 let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
226 let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
227
228 for route in &self.routes {
229 if route.method != request.method {
230 continue;
231 }
232 if let Some(params_map) = try_match(&route.segments, &path_segs) {
233 let params = PathParams::from_map(params_map);
234 let fut = (route.handler)(
235 request.clone(),
236 params,
237 connection.clone(),
238 Arc::clone(&self.state),
239 );
240 return Ok(fut.await);
241 }
242 }
243
244 App::new().execute(request, connection)
245 }
246}
247
248impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
249 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
250 let request = request.clone();
251 let connection = connection.clone();
252 match tokio::runtime::Handle::try_current() {
253 Ok(_) => {
254 std::thread::scope(|s| {
257 s.spawn(|| {
258 tokio::runtime::Builder::new_current_thread()
259 .enable_all()
260 .build()
261 .unwrap()
262 .block_on(self.execute_async(&request, &connection))
263 })
264 .join()
265 .unwrap()
266 })
267 }
268 Err(_) => {
269 tokio::runtime::Builder::new_current_thread()
271 .enable_all()
272 .build()
273 .unwrap()
274 .block_on(self.execute_async(&request, &connection))
275 }
276 }
277 }
278}