Skip to main content

rust_web_server/async_state/
mod.rs

1//! Async-capable state-aware application — requires the `http2` feature (tokio).
2//!
3//! [`AsyncAppWithState<S>`] is the async counterpart to [`AppWithState<S>`]:
4//! handlers are `async fn` closures that can `await` database queries, HTTP
5//! clients, or any other async I/O without blocking the OS thread.
6//!
7//! The sync [`Application`] bridge works in any calling context: when inside
8//! an existing tokio runtime (HTTP/2 / HTTP/3), it spawns a scoped OS thread
9//! with its own single-threaded runtime; when called from the HTTP/1.1
10//! thread-pool (no runtime), it creates a temporary single-threaded runtime.
11//!
12//! Unmatched routes fall through to the built-in [`App`] controller chain.
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! use std::sync::Arc;
18//! use rust_web_server::async_state::AsyncAppWithState;
19//! use rust_web_server::core::New;
20//! use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
21//! use rust_web_server::range::Range;
22//! use rust_web_server::mime_type::MimeType;
23//! use rust_web_server::router::PathParams;
24//! use rust_web_server::request::Request;
25//! use rust_web_server::server::ConnectionInfo;
26//!
27//! struct AppState {
28//!     greeting: String,
29//! }
30//!
31//! let app = AsyncAppWithState::new(AppState { greeting: "Hello".to_string() })
32//!     .get("/greet/:name", |_req, params, _conn, state| async move {
33//!         let name = params.get("name").unwrap_or("world");
34//!         let body = format!("{}, {}!", state.greeting, name);
35//!         let mut r = Response::new();
36//!         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
37//!         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
38//!         r.content_range_list = vec![
39//!             Range::get_content_range(body.into_bytes(), MimeType::TEXT_PLAIN.to_string())
40//!         ];
41//!         r
42//!     });
43//! ```
44//!
45//! [`AppWithState<S>`]: crate::state::AppWithState
46
47#[cfg(test)]
48mod tests;
49
50use std::collections::HashMap;
51use std::future::Future;
52use std::pin::Pin;
53use std::sync::Arc;
54
55use crate::app::App;
56use crate::application::Application;
57use crate::core::New;
58use crate::request::Request;
59use crate::response::Response;
60use crate::router::PathParams;
61use crate::server::ConnectionInfo;
62
63type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
64
65type AsyncHandlerFn<S> = Arc<
66    dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
67>;
68
69// ── Internal pattern matching (mirrors Router) ────────────────────────────────
70
71#[derive(Clone)]
72enum Segment {
73    Literal(String),
74    Param(String),
75    Wildcard(String),
76}
77
78fn parse_pattern(pattern: &str) -> Vec<Segment> {
79    if pattern == "/" {
80        return vec![];
81    }
82    pattern
83        .split('/')
84        .filter(|s| !s.is_empty())
85        .map(|seg| {
86            if let Some(name) = seg.strip_prefix(':') {
87                Segment::Param(name.to_string())
88            } else if let Some(name) = seg.strip_prefix('*') {
89                Segment::Wildcard(name.to_string())
90            } else {
91                Segment::Literal(seg.to_string())
92            }
93        })
94        .collect()
95}
96
97fn try_match(pattern: &[Segment], path: &[&str]) -> Option<HashMap<String, String>> {
98    let mut params = HashMap::new();
99    let mut pi = 0;
100
101    for (si, seg) in pattern.iter().enumerate() {
102        match seg {
103            Segment::Literal(lit) => {
104                if pi >= path.len() || path[pi] != lit.as_str() {
105                    return None;
106                }
107                pi += 1;
108            }
109            Segment::Param(name) => {
110                if pi >= path.len() {
111                    return None;
112                }
113                params.insert(name.clone(), path[pi].to_string());
114                pi += 1;
115            }
116            Segment::Wildcard(name) => {
117                if si != pattern.len() - 1 {
118                    return None;
119                }
120                params.insert(name.clone(), path[pi..].join("/"));
121                pi = path.len();
122            }
123        }
124    }
125
126    if pi == path.len() { Some(params) } else { None }
127}
128
129// ── AsyncRoute ────────────────────────────────────────────────────────────────
130
131#[derive(Clone)]
132struct AsyncRoute<S> {
133    method: String,
134    segments: Vec<Segment>,
135    handler: AsyncHandlerFn<S>,
136}
137
138// ── AsyncAppWithState ─────────────────────────────────────────────────────────
139
140/// An [`Application`] whose route handlers are `async` functions.
141///
142/// State is stored as `Arc<S>` and passed by value (cheap clone) to each
143/// handler invocation. Handlers receive owned `Request`, `PathParams`, and
144/// `ConnectionInfo` values so the returned future is `'static`.
145#[derive(Clone)]
146pub struct AsyncAppWithState<S> {
147    state: Arc<S>,
148    routes: Vec<AsyncRoute<S>>,
149}
150
151impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
152    /// Create a new `AsyncAppWithState` wrapping `state`.
153    pub fn new(state: S) -> Self {
154        AsyncAppWithState { state: Arc::new(state), routes: Vec::new() }
155    }
156
157    /// Return a reference to the shared state.
158    pub fn state(&self) -> &S {
159        &self.state
160    }
161
162    fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
163    where
164        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
165        Fut: Future<Output = Response> + Send + 'static,
166    {
167        self.routes.push(AsyncRoute {
168            method: method.to_string(),
169            segments: parse_pattern(pattern),
170            handler: Arc::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
171        });
172        self
173    }
174
175    /// Register an async `GET` handler for `pattern`.
176    pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
177    where
178        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
179        Fut: Future<Output = Response> + Send + 'static,
180    {
181        self.add("GET", pattern, handler)
182    }
183
184    /// Register an async `POST` handler for `pattern`.
185    pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
186    where
187        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
188        Fut: Future<Output = Response> + Send + 'static,
189    {
190        self.add("POST", pattern, handler)
191    }
192
193    /// Register an async `PUT` handler for `pattern`.
194    pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
195    where
196        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
197        Fut: Future<Output = Response> + Send + 'static,
198    {
199        self.add("PUT", pattern, handler)
200    }
201
202    /// Register an async `PATCH` handler for `pattern`.
203    pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
204    where
205        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
206        Fut: Future<Output = Response> + Send + 'static,
207    {
208        self.add("PATCH", pattern, handler)
209    }
210
211    /// Register an async `DELETE` handler for `pattern`.
212    pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
213    where
214        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
215        Fut: Future<Output = Response> + Send + 'static,
216    {
217        self.add("DELETE", pattern, handler)
218    }
219
220    async fn execute_async(
221        &self,
222        request: &Request,
223        connection: &ConnectionInfo,
224    ) -> Result<Response, String> {
225        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
226        let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
227
228        for route in &self.routes {
229            if route.method != request.method {
230                continue;
231            }
232            if let Some(params_map) = try_match(&route.segments, &path_segs) {
233                let params = PathParams::from_map(params_map);
234                let fut = (route.handler)(
235                    request.clone(),
236                    params,
237                    connection.clone(),
238                    Arc::clone(&self.state),
239                );
240                return Ok(fut.await);
241            }
242        }
243
244        App::new().execute(request, connection)
245    }
246}
247
248impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
249    fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
250        let request = request.clone();
251        let connection = connection.clone();
252        match tokio::runtime::Handle::try_current() {
253            Ok(_) => {
254                // Inside an existing runtime: run the future on a scoped OS thread
255                // with its own single-threaded runtime to avoid blocking the event loop.
256                std::thread::scope(|s| {
257                    s.spawn(|| {
258                        tokio::runtime::Builder::new_current_thread()
259                            .enable_all()
260                            .build()
261                            .unwrap()
262                            .block_on(self.execute_async(&request, &connection))
263                    })
264                    .join()
265                    .unwrap()
266                })
267            }
268            Err(_) => {
269                // Not inside any runtime (HTTP/1.1 thread pool): create a temporary one.
270                tokio::runtime::Builder::new_current_thread()
271                    .enable_all()
272                    .build()
273                    .unwrap()
274                    .block_on(self.execute_async(&request, &connection))
275            }
276        }
277    }
278}