Skip to main content

rust_web_server/state/
mod.rs

1//! Shared application state and state-aware routing.
2//!
3//! [`AppWithState<S>`] combines a typed state value (database pools, config,
4//! caches) with route registration.  Routes are tried first; requests that do
5//! not match fall through to the built-in [`App`] controller chain (static
6//! files, healthz, metrics, …).
7//!
8//! State is stored as an [`Arc<S>`] and shared across all handlers. Handlers
9//! receive an immutable `&S` reference alongside the request context.
10//!
11//! # Example
12//!
13//! ```rust,no_run
14//! use rust_web_server::state::AppWithState;
15//! use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
16//! use rust_web_server::range::Range;
17//! use rust_web_server::mime_type::MimeType;
18//! use rust_web_server::core::New;
19//!
20//! struct AppState {
21//!     greeting: String,
22//! }
23//!
24//! let app = AppWithState::new(AppState { greeting: "Hello".to_string() })
25//!     .get("/greet", |_req, _params, _conn, state| {
26//!         let mut r = Response::new();
27//!         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
28//!         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
29//!         r.content_range_list = vec![
30//!             Range::get_content_range(
31//!                 state.greeting.as_bytes().to_vec(),
32//!                 MimeType::TEXT_PLAIN.to_string(),
33//!             )
34//!         ];
35//!         r
36//!     })
37//!     .get("/users/:id", |_req, params, _conn, state| {
38//!         let id = params.get("id").unwrap_or("?");
39//!         let body = format!("{}, user {}!", state.greeting, id);
40//!         let mut r = Response::new();
41//!         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
42//!         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
43//!         r.content_range_list = vec![
44//!             Range::get_content_range(body.into_bytes(), MimeType::TEXT_PLAIN.to_string())
45//!         ];
46//!         r
47//!     });
48//! ```
49
50#[cfg(test)]
51mod tests;
52
53use std::sync::Arc;
54
55use crate::app::App;
56use crate::application::Application;
57use crate::core::New;
58use crate::middleware::{Middleware, WithMiddleware};
59use crate::request::Request;
60use crate::response::Response;
61use crate::router::{PathParams, Router};
62use crate::server::ConnectionInfo;
63
64/// An [`Application`] that combines user-defined state-aware routes with the
65/// built-in [`App`] controller chain as a fallback.
66///
67/// Routes are matched in registration order. The first match wins; unmatched
68/// requests are forwarded to [`App`] (static files, health probes, etc.).
69#[derive(Clone)]
70pub struct AppWithState<S> {
71    state: Arc<S>,
72    router: Router,
73}
74
75impl<S: Send + Sync + 'static> AppWithState<S> {
76    /// Create a new `AppWithState` wrapping `state`.
77    ///
78    /// `state` is stored behind an `Arc` so it can be shared across threads
79    /// without cloning. Register routes with the builder methods.
80    pub fn new(state: S) -> Self {
81        AppWithState {
82            state: Arc::new(state),
83            router: Router::new(),
84        }
85    }
86
87    /// Return a reference to the shared state.
88    pub fn state(&self) -> &S {
89        &self.state
90    }
91
92    /// Register a `GET` handler for `pattern`.
93    pub fn get<F>(mut self, pattern: &str, handler: F) -> Self
94    where
95        F: Fn(&Request, &PathParams, &ConnectionInfo, &S) -> Response + Send + Sync + 'static,
96    {
97        let state = Arc::clone(&self.state);
98        self.router = self.router.get(pattern, move |req, params, conn| {
99            handler(req, params, conn, &state)
100        });
101        self
102    }
103
104    /// Register a `POST` handler for `pattern`.
105    pub fn post<F>(mut self, pattern: &str, handler: F) -> Self
106    where
107        F: Fn(&Request, &PathParams, &ConnectionInfo, &S) -> Response + Send + Sync + 'static,
108    {
109        let state = Arc::clone(&self.state);
110        self.router = self.router.post(pattern, move |req, params, conn| {
111            handler(req, params, conn, &state)
112        });
113        self
114    }
115
116    /// Register a `PUT` handler for `pattern`.
117    pub fn put<F>(mut self, pattern: &str, handler: F) -> Self
118    where
119        F: Fn(&Request, &PathParams, &ConnectionInfo, &S) -> Response + Send + Sync + 'static,
120    {
121        let state = Arc::clone(&self.state);
122        self.router = self.router.put(pattern, move |req, params, conn| {
123            handler(req, params, conn, &state)
124        });
125        self
126    }
127
128    /// Register a `PATCH` handler for `pattern`.
129    pub fn patch<F>(mut self, pattern: &str, handler: F) -> Self
130    where
131        F: Fn(&Request, &PathParams, &ConnectionInfo, &S) -> Response + Send + Sync + 'static,
132    {
133        let state = Arc::clone(&self.state);
134        self.router = self.router.patch(pattern, move |req, params, conn| {
135            handler(req, params, conn, &state)
136        });
137        self
138    }
139
140    /// Register a `DELETE` handler for `pattern`.
141    pub fn delete<F>(mut self, pattern: &str, handler: F) -> Self
142    where
143        F: Fn(&Request, &PathParams, &ConnectionInfo, &S) -> Response + Send + Sync + 'static,
144    {
145        let state = Arc::clone(&self.state);
146        self.router = self.router.delete(pattern, move |req, params, conn| {
147            handler(req, params, conn, &state)
148        });
149        self
150    }
151
152    /// Return a snapshot of all registered routes as `(method, pattern)` pairs.
153    pub fn route_entries(&self) -> Vec<crate::router::RouteInfo> {
154        self.router.route_entries()
155    }
156
157    /// Attach an MCP server to this application. Requests that do not match
158    /// the MCP endpoint (`POST /mcp`) are forwarded to `self`, so all
159    /// previously registered routes remain active.
160    ///
161    /// ```rust,no_run
162    /// use rust_web_server::app::App;
163    /// use rust_web_server::mcp::{McpContent, extract_arg};
164    /// use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
165    /// use rust_web_server::core::New;
166    ///
167    /// struct Db { url: String }
168    ///
169    /// let app = App::with_state(Db { url: "postgres://localhost/mydb".to_string() })
170    ///     .get("/api/users", |_req, _params, _conn, _db| {
171    ///         let mut r = Response::new();
172    ///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
173    ///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
174    ///         r
175    ///     })
176    ///     .mcp("my-server", "1.0")
177    ///     .tool("list_users", "List all users", "{}", |_| {
178    ///         Ok(McpContent::json(r#"[{"id":1,"name":"Alice"}]"#))
179    ///     });
180    /// ```
181    pub fn mcp(self, name: impl Into<String>, version: impl Into<String>) -> crate::mcp::McpServer {
182        crate::mcp::McpServer::new(name, version).wrap(self)
183    }
184
185    /// Wrap this application in a middleware layer.
186    ///
187    /// Enables fluent composition:
188    ///
189    /// ```rust,no_run
190    /// use rust_web_server::app::App;
191    /// use rust_web_server::core::New;
192    /// use rust_web_server::middleware::RateLimitLayer;
193    /// use rust_web_server::response::Response;
194    ///
195    /// let app = App::with_state(())
196    ///     .get("/ping", |_, _, _, _| Response::new())
197    ///     .wrap(RateLimitLayer);
198    /// ```
199    pub fn wrap<M: Middleware + 'static>(self, layer: M) -> WithMiddleware<AppWithState<S>> {
200        WithMiddleware::new(self).wrap(layer)
201    }
202}
203
204impl<S: Send + Sync + 'static> Application for AppWithState<S> {
205    fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
206        if let Some(response) = self.router.handle(request, connection) {
207            return Ok(response);
208        }
209        App::new().execute(request, connection)
210    }
211}