Skip to main content

rustauth_core/api/
router.rs

1use http::StatusCode;
2use serde_json::Value;
3
4use crate::context::request_state::{
5    run_with_request_state, set_current_request_path, set_request_external,
6};
7use crate::context::AuthContext;
8use crate::error::RustAuthError;
9use crate::plugin::{PluginBeforeHookAction, PluginRequestAction};
10use crate::rate_limit::{consume_rate_limit, on_request_rate_limit};
11use crate::utils::url::normalize_pathname;
12
13use super::endpoint::{
14    run_endpoint_middlewares, validate_async_endpoint_request, ApiRequest, ApiResponse,
15    AsyncAuthEndpoint, AuthEndpoint, EndpointInfo, EndpointKind,
16};
17use super::error::{api_error, rate_limit_response, response, ApiErrorCode};
18use super::on_api_error::handle_on_api_error;
19use super::openapi::build_openapi_schema;
20use super::path::{match_path_pattern, route_pathname, PathParams};
21use super::plugin_pipeline::{
22    endpoint_operation_id, finalize_response, finalize_response_async, plugin_async_endpoints,
23    run_after_hooks, run_async_after_hooks, run_async_before_hooks, run_before_hooks,
24    run_matching_async_middlewares, run_matching_middlewares, run_on_request_plugins,
25    validate_endpoint_conflicts,
26};
27use super::security::validate_request_security;
28
29#[derive(Clone)]
30pub struct AuthRouter {
31    context: AuthContext,
32    endpoints: Vec<AuthEndpoint>,
33    async_endpoints: Vec<AsyncAuthEndpoint>,
34}
35
36impl AuthRouter {
37    pub fn new(context: AuthContext, endpoints: Vec<AuthEndpoint>) -> Self {
38        let async_endpoints = plugin_async_endpoints(&context, Vec::new());
39        Self {
40            context,
41            endpoints,
42            async_endpoints,
43        }
44    }
45
46    pub fn try_new(
47        context: AuthContext,
48        endpoints: Vec<AuthEndpoint>,
49    ) -> Result<Self, RustAuthError> {
50        let async_endpoints = plugin_async_endpoints(&context, Vec::new());
51        validate_endpoint_conflicts(&endpoints, &async_endpoints)?;
52        Ok(Self {
53            context,
54            endpoints,
55            async_endpoints,
56        })
57    }
58
59    pub fn with_async_endpoints(
60        context: AuthContext,
61        endpoints: Vec<AuthEndpoint>,
62        async_endpoints: Vec<AsyncAuthEndpoint>,
63    ) -> Result<Self, RustAuthError> {
64        let async_endpoints = plugin_async_endpoints(&context, async_endpoints);
65        validate_endpoint_conflicts(&endpoints, &async_endpoints)?;
66        Ok(Self {
67            context,
68            endpoints,
69            async_endpoints,
70        })
71    }
72
73    pub fn endpoint_registry(&self) -> Vec<EndpointInfo> {
74        let sync_endpoints = self.endpoints.iter().map(|endpoint| EndpointInfo {
75            path: endpoint.path.clone(),
76            method: endpoint.method.clone(),
77            kind: EndpointKind::Sync,
78            operation_id: None,
79            allowed_media_types: Vec::new(),
80        });
81        let async_endpoints = self
82            .async_endpoints
83            .iter()
84            .filter(|endpoint| !endpoint.options.server_only)
85            .map(|endpoint| EndpointInfo {
86                path: endpoint.path.clone(),
87                method: endpoint.method.clone(),
88                kind: EndpointKind::Async,
89                operation_id: endpoint
90                    .options
91                    .operation_id
92                    .clone()
93                    .or_else(|| endpoint.options.openapi.as_ref()?.operation_id.clone()),
94                allowed_media_types: endpoint.options.allowed_media_types.clone(),
95            });
96        sync_endpoints.chain(async_endpoints).collect()
97    }
98
99    pub fn openapi_schema(&self) -> Value {
100        build_openapi_schema(&self.context, &self.async_endpoints)
101    }
102
103    pub fn handle(&self, request: ApiRequest) -> Result<ApiResponse, RustAuthError> {
104        let request_for_error = request.clone();
105        match self.handle_inner(request) {
106            Ok(response) => Ok(response),
107            Err(error) => handle_on_api_error(&self.context, &request_for_error, error),
108        }
109    }
110
111    fn handle_inner(&self, mut request: ApiRequest) -> Result<ApiResponse, RustAuthError> {
112        let normalized_path =
113            normalize_pathname(&request.uri().to_string(), &self.context.base_path);
114        if self
115            .context
116            .disabled_paths
117            .iter()
118            .any(|item| item == &normalized_path)
119        {
120            return finalize_response(
121                &self.context,
122                &request,
123                api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
124            );
125        }
126        let finalize_request = request.clone();
127        request = match run_on_request_plugins(&self.context, request)? {
128            PluginRequestAction::Continue(request) => request,
129            PluginRequestAction::Respond(response) => {
130                return finalize_response(&self.context, &finalize_request, response);
131            }
132        };
133        if let Some(rejection) = validate_request_security(&self.context, &request, false)? {
134            return finalize_response(&self.context, &request, rejection);
135        }
136        let path = route_pathname(
137            &request.uri().to_string(),
138            &self.context.base_path,
139            self.context.options.advanced.skip_trailing_slashes,
140        );
141        let Some((endpoint, params)) = self.endpoints.iter().find_map(|endpoint| {
142            (endpoint.method == *request.method())
143                .then(|| match_path_pattern(&endpoint.path, &path).map(|params| (endpoint, params)))
144                .flatten()
145        }) else {
146            if self.async_endpoints.iter().any(|endpoint| {
147                endpoint.method == *request.method()
148                    && !endpoint.options.server_only
149                    && match_path_pattern(&endpoint.path, &path).is_some()
150            }) {
151                return Err(RustAuthError::Api(
152                    "async endpoint requires AuthRouter::handle_async".to_owned(),
153                ));
154            }
155            return finalize_response(
156                &self.context,
157                &request,
158                api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
159            );
160        };
161        request.extensions_mut().insert(PathParams::new(params));
162        if let Some(response) = run_matching_middlewares(&self.context, &request, &path)? {
163            return finalize_response(&self.context, &request, response);
164        }
165        if let Some(rejection) = on_request_rate_limit(&self.context, &request)? {
166            return finalize_response(&self.context, &request, rate_limit_response(rejection)?);
167        }
168        let finalize_request = request.clone();
169        request = match run_before_hooks(&self.context, request, &endpoint.method, &path, None)? {
170            PluginBeforeHookAction::Continue(request) => request,
171            PluginBeforeHookAction::Respond(response) => {
172                return finalize_response(&self.context, &finalize_request, response);
173            }
174        };
175        let response = (endpoint.handler)(&self.context, request.clone())?;
176        let response = run_after_hooks(
177            &self.context,
178            &request,
179            response,
180            &endpoint.method,
181            &path,
182            None,
183        )?;
184        finalize_response(&self.context, &request, response)
185    }
186
187    pub async fn handle_async(&self, request: ApiRequest) -> Result<ApiResponse, RustAuthError> {
188        let request_for_error = request.clone();
189        match run_with_request_state(self.handle_async_scoped(request, true)).await {
190            Ok(response) => Ok(response),
191            Err(error) => handle_on_api_error(&self.context, &request_for_error, error),
192        }
193    }
194
195    /// Handle a request from trusted server-side code.
196    ///
197    /// Runs the same pipeline as [`handle_async`](Self::handle_async) but marks
198    /// the request as non-internet-facing, allowing endpoints to honor
199    /// server-only inputs (such as an explicit user id) that must never be
200    /// trusted from internet clients.
201    pub async fn handle_async_server(
202        &self,
203        request: ApiRequest,
204    ) -> Result<ApiResponse, RustAuthError> {
205        let request_for_error = request.clone();
206        match run_with_request_state(self.handle_async_scoped(request, false)).await {
207            Ok(response) => Ok(response),
208            Err(error) => handle_on_api_error(&self.context, &request_for_error, error),
209        }
210    }
211
212    async fn handle_async_scoped(
213        &self,
214        mut request: ApiRequest,
215        external: bool,
216    ) -> Result<ApiResponse, RustAuthError> {
217        set_request_external(external)?;
218        let normalized_path =
219            normalize_pathname(&request.uri().to_string(), &self.context.base_path);
220        if self
221            .context
222            .disabled_paths
223            .iter()
224            .any(|item| item == &normalized_path)
225        {
226            return finalize_response_async(
227                &self.context,
228                &request,
229                api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
230            )
231            .await;
232        }
233        let finalize_request = request.clone();
234        request = match run_on_request_plugins(&self.context, request)? {
235            PluginRequestAction::Continue(request) => request,
236            PluginRequestAction::Respond(response) => {
237                return finalize_response_async(&self.context, &finalize_request, response).await;
238            }
239        };
240        let path = route_pathname(
241            &request.uri().to_string(),
242            &self.context.base_path,
243            self.context.options.advanced.skip_trailing_slashes,
244        );
245        let async_endpoint = self.async_endpoints.iter().find_map(|endpoint| {
246            (endpoint.method == *request.method())
247                .then(|| match_path_pattern(&endpoint.path, &path).map(|params| (endpoint, params)))
248                .flatten()
249        });
250        let sync_endpoint = self.endpoints.iter().find_map(|endpoint| {
251            (endpoint.method == *request.method())
252                .then(|| match_path_pattern(&endpoint.path, &path).map(|params| (endpoint, params)))
253                .flatten()
254        });
255        let bypass_origin_security = async_endpoint.as_ref().is_some_and(|(endpoint, _)| {
256            !endpoint.options.server_only && endpoint.options.bypass_origin_security
257        });
258        if let Some(rejection) =
259            validate_request_security(&self.context, &request, bypass_origin_security)?
260        {
261            return finalize_response_async(&self.context, &request, rejection).await;
262        }
263        if async_endpoint.is_none() && sync_endpoint.is_none() {
264            return finalize_response_async(
265                &self.context,
266                &request,
267                api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
268            )
269            .await;
270        }
271        // Consume the route rate limit before plugin middlewares so that security
272        // middlewares (such as CAPTCHA) returning a rejection cannot bypass route
273        // throttling or force repeated outbound provider calls.
274        if let Some(rejection) = consume_rate_limit(&self.context, &request).await? {
275            return finalize_response_async(
276                &self.context,
277                &request,
278                rate_limit_response(rejection)?,
279            )
280            .await;
281        }
282        if let Some(response) = run_matching_middlewares(&self.context, &request, &path)? {
283            return finalize_response_async(&self.context, &request, response).await;
284        }
285        if let Some(response) =
286            run_matching_async_middlewares(&self.context, &request, &path).await?
287        {
288            return finalize_response_async(&self.context, &request, response).await;
289        }
290        if let Some((endpoint, params)) = async_endpoint {
291            if endpoint.options.server_only && external {
292                return finalize_response_async(
293                    &self.context,
294                    &request,
295                    api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
296                )
297                .await;
298            }
299            set_current_request_path(path.clone())?;
300            request.extensions_mut().insert(PathParams::new(params));
301            if let Some(response) = validate_async_endpoint_request(endpoint, &request)? {
302                return finalize_response_async(&self.context, &request, response).await;
303            }
304            if let Some(response) =
305                run_endpoint_middlewares(&self.context, endpoint, &request).await?
306            {
307                return finalize_response_async(&self.context, &request, response).await;
308            }
309            let finalize_request = request.clone();
310            request = match run_before_hooks(
311                &self.context,
312                request,
313                &endpoint.method,
314                &path,
315                endpoint_operation_id(endpoint),
316            )? {
317                PluginBeforeHookAction::Continue(request) => request,
318                PluginBeforeHookAction::Respond(response) => {
319                    return finalize_response_async(&self.context, &finalize_request, response)
320                        .await;
321                }
322            };
323            let finalize_request = request.clone();
324            request = match run_async_before_hooks(
325                &self.context,
326                request,
327                &endpoint.method,
328                &path,
329                endpoint_operation_id(endpoint),
330            )
331            .await?
332            {
333                PluginBeforeHookAction::Continue(request) => request,
334                PluginBeforeHookAction::Respond(response) => {
335                    return finalize_response_async(&self.context, &finalize_request, response)
336                        .await;
337                }
338            };
339            let response = (endpoint.handler)(&self.context, request.clone()).await?;
340            let response = run_after_hooks(
341                &self.context,
342                &request,
343                response,
344                &endpoint.method,
345                &path,
346                endpoint_operation_id(endpoint),
347            )?;
348            let response = run_async_after_hooks(
349                &self.context,
350                &request,
351                response,
352                &endpoint.method,
353                &path,
354                endpoint_operation_id(endpoint),
355            )
356            .await?;
357            return finalize_response_async(&self.context, &request, response).await;
358        }
359        if let Some((endpoint, params)) = sync_endpoint {
360            set_current_request_path(path.clone())?;
361            request.extensions_mut().insert(PathParams::new(params));
362            let finalize_request = request.clone();
363            request = match run_before_hooks(&self.context, request, &endpoint.method, &path, None)?
364            {
365                PluginBeforeHookAction::Continue(request) => request,
366                PluginBeforeHookAction::Respond(response) => {
367                    return finalize_response_async(&self.context, &finalize_request, response)
368                        .await;
369                }
370            };
371            let finalize_request = request.clone();
372            request =
373                match run_async_before_hooks(&self.context, request, &endpoint.method, &path, None)
374                    .await?
375                {
376                    PluginBeforeHookAction::Continue(request) => request,
377                    PluginBeforeHookAction::Respond(response) => {
378                        return finalize_response_async(&self.context, &finalize_request, response)
379                            .await;
380                    }
381                };
382            let response = (endpoint.handler)(&self.context, request.clone())?;
383            let response = run_after_hooks(
384                &self.context,
385                &request,
386                response,
387                &endpoint.method,
388                &path,
389                None,
390            )?;
391            let response = run_async_after_hooks(
392                &self.context,
393                &request,
394                response,
395                &endpoint.method,
396                &path,
397                None,
398            )
399            .await?;
400            return finalize_response_async(&self.context, &request, response).await;
401        }
402        finalize_response_async(
403            &self.context,
404            &request,
405            api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
406        )
407        .await
408    }
409}
410
411pub fn ok_endpoint(
412    _context: &AuthContext,
413    _request: ApiRequest,
414) -> Result<ApiResponse, RustAuthError> {
415    response(StatusCode::OK, b"OK".to_vec())
416}
417
418pub fn core_endpoints() -> Vec<AuthEndpoint> {
419    vec![AuthEndpoint {
420        path: "/ok".to_owned(),
421        method: http::Method::GET,
422        handler: ok_endpoint,
423    }]
424}