Skip to main content

rustapi_core/router/
core.rs

1use super::conflict::{RouteConflictError, RouteInfo};
2use super::match_::{
3    convert_path_params, normalize_path_for_comparison, normalize_prefix, RouteMatch,
4};
5use super::method_router::MethodRouter;
6use crate::path_params::PathParams;
7use crate::typed_path::TypedPath;
8use http::{Extensions, Method};
9use matchit::Router as MatchitRouter;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// Main router
14#[derive(Clone)]
15pub struct Router {
16    inner: MatchitRouter<MethodRouter>,
17    pub(super) state: Arc<Extensions>,
18    /// Track registered routes for conflict detection
19    registered_routes: HashMap<String, RouteInfo>,
20    /// Store MethodRouters for nesting support (keyed by matchit path)
21    method_routers: HashMap<String, MethodRouter>,
22    /// Track state type IDs for merging (type name -> whether it's set)
23    /// This is a workaround since Extensions doesn't support iteration
24    state_type_ids: Vec<std::any::TypeId>,
25}
26
27impl Router {
28    /// Create a new router
29    pub fn new() -> Self {
30        Self {
31            inner: MatchitRouter::new(),
32            state: Arc::new(Extensions::new()),
33            registered_routes: HashMap::new(),
34            method_routers: HashMap::new(),
35            state_type_ids: Vec::new(),
36        }
37    }
38
39    /// Add a typed route using a TypedPath
40    pub fn typed<P: TypedPath>(self, method_router: MethodRouter) -> Self {
41        self.route(P::PATH, method_router)
42    }
43
44    /// Add a route
45    pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self {
46        // Convert {param} style to :param for matchit
47        let matchit_path = convert_path_params(path);
48
49        // Get the methods being registered
50        let methods: Vec<Method> = method_router.handlers.keys().cloned().collect();
51
52        // Store a clone of the MethodRouter for nesting support
53        self.method_routers
54            .insert(matchit_path.clone(), method_router.clone());
55
56        match self.inner.insert(matchit_path.clone(), method_router) {
57            Ok(_) => {
58                // Track the registered route
59                self.registered_routes.insert(
60                    matchit_path.clone(),
61                    RouteInfo {
62                        path: path.to_string(),
63                        methods,
64                    },
65                );
66            }
67            Err(e) => {
68                // Remove the method_router we just added since registration failed
69                self.method_routers.remove(&matchit_path);
70
71                // Find the existing conflicting route
72                let existing_path = self
73                    .find_conflicting_route(&matchit_path)
74                    .map(|info| info.path.clone())
75                    .unwrap_or_else(|| "<unknown>".to_string());
76
77                let conflict_error = RouteConflictError {
78                    new_path: path.to_string(),
79                    method: methods.first().cloned(),
80                    existing_path,
81                    details: e.to_string(),
82                };
83
84                panic!("{}", conflict_error);
85            }
86        }
87        self
88    }
89
90    /// Find a conflicting route by checking registered routes
91    fn find_conflicting_route(&self, matchit_path: &str) -> Option<&RouteInfo> {
92        // Try to find an exact match first
93        if let Some(info) = self.registered_routes.get(matchit_path) {
94            return Some(info);
95        }
96
97        // Try to find a route that would conflict (same structure but different param names)
98        let normalized_new = normalize_path_for_comparison(matchit_path);
99
100        for (registered_path, info) in &self.registered_routes {
101            let normalized_existing = normalize_path_for_comparison(registered_path);
102            if normalized_new == normalized_existing {
103                return Some(info);
104            }
105        }
106
107        None
108    }
109
110    /// Add application state
111    pub fn state<S: Clone + Send + Sync + 'static>(mut self, state: S) -> Self {
112        let type_id = std::any::TypeId::of::<S>();
113        let extensions = Arc::make_mut(&mut self.state);
114        extensions.insert(state);
115        if !self.state_type_ids.contains(&type_id) {
116            self.state_type_ids.push(type_id);
117        }
118        self
119    }
120
121    /// Check if state of a given type exists
122    pub fn has_state<S: 'static>(&self) -> bool {
123        self.state_type_ids.contains(&std::any::TypeId::of::<S>())
124    }
125
126    /// Get state type IDs (for testing and debugging)
127    pub fn state_type_ids(&self) -> &[std::any::TypeId] {
128        &self.state_type_ids
129    }
130
131    /// Nest another router under a prefix
132    ///
133    /// All routes from the nested router will be registered with the prefix
134    /// prepended to their paths. State from the nested router is merged into
135    /// the parent router (parent state takes precedence for type conflicts).
136    ///
137    /// # State Merging
138    ///
139    /// When nesting routers with state:
140    /// - If the parent router has state of type T, it is preserved (parent wins)
141    /// - If only the nested router has state of type T, it is added to the parent
142    /// - State type tracking is merged to enable proper conflict detection
143    ///
144    /// Note: Due to limitations of `http::Extensions`, automatic state merging
145    /// requires using the `merge_state` method for specific types.
146    ///
147    /// # Example
148    ///
149    /// ```rust,ignore
150    /// use rustapi_core::{Router, get};
151    ///
152    /// async fn list_users() -> &'static str { "List users" }
153    /// async fn get_user() -> &'static str { "Get user" }
154    ///
155    /// let users_router = Router::new()
156    ///     .route("/", get(list_users))
157    ///     .route("/{id}", get(get_user));
158    ///
159    /// let app = Router::new()
160    ///     .nest("/api/users", users_router);
161    ///
162    /// // Routes are now:
163    /// // GET /api/users/
164    /// // GET /api/users/{id}
165    /// ```
166    ///
167    /// # Nesting with State
168    ///
169    /// The `nest` method automatically tracks state types from the nested router to prevent
170    /// conflicts, but it does NOT automatically merge the state values instance by instance.
171    /// You should distinctively add state to the parent, or use `merge_state` if you want
172    /// to pull a specific state object from the child.
173    ///
174    /// ```rust,ignore
175    /// use rustapi_core::Router;
176    /// use std::sync::Arc;
177    ///
178    /// #[derive(Clone)]
179    /// struct Database { /* ... */ }
180    ///
181    /// let db = Database { /* ... */ };
182    ///
183    /// // Option 1: Add state to the parent (Recommended)
184    /// let api = Router::new()
185    ///     .nest("/v1", Router::new()
186    ///         .route("/users", get(list_users))) // Needs Database
187    ///     .state(db);
188    ///
189    /// // Option 2: Define specific state in sub-router and merge explicitly
190    /// let sub_router = Router::new()
191    ///     .state(Database { /* ... */ })
192    ///     .route("/items", get(list_items));
193    ///
194    /// let app = Router::new()
195    ///     .merge_state::<Database>(&sub_router) // Pulls Database from sub_router
196    ///     .nest("/api", sub_router);
197    /// ```
198    pub fn nest(mut self, prefix: &str, router: Router) -> Self {
199        // 1. Normalize the prefix
200        let normalized_prefix = normalize_prefix(prefix);
201
202        // 2. Merge state type IDs from nested router
203        // Parent state takes precedence - we only track types, actual values
204        // are handled by merge_state calls or by the user adding state to parent
205        for type_id in &router.state_type_ids {
206            if !self.state_type_ids.contains(type_id) {
207                self.state_type_ids.push(*type_id);
208            }
209        }
210
211        // 3. Collect routes from the nested router before consuming it
212        // We need to iterate over registered_routes and get the corresponding MethodRouters
213        let nested_routes: Vec<(String, RouteInfo, MethodRouter)> = router
214            .registered_routes
215            .into_iter()
216            .filter_map(|(matchit_path, route_info)| {
217                router
218                    .method_routers
219                    .get(&matchit_path)
220                    .map(|mr| (matchit_path, route_info, mr.clone()))
221            })
222            .collect();
223
224        // 4. Register each nested route with the prefix
225        for (matchit_path, route_info, method_router) in nested_routes {
226            // Build the prefixed path
227            // The matchit_path already has the :param format
228            // The route_info.path has the {param} format
229            let prefixed_matchit_path = if matchit_path == "/" {
230                normalized_prefix.clone()
231            } else {
232                format!("{}{}", normalized_prefix, matchit_path)
233            };
234
235            let prefixed_display_path = if route_info.path == "/" {
236                normalized_prefix.clone()
237            } else {
238                format!("{}{}", normalized_prefix, route_info.path)
239            };
240
241            // Store the MethodRouter for future nesting
242            self.method_routers
243                .insert(prefixed_matchit_path.clone(), method_router.clone());
244
245            // Try to insert into the matchit router
246            match self
247                .inner
248                .insert(prefixed_matchit_path.clone(), method_router)
249            {
250                Ok(_) => {
251                    // Track the registered route
252                    self.registered_routes.insert(
253                        prefixed_matchit_path,
254                        RouteInfo {
255                            path: prefixed_display_path,
256                            methods: route_info.methods,
257                        },
258                    );
259                }
260                Err(e) => {
261                    // Remove the method_router we just added since registration failed
262                    self.method_routers.remove(&prefixed_matchit_path);
263
264                    // Find the existing conflicting route
265                    let existing_path = self
266                        .find_conflicting_route(&prefixed_matchit_path)
267                        .map(|info| info.path.clone())
268                        .unwrap_or_else(|| "<unknown>".to_string());
269
270                    let conflict_error = RouteConflictError {
271                        new_path: prefixed_display_path,
272                        method: route_info.methods.first().cloned(),
273                        existing_path,
274                        details: e.to_string(),
275                    };
276
277                    panic!("{}", conflict_error);
278                }
279            }
280        }
281
282        self
283    }
284
285    /// Merge state from another router into this one
286    ///
287    /// This method allows explicit state merging when nesting routers.
288    /// Parent state takes precedence - if the parent already has state of type S,
289    /// the nested state is ignored.
290    ///
291    /// # Example
292    ///
293    /// ```rust,ignore
294    /// #[derive(Clone)]
295    /// struct DbPool(String);
296    ///
297    /// let nested = Router::new().state(DbPool("nested".to_string()));
298    /// let parent = Router::new()
299    ///     .merge_state::<DbPool>(&nested); // Adds DbPool from nested
300    /// ```
301    pub fn merge_state<S: Clone + Send + Sync + 'static>(mut self, other: &Router) -> Self {
302        let type_id = std::any::TypeId::of::<S>();
303
304        // Parent wins - only merge if parent doesn't have this state type
305        if !self.state_type_ids.contains(&type_id) {
306            // Try to get the state from the other router
307            if let Some(state) = other.state.get::<S>() {
308                let extensions = Arc::make_mut(&mut self.state);
309                extensions.insert(state.clone());
310                self.state_type_ids.push(type_id);
311            }
312        }
313
314        self
315    }
316
317    /// Match a request and return the handler + params
318    pub fn match_route(&self, path: &str, method: &Method) -> RouteMatch<'_> {
319        match self.inner.at(path) {
320            Ok(matched) => {
321                let method_router = matched.value;
322
323                if let Some(handler) = method_router.get_handler(method) {
324                    // Use stack-optimized PathParams (avoids heap allocation for ≤4 params)
325                    let params: PathParams = matched
326                        .params
327                        .iter()
328                        .map(|(k, v)| (k.to_string(), v.to_string()))
329                        .collect();
330
331                    RouteMatch::Found { handler, params }
332                } else {
333                    RouteMatch::MethodNotAllowed {
334                        allowed: method_router.allowed_methods(),
335                    }
336                }
337            }
338            Err(_) => RouteMatch::NotFound,
339        }
340    }
341
342    /// Get shared state
343    pub fn state_ref(&self) -> Arc<Extensions> {
344        self.state.clone()
345    }
346
347    /// Get registered routes (for testing and debugging)
348    pub fn registered_routes(&self) -> &HashMap<String, RouteInfo> {
349        &self.registered_routes
350    }
351
352    /// Get method routers (for OpenAPI integration during nesting)
353    pub fn method_routers(&self) -> &HashMap<String, MethodRouter> {
354        &self.method_routers
355    }
356}
357
358impl Default for Router {
359    fn default() -> Self {
360        Self::new()
361    }
362}