1use http::StatusCode;
2use serde_json::Value;
3
4use crate::context::request_state::{
5 run_with_request_state, set_current_request_path, set_request_external,
6};
7use crate::context::AuthContext;
8use crate::error::RustAuthError;
9use crate::plugin::{PluginBeforeHookAction, PluginRequestAction};
10use crate::rate_limit::{consume_rate_limit, on_request_rate_limit};
11use crate::utils::url::normalize_pathname;
12
13use super::endpoint::{
14 run_endpoint_middlewares, validate_async_endpoint_request, ApiRequest, ApiResponse,
15 AsyncAuthEndpoint, AuthEndpoint, EndpointInfo, EndpointKind,
16};
17use super::error::{api_error, rate_limit_response, response, ApiErrorCode};
18use super::on_api_error::handle_on_api_error;
19use super::openapi::build_openapi_schema;
20use super::path::{match_path_pattern, route_pathname, PathParams};
21use super::plugin_pipeline::{
22 endpoint_operation_id, finalize_response, finalize_response_async, plugin_async_endpoints,
23 run_after_hooks, run_async_after_hooks, run_async_before_hooks, run_before_hooks,
24 run_matching_async_middlewares, run_matching_middlewares, run_on_request_plugins,
25 validate_endpoint_conflicts,
26};
27use super::security::validate_request_security;
28
29#[derive(Clone)]
30pub struct AuthRouter {
31 context: AuthContext,
32 endpoints: Vec<AuthEndpoint>,
33 async_endpoints: Vec<AsyncAuthEndpoint>,
34}
35
36impl AuthRouter {
37 pub fn new(context: AuthContext, endpoints: Vec<AuthEndpoint>) -> Self {
38 let async_endpoints = plugin_async_endpoints(&context, Vec::new());
39 Self {
40 context,
41 endpoints,
42 async_endpoints,
43 }
44 }
45
46 pub fn try_new(
47 context: AuthContext,
48 endpoints: Vec<AuthEndpoint>,
49 ) -> Result<Self, RustAuthError> {
50 let async_endpoints = plugin_async_endpoints(&context, Vec::new());
51 validate_endpoint_conflicts(&endpoints, &async_endpoints)?;
52 Ok(Self {
53 context,
54 endpoints,
55 async_endpoints,
56 })
57 }
58
59 pub fn with_async_endpoints(
60 context: AuthContext,
61 endpoints: Vec<AuthEndpoint>,
62 async_endpoints: Vec<AsyncAuthEndpoint>,
63 ) -> Result<Self, RustAuthError> {
64 let async_endpoints = plugin_async_endpoints(&context, async_endpoints);
65 validate_endpoint_conflicts(&endpoints, &async_endpoints)?;
66 Ok(Self {
67 context,
68 endpoints,
69 async_endpoints,
70 })
71 }
72
73 pub fn endpoint_registry(&self) -> Vec<EndpointInfo> {
74 let sync_endpoints = self.endpoints.iter().map(|endpoint| EndpointInfo {
75 path: endpoint.path.clone(),
76 method: endpoint.method.clone(),
77 kind: EndpointKind::Sync,
78 operation_id: None,
79 allowed_media_types: Vec::new(),
80 });
81 let async_endpoints = self
82 .async_endpoints
83 .iter()
84 .filter(|endpoint| !endpoint.options.server_only)
85 .map(|endpoint| EndpointInfo {
86 path: endpoint.path.clone(),
87 method: endpoint.method.clone(),
88 kind: EndpointKind::Async,
89 operation_id: endpoint
90 .options
91 .operation_id
92 .clone()
93 .or_else(|| endpoint.options.openapi.as_ref()?.operation_id.clone()),
94 allowed_media_types: endpoint.options.allowed_media_types.clone(),
95 });
96 sync_endpoints.chain(async_endpoints).collect()
97 }
98
99 pub fn openapi_schema(&self) -> Value {
100 build_openapi_schema(&self.context, &self.async_endpoints)
101 }
102
103 pub fn handle(&self, request: ApiRequest) -> Result<ApiResponse, RustAuthError> {
104 let request_for_error = request.clone();
105 match self.handle_inner(request) {
106 Ok(response) => Ok(response),
107 Err(error) => handle_on_api_error(&self.context, &request_for_error, error),
108 }
109 }
110
111 fn handle_inner(&self, mut request: ApiRequest) -> Result<ApiResponse, RustAuthError> {
112 let normalized_path =
113 normalize_pathname(&request.uri().to_string(), &self.context.base_path);
114 if self
115 .context
116 .disabled_paths
117 .iter()
118 .any(|item| item == &normalized_path)
119 {
120 return finalize_response(
121 &self.context,
122 &request,
123 api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
124 );
125 }
126 let finalize_request = request.clone();
127 request = match run_on_request_plugins(&self.context, request)? {
128 PluginRequestAction::Continue(request) => request,
129 PluginRequestAction::Respond(response) => {
130 return finalize_response(&self.context, &finalize_request, response);
131 }
132 };
133 if let Some(rejection) = validate_request_security(&self.context, &request, false)? {
134 return finalize_response(&self.context, &request, rejection);
135 }
136 let path = route_pathname(
137 &request.uri().to_string(),
138 &self.context.base_path,
139 self.context.options.advanced.skip_trailing_slashes,
140 );
141 let Some((endpoint, params)) = self.endpoints.iter().find_map(|endpoint| {
142 (endpoint.method == *request.method())
143 .then(|| match_path_pattern(&endpoint.path, &path).map(|params| (endpoint, params)))
144 .flatten()
145 }) else {
146 if self.async_endpoints.iter().any(|endpoint| {
147 endpoint.method == *request.method()
148 && !endpoint.options.server_only
149 && match_path_pattern(&endpoint.path, &path).is_some()
150 }) {
151 return Err(RustAuthError::Api(
152 "async endpoint requires AuthRouter::handle_async".to_owned(),
153 ));
154 }
155 return finalize_response(
156 &self.context,
157 &request,
158 api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
159 );
160 };
161 request.extensions_mut().insert(PathParams::new(params));
162 if let Some(response) = run_matching_middlewares(&self.context, &request, &path)? {
163 return finalize_response(&self.context, &request, response);
164 }
165 if let Some(rejection) = on_request_rate_limit(&self.context, &request)? {
166 return finalize_response(&self.context, &request, rate_limit_response(rejection)?);
167 }
168 let finalize_request = request.clone();
169 request = match run_before_hooks(&self.context, request, &endpoint.method, &path, None)? {
170 PluginBeforeHookAction::Continue(request) => request,
171 PluginBeforeHookAction::Respond(response) => {
172 return finalize_response(&self.context, &finalize_request, response);
173 }
174 };
175 let response = (endpoint.handler)(&self.context, request.clone())?;
176 let response = run_after_hooks(
177 &self.context,
178 &request,
179 response,
180 &endpoint.method,
181 &path,
182 None,
183 )?;
184 finalize_response(&self.context, &request, response)
185 }
186
187 pub async fn handle_async(&self, request: ApiRequest) -> Result<ApiResponse, RustAuthError> {
188 let request_for_error = request.clone();
189 match run_with_request_state(self.handle_async_scoped(request, true)).await {
190 Ok(response) => Ok(response),
191 Err(error) => handle_on_api_error(&self.context, &request_for_error, error),
192 }
193 }
194
195 pub async fn handle_async_server(
202 &self,
203 request: ApiRequest,
204 ) -> Result<ApiResponse, RustAuthError> {
205 let request_for_error = request.clone();
206 match run_with_request_state(self.handle_async_scoped(request, false)).await {
207 Ok(response) => Ok(response),
208 Err(error) => handle_on_api_error(&self.context, &request_for_error, error),
209 }
210 }
211
212 async fn handle_async_scoped(
213 &self,
214 mut request: ApiRequest,
215 external: bool,
216 ) -> Result<ApiResponse, RustAuthError> {
217 set_request_external(external)?;
218 let normalized_path =
219 normalize_pathname(&request.uri().to_string(), &self.context.base_path);
220 if self
221 .context
222 .disabled_paths
223 .iter()
224 .any(|item| item == &normalized_path)
225 {
226 return finalize_response_async(
227 &self.context,
228 &request,
229 api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
230 )
231 .await;
232 }
233 let finalize_request = request.clone();
234 request = match run_on_request_plugins(&self.context, request)? {
235 PluginRequestAction::Continue(request) => request,
236 PluginRequestAction::Respond(response) => {
237 return finalize_response_async(&self.context, &finalize_request, response).await;
238 }
239 };
240 let path = route_pathname(
241 &request.uri().to_string(),
242 &self.context.base_path,
243 self.context.options.advanced.skip_trailing_slashes,
244 );
245 let async_endpoint = self.async_endpoints.iter().find_map(|endpoint| {
246 (endpoint.method == *request.method())
247 .then(|| match_path_pattern(&endpoint.path, &path).map(|params| (endpoint, params)))
248 .flatten()
249 });
250 let sync_endpoint = self.endpoints.iter().find_map(|endpoint| {
251 (endpoint.method == *request.method())
252 .then(|| match_path_pattern(&endpoint.path, &path).map(|params| (endpoint, params)))
253 .flatten()
254 });
255 let bypass_origin_security = async_endpoint.as_ref().is_some_and(|(endpoint, _)| {
256 !endpoint.options.server_only && endpoint.options.bypass_origin_security
257 });
258 if let Some(rejection) =
259 validate_request_security(&self.context, &request, bypass_origin_security)?
260 {
261 return finalize_response_async(&self.context, &request, rejection).await;
262 }
263 if async_endpoint.is_none() && sync_endpoint.is_none() {
264 return finalize_response_async(
265 &self.context,
266 &request,
267 api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
268 )
269 .await;
270 }
271 if let Some(rejection) = consume_rate_limit(&self.context, &request).await? {
275 return finalize_response_async(
276 &self.context,
277 &request,
278 rate_limit_response(rejection)?,
279 )
280 .await;
281 }
282 if let Some(response) = run_matching_middlewares(&self.context, &request, &path)? {
283 return finalize_response_async(&self.context, &request, response).await;
284 }
285 if let Some(response) =
286 run_matching_async_middlewares(&self.context, &request, &path).await?
287 {
288 return finalize_response_async(&self.context, &request, response).await;
289 }
290 if let Some((endpoint, params)) = async_endpoint {
291 if endpoint.options.server_only && external {
292 return finalize_response_async(
293 &self.context,
294 &request,
295 api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
296 )
297 .await;
298 }
299 set_current_request_path(path.clone())?;
300 request.extensions_mut().insert(PathParams::new(params));
301 if let Some(response) = validate_async_endpoint_request(endpoint, &request)? {
302 return finalize_response_async(&self.context, &request, response).await;
303 }
304 if let Some(response) =
305 run_endpoint_middlewares(&self.context, endpoint, &request).await?
306 {
307 return finalize_response_async(&self.context, &request, response).await;
308 }
309 let finalize_request = request.clone();
310 request = match run_before_hooks(
311 &self.context,
312 request,
313 &endpoint.method,
314 &path,
315 endpoint_operation_id(endpoint),
316 )? {
317 PluginBeforeHookAction::Continue(request) => request,
318 PluginBeforeHookAction::Respond(response) => {
319 return finalize_response_async(&self.context, &finalize_request, response)
320 .await;
321 }
322 };
323 let finalize_request = request.clone();
324 request = match run_async_before_hooks(
325 &self.context,
326 request,
327 &endpoint.method,
328 &path,
329 endpoint_operation_id(endpoint),
330 )
331 .await?
332 {
333 PluginBeforeHookAction::Continue(request) => request,
334 PluginBeforeHookAction::Respond(response) => {
335 return finalize_response_async(&self.context, &finalize_request, response)
336 .await;
337 }
338 };
339 let response = (endpoint.handler)(&self.context, request.clone()).await?;
340 let response = run_after_hooks(
341 &self.context,
342 &request,
343 response,
344 &endpoint.method,
345 &path,
346 endpoint_operation_id(endpoint),
347 )?;
348 let response = run_async_after_hooks(
349 &self.context,
350 &request,
351 response,
352 &endpoint.method,
353 &path,
354 endpoint_operation_id(endpoint),
355 )
356 .await?;
357 return finalize_response_async(&self.context, &request, response).await;
358 }
359 if let Some((endpoint, params)) = sync_endpoint {
360 set_current_request_path(path.clone())?;
361 request.extensions_mut().insert(PathParams::new(params));
362 let finalize_request = request.clone();
363 request = match run_before_hooks(&self.context, request, &endpoint.method, &path, None)?
364 {
365 PluginBeforeHookAction::Continue(request) => request,
366 PluginBeforeHookAction::Respond(response) => {
367 return finalize_response_async(&self.context, &finalize_request, response)
368 .await;
369 }
370 };
371 let finalize_request = request.clone();
372 request =
373 match run_async_before_hooks(&self.context, request, &endpoint.method, &path, None)
374 .await?
375 {
376 PluginBeforeHookAction::Continue(request) => request,
377 PluginBeforeHookAction::Respond(response) => {
378 return finalize_response_async(&self.context, &finalize_request, response)
379 .await;
380 }
381 };
382 let response = (endpoint.handler)(&self.context, request.clone())?;
383 let response = run_after_hooks(
384 &self.context,
385 &request,
386 response,
387 &endpoint.method,
388 &path,
389 None,
390 )?;
391 let response = run_async_after_hooks(
392 &self.context,
393 &request,
394 response,
395 &endpoint.method,
396 &path,
397 None,
398 )
399 .await?;
400 return finalize_response_async(&self.context, &request, response).await;
401 }
402 finalize_response_async(
403 &self.context,
404 &request,
405 api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound)?,
406 )
407 .await
408 }
409}
410
411pub fn ok_endpoint(
412 _context: &AuthContext,
413 _request: ApiRequest,
414) -> Result<ApiResponse, RustAuthError> {
415 response(StatusCode::OK, b"OK".to_vec())
416}
417
418pub fn core_endpoints() -> Vec<AuthEndpoint> {
419 vec![AuthEndpoint {
420 path: "/ok".to_owned(),
421 method: http::Method::GET,
422 handler: ok_endpoint,
423 }]
424}