Skip to main content

rustauth_core/api/
endpoint.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use http::{header, Request, Response, StatusCode};
6use serde_json::Value;
7
8use crate::context::AuthContext;
9use crate::error::RustAuthError;
10
11use super::body::parse_request_body;
12use super::error::ApiErrorResponse;
13use super::openapi::OpenApiOperation;
14use super::schema::BodySchema;
15
16pub type Body = Vec<u8>;
17pub type ApiRequest = Request<Body>;
18pub type ApiResponse = Response<Body>;
19pub type EndpointHandler = fn(&AuthContext, ApiRequest) -> Result<ApiResponse, RustAuthError>;
20pub type EndpointFuture<'a> =
21    Pin<Box<dyn Future<Output = Result<ApiResponse, RustAuthError>> + Send + 'a>>;
22pub type AsyncEndpointHandler =
23    Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync>;
24pub type EndpointMiddlewareFuture<'a> =
25    Pin<Box<dyn Future<Output = Result<Option<ApiResponse>, RustAuthError>> + Send + 'a>>;
26pub type EndpointMiddlewareHandler = Arc<
27    dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a> + Send + Sync,
28>;
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct RequestBaseUrl(pub String);
32
33pub fn request_base_url<'a>(context: &'a AuthContext, request: Option<&'a ApiRequest>) -> &'a str {
34    request
35        .and_then(|request| request.extensions().get::<RequestBaseUrl>())
36        .map_or(context.base_url.as_str(), |base_url| base_url.0.as_str())
37}
38
39#[derive(Clone)]
40pub struct EndpointMiddleware {
41    pub handler: EndpointMiddlewareHandler,
42}
43
44impl EndpointMiddleware {
45    pub fn new<F>(handler: F) -> Self
46    where
47        F: for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a>
48            + Send
49            + Sync
50            + 'static,
51    {
52        Self {
53            handler: Arc::new(handler),
54        }
55    }
56}
57
58#[derive(Clone, Default)]
59pub struct AuthEndpointOptions {
60    pub operation_id: Option<String>,
61    pub allowed_media_types: Vec<String>,
62    pub body_schema: Option<BodySchema>,
63    pub middlewares: Vec<EndpointMiddleware>,
64    pub openapi: Option<OpenApiOperation>,
65    pub server_only: bool,
66    pub hide_from_openapi: bool,
67    pub bypass_origin_security: bool,
68}
69
70impl AuthEndpointOptions {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    #[must_use]
76    pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
77        self.operation_id = Some(operation_id.into());
78        self
79    }
80
81    #[must_use]
82    pub fn allowed_media_types<I, S>(mut self, media_types: I) -> Self
83    where
84        I: IntoIterator<Item = S>,
85        S: Into<String>,
86    {
87        self.allowed_media_types = media_types.into_iter().map(Into::into).collect();
88        self
89    }
90
91    #[must_use]
92    pub fn body_schema(mut self, schema: BodySchema) -> Self {
93        self.body_schema = Some(schema);
94        self
95    }
96
97    #[must_use]
98    pub fn middleware(mut self, middleware: EndpointMiddleware) -> Self {
99        self.middlewares.push(middleware);
100        self
101    }
102
103    #[must_use]
104    pub fn openapi(mut self, operation: OpenApiOperation) -> Self {
105        self.openapi = Some(operation);
106        self
107    }
108
109    #[must_use]
110    pub fn server_only(mut self) -> Self {
111        self.server_only = true;
112        self
113    }
114
115    #[must_use]
116    pub fn hide_from_openapi(mut self) -> Self {
117        self.hide_from_openapi = true;
118        self
119    }
120
121    #[must_use]
122    pub fn bypass_origin_security(mut self) -> Self {
123        self.bypass_origin_security = true;
124        self
125    }
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub enum EndpointKind {
130    Sync,
131    Async,
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135pub struct EndpointInfo {
136    pub path: String,
137    pub method: http::Method,
138    pub kind: EndpointKind,
139    pub operation_id: Option<String>,
140    pub allowed_media_types: Vec<String>,
141}
142
143#[derive(Clone)]
144pub struct AuthEndpoint {
145    pub path: String,
146    pub method: http::Method,
147    pub handler: EndpointHandler,
148}
149
150#[derive(Clone)]
151pub struct AsyncAuthEndpoint {
152    pub path: String,
153    pub method: http::Method,
154    pub handler: AsyncEndpointHandler,
155    pub options: AuthEndpointOptions,
156}
157
158impl AsyncAuthEndpoint {
159    pub fn new<F>(path: impl Into<String>, method: http::Method, handler: F) -> Self
160    where
161        F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
162    {
163        Self {
164            path: path.into(),
165            method,
166            handler: Arc::new(handler),
167            options: AuthEndpointOptions::default(),
168        }
169    }
170}
171
172pub fn create_auth_endpoint_raw<F>(
173    path: impl Into<String>,
174    method: http::Method,
175    options: AuthEndpointOptions,
176    handler: F,
177) -> AsyncAuthEndpoint
178where
179    F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
180{
181    AsyncAuthEndpoint {
182        path: path.into(),
183        method,
184        handler: Arc::new(handler),
185        options,
186    }
187}
188
189/// Wraps an async handler so endpoint authors do not need `Box::pin`.
190pub fn async_auth_handler<F, Fut>(
191    handler: F,
192) -> impl for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + Clone + 'static
193where
194    F: Fn(AuthContext, ApiRequest) -> Fut + Send + Sync + Clone + 'static,
195    Fut: Future<Output = Result<ApiResponse, RustAuthError>> + Send + 'static,
196{
197    move |context: &AuthContext, request: ApiRequest| {
198        let handler = handler.clone();
199        let context = context.clone();
200        Box::pin(async move { handler(context, request).await })
201    }
202}
203
204/// Defines an async auth endpoint without manual `Box::pin`.
205pub fn create_auth_endpoint<F, Fut>(
206    path: impl Into<String>,
207    method: http::Method,
208    options: AuthEndpointOptions,
209    handler: F,
210) -> AsyncAuthEndpoint
211where
212    F: Fn(AuthContext, ApiRequest) -> Fut + Send + Sync + Clone + 'static,
213    Fut: Future<Output = Result<ApiResponse, RustAuthError>> + Send + 'static,
214{
215    create_auth_endpoint_raw(path, method, options, async_auth_handler(handler))
216}
217
218/// Wraps async endpoint middleware so authors do not need `Box::pin`.
219pub fn async_endpoint_middleware<F, Fut>(handler: F) -> EndpointMiddleware
220where
221    F: for<'a> Fn(AuthContext, &'a ApiRequest) -> Fut + Send + Sync + Clone + 'static,
222    for<'a> Fut: Future<Output = Result<Option<ApiResponse>, RustAuthError>> + Send + 'a,
223{
224    EndpointMiddleware::new(move |context: &AuthContext, request: &ApiRequest| {
225        let handler = handler.clone();
226        let context = context.clone();
227        Box::pin(async move { handler(context, request).await })
228    })
229}
230
231pub(super) fn validate_async_endpoint_request(
232    endpoint: &AsyncAuthEndpoint,
233    request: &ApiRequest,
234) -> Result<Option<ApiResponse>, RustAuthError> {
235    if endpoint.options.allowed_media_types.is_empty() && endpoint.options.body_schema.is_none() {
236        return Ok(None);
237    }
238
239    let content_type = request
240        .headers()
241        .get(header::CONTENT_TYPE)
242        .and_then(|value| value.to_str().ok())
243        .and_then(|value| value.split(';').next())
244        .map(str::trim)
245        .filter(|value| !value.is_empty());
246
247    if !endpoint.options.allowed_media_types.is_empty() {
248        let Some(content_type) = content_type else {
249            return invalid_request_response(
250                StatusCode::UNSUPPORTED_MEDIA_TYPE,
251                "UNSUPPORTED_MEDIA_TYPE",
252                "Missing Content-Type",
253            )
254            .map(Some);
255        };
256        if !endpoint
257            .options
258            .allowed_media_types
259            .iter()
260            .any(|allowed| allowed.eq_ignore_ascii_case(content_type))
261        {
262            return invalid_request_response(
263                StatusCode::UNSUPPORTED_MEDIA_TYPE,
264                "UNSUPPORTED_MEDIA_TYPE",
265                "Unsupported Content-Type",
266            )
267            .map(Some);
268        }
269    }
270
271    if let Some(schema) = &endpoint.options.body_schema {
272        let body = match parse_request_body::<Value>(request) {
273            Ok(body) => body,
274            Err(error) => {
275                return invalid_request_response(
276                    StatusCode::BAD_REQUEST,
277                    "INVALID_REQUEST_BODY",
278                    &error.to_string(),
279                )
280                .map(Some);
281            }
282        };
283        if let Err(message) = schema.validate(&body) {
284            return invalid_request_response(
285                StatusCode::BAD_REQUEST,
286                "INVALID_REQUEST_BODY",
287                &message,
288            )
289            .map(Some);
290        }
291    }
292
293    Ok(None)
294}
295
296pub(super) async fn run_endpoint_middlewares(
297    context: &AuthContext,
298    endpoint: &AsyncAuthEndpoint,
299    request: &ApiRequest,
300) -> Result<Option<ApiResponse>, RustAuthError> {
301    for middleware in &endpoint.options.middlewares {
302        if let Some(response) = (middleware.handler)(context, request).await? {
303            return Ok(Some(response));
304        }
305    }
306    Ok(None)
307}
308
309fn invalid_request_response(
310    status: StatusCode,
311    code: &str,
312    message: &str,
313) -> Result<ApiResponse, RustAuthError> {
314    let body = serde_json::to_vec(&ApiErrorResponse {
315        code: code.to_owned(),
316        message: message.to_owned(),
317        original_message: None,
318    })
319    .map_err(|error| RustAuthError::Serialization {
320        context: "serializing API error response",
321        message: error.to_string(),
322    })?;
323
324    Response::builder()
325        .status(status)
326        .header(header::CONTENT_TYPE, "application/json")
327        .body(body)
328        .map_err(|error| RustAuthError::Serialization {
329            context: "building API error response",
330            message: error.to_string(),
331        })
332}