1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use http::{header, Request, Response, StatusCode};
6use serde_json::Value;
7
8use crate::context::AuthContext;
9use crate::error::RustAuthError;
10
11use super::body::parse_request_body;
12use super::error::ApiErrorResponse;
13use super::openapi::OpenApiOperation;
14use super::schema::BodySchema;
15
16pub type Body = Vec<u8>;
17pub type ApiRequest = Request<Body>;
18pub type ApiResponse = Response<Body>;
19pub type EndpointHandler = fn(&AuthContext, ApiRequest) -> Result<ApiResponse, RustAuthError>;
20pub type EndpointFuture<'a> =
21 Pin<Box<dyn Future<Output = Result<ApiResponse, RustAuthError>> + Send + 'a>>;
22pub type AsyncEndpointHandler =
23 Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync>;
24pub type EndpointMiddlewareFuture<'a> =
25 Pin<Box<dyn Future<Output = Result<Option<ApiResponse>, RustAuthError>> + Send + 'a>>;
26pub type EndpointMiddlewareHandler = Arc<
27 dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a> + Send + Sync,
28>;
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct RequestBaseUrl(pub String);
32
33pub fn request_base_url<'a>(context: &'a AuthContext, request: Option<&'a ApiRequest>) -> &'a str {
34 request
35 .and_then(|request| request.extensions().get::<RequestBaseUrl>())
36 .map_or(context.base_url.as_str(), |base_url| base_url.0.as_str())
37}
38
39#[derive(Clone)]
40pub struct EndpointMiddleware {
41 pub handler: EndpointMiddlewareHandler,
42}
43
44impl EndpointMiddleware {
45 pub fn new<F>(handler: F) -> Self
46 where
47 F: for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a>
48 + Send
49 + Sync
50 + 'static,
51 {
52 Self {
53 handler: Arc::new(handler),
54 }
55 }
56}
57
58#[derive(Clone, Default)]
59pub struct AuthEndpointOptions {
60 pub operation_id: Option<String>,
61 pub allowed_media_types: Vec<String>,
62 pub body_schema: Option<BodySchema>,
63 pub middlewares: Vec<EndpointMiddleware>,
64 pub openapi: Option<OpenApiOperation>,
65 pub server_only: bool,
66 pub hide_from_openapi: bool,
67 pub bypass_origin_security: bool,
68}
69
70impl AuthEndpointOptions {
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 #[must_use]
76 pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
77 self.operation_id = Some(operation_id.into());
78 self
79 }
80
81 #[must_use]
82 pub fn allowed_media_types<I, S>(mut self, media_types: I) -> Self
83 where
84 I: IntoIterator<Item = S>,
85 S: Into<String>,
86 {
87 self.allowed_media_types = media_types.into_iter().map(Into::into).collect();
88 self
89 }
90
91 #[must_use]
92 pub fn body_schema(mut self, schema: BodySchema) -> Self {
93 self.body_schema = Some(schema);
94 self
95 }
96
97 #[must_use]
98 pub fn middleware(mut self, middleware: EndpointMiddleware) -> Self {
99 self.middlewares.push(middleware);
100 self
101 }
102
103 #[must_use]
104 pub fn openapi(mut self, operation: OpenApiOperation) -> Self {
105 self.openapi = Some(operation);
106 self
107 }
108
109 #[must_use]
110 pub fn server_only(mut self) -> Self {
111 self.server_only = true;
112 self
113 }
114
115 #[must_use]
116 pub fn hide_from_openapi(mut self) -> Self {
117 self.hide_from_openapi = true;
118 self
119 }
120
121 #[must_use]
122 pub fn bypass_origin_security(mut self) -> Self {
123 self.bypass_origin_security = true;
124 self
125 }
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub enum EndpointKind {
130 Sync,
131 Async,
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135pub struct EndpointInfo {
136 pub path: String,
137 pub method: http::Method,
138 pub kind: EndpointKind,
139 pub operation_id: Option<String>,
140 pub allowed_media_types: Vec<String>,
141}
142
143#[derive(Clone)]
144pub struct AuthEndpoint {
145 pub path: String,
146 pub method: http::Method,
147 pub handler: EndpointHandler,
148}
149
150#[derive(Clone)]
151pub struct AsyncAuthEndpoint {
152 pub path: String,
153 pub method: http::Method,
154 pub handler: AsyncEndpointHandler,
155 pub options: AuthEndpointOptions,
156}
157
158impl AsyncAuthEndpoint {
159 pub fn new<F>(path: impl Into<String>, method: http::Method, handler: F) -> Self
160 where
161 F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
162 {
163 Self {
164 path: path.into(),
165 method,
166 handler: Arc::new(handler),
167 options: AuthEndpointOptions::default(),
168 }
169 }
170}
171
172pub fn create_auth_endpoint_raw<F>(
173 path: impl Into<String>,
174 method: http::Method,
175 options: AuthEndpointOptions,
176 handler: F,
177) -> AsyncAuthEndpoint
178where
179 F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
180{
181 AsyncAuthEndpoint {
182 path: path.into(),
183 method,
184 handler: Arc::new(handler),
185 options,
186 }
187}
188
189pub fn async_auth_handler<F, Fut>(
191 handler: F,
192) -> impl for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + Clone + 'static
193where
194 F: Fn(AuthContext, ApiRequest) -> Fut + Send + Sync + Clone + 'static,
195 Fut: Future<Output = Result<ApiResponse, RustAuthError>> + Send + 'static,
196{
197 move |context: &AuthContext, request: ApiRequest| {
198 let handler = handler.clone();
199 let context = context.clone();
200 Box::pin(async move { handler(context, request).await })
201 }
202}
203
204pub fn create_auth_endpoint<F, Fut>(
206 path: impl Into<String>,
207 method: http::Method,
208 options: AuthEndpointOptions,
209 handler: F,
210) -> AsyncAuthEndpoint
211where
212 F: Fn(AuthContext, ApiRequest) -> Fut + Send + Sync + Clone + 'static,
213 Fut: Future<Output = Result<ApiResponse, RustAuthError>> + Send + 'static,
214{
215 create_auth_endpoint_raw(path, method, options, async_auth_handler(handler))
216}
217
218pub fn async_endpoint_middleware<F, Fut>(handler: F) -> EndpointMiddleware
220where
221 F: for<'a> Fn(AuthContext, &'a ApiRequest) -> Fut + Send + Sync + Clone + 'static,
222 for<'a> Fut: Future<Output = Result<Option<ApiResponse>, RustAuthError>> + Send + 'a,
223{
224 EndpointMiddleware::new(move |context: &AuthContext, request: &ApiRequest| {
225 let handler = handler.clone();
226 let context = context.clone();
227 Box::pin(async move { handler(context, request).await })
228 })
229}
230
231pub(super) fn validate_async_endpoint_request(
232 endpoint: &AsyncAuthEndpoint,
233 request: &ApiRequest,
234) -> Result<Option<ApiResponse>, RustAuthError> {
235 if endpoint.options.allowed_media_types.is_empty() && endpoint.options.body_schema.is_none() {
236 return Ok(None);
237 }
238
239 let content_type = request
240 .headers()
241 .get(header::CONTENT_TYPE)
242 .and_then(|value| value.to_str().ok())
243 .and_then(|value| value.split(';').next())
244 .map(str::trim)
245 .filter(|value| !value.is_empty());
246
247 if !endpoint.options.allowed_media_types.is_empty() {
248 let Some(content_type) = content_type else {
249 return invalid_request_response(
250 StatusCode::UNSUPPORTED_MEDIA_TYPE,
251 "UNSUPPORTED_MEDIA_TYPE",
252 "Missing Content-Type",
253 )
254 .map(Some);
255 };
256 if !endpoint
257 .options
258 .allowed_media_types
259 .iter()
260 .any(|allowed| allowed.eq_ignore_ascii_case(content_type))
261 {
262 return invalid_request_response(
263 StatusCode::UNSUPPORTED_MEDIA_TYPE,
264 "UNSUPPORTED_MEDIA_TYPE",
265 "Unsupported Content-Type",
266 )
267 .map(Some);
268 }
269 }
270
271 if let Some(schema) = &endpoint.options.body_schema {
272 let body = match parse_request_body::<Value>(request) {
273 Ok(body) => body,
274 Err(error) => {
275 return invalid_request_response(
276 StatusCode::BAD_REQUEST,
277 "INVALID_REQUEST_BODY",
278 &error.to_string(),
279 )
280 .map(Some);
281 }
282 };
283 if let Err(message) = schema.validate(&body) {
284 return invalid_request_response(
285 StatusCode::BAD_REQUEST,
286 "INVALID_REQUEST_BODY",
287 &message,
288 )
289 .map(Some);
290 }
291 }
292
293 Ok(None)
294}
295
296pub(super) async fn run_endpoint_middlewares(
297 context: &AuthContext,
298 endpoint: &AsyncAuthEndpoint,
299 request: &ApiRequest,
300) -> Result<Option<ApiResponse>, RustAuthError> {
301 for middleware in &endpoint.options.middlewares {
302 if let Some(response) = (middleware.handler)(context, request).await? {
303 return Ok(Some(response));
304 }
305 }
306 Ok(None)
307}
308
309fn invalid_request_response(
310 status: StatusCode,
311 code: &str,
312 message: &str,
313) -> Result<ApiResponse, RustAuthError> {
314 let body = serde_json::to_vec(&ApiErrorResponse {
315 code: code.to_owned(),
316 message: message.to_owned(),
317 original_message: None,
318 })
319 .map_err(|error| RustAuthError::Serialization {
320 context: "serializing API error response",
321 message: error.to_string(),
322 })?;
323
324 Response::builder()
325 .status(status)
326 .header(header::CONTENT_TYPE, "application/json")
327 .body(body)
328 .map_err(|error| RustAuthError::Serialization {
329 context: "building API error response",
330 message: error.to_string(),
331 })
332}