torch_web/
app.rs

1use std::net::SocketAddr;
2use http::Method;
3use crate::{
4    Request, Response, Router, Handler,
5    middleware::{MiddlewareStack, Middleware},
6    error_pages::ErrorPages,
7    server::serve,
8    extractors::state::{StateMap, RequestStateExt},
9};
10
11/// Your app's starting point - where all the magic happens
12pub struct App {
13    router: Router,
14    middleware: MiddlewareStack,
15    error_pages: ErrorPages,
16    state: StateMap,
17    #[cfg(feature = "api")]
18    pub(crate) api_docs: Option<crate::api::ApiDocBuilder>,
19    #[cfg(not(feature = "api"))]
20    _phantom: std::marker::PhantomData<()>,
21}
22
23impl App {
24    /// Start with a clean slate
25    pub fn new() -> Self {
26        Self {
27            router: Router::new(),
28            middleware: MiddlewareStack::new(),
29            error_pages: ErrorPages::new(),
30            state: StateMap::new(),
31            #[cfg(feature = "api")]
32            api_docs: None,
33            #[cfg(not(feature = "api"))]
34            _phantom: std::marker::PhantomData,
35        }
36    }
37
38    /// Add application state that can be accessed in handlers
39    pub fn with_state<T>(mut self, state: T) -> Self
40    where
41        T: Clone + Send + Sync + 'static,
42    {
43        self.state.insert(state);
44        self
45    }
46
47    /// Stack up middleware for request processing
48    pub fn middleware<M>(mut self, middleware: M) -> Self
49    where
50        M: Middleware,
51    {
52        self.middleware.add(middleware);
53        self
54    }
55
56    /// Map any HTTP method to a path with your handler
57    pub fn route<H, T>(mut self, method: Method, path: &str, handler: H) -> Self
58    where
59        H: Handler<T>,
60    {
61        let handler_fn = crate::handler::into_handler_fn(handler);
62        self.router.route(method, path, handler_fn);
63        self
64    }
65
66    /// Handle GET requests - perfect for serving pages and fetching data
67    pub fn get<H, T>(self, path: &str, handler: H) -> Self
68    where
69        H: Handler<T>,
70    {
71        self.route(Method::GET, path, handler)
72    }
73
74    /// Handle POST requests - for creating new resources
75    pub fn post<H, T>(self, path: &str, handler: H) -> Self
76    where
77        H: Handler<T>,
78    {
79        self.route(Method::POST, path, handler)
80    }
81
82    /// Handle PUT requests - for updating entire resources
83    pub fn put<H, T>(self, path: &str, handler: H) -> Self
84    where
85        H: Handler<T>,
86    {
87        self.route(Method::PUT, path, handler)
88    }
89
90    /// Handle DELETE requests - for removing resources
91    pub fn delete<H, T>(self, path: &str, handler: H) -> Self
92    where
93        H: Handler<T>,
94    {
95        self.route(Method::DELETE, path, handler)
96    }
97
98    /// Handle PATCH requests - for partial updates
99    pub fn patch<H, T>(self, path: &str, handler: H) -> Self
100    where
101        H: Handler<T>,
102    {
103        self.route(Method::PATCH, path, handler)
104    }
105
106    /// Handle OPTIONS requests - usually for CORS preflight
107    pub fn options<H, T>(self, path: &str, handler: H) -> Self
108    where
109        H: Handler<T>,
110    {
111        self.route(Method::OPTIONS, path, handler)
112    }
113
114    /// Handle HEAD requests - like GET but without the body
115    pub fn head<H, T>(self, path: &str, handler: H) -> Self
116    where
117        H: Handler<T>,
118    {
119        self.route(Method::HEAD, path, handler)
120    }
121
122    /// Catch requests that don't match any route
123    pub fn not_found<H, T>(mut self, handler: H) -> Self
124    where
125        H: Handler<T>,
126    {
127        let handler_fn = crate::handler::into_handler_fn(handler);
128        self.router.not_found(handler_fn);
129        self
130    }
131
132    /// Mount another router at a specific path prefix
133    pub fn mount(mut self, prefix: &str, other: Router) -> Self {
134        // Merge routes from the other router with the prefix
135        let prefix = prefix.trim_end_matches('/');
136
137        // Get all routes from the other router and add them with prefix
138        for (method, path, handler) in other.get_all_routes() {
139            let prefixed_path = if prefix.is_empty() {
140                path
141            } else {
142                format!("{}{}", prefix, path)
143            };
144
145            self.router.route(method, &prefixed_path, handler);
146        }
147
148        self
149    }
150
151    /// Configure error pages
152    pub fn error_pages(mut self, error_pages: ErrorPages) -> Self {
153        self.error_pages = error_pages;
154        self
155    }
156
157    /// Set a custom 404 page
158    pub fn custom_404(mut self, html: String) -> Self {
159        self.error_pages = self.error_pages.custom_404(html);
160        self
161    }
162
163    /// Set a custom 500 page
164    pub fn custom_500(mut self, html: String) -> Self {
165        self.error_pages = self.error_pages.custom_500(html);
166        self
167    }
168
169    /// Disable default error page styling (use plain HTML)
170    pub fn plain_error_pages(mut self) -> Self {
171        self.error_pages = self.error_pages.without_default_styling();
172        self
173    }
174
175    /// Add a WebSocket endpoint
176    #[cfg(feature = "websocket")]
177    pub fn websocket<F, Fut>(self, path: &str, handler: F) -> Self
178    where
179        F: Fn(crate::websocket::WebSocketConnection) -> Fut + Send + Sync + 'static,
180        Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
181    {
182        let handler = std::sync::Arc::new(handler);
183        self.get::<_, (Request,)>(path, move |req: Request| {
184            let _handler = handler.clone();
185            async move {
186                // Check if this is a WebSocket upgrade request
187                if crate::websocket::is_websocket_upgrade_request(&req) {
188                    crate::websocket::websocket_upgrade(req).await
189                } else {
190                    Response::bad_request().body("WebSocket upgrade required")
191                }
192            }
193        })
194    }
195
196    /// Add a WebSocket endpoint (no-op when websocket feature is disabled)
197    #[cfg(not(feature = "websocket"))]
198    pub fn websocket<F, Fut>(self, _path: &str, _handler: F) -> Self
199    where
200        F: Fn() -> Fut + Send + Sync + 'static,
201        Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
202    {
203        // WebSocket feature not enabled, return self unchanged
204        self
205    }
206
207    /// Fire up the server and start handling requests
208    pub async fn listen(self, addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
209        let addr: SocketAddr = addr.parse()?;
210        println!("🔥 Torch server starting on http://{}", addr);
211        serve(addr, self).await
212    }
213
214    /// Process incoming requests through middleware and routing
215    pub(crate) async fn handle_request(&self, mut req: Request) -> Response {
216        // Inject application state into the request
217        req.set_state_map(self.state.clone());
218
219        let router = self.router.clone();
220        let error_pages = self.error_pages.clone();
221
222        let response = self.middleware
223            .execute(req, move |req| {
224                let router = router.clone();
225                Box::pin(async move { router.route_request(req).await })
226            })
227            .await;
228
229        // Check if this is an error response that should be rendered with error pages
230        let status_code = response.status_code().as_u16();
231        if status_code >= 400 && self.should_render_error_page(&response) {
232            // Create a simple request for error page rendering
233            let dummy_req = Request::new();
234            error_pages.render_error(status_code, None, &dummy_req)
235        } else {
236            response
237        }
238    }
239
240    /// Check if we should render an error page for this response
241    fn should_render_error_page(&self, response: &Response) -> bool {
242        // Only render error pages for responses that look like default error responses
243        // (i.e., they have simple text bodies, not custom HTML)
244        let content_type = response.headers().get("content-type")
245            .and_then(|v| v.to_str().ok())
246            .unwrap_or("");
247
248        // Don't override responses that are already HTML or have custom content types
249        !content_type.starts_with("text/html") &&
250        !content_type.starts_with("application/json") &&
251        response.body_data().len() < 100 // Simple heuristic for basic error messages
252    }
253}
254
255impl Default for App {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261/// Quick way to create a new app
262pub fn app() -> App {
263    App::new()
264}
265
266/// Pre-configured apps for common scenarios
267impl App {
268    /// Create a new app with logging middleware
269    pub fn with_logging() -> Self {
270        Self::new().middleware(crate::middleware::logger())
271    }
272
273    /// Create a new app with CORS middleware
274    pub fn with_cors() -> Self {
275        Self::new().middleware(crate::middleware::cors())
276    }
277
278    /// Create a production-ready app with security, monitoring, and performance middleware
279    pub fn with_defaults() -> Self {
280        Self::new()
281            // Request logging and monitoring
282            .middleware(crate::middleware::logger())
283            .middleware(crate::production::MetricsCollector::new())
284            .middleware(crate::production::PerformanceMonitor)
285
286            // Security middleware
287            .middleware(crate::security::SecurityHeaders::new())
288            .middleware(crate::security::RequestId)
289            .middleware(crate::security::InputValidator)
290
291            // CORS support
292            .middleware(crate::middleware::cors())
293
294            // Production features
295            .middleware(crate::production::RequestTimeout::new(std::time::Duration::from_secs(30)))
296            .middleware(crate::production::RequestSizeLimit::new(16 * 1024 * 1024)) // 16MB
297            .middleware(crate::production::health_check())
298    }
299
300    /// Create an app with basic security features
301    pub fn with_security() -> Self {
302        Self::new()
303            .middleware(crate::security::SecurityHeaders::new())
304            .middleware(crate::security::RequestId)
305            .middleware(crate::security::InputValidator)
306    }
307
308    /// Create an app with monitoring and metrics
309    pub fn with_monitoring() -> Self {
310        Self::new()
311            .middleware(crate::middleware::logger())
312            .middleware(crate::production::MetricsCollector::new())
313            .middleware(crate::production::PerformanceMonitor)
314            .middleware(crate::production::health_check())
315    }
316}
317
318#[cfg(disabled_for_now)]
319mod tests {
320    use super::*;
321    use std::pin::Pin;
322    use std::future::Future;
323    use crate::Response;
324
325    #[tokio::test]
326    async fn test_app_creation() {
327        let app = App::new()
328            .get("/", |_req: Request| async {
329                Response::ok().body("Hello, World!")
330            })
331            .post("/users", |_req: Request| async {
332                Response::ok().body("User created")
333            });
334
335        // Test GET route
336        let req = Request::from_hyper(
337            http::Request::builder()
338                .method("GET")
339                .uri("/")
340                .body(())
341                .unwrap()
342                .into_parts()
343                .0,
344            Vec::new(),
345        )
346        .await
347        .unwrap();
348
349        let response = app.handle_request(req).await;
350        assert_eq!(response.body_data(), b"Hello, World!");
351    }
352
353    #[tokio::test]
354    async fn test_app_with_middleware() {
355        let app = App::new()
356            .middleware(|req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
357                Box::pin(async move {
358                    let mut response = next(req).await;
359                    response = response.header("X-Test", "middleware");
360                    response
361                })
362            })
363            .get("/", |_req: Request| async {
364                Response::ok().body("Hello")
365            });
366
367        let req = Request::from_hyper(
368            http::Request::builder()
369                .method("GET")
370                .uri("/")
371                .body(())
372                .unwrap()
373                .into_parts()
374                .0,
375            Vec::new(),
376        )
377        .await
378        .unwrap();
379
380        let response = app.handle_request(req).await;
381        assert_eq!(response.headers().get("X-Test").unwrap(), "middleware");
382    }
383
384    #[test]
385    fn test_app_builder_pattern() {
386        let _app = App::new()
387            .get("/", |_req: Request| async { Response::ok() })
388            .post("/users", |_req: Request| async { Response::ok() })
389            .middleware(crate::middleware::logger())
390            .middleware(crate::middleware::cors());
391    }
392
393    #[test]
394    fn test_convenience_constructors() {
395        let _app1 = App::with_logging();
396        let _app2 = App::with_cors();
397        let _app3 = App::with_security();
398        let _app4 = App::with_defaults();
399    }
400}