Skip to main content

rust_web_server/async_state/
mod.rs

1//! Async-capable state-aware application — requires the `http2` feature (tokio).
2//!
3//! [`AsyncAppWithState<S>`] is the async counterpart to [`AppWithState<S>`]:
4//! handlers are `async fn` closures that can `await` database queries, HTTP
5//! clients, or any other async I/O without blocking the OS thread.
6//!
7//! The sync [`Application`] bridge works in any calling context: when inside
8//! an existing tokio runtime (HTTP/2 / HTTP/3), it spawns a scoped OS thread
9//! with its own single-threaded runtime; when called from the HTTP/1.1
10//! thread-pool (no runtime), it creates a temporary single-threaded runtime.
11//!
12//! Unmatched routes fall through to the built-in [`App`] controller chain.
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! use std::sync::Arc;
18//! use rust_web_server::async_state::AsyncAppWithState;
19//! use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
20//! use rust_web_server::range::Range;
21//! use rust_web_server::mime_type::MimeType;
22//! use rust_web_server::router::PathParams;
23//! use rust_web_server::request::Request;
24//! use rust_web_server::server::ConnectionInfo;
25//!
26//! struct AppState {
27//!     greeting: String,
28//! }
29//!
30//! let app = AsyncAppWithState::new(AppState { greeting: "Hello".to_string() })
31//!     .get("/greet/:name", |_req, params, _conn, state| async move {
32//!         let name = params.get("name").unwrap_or("world");
33//!         let body = format!("{}, {}!", state.greeting, name);
34//!         let mut r = Response::new();
35//!         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
36//!         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
37//!         r.content_range_list = vec![
38//!             Range::get_content_range(body.into_bytes(), MimeType::TEXT_PLAIN.to_string())
39//!         ];
40//!         r
41//!     });
42//! ```
43//!
44//! [`AppWithState<S>`]: crate::state::AppWithState
45
46#[cfg(test)]
47mod tests;
48
49use std::collections::HashMap;
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53
54use crate::app::App;
55use crate::application::Application;
56use crate::core::New;
57use crate::request::Request;
58use crate::response::Response;
59use crate::router::PathParams;
60use crate::server::ConnectionInfo;
61
62type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
63
64type AsyncHandlerFn<S> = Box<
65    dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
66>;
67
68// ── Internal pattern matching (mirrors Router) ────────────────────────────────
69
70enum Segment {
71    Literal(String),
72    Param(String),
73    Wildcard(String),
74}
75
76fn parse_pattern(pattern: &str) -> Vec<Segment> {
77    if pattern == "/" {
78        return vec![];
79    }
80    pattern
81        .split('/')
82        .filter(|s| !s.is_empty())
83        .map(|seg| {
84            if let Some(name) = seg.strip_prefix(':') {
85                Segment::Param(name.to_string())
86            } else if let Some(name) = seg.strip_prefix('*') {
87                Segment::Wildcard(name.to_string())
88            } else {
89                Segment::Literal(seg.to_string())
90            }
91        })
92        .collect()
93}
94
95fn try_match(pattern: &[Segment], path: &[&str]) -> Option<HashMap<String, String>> {
96    let mut params = HashMap::new();
97    let mut pi = 0;
98
99    for (si, seg) in pattern.iter().enumerate() {
100        match seg {
101            Segment::Literal(lit) => {
102                if pi >= path.len() || path[pi] != lit.as_str() {
103                    return None;
104                }
105                pi += 1;
106            }
107            Segment::Param(name) => {
108                if pi >= path.len() {
109                    return None;
110                }
111                params.insert(name.clone(), path[pi].to_string());
112                pi += 1;
113            }
114            Segment::Wildcard(name) => {
115                if si != pattern.len() - 1 {
116                    return None;
117                }
118                params.insert(name.clone(), path[pi..].join("/"));
119                pi = path.len();
120            }
121        }
122    }
123
124    if pi == path.len() { Some(params) } else { None }
125}
126
127// ── AsyncRoute ────────────────────────────────────────────────────────────────
128
129struct AsyncRoute<S> {
130    method: String,
131    segments: Vec<Segment>,
132    handler: AsyncHandlerFn<S>,
133}
134
135// ── AsyncAppWithState ─────────────────────────────────────────────────────────
136
137/// An [`Application`] whose route handlers are `async` functions.
138///
139/// State is stored as `Arc<S>` and passed by value (cheap clone) to each
140/// handler invocation. Handlers receive owned `Request`, `PathParams`, and
141/// `ConnectionInfo` values so the returned future is `'static`.
142pub struct AsyncAppWithState<S> {
143    state: Arc<S>,
144    routes: Vec<AsyncRoute<S>>,
145}
146
147impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
148    /// Create a new `AsyncAppWithState` wrapping `state`.
149    pub fn new(state: S) -> Self {
150        AsyncAppWithState { state: Arc::new(state), routes: Vec::new() }
151    }
152
153    /// Return a reference to the shared state.
154    pub fn state(&self) -> &S {
155        &self.state
156    }
157
158    fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
159    where
160        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
161        Fut: Future<Output = Response> + Send + 'static,
162    {
163        self.routes.push(AsyncRoute {
164            method: method.to_string(),
165            segments: parse_pattern(pattern),
166            handler: Box::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
167        });
168        self
169    }
170
171    /// Register an async `GET` handler for `pattern`.
172    pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
173    where
174        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
175        Fut: Future<Output = Response> + Send + 'static,
176    {
177        self.add("GET", pattern, handler)
178    }
179
180    /// Register an async `POST` handler for `pattern`.
181    pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
182    where
183        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
184        Fut: Future<Output = Response> + Send + 'static,
185    {
186        self.add("POST", pattern, handler)
187    }
188
189    /// Register an async `PUT` handler for `pattern`.
190    pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
191    where
192        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
193        Fut: Future<Output = Response> + Send + 'static,
194    {
195        self.add("PUT", pattern, handler)
196    }
197
198    /// Register an async `PATCH` handler for `pattern`.
199    pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
200    where
201        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
202        Fut: Future<Output = Response> + Send + 'static,
203    {
204        self.add("PATCH", pattern, handler)
205    }
206
207    /// Register an async `DELETE` handler for `pattern`.
208    pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
209    where
210        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
211        Fut: Future<Output = Response> + Send + 'static,
212    {
213        self.add("DELETE", pattern, handler)
214    }
215
216    async fn execute_async(
217        &self,
218        request: &Request,
219        connection: &ConnectionInfo,
220    ) -> Result<Response, String> {
221        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
222        let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
223
224        for route in &self.routes {
225            if route.method != request.method {
226                continue;
227            }
228            if let Some(params_map) = try_match(&route.segments, &path_segs) {
229                let params = PathParams::from_map(params_map);
230                let fut = (route.handler)(
231                    request.clone(),
232                    params,
233                    connection.clone(),
234                    Arc::clone(&self.state),
235                );
236                return Ok(fut.await);
237            }
238        }
239
240        App::new().execute(request, connection)
241    }
242}
243
244impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
245    fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
246        let request = request.clone();
247        let connection = connection.clone();
248        match tokio::runtime::Handle::try_current() {
249            Ok(_) => {
250                // Inside an existing runtime: run the future on a scoped OS thread
251                // with its own single-threaded runtime to avoid blocking the event loop.
252                std::thread::scope(|s| {
253                    s.spawn(|| {
254                        tokio::runtime::Builder::new_current_thread()
255                            .enable_all()
256                            .build()
257                            .unwrap()
258                            .block_on(self.execute_async(&request, &connection))
259                    })
260                    .join()
261                    .unwrap()
262                })
263            }
264            Err(_) => {
265                // Not inside any runtime (HTTP/1.1 thread pool): create a temporary one.
266                tokio::runtime::Builder::new_current_thread()
267                    .enable_all()
268                    .build()
269                    .unwrap()
270                    .block_on(self.execute_async(&request, &connection))
271            }
272        }
273    }
274}