torch_web/
app.rs

1use std::net::SocketAddr;
2use http::Method;
3use crate::{
4    Request, Response, Router, Handler,
5    middleware::{MiddlewareStack, Middleware},
6    error_pages::ErrorPages,
7    server::serve,
8};
9
10/// Your app's starting point - where all the magic happens
11pub struct App {
12    router: Router,
13    middleware: MiddlewareStack,
14    error_pages: ErrorPages,
15    #[cfg(feature = "api")]
16    pub(crate) api_docs: Option<crate::api::ApiDocBuilder>,
17    #[cfg(not(feature = "api"))]
18    _phantom: std::marker::PhantomData<()>,
19}
20
21impl App {
22    /// Start with a clean slate
23    pub fn new() -> Self {
24        Self {
25            router: Router::new(),
26            middleware: MiddlewareStack::new(),
27            error_pages: ErrorPages::new(),
28            #[cfg(feature = "api")]
29            api_docs: None,
30            #[cfg(not(feature = "api"))]
31            _phantom: std::marker::PhantomData,
32        }
33    }
34
35    /// Stack up middleware for request processing
36    pub fn middleware<M>(mut self, middleware: M) -> Self
37    where
38        M: Middleware,
39    {
40        self.middleware.add(middleware);
41        self
42    }
43
44    /// Map any HTTP method to a path with your handler
45    pub fn route<H, T>(mut self, method: Method, path: &str, handler: H) -> Self
46    where
47        H: Handler<T>,
48    {
49        let handler_fn = crate::handler::into_handler_fn(handler);
50        self.router.route(method, path, handler_fn);
51        self
52    }
53
54    /// Handle GET requests - perfect for serving pages and fetching data
55    pub fn get<H, T>(self, path: &str, handler: H) -> Self
56    where
57        H: Handler<T>,
58    {
59        self.route(Method::GET, path, handler)
60    }
61
62    /// Handle POST requests - for creating new resources
63    pub fn post<H, T>(self, path: &str, handler: H) -> Self
64    where
65        H: Handler<T>,
66    {
67        self.route(Method::POST, path, handler)
68    }
69
70    /// Handle PUT requests - for updating entire resources
71    pub fn put<H, T>(self, path: &str, handler: H) -> Self
72    where
73        H: Handler<T>,
74    {
75        self.route(Method::PUT, path, handler)
76    }
77
78    /// Handle DELETE requests - for removing resources
79    pub fn delete<H, T>(self, path: &str, handler: H) -> Self
80    where
81        H: Handler<T>,
82    {
83        self.route(Method::DELETE, path, handler)
84    }
85
86    /// Handle PATCH requests - for partial updates
87    pub fn patch<H, T>(self, path: &str, handler: H) -> Self
88    where
89        H: Handler<T>,
90    {
91        self.route(Method::PATCH, path, handler)
92    }
93
94    /// Handle OPTIONS requests - usually for CORS preflight
95    pub fn options<H, T>(self, path: &str, handler: H) -> Self
96    where
97        H: Handler<T>,
98    {
99        self.route(Method::OPTIONS, path, handler)
100    }
101
102    /// Handle HEAD requests - like GET but without the body
103    pub fn head<H, T>(self, path: &str, handler: H) -> Self
104    where
105        H: Handler<T>,
106    {
107        self.route(Method::HEAD, path, handler)
108    }
109
110    /// Catch requests that don't match any route
111    pub fn not_found<H, T>(mut self, handler: H) -> Self
112    where
113        H: Handler<T>,
114    {
115        let handler_fn = crate::handler::into_handler_fn(handler);
116        self.router.not_found(handler_fn);
117        self
118    }
119
120    /// Mount another router at a specific path prefix
121    pub fn mount(mut self, prefix: &str, other: Router) -> Self {
122        // Merge routes from the other router with the prefix
123        let prefix = prefix.trim_end_matches('/');
124
125        // Get all routes from the other router and add them with prefix
126        for (method, path, handler) in other.get_all_routes() {
127            let prefixed_path = if prefix.is_empty() {
128                path
129            } else {
130                format!("{}{}", prefix, path)
131            };
132
133            self.router.route(method, &prefixed_path, handler);
134        }
135
136        self
137    }
138
139    /// Configure error pages
140    pub fn error_pages(mut self, error_pages: ErrorPages) -> Self {
141        self.error_pages = error_pages;
142        self
143    }
144
145    /// Set a custom 404 page
146    pub fn custom_404(mut self, html: String) -> Self {
147        self.error_pages = self.error_pages.custom_404(html);
148        self
149    }
150
151    /// Set a custom 500 page
152    pub fn custom_500(mut self, html: String) -> Self {
153        self.error_pages = self.error_pages.custom_500(html);
154        self
155    }
156
157    /// Disable default error page styling (use plain HTML)
158    pub fn plain_error_pages(mut self) -> Self {
159        self.error_pages = self.error_pages.without_default_styling();
160        self
161    }
162
163    /// Add a WebSocket endpoint
164    #[cfg(feature = "websocket")]
165    pub fn websocket<F, Fut>(self, path: &str, handler: F) -> Self
166    where
167        F: Fn(crate::websocket::WebSocketConnection) -> Fut + Send + Sync + 'static,
168        Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
169    {
170        let handler = std::sync::Arc::new(handler);
171        self.get(path, move |req: Request| {
172            let _handler = handler.clone();
173            async move {
174                // Check if this is a WebSocket upgrade request
175                if crate::websocket::is_websocket_upgrade_request(&req) {
176                    crate::websocket::websocket_upgrade(req).await
177                } else {
178                    Response::bad_request().body("WebSocket upgrade required")
179                }
180            }
181        })
182    }
183
184    /// Add a WebSocket endpoint (no-op when websocket feature is disabled)
185    #[cfg(not(feature = "websocket"))]
186    pub fn websocket<F, Fut>(self, _path: &str, _handler: F) -> Self
187    where
188        F: Fn() -> Fut + Send + Sync + 'static,
189        Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
190    {
191        // WebSocket feature not enabled, return self unchanged
192        self
193    }
194
195    /// Fire up the server and start handling requests
196    pub async fn listen(self, addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
197        let addr: SocketAddr = addr.parse()?;
198        println!("🔥 Torch server starting on http://{}", addr);
199        serve(addr, self).await
200    }
201
202    /// Process incoming requests through middleware and routing
203    pub(crate) async fn handle_request(&self, req: Request) -> Response {
204        let router = self.router.clone();
205        let error_pages = self.error_pages.clone();
206
207        let response = self.middleware
208            .execute(req, move |req| {
209                let router = router.clone();
210                Box::pin(async move { router.route_request(req).await })
211            })
212            .await;
213
214        // Check if this is an error response that should be rendered with error pages
215        let status_code = response.status_code().as_u16();
216        if status_code >= 400 && self.should_render_error_page(&response) {
217            // Create a simple request for error page rendering
218            let dummy_req = Request::new();
219            error_pages.render_error(status_code, None, &dummy_req)
220        } else {
221            response
222        }
223    }
224
225    /// Check if we should render an error page for this response
226    fn should_render_error_page(&self, response: &Response) -> bool {
227        // Only render error pages for responses that look like default error responses
228        // (i.e., they have simple text bodies, not custom HTML)
229        let content_type = response.headers().get("content-type")
230            .and_then(|v| v.to_str().ok())
231            .unwrap_or("");
232
233        // Don't override responses that are already HTML or have custom content types
234        !content_type.starts_with("text/html") &&
235        !content_type.starts_with("application/json") &&
236        response.body_data().len() < 100 // Simple heuristic for basic error messages
237    }
238}
239
240impl Default for App {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246/// Quick way to create a new app
247pub fn app() -> App {
248    App::new()
249}
250
251/// Pre-configured apps for common scenarios
252impl App {
253    /// Create a new app with logging middleware
254    pub fn with_logging() -> Self {
255        Self::new().middleware(crate::middleware::logger())
256    }
257
258    /// Create a new app with CORS middleware
259    pub fn with_cors() -> Self {
260        Self::new().middleware(crate::middleware::cors())
261    }
262
263    /// Create a production-ready app with security, monitoring, and performance middleware
264    pub fn with_defaults() -> Self {
265        Self::new()
266            // Request logging and monitoring
267            .middleware(crate::middleware::logger())
268            .middleware(crate::production::MetricsCollector::new())
269            .middleware(crate::production::PerformanceMonitor)
270
271            // Security middleware
272            .middleware(crate::security::SecurityHeaders::new())
273            .middleware(crate::security::RequestId)
274            .middleware(crate::security::InputValidator)
275
276            // CORS support
277            .middleware(crate::middleware::cors())
278
279            // Production features
280            .middleware(crate::production::RequestTimeout::new(std::time::Duration::from_secs(30)))
281            .middleware(crate::production::RequestSizeLimit::new(16 * 1024 * 1024)) // 16MB
282            .middleware(crate::production::health_check())
283    }
284
285    /// Create an app with basic security features
286    pub fn with_security() -> Self {
287        Self::new()
288            .middleware(crate::security::SecurityHeaders::new())
289            .middleware(crate::security::RequestId)
290            .middleware(crate::security::InputValidator)
291    }
292
293    /// Create an app with monitoring and metrics
294    pub fn with_monitoring() -> Self {
295        Self::new()
296            .middleware(crate::middleware::logger())
297            .middleware(crate::production::MetricsCollector::new())
298            .middleware(crate::production::PerformanceMonitor)
299            .middleware(crate::production::health_check())
300    }
301}
302
303#[cfg(disabled_for_now)]
304mod tests {
305    use super::*;
306    use std::pin::Pin;
307    use std::future::Future;
308    use crate::Response;
309
310    #[tokio::test]
311    async fn test_app_creation() {
312        let app = App::new()
313            .get("/", |_req: Request| async {
314                Response::ok().body("Hello, World!")
315            })
316            .post("/users", |_req: Request| async {
317                Response::ok().body("User created")
318            });
319
320        // Test GET route
321        let req = Request::from_hyper(
322            http::Request::builder()
323                .method("GET")
324                .uri("/")
325                .body(())
326                .unwrap()
327                .into_parts()
328                .0,
329            Vec::new(),
330        )
331        .await
332        .unwrap();
333
334        let response = app.handle_request(req).await;
335        assert_eq!(response.body_data(), b"Hello, World!");
336    }
337
338    #[tokio::test]
339    async fn test_app_with_middleware() {
340        let app = App::new()
341            .middleware(|req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
342                Box::pin(async move {
343                    let mut response = next(req).await;
344                    response = response.header("X-Test", "middleware");
345                    response
346                })
347            })
348            .get("/", |_req: Request| async {
349                Response::ok().body("Hello")
350            });
351
352        let req = Request::from_hyper(
353            http::Request::builder()
354                .method("GET")
355                .uri("/")
356                .body(())
357                .unwrap()
358                .into_parts()
359                .0,
360            Vec::new(),
361        )
362        .await
363        .unwrap();
364
365        let response = app.handle_request(req).await;
366        assert_eq!(response.headers().get("X-Test").unwrap(), "middleware");
367    }
368
369    #[test]
370    fn test_app_builder_pattern() {
371        let _app = App::new()
372            .get("/", |_req: Request| async { Response::ok() })
373            .post("/users", |_req: Request| async { Response::ok() })
374            .middleware(crate::middleware::logger())
375            .middleware(crate::middleware::cors());
376    }
377
378    #[test]
379    fn test_convenience_constructors() {
380        let _app1 = App::with_logging();
381        let _app2 = App::with_cors();
382        let _app3 = App::with_security();
383        let _app4 = App::with_defaults();
384    }
385}