wsi_streamer/server/
routes.rs

1//! Router configuration for WSI Streamer.
2//!
3//! This module defines the HTTP routes and applies middleware for authentication
4//! and CORS.
5//!
6//! # Route Structure
7//!
8//! ```text
9//! /health                                    - Health check (public)
10//! /tiles/{slide_id}/{level}/{x}/{y}.jpg      - Tile endpoint (protected)
11//! /slides                                    - List slides (protected)
12//! ```
13//!
14//! # Example
15//!
16//! ```ignore
17//! use wsi_streamer::server::routes::{create_router, RouterConfig};
18//! use wsi_streamer::tile::TileService;
19//! use wsi_streamer::slide::SlideRegistry;
20//!
21//! // Create the tile service
22//! let registry = SlideRegistry::new(source);
23//! let tile_service = TileService::new(registry);
24//!
25//! // Configure and create router
26//! let config = RouterConfig::new("my-secret-key")
27//!     .with_cors_origins(vec!["https://example.com".to_string()]);
28//!
29//! let router = create_router(tile_service, config);
30//!
31//! // Run the server
32//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
33//! axum::serve(listener, router).await?;
34//! ```
35
36use std::time::Duration;
37
38use axum::{middleware, routing::get, Router};
39use http::header::{AUTHORIZATION, CONTENT_TYPE};
40use http::Method;
41use tower_http::cors::{Any, CorsLayer};
42use tower_http::trace::TraceLayer;
43
44use super::auth::SignedUrlAuth;
45use super::handlers::{
46    dzi_descriptor_handler, health_handler, slide_metadata_handler, slides_handler,
47    thumbnail_handler, tile_handler, viewer_handler, AppState,
48};
49use crate::slide::SlideSource;
50use crate::tile::TileService;
51
52// =============================================================================
53// Router Configuration
54// =============================================================================
55
56/// Configuration for the HTTP router.
57#[derive(Clone)]
58pub struct RouterConfig {
59    /// Secret key for signed URL authentication
60    pub auth_secret: String,
61
62    /// Whether authentication is enabled for tile requests
63    pub auth_enabled: bool,
64
65    /// Allowed CORS origins (None = allow any origin)
66    pub cors_origins: Option<Vec<String>>,
67
68    /// Cache-Control max-age in seconds
69    pub cache_max_age: u32,
70
71    /// Whether to enable request tracing
72    pub enable_tracing: bool,
73}
74
75impl RouterConfig {
76    /// Create a new router configuration with the given auth secret.
77    ///
78    /// By default:
79    /// - Authentication is enabled
80    /// - CORS allows any origin
81    /// - Cache max-age is 1 hour (3600 seconds)
82    /// - Tracing is enabled
83    pub fn new(auth_secret: impl Into<String>) -> Self {
84        Self {
85            auth_secret: auth_secret.into(),
86            auth_enabled: true,
87            cors_origins: None, // Allow any origin by default
88            cache_max_age: 3600,
89            enable_tracing: true,
90        }
91    }
92
93    /// Create a configuration with authentication disabled.
94    ///
95    /// **Warning**: This should only be used for development/testing.
96    pub fn without_auth() -> Self {
97        Self {
98            auth_secret: String::new(),
99            auth_enabled: false,
100            cors_origins: None,
101            cache_max_age: 3600,
102            enable_tracing: true,
103        }
104    }
105
106    /// Set specific allowed CORS origins.
107    ///
108    /// Pass an empty vec to disallow all cross-origin requests.
109    /// Pass None (or don't call this method) to allow any origin.
110    pub fn with_cors_origins(mut self, origins: Vec<String>) -> Self {
111        self.cors_origins = Some(origins);
112        self
113    }
114
115    /// Allow any CORS origin.
116    pub fn with_cors_any_origin(mut self) -> Self {
117        self.cors_origins = None;
118        self
119    }
120
121    /// Set the Cache-Control max-age in seconds.
122    pub fn with_cache_max_age(mut self, seconds: u32) -> Self {
123        self.cache_max_age = seconds;
124        self
125    }
126
127    /// Enable or disable authentication.
128    pub fn with_auth_enabled(mut self, enabled: bool) -> Self {
129        self.auth_enabled = enabled;
130        self
131    }
132
133    /// Enable or disable request tracing.
134    pub fn with_tracing(mut self, enabled: bool) -> Self {
135        self.enable_tracing = enabled;
136        self
137    }
138}
139
140// =============================================================================
141// Router Builder
142// =============================================================================
143
144/// Create the main application router.
145///
146/// This function builds the complete Axum router with:
147/// - Public routes (health check)
148/// - Protected routes (tile API with optional auth)
149/// - CORS configuration
150/// - Request tracing (optional)
151///
152/// # Arguments
153///
154/// * `tile_service` - The tile service for handling tile requests
155/// * `config` - Router configuration
156///
157/// # Returns
158///
159/// A configured Axum router ready to be served.
160pub fn create_router<S>(tile_service: TileService<S>, config: RouterConfig) -> Router
161where
162    S: SlideSource + 'static,
163{
164    // Create application state with auth info for viewer token generation
165    let app_state = if config.auth_enabled {
166        let auth = SignedUrlAuth::new(&config.auth_secret);
167        AppState::with_cache_max_age(tile_service, config.cache_max_age).with_auth(auth.clone())
168    } else {
169        AppState::with_cache_max_age(tile_service, config.cache_max_age)
170    };
171
172    // Create the auth layer if enabled
173    let auth = SignedUrlAuth::new(&config.auth_secret);
174
175    // Build CORS layer
176    let cors = build_cors_layer(&config);
177
178    // Build the router
179    let router = if config.auth_enabled {
180        build_protected_router(app_state, auth, cors)
181    } else {
182        build_public_router(app_state, cors)
183    };
184
185    // Add tracing if enabled
186    if config.enable_tracing {
187        router.layer(TraceLayer::new_for_http())
188    } else {
189        router
190    }
191}
192
193/// Build router with authentication on tile and slides routes.
194fn build_protected_router<S>(app_state: AppState<S>, auth: SignedUrlAuth, cors: CorsLayer) -> Router
195where
196    S: SlideSource + 'static,
197{
198    // Protected tile routes (require authentication)
199    // Uses {filename} to capture both "{y}" and "{y}.jpg" formats
200    // Auth middleware is applied to the nested router AFTER nesting so it sees the full /tiles/... path
201    let tile_routes = Router::new()
202        .route("/{slide_id}/{level}/{x}/{filename}", get(tile_handler::<S>))
203        .with_state(app_state.clone());
204
205    // Protected slides routes (require authentication)
206    let slides_routes = Router::new()
207        .route("/", get(slides_handler::<S>))
208        .route("/{slide_id}", get(slide_metadata_handler::<S>))
209        .route("/{slide_id}/dzi", get(dzi_descriptor_handler::<S>))
210        .route("/{slide_id}/thumbnail", get(thumbnail_handler::<S>))
211        .with_state(app_state.clone());
212
213    // Create nested routes with auth applied AFTER nesting
214    let protected_routes = Router::new()
215        .nest("/tiles", tile_routes)
216        .nest("/slides", slides_routes)
217        .layer(middleware::from_fn_with_state(
218            auth,
219            super::auth::auth_middleware,
220        ));
221
222    // Public routes (no auth required)
223    // The viewer is public because it's just HTML - tile requests are still protected
224    let public_routes = Router::new()
225        .route("/health", get(health_handler))
226        .route("/view/{slide_id}", get(viewer_handler::<S>))
227        .with_state(app_state);
228
229    // Combine routes
230    Router::new()
231        .merge(protected_routes)
232        .merge(public_routes)
233        .layer(cors)
234}
235
236/// Build router without authentication (for development/testing).
237fn build_public_router<S>(app_state: AppState<S>, cors: CorsLayer) -> Router
238where
239    S: SlideSource + 'static,
240{
241    // All routes are public
242    // Uses {filename} to capture both "{y}" and "{y}.jpg" formats
243    Router::new()
244        .route("/health", get(health_handler))
245        .route(
246            "/tiles/{slide_id}/{level}/{x}/{filename}",
247            get(tile_handler::<S>),
248        )
249        .route("/slides", get(slides_handler::<S>))
250        .route("/slides/{slide_id}", get(slide_metadata_handler::<S>))
251        .route("/slides/{slide_id}/dzi", get(dzi_descriptor_handler::<S>))
252        .route("/slides/{slide_id}/thumbnail", get(thumbnail_handler::<S>))
253        .route("/view/{slide_id}", get(viewer_handler::<S>))
254        .with_state(app_state)
255        .layer(cors)
256}
257
258/// Build the CORS layer based on configuration.
259fn build_cors_layer(config: &RouterConfig) -> CorsLayer {
260    let cors = CorsLayer::new()
261        .allow_methods([Method::GET, Method::HEAD, Method::OPTIONS])
262        .allow_headers([AUTHORIZATION, CONTENT_TYPE])
263        .max_age(Duration::from_secs(86400)); // 24 hours
264
265    match &config.cors_origins {
266        None => cors.allow_origin(Any),
267        Some(origins) if origins.is_empty() => {
268            // No origins allowed - this effectively disables CORS
269            cors
270        }
271        Some(origins) => {
272            // Parse origins into HeaderValues
273            let parsed_origins: Vec<_> = origins.iter().filter_map(|o| o.parse().ok()).collect();
274            cors.allow_origin(parsed_origins)
275        }
276    }
277}
278
279// =============================================================================
280// Convenience Functions
281// =============================================================================
282
283/// Create a development router with authentication disabled.
284///
285/// **Warning**: This should only be used for local development and testing.
286/// Never use this in production.
287pub fn create_dev_router<S>(tile_service: TileService<S>) -> Router
288where
289    S: SlideSource + 'static,
290{
291    create_router(tile_service, RouterConfig::without_auth())
292}
293
294/// Create a production router with the given secret key.
295///
296/// Uses secure defaults:
297/// - Authentication enabled
298/// - 1 hour cache max-age
299/// - Tracing enabled
300/// - CORS allows any origin (configure as needed)
301pub fn create_production_router<S>(
302    tile_service: TileService<S>,
303    auth_secret: impl Into<String>,
304) -> Router
305where
306    S: SlideSource + 'static,
307{
308    create_router(tile_service, RouterConfig::new(auth_secret))
309}
310
311// =============================================================================
312// Tests
313// =============================================================================
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_router_config_defaults() {
321        let config = RouterConfig::new("secret");
322        assert_eq!(config.auth_secret, "secret");
323        assert!(config.auth_enabled);
324        assert!(config.cors_origins.is_none());
325        assert_eq!(config.cache_max_age, 3600);
326        assert!(config.enable_tracing);
327    }
328
329    #[test]
330    fn test_router_config_without_auth() {
331        let config = RouterConfig::without_auth();
332        assert!(!config.auth_enabled);
333        assert!(config.auth_secret.is_empty());
334    }
335
336    #[test]
337    fn test_router_config_builder() {
338        let config = RouterConfig::new("secret")
339            .with_cors_origins(vec!["https://example.com".to_string()])
340            .with_cache_max_age(7200)
341            .with_auth_enabled(false)
342            .with_tracing(false);
343
344        assert_eq!(config.auth_secret, "secret");
345        assert!(!config.auth_enabled);
346        assert_eq!(
347            config.cors_origins,
348            Some(vec!["https://example.com".to_string()])
349        );
350        assert_eq!(config.cache_max_age, 7200);
351        assert!(!config.enable_tracing);
352    }
353
354    #[test]
355    fn test_router_config_cors_any() {
356        let config = RouterConfig::new("secret")
357            .with_cors_origins(vec!["https://example.com".to_string()])
358            .with_cors_any_origin();
359
360        assert!(config.cors_origins.is_none());
361    }
362
363    #[test]
364    fn test_build_cors_layer_any_origin() {
365        let config = RouterConfig::new("secret");
366        let _cors = build_cors_layer(&config);
367        // Just verify it doesn't panic
368    }
369
370    #[test]
371    fn test_build_cors_layer_specific_origins() {
372        let config = RouterConfig::new("secret").with_cors_origins(vec![
373            "https://example.com".to_string(),
374            "https://other.com".to_string(),
375        ]);
376        let _cors = build_cors_layer(&config);
377        // Just verify it doesn't panic
378    }
379
380    #[test]
381    fn test_build_cors_layer_empty_origins() {
382        let config = RouterConfig::new("secret").with_cors_origins(vec![]);
383        let _cors = build_cors_layer(&config);
384        // Just verify it doesn't panic
385    }
386}