1use std::net::SocketAddr;
2use http::Method;
3use crate::{
4 Request, Response, Router, Handler,
5 middleware::{MiddlewareStack, Middleware},
6 error_pages::ErrorPages,
7 server::serve,
8 extractors::state::{StateMap, RequestStateExt},
9};
10
11pub struct App {
13 router: Router,
14 middleware: MiddlewareStack,
15 error_pages: ErrorPages,
16 state: StateMap,
17 #[cfg(feature = "api")]
18 pub(crate) api_docs: Option<crate::api::ApiDocBuilder>,
19 #[cfg(not(feature = "api"))]
20 _phantom: std::marker::PhantomData<()>,
21}
22
23impl App {
24 pub fn new() -> Self {
26 Self {
27 router: Router::new(),
28 middleware: MiddlewareStack::new(),
29 error_pages: ErrorPages::new(),
30 state: StateMap::new(),
31 #[cfg(feature = "api")]
32 api_docs: None,
33 #[cfg(not(feature = "api"))]
34 _phantom: std::marker::PhantomData,
35 }
36 }
37
38 pub fn with_state<T>(mut self, state: T) -> Self
40 where
41 T: Clone + Send + Sync + 'static,
42 {
43 self.state.insert(state);
44 self
45 }
46
47 pub fn middleware<M>(mut self, middleware: M) -> Self
49 where
50 M: Middleware,
51 {
52 self.middleware.add(middleware);
53 self
54 }
55
56 pub fn route<H, T>(mut self, method: Method, path: &str, handler: H) -> Self
58 where
59 H: Handler<T>,
60 {
61 let handler_fn = crate::handler::into_handler_fn(handler);
62 self.router.route(method, path, handler_fn);
63 self
64 }
65
66 pub fn get<H, T>(self, path: &str, handler: H) -> Self
68 where
69 H: Handler<T>,
70 {
71 self.route(Method::GET, path, handler)
72 }
73
74 pub fn post<H, T>(self, path: &str, handler: H) -> Self
76 where
77 H: Handler<T>,
78 {
79 self.route(Method::POST, path, handler)
80 }
81
82 pub fn put<H, T>(self, path: &str, handler: H) -> Self
84 where
85 H: Handler<T>,
86 {
87 self.route(Method::PUT, path, handler)
88 }
89
90 pub fn delete<H, T>(self, path: &str, handler: H) -> Self
92 where
93 H: Handler<T>,
94 {
95 self.route(Method::DELETE, path, handler)
96 }
97
98 pub fn patch<H, T>(self, path: &str, handler: H) -> Self
100 where
101 H: Handler<T>,
102 {
103 self.route(Method::PATCH, path, handler)
104 }
105
106 pub fn options<H, T>(self, path: &str, handler: H) -> Self
108 where
109 H: Handler<T>,
110 {
111 self.route(Method::OPTIONS, path, handler)
112 }
113
114 pub fn head<H, T>(self, path: &str, handler: H) -> Self
116 where
117 H: Handler<T>,
118 {
119 self.route(Method::HEAD, path, handler)
120 }
121
122 pub fn not_found<H, T>(mut self, handler: H) -> Self
124 where
125 H: Handler<T>,
126 {
127 let handler_fn = crate::handler::into_handler_fn(handler);
128 self.router.not_found(handler_fn);
129 self
130 }
131
132 pub fn mount(mut self, prefix: &str, other: Router) -> Self {
134 let prefix = prefix.trim_end_matches('/');
136
137 for (method, path, handler) in other.get_all_routes() {
139 let prefixed_path = if prefix.is_empty() {
140 path
141 } else {
142 format!("{}{}", prefix, path)
143 };
144
145 self.router.route(method, &prefixed_path, handler);
146 }
147
148 self
149 }
150
151 pub fn error_pages(mut self, error_pages: ErrorPages) -> Self {
153 self.error_pages = error_pages;
154 self
155 }
156
157 pub fn custom_404(mut self, html: String) -> Self {
159 self.error_pages = self.error_pages.custom_404(html);
160 self
161 }
162
163 pub fn custom_500(mut self, html: String) -> Self {
165 self.error_pages = self.error_pages.custom_500(html);
166 self
167 }
168
169 pub fn plain_error_pages(mut self) -> Self {
171 self.error_pages = self.error_pages.without_default_styling();
172 self
173 }
174
175 #[cfg(feature = "websocket")]
177 pub fn websocket<F, Fut>(self, path: &str, handler: F) -> Self
178 where
179 F: Fn(crate::websocket::WebSocketConnection) -> Fut + Send + Sync + 'static,
180 Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
181 {
182 let handler = std::sync::Arc::new(handler);
183 self.get::<_, (Request,)>(path, move |req: Request| {
184 let _handler = handler.clone();
185 async move {
186 if crate::websocket::is_websocket_upgrade_request(&req) {
188 crate::websocket::websocket_upgrade(req).await
189 } else {
190 Response::bad_request().body("WebSocket upgrade required")
191 }
192 }
193 })
194 }
195
196 #[cfg(not(feature = "websocket"))]
198 pub fn websocket<F, Fut>(self, _path: &str, _handler: F) -> Self
199 where
200 F: Fn() -> Fut + Send + Sync + 'static,
201 Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
202 {
203 self
205 }
206
207 pub async fn listen(self, addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
209 let addr: SocketAddr = addr.parse()?;
210 println!("🔥 Torch server starting on http://{}", addr);
211 serve(addr, self).await
212 }
213
214 pub(crate) async fn handle_request(&self, mut req: Request) -> Response {
216 req.set_state_map(self.state.clone());
218
219 let router = self.router.clone();
220 let error_pages = self.error_pages.clone();
221
222 let response = self.middleware
223 .execute(req, move |req| {
224 let router = router.clone();
225 Box::pin(async move { router.route_request(req).await })
226 })
227 .await;
228
229 let status_code = response.status_code().as_u16();
231 if status_code >= 400 && self.should_render_error_page(&response) {
232 let dummy_req = Request::new();
234 error_pages.render_error(status_code, None, &dummy_req)
235 } else {
236 response
237 }
238 }
239
240 fn should_render_error_page(&self, response: &Response) -> bool {
242 let content_type = response.headers().get("content-type")
245 .and_then(|v| v.to_str().ok())
246 .unwrap_or("");
247
248 !content_type.starts_with("text/html") &&
250 !content_type.starts_with("application/json") &&
251 response.body_data().len() < 100 }
253}
254
255impl Default for App {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261pub fn app() -> App {
263 App::new()
264}
265
266impl App {
268 pub fn with_logging() -> Self {
270 Self::new().middleware(crate::middleware::logger())
271 }
272
273 pub fn with_cors() -> Self {
275 Self::new().middleware(crate::middleware::cors())
276 }
277
278 pub fn with_defaults() -> Self {
280 Self::new()
281 .middleware(crate::middleware::logger())
283 .middleware(crate::production::MetricsCollector::new())
284 .middleware(crate::production::PerformanceMonitor)
285
286 .middleware(crate::security::SecurityHeaders::new())
288 .middleware(crate::security::RequestId)
289 .middleware(crate::security::InputValidator)
290
291 .middleware(crate::middleware::cors())
293
294 .middleware(crate::production::RequestTimeout::new(std::time::Duration::from_secs(30)))
296 .middleware(crate::production::RequestSizeLimit::new(16 * 1024 * 1024)) .middleware(crate::production::health_check())
298 }
299
300 pub fn with_security() -> Self {
302 Self::new()
303 .middleware(crate::security::SecurityHeaders::new())
304 .middleware(crate::security::RequestId)
305 .middleware(crate::security::InputValidator)
306 }
307
308 pub fn with_monitoring() -> Self {
310 Self::new()
311 .middleware(crate::middleware::logger())
312 .middleware(crate::production::MetricsCollector::new())
313 .middleware(crate::production::PerformanceMonitor)
314 .middleware(crate::production::health_check())
315 }
316}
317
318#[cfg(disabled_for_now)]
319mod tests {
320 use super::*;
321 use std::pin::Pin;
322 use std::future::Future;
323 use crate::Response;
324
325 #[tokio::test]
326 async fn test_app_creation() {
327 let app = App::new()
328 .get("/", |_req: Request| async {
329 Response::ok().body("Hello, World!")
330 })
331 .post("/users", |_req: Request| async {
332 Response::ok().body("User created")
333 });
334
335 let req = Request::from_hyper(
337 http::Request::builder()
338 .method("GET")
339 .uri("/")
340 .body(())
341 .unwrap()
342 .into_parts()
343 .0,
344 Vec::new(),
345 )
346 .await
347 .unwrap();
348
349 let response = app.handle_request(req).await;
350 assert_eq!(response.body_data(), b"Hello, World!");
351 }
352
353 #[tokio::test]
354 async fn test_app_with_middleware() {
355 let app = App::new()
356 .middleware(|req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
357 Box::pin(async move {
358 let mut response = next(req).await;
359 response = response.header("X-Test", "middleware");
360 response
361 })
362 })
363 .get("/", |_req: Request| async {
364 Response::ok().body("Hello")
365 });
366
367 let req = Request::from_hyper(
368 http::Request::builder()
369 .method("GET")
370 .uri("/")
371 .body(())
372 .unwrap()
373 .into_parts()
374 .0,
375 Vec::new(),
376 )
377 .await
378 .unwrap();
379
380 let response = app.handle_request(req).await;
381 assert_eq!(response.headers().get("X-Test").unwrap(), "middleware");
382 }
383
384 #[test]
385 fn test_app_builder_pattern() {
386 let _app = App::new()
387 .get("/", |_req: Request| async { Response::ok() })
388 .post("/users", |_req: Request| async { Response::ok() })
389 .middleware(crate::middleware::logger())
390 .middleware(crate::middleware::cors());
391 }
392
393 #[test]
394 fn test_convenience_constructors() {
395 let _app1 = App::with_logging();
396 let _app2 = App::with_cors();
397 let _app3 = App::with_security();
398 let _app4 = App::with_defaults();
399 }
400}