1use std::net::SocketAddr;
2use http::Method;
3use crate::{
4 Request, Response, Router, Handler,
5 middleware::{MiddlewareStack, Middleware},
6 error_pages::ErrorPages,
7 server::serve,
8};
9
10pub struct App {
12 router: Router,
13 middleware: MiddlewareStack,
14 error_pages: ErrorPages,
15 #[cfg(feature = "api")]
16 pub(crate) api_docs: Option<crate::api::ApiDocBuilder>,
17 #[cfg(not(feature = "api"))]
18 _phantom: std::marker::PhantomData<()>,
19}
20
21impl App {
22 pub fn new() -> Self {
24 Self {
25 router: Router::new(),
26 middleware: MiddlewareStack::new(),
27 error_pages: ErrorPages::new(),
28 #[cfg(feature = "api")]
29 api_docs: None,
30 #[cfg(not(feature = "api"))]
31 _phantom: std::marker::PhantomData,
32 }
33 }
34
35 pub fn middleware<M>(mut self, middleware: M) -> Self
37 where
38 M: Middleware,
39 {
40 self.middleware.add(middleware);
41 self
42 }
43
44 pub fn route<H, T>(mut self, method: Method, path: &str, handler: H) -> Self
46 where
47 H: Handler<T>,
48 {
49 let handler_fn = crate::handler::into_handler_fn(handler);
50 self.router.route(method, path, handler_fn);
51 self
52 }
53
54 pub fn get<H, T>(self, path: &str, handler: H) -> Self
56 where
57 H: Handler<T>,
58 {
59 self.route(Method::GET, path, handler)
60 }
61
62 pub fn post<H, T>(self, path: &str, handler: H) -> Self
64 where
65 H: Handler<T>,
66 {
67 self.route(Method::POST, path, handler)
68 }
69
70 pub fn put<H, T>(self, path: &str, handler: H) -> Self
72 where
73 H: Handler<T>,
74 {
75 self.route(Method::PUT, path, handler)
76 }
77
78 pub fn delete<H, T>(self, path: &str, handler: H) -> Self
80 where
81 H: Handler<T>,
82 {
83 self.route(Method::DELETE, path, handler)
84 }
85
86 pub fn patch<H, T>(self, path: &str, handler: H) -> Self
88 where
89 H: Handler<T>,
90 {
91 self.route(Method::PATCH, path, handler)
92 }
93
94 pub fn options<H, T>(self, path: &str, handler: H) -> Self
96 where
97 H: Handler<T>,
98 {
99 self.route(Method::OPTIONS, path, handler)
100 }
101
102 pub fn head<H, T>(self, path: &str, handler: H) -> Self
104 where
105 H: Handler<T>,
106 {
107 self.route(Method::HEAD, path, handler)
108 }
109
110 pub fn not_found<H, T>(mut self, handler: H) -> Self
112 where
113 H: Handler<T>,
114 {
115 let handler_fn = crate::handler::into_handler_fn(handler);
116 self.router.not_found(handler_fn);
117 self
118 }
119
120 pub fn mount(mut self, prefix: &str, other: Router) -> Self {
122 let prefix = prefix.trim_end_matches('/');
124
125 for (method, path, handler) in other.get_all_routes() {
127 let prefixed_path = if prefix.is_empty() {
128 path
129 } else {
130 format!("{}{}", prefix, path)
131 };
132
133 self.router.route(method, &prefixed_path, handler);
134 }
135
136 self
137 }
138
139 pub fn error_pages(mut self, error_pages: ErrorPages) -> Self {
141 self.error_pages = error_pages;
142 self
143 }
144
145 pub fn custom_404(mut self, html: String) -> Self {
147 self.error_pages = self.error_pages.custom_404(html);
148 self
149 }
150
151 pub fn custom_500(mut self, html: String) -> Self {
153 self.error_pages = self.error_pages.custom_500(html);
154 self
155 }
156
157 pub fn plain_error_pages(mut self) -> Self {
159 self.error_pages = self.error_pages.without_default_styling();
160 self
161 }
162
163 #[cfg(feature = "websocket")]
165 pub fn websocket<F, Fut>(self, path: &str, handler: F) -> Self
166 where
167 F: Fn(crate::websocket::WebSocketConnection) -> Fut + Send + Sync + 'static,
168 Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
169 {
170 let handler = std::sync::Arc::new(handler);
171 self.get(path, move |req: Request| {
172 let _handler = handler.clone();
173 async move {
174 if crate::websocket::is_websocket_upgrade_request(&req) {
176 crate::websocket::websocket_upgrade(req).await
177 } else {
178 Response::bad_request().body("WebSocket upgrade required")
179 }
180 }
181 })
182 }
183
184 #[cfg(not(feature = "websocket"))]
186 pub fn websocket<F, Fut>(self, _path: &str, _handler: F) -> Self
187 where
188 F: Fn() -> Fut + Send + Sync + 'static,
189 Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
190 {
191 self
193 }
194
195 pub async fn listen(self, addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
197 let addr: SocketAddr = addr.parse()?;
198 println!("🔥 Torch server starting on http://{}", addr);
199 serve(addr, self).await
200 }
201
202 pub(crate) async fn handle_request(&self, req: Request) -> Response {
204 let router = self.router.clone();
205 let error_pages = self.error_pages.clone();
206
207 let response = self.middleware
208 .execute(req, move |req| {
209 let router = router.clone();
210 Box::pin(async move { router.route_request(req).await })
211 })
212 .await;
213
214 let status_code = response.status_code().as_u16();
216 if status_code >= 400 && self.should_render_error_page(&response) {
217 let dummy_req = Request::new();
219 error_pages.render_error(status_code, None, &dummy_req)
220 } else {
221 response
222 }
223 }
224
225 fn should_render_error_page(&self, response: &Response) -> bool {
227 let content_type = response.headers().get("content-type")
230 .and_then(|v| v.to_str().ok())
231 .unwrap_or("");
232
233 !content_type.starts_with("text/html") &&
235 !content_type.starts_with("application/json") &&
236 response.body_data().len() < 100 }
238}
239
240impl Default for App {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246pub fn app() -> App {
248 App::new()
249}
250
251impl App {
253 pub fn with_logging() -> Self {
255 Self::new().middleware(crate::middleware::logger())
256 }
257
258 pub fn with_cors() -> Self {
260 Self::new().middleware(crate::middleware::cors())
261 }
262
263 pub fn with_defaults() -> Self {
265 Self::new()
266 .middleware(crate::middleware::logger())
268 .middleware(crate::production::MetricsCollector::new())
269 .middleware(crate::production::PerformanceMonitor)
270
271 .middleware(crate::security::SecurityHeaders::new())
273 .middleware(crate::security::RequestId)
274 .middleware(crate::security::InputValidator)
275
276 .middleware(crate::middleware::cors())
278
279 .middleware(crate::production::RequestTimeout::new(std::time::Duration::from_secs(30)))
281 .middleware(crate::production::RequestSizeLimit::new(16 * 1024 * 1024)) .middleware(crate::production::health_check())
283 }
284
285 pub fn with_security() -> Self {
287 Self::new()
288 .middleware(crate::security::SecurityHeaders::new())
289 .middleware(crate::security::RequestId)
290 .middleware(crate::security::InputValidator)
291 }
292
293 pub fn with_monitoring() -> Self {
295 Self::new()
296 .middleware(crate::middleware::logger())
297 .middleware(crate::production::MetricsCollector::new())
298 .middleware(crate::production::PerformanceMonitor)
299 .middleware(crate::production::health_check())
300 }
301}
302
303#[cfg(disabled_for_now)]
304mod tests {
305 use super::*;
306 use std::pin::Pin;
307 use std::future::Future;
308 use crate::Response;
309
310 #[tokio::test]
311 async fn test_app_creation() {
312 let app = App::new()
313 .get("/", |_req: Request| async {
314 Response::ok().body("Hello, World!")
315 })
316 .post("/users", |_req: Request| async {
317 Response::ok().body("User created")
318 });
319
320 let req = Request::from_hyper(
322 http::Request::builder()
323 .method("GET")
324 .uri("/")
325 .body(())
326 .unwrap()
327 .into_parts()
328 .0,
329 Vec::new(),
330 )
331 .await
332 .unwrap();
333
334 let response = app.handle_request(req).await;
335 assert_eq!(response.body_data(), b"Hello, World!");
336 }
337
338 #[tokio::test]
339 async fn test_app_with_middleware() {
340 let app = App::new()
341 .middleware(|req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
342 Box::pin(async move {
343 let mut response = next(req).await;
344 response = response.header("X-Test", "middleware");
345 response
346 })
347 })
348 .get("/", |_req: Request| async {
349 Response::ok().body("Hello")
350 });
351
352 let req = Request::from_hyper(
353 http::Request::builder()
354 .method("GET")
355 .uri("/")
356 .body(())
357 .unwrap()
358 .into_parts()
359 .0,
360 Vec::new(),
361 )
362 .await
363 .unwrap();
364
365 let response = app.handle_request(req).await;
366 assert_eq!(response.headers().get("X-Test").unwrap(), "middleware");
367 }
368
369 #[test]
370 fn test_app_builder_pattern() {
371 let _app = App::new()
372 .get("/", |_req: Request| async { Response::ok() })
373 .post("/users", |_req: Request| async { Response::ok() })
374 .middleware(crate::middleware::logger())
375 .middleware(crate::middleware::cors());
376 }
377
378 #[test]
379 fn test_convenience_constructors() {
380 let _app1 = App::with_logging();
381 let _app2 = App::with_cors();
382 let _app3 = App::with_security();
383 let _app4 = App::with_defaults();
384 }
385}