Skip to main content

rust_web_server/async_state/
mod.rs

1//! Async-capable state-aware application — requires the `http2` feature (tokio).
2//!
3//! [`AsyncAppWithState<S>`] is the async counterpart to [`AppWithState<S>`]:
4//! handlers are `async fn` closures that can `await` database queries, HTTP
5//! clients, or any other async I/O without blocking the OS thread.
6//!
7//! The sync [`Application`] bridge works in any calling context: when inside
8//! an existing tokio runtime (HTTP/2 / HTTP/3), it spawns a scoped OS thread
9//! with its own single-threaded runtime; when called from the HTTP/1.1
10//! thread-pool (no runtime), it creates a temporary single-threaded runtime.
11//!
12//! Unmatched routes fall through to the built-in [`App`] controller chain.
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! use std::sync::Arc;
18//! use rust_web_server::async_state::AsyncAppWithState;
19//! use rust_web_server::core::New;
20//! use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
21//! use rust_web_server::range::Range;
22//! use rust_web_server::mime_type::MimeType;
23//! use rust_web_server::router::PathParams;
24//! use rust_web_server::request::Request;
25//! use rust_web_server::server::ConnectionInfo;
26//!
27//! struct AppState {
28//!     greeting: String,
29//! }
30//!
31//! let app = AsyncAppWithState::new(AppState { greeting: "Hello".to_string() })
32//!     .get("/greet/:name", |_req, params, _conn, state| async move {
33//!         let name = params.get("name").unwrap_or("world");
34//!         let body = format!("{}, {}!", state.greeting, name);
35//!         let mut r = Response::new();
36//!         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
37//!         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
38//!         r.content_range_list = vec![
39//!             Range::get_content_range(body.into_bytes(), MimeType::TEXT_PLAIN.to_string())
40//!         ];
41//!         r
42//!     });
43//! ```
44//!
45//! [`AppWithState<S>`]: crate::state::AppWithState
46
47#[cfg(test)]
48mod tests;
49
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53
54use crate::app::App;
55use crate::application::Application;
56use crate::core::New;
57use crate::request::Request;
58use crate::response::Response;
59use crate::router::matcher::{self, Segment};
60use crate::router::PathParams;
61use crate::server::ConnectionInfo;
62#[cfg(feature = "openapi")]
63use crate::mime_type::MimeType;
64#[cfg(feature = "openapi")]
65use crate::range::Range;
66#[cfg(feature = "openapi")]
67use crate::response::STATUS_CODE_REASON_PHRASE;
68
69type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
70
71type AsyncHandlerFn<S> = Arc<
72    dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
73>;
74
75// ── AsyncRoute ────────────────────────────────────────────────────────────────
76
77#[derive(Clone)]
78struct AsyncRoute<S> {
79    method: String,
80    segments: Vec<Segment>,
81    handler: AsyncHandlerFn<S>,
82}
83
84// ── AsyncAppWithState ─────────────────────────────────────────────────────────
85
86/// An [`Application`] whose route handlers are `async` functions.
87///
88/// State is stored as `Arc<S>` and passed by value (cheap clone) to each
89/// handler invocation. Handlers receive owned `Request`, `PathParams`, and
90/// `ConnectionInfo` values so the returned future is `'static`.
91#[derive(Clone)]
92pub struct AsyncAppWithState<S> {
93    state: Arc<S>,
94    routes: Vec<AsyncRoute<S>>,
95    /// When `Some`, the fallback `App` is pinned to this config (see
96    /// [`App::with_config`]); when `None`, the fallback reads
97    /// `RWS_CONFIG_*` env vars per request via `App::new()`, same as `App`'s
98    /// own default.
99    config: Option<Arc<crate::server_config::ServerConfig>>,
100}
101
102impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
103    /// Create a new `AsyncAppWithState` wrapping `state`.
104    pub fn new(state: S) -> Self {
105        AsyncAppWithState { state: Arc::new(state), routes: Vec::new(), config: None }
106    }
107
108    /// Pin the fallback [`App`] (used for any request this app's own routes
109    /// don't match) to an explicit [`crate::server_config::ServerConfig`],
110    /// instead of reading `RWS_CONFIG_*` environment variables per request.
111    ///
112    /// Mirrors [`App::with_config`] / [`crate::state::AppWithState::with_config`].
113    pub fn with_config(mut self, config: crate::server_config::ServerConfig) -> Self {
114        self.config = Some(Arc::new(config));
115        self
116    }
117
118    /// Return a reference to the shared state.
119    pub fn state(&self) -> &S {
120        &self.state
121    }
122
123    /// Return a snapshot of all registered routes as `(method, pattern)` pairs.
124    pub fn route_entries(&self) -> Vec<crate::router::RouteInfo> {
125        self.routes
126            .iter()
127            .map(|r| crate::router::RouteInfo {
128                method: r.method.clone(),
129                pattern: matcher::segments_to_pattern(&r.segments),
130            })
131            .collect()
132    }
133
134    /// The fallback `App` for requests this app's own routes don't match —
135    /// pinned to `self.config` if set, otherwise `App::new()`'s default
136    /// per-request env read.
137    fn fallback_app(&self) -> App {
138        match &self.config {
139            Some(c) => App::with_config((**c).clone()),
140            None => App::new(),
141        }
142    }
143
144    fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
145    where
146        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
147        Fut: Future<Output = Response> + Send + 'static,
148    {
149        self.routes.push(AsyncRoute {
150            method: method.to_string(),
151            segments: matcher::parse_pattern(pattern),
152            handler: Arc::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
153        });
154        self
155    }
156
157    /// Register an async `GET` handler for `pattern`.
158    pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
159    where
160        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
161        Fut: Future<Output = Response> + Send + 'static,
162    {
163        self.add("GET", pattern, handler)
164    }
165
166    /// Register an async `POST` handler for `pattern`.
167    pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
168    where
169        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
170        Fut: Future<Output = Response> + Send + 'static,
171    {
172        self.add("POST", pattern, handler)
173    }
174
175    /// Register an async `PUT` handler for `pattern`.
176    pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
177    where
178        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
179        Fut: Future<Output = Response> + Send + 'static,
180    {
181        self.add("PUT", pattern, handler)
182    }
183
184    /// Register an async `PATCH` handler for `pattern`.
185    pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
186    where
187        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
188        Fut: Future<Output = Response> + Send + 'static,
189    {
190        self.add("PATCH", pattern, handler)
191    }
192
193    /// Register an async `DELETE` handler for `pattern`.
194    pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
195    where
196        F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
197        Fut: Future<Output = Response> + Send + 'static,
198    {
199        self.add("DELETE", pattern, handler)
200    }
201
202    /// Add `GET /openapi.json` (a generated OpenAPI 3.0 document covering
203    /// every route registered so far) and `GET /docs` (Swagger UI, loaded
204    /// from a CDN, pointed at `/openapi.json`).
205    ///
206    /// Call this *after* registering your routes — routes added afterward
207    /// still work but won't appear in the generated spec, since it's built
208    /// once at this call rather than read dynamically per request.
209    ///
210    /// Requires the `openapi` feature (and `http2`, since `AsyncAppWithState`
211    /// already requires it).
212    #[cfg(feature = "openapi")]
213    pub fn openapi(self, config: crate::openapi::OpenApiConfig) -> Self {
214        let spec_json = Arc::new(crate::openapi::build_spec(&config, &self.route_entries()));
215        let html = Arc::new(crate::openapi::swagger_ui_html("/openapi.json"));
216
217        let spec_for_route = Arc::clone(&spec_json);
218        self.get("/openapi.json", move |_req, _params, _conn, _state| {
219            let spec_for_route = Arc::clone(&spec_for_route);
220            async move {
221                let mut r = Response::new();
222                r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
223                r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
224                r.content_range_list = vec![Range::get_content_range(
225                    spec_for_route.as_bytes().to_vec(),
226                    MimeType::APPLICATION_JSON.to_string(),
227                )];
228                r
229            }
230        })
231        .get("/docs", move |_req, _params, _conn, _state| {
232            let html = Arc::clone(&html);
233            async move {
234                let mut r = Response::new();
235                r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
236                r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
237                r.content_range_list = vec![Range::get_content_range(
238                    html.as_bytes().to_vec(),
239                    MimeType::TEXT_HTML.to_string(),
240                )];
241                r
242            }
243        })
244    }
245
246    async fn execute_async(
247        &self,
248        request: &Request,
249        connection: &ConnectionInfo,
250    ) -> Result<Response, String> {
251        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
252        let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
253
254        for route in &self.routes {
255            if route.method != request.method {
256                continue;
257            }
258            if let Some(params_map) = matcher::try_match(&route.segments, &path_segs) {
259                let params = PathParams::from_map(params_map);
260                let fut = (route.handler)(
261                    request.clone(),
262                    params,
263                    connection.clone(),
264                    Arc::clone(&self.state),
265                );
266                return Ok(fut.await);
267            }
268        }
269
270        self.fallback_app().execute(request, connection)
271    }
272}
273
274impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
275    fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
276        let request = request.clone();
277        let connection = connection.clone();
278        match tokio::runtime::Handle::try_current() {
279            Ok(_) => {
280                // Inside an existing runtime: run the future on a scoped OS thread
281                // with its own single-threaded runtime to avoid blocking the event loop.
282                std::thread::scope(|s| {
283                    s.spawn(|| {
284                        tokio::runtime::Builder::new_current_thread()
285                            .enable_all()
286                            .build()
287                            .unwrap()
288                            .block_on(self.execute_async(&request, &connection))
289                    })
290                    .join()
291                    .unwrap()
292                })
293            }
294            Err(_) => {
295                // Not inside any runtime (HTTP/1.1 thread pool): create a temporary one.
296                tokio::runtime::Builder::new_current_thread()
297                    .enable_all()
298                    .build()
299                    .unwrap()
300                    .block_on(self.execute_async(&request, &connection))
301            }
302        }
303    }
304}