Skip to main content

rustauth_core/
plugin.rs

1//! Plugin contracts for RustAuth extensions.
2
3use std::any::Any;
4use std::future::Future;
5use std::pin::Pin;
6
7mod db;
8mod endpoint;
9mod error;
10mod hooks;
11mod init;
12mod password;
13mod rate_limit;
14mod schema;
15
16pub use db::{
17    PluginDatabaseAfterHookHandler, PluginDatabaseAfterInput, PluginDatabaseBeforeAction,
18    PluginDatabaseBeforeHookHandler, PluginDatabaseBeforeInput, PluginDatabaseHook,
19    PluginDatabaseHookContext, PluginDatabaseOperation, PluginMigration, PluginMigrationBody,
20    PluginMigrationStep,
21};
22pub use endpoint::PluginEndpoint;
23pub use error::PluginErrorCode;
24pub use hooks::{
25    async_after_hook_handler, async_before_hook_handler, PluginAfterHook, PluginAfterHookAction,
26    PluginAfterHookFuture, PluginAfterHookHandler, PluginAsyncAfterHook,
27    PluginAsyncAfterHookHandler, PluginAsyncBeforeHook, PluginAsyncBeforeHookHandler,
28    PluginBeforeHook, PluginBeforeHookAction, PluginBeforeHookFuture, PluginBeforeHookHandler,
29    PluginEndpointHooks, PluginHookMatcher,
30};
31pub use init::{PluginInitHandler, PluginInitOutput};
32pub use password::{
33    PluginPasswordValidationInput, PluginPasswordValidationRejection, PluginPasswordValidator,
34    PluginPasswordValidatorFuture, PluginPasswordValidatorHandler,
35};
36pub use rate_limit::PluginRateLimitRule;
37pub use schema::PluginSchemaContribution;
38
39use crate::api::{ApiRequest, ApiResponse, AsyncAuthEndpoint, Body};
40use crate::context::AuthContext;
41use crate::error::RustAuthError;
42#[cfg(feature = "oauth")]
43use rustauth_oauth::oauth2::SocialOAuthProvider;
44use serde_json::Value;
45use std::fmt;
46use std::sync::Arc;
47
48/// Alias for [`Body`]; prefer [`Body`] or [`ApiRequest`] in new code.
49pub type PluginBody = Body;
50/// Alias for [`ApiRequest`]; prefer [`ApiRequest`] in new code.
51pub type PluginRequest = ApiRequest;
52/// Alias for [`ApiResponse`]; prefer [`ApiResponse`] in new code.
53pub type PluginResponse = ApiResponse;
54pub type PluginMiddlewareFuture<'a> =
55    Pin<Box<dyn Future<Output = Result<Option<PluginResponse>, RustAuthError>> + Send + 'a>>;
56pub type PluginOnRequest = Arc<
57    dyn Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, RustAuthError> + Send + Sync,
58>;
59pub type PluginOnResponse = Arc<
60    dyn Fn(&AuthContext, &PluginRequest, PluginResponse) -> Result<PluginResponse, RustAuthError>
61        + Send
62        + Sync,
63>;
64pub type PluginOnResponseAsyncFuture<'a> =
65    Pin<Box<dyn Future<Output = Result<(), RustAuthError>> + Send + 'a>>;
66pub type PluginOnResponseAsync = Arc<
67    dyn for<'a> Fn(
68            &'a AuthContext,
69            &'a PluginRequest,
70            &'a PluginResponse,
71        ) -> PluginOnResponseAsyncFuture<'a>
72        + Send
73        + Sync,
74>;
75pub type PluginMiddlewareHandler = Arc<
76    dyn Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, RustAuthError>
77        + Send
78        + Sync,
79>;
80pub type PluginAsyncMiddlewareHandler = Arc<
81    dyn for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a> + Send + Sync,
82>;
83
84#[derive(Clone)]
85pub struct AuthPlugin {
86    pub id: String,
87    pub version: Option<String>,
88    pub options: Option<Value>,
89    pub endpoints: Vec<AsyncAuthEndpoint>,
90    pub middlewares: Vec<PluginMiddleware>,
91    pub async_middlewares: Vec<PluginAsyncMiddleware>,
92    pub on_request: Option<PluginOnRequest>,
93    pub on_response: Option<PluginOnResponse>,
94    pub on_response_async: Option<PluginOnResponseAsync>,
95    pub init: Option<PluginInitHandler>,
96    pub schema: Vec<PluginSchemaContribution>,
97    pub rate_limit: Vec<PluginRateLimitRule>,
98    pub hooks: PluginEndpointHooks,
99    pub error_codes: Vec<PluginErrorCode>,
100    pub database_hooks: Vec<PluginDatabaseHook>,
101    pub migrations: Vec<PluginMigration>,
102    #[cfg(feature = "oauth")]
103    pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
104    pub password_validators: Vec<PluginPasswordValidator>,
105    pub state: Option<Arc<dyn Any + Send + Sync>>,
106}
107
108impl AuthPlugin {
109    pub fn new(id: impl Into<String>) -> Self {
110        Self {
111            id: id.into(),
112            version: None,
113            options: None,
114            endpoints: Vec::new(),
115            middlewares: Vec::new(),
116            async_middlewares: Vec::new(),
117            on_request: None,
118            on_response: None,
119            on_response_async: None,
120            init: None,
121            schema: Vec::new(),
122            rate_limit: Vec::new(),
123            hooks: PluginEndpointHooks::default(),
124            error_codes: Vec::new(),
125            database_hooks: Vec::new(),
126            migrations: Vec::new(),
127            #[cfg(feature = "oauth")]
128            social_providers: Vec::new(),
129            password_validators: Vec::new(),
130            state: None,
131        }
132    }
133
134    pub fn with_version(mut self, version: impl Into<String>) -> Self {
135        self.version = Some(version.into());
136        self
137    }
138
139    pub fn with_options(mut self, options: Value) -> Self {
140        self.options = Some(options);
141        self
142    }
143
144    pub fn with_endpoint(mut self, endpoint: AsyncAuthEndpoint) -> Self {
145        self.endpoints.push(endpoint);
146        self
147    }
148
149    pub fn with_init<F>(mut self, init: F) -> Self
150    where
151        F: Fn(&AuthContext) -> Result<PluginInitOutput, RustAuthError> + Send + Sync + 'static,
152    {
153        self.init = Some(Arc::new(init));
154        self
155    }
156
157    pub fn with_schema(mut self, contribution: PluginSchemaContribution) -> Self {
158        self.schema.push(contribution);
159        self
160    }
161
162    pub fn with_rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
163        self.rate_limit.push(rule);
164        self
165    }
166
167    pub fn with_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
168    where
169        F: Fn(&AuthContext, PluginRequest) -> Result<PluginBeforeHookAction, RustAuthError>
170            + Send
171            + Sync
172            + 'static,
173    {
174        self.hooks.before.push(PluginBeforeHook {
175            matcher: PluginHookMatcher::path(path),
176            handler: Arc::new(hook),
177        });
178        self
179    }
180
181    pub fn with_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
182    where
183        F: Fn(
184                &AuthContext,
185                &PluginRequest,
186                PluginResponse,
187            ) -> Result<PluginAfterHookAction, RustAuthError>
188            + Send
189            + Sync
190            + 'static,
191    {
192        self.hooks.after.push(PluginAfterHook {
193            matcher: PluginHookMatcher::path(path),
194            handler: Arc::new(hook),
195        });
196        self
197    }
198
199    pub fn with_async_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
200    where
201        F: for<'a> Fn(&'a AuthContext, PluginRequest) -> PluginBeforeHookFuture<'a>
202            + Send
203            + Sync
204            + 'static,
205    {
206        self.hooks.async_before.push(PluginAsyncBeforeHook {
207            matcher: PluginHookMatcher::path(path),
208            handler: Arc::new(hook),
209        });
210        self
211    }
212
213    pub fn with_async_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
214    where
215        F: for<'a> Fn(
216                &'a AuthContext,
217                &'a PluginRequest,
218                PluginResponse,
219            ) -> PluginAfterHookFuture<'a>
220            + Send
221            + Sync
222            + 'static,
223    {
224        self.hooks.async_after.push(PluginAsyncAfterHook {
225            matcher: PluginHookMatcher::path(path),
226            handler: Arc::new(hook),
227        });
228        self
229    }
230
231    /// Registers an async after-hook without manual `Box::pin`.
232    pub fn with_async_after_handler<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
233    where
234        F: for<'a> Fn(AuthContext, &'a PluginRequest, PluginResponse) -> Fut
235            + Send
236            + Sync
237            + Clone
238            + 'static,
239        for<'a> Fut: Future<Output = Result<PluginAfterHookAction, RustAuthError>> + Send + 'a,
240    {
241        self.with_async_after_hook(path, hooks::async_after_hook_handler(handler))
242    }
243
244    /// Registers an async before-hook without manual `Box::pin`.
245    pub fn with_async_before_handler<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
246    where
247        F: Fn(AuthContext, PluginRequest) -> Fut + Send + Sync + Clone + 'static,
248        Fut: Future<Output = Result<PluginBeforeHookAction, RustAuthError>> + Send + 'static,
249    {
250        self.with_async_before_hook(path, hooks::async_before_hook_handler(handler))
251    }
252
253    pub fn with_error_code(mut self, error_code: PluginErrorCode) -> Self {
254        self.error_codes.push(error_code);
255        self
256    }
257
258    pub fn with_database_hook(mut self, hook: PluginDatabaseHook) -> Self {
259        self.database_hooks.push(hook);
260        self
261    }
262
263    pub fn with_migration(mut self, migration: PluginMigration) -> Self {
264        self.migrations.push(migration);
265        self
266    }
267
268    #[cfg(feature = "oauth")]
269    pub fn with_social_provider(
270        mut self,
271        provider: impl Into<Arc<dyn SocialOAuthProvider>>,
272    ) -> Self {
273        self.social_providers.push(provider.into());
274        self
275    }
276
277    pub fn with_password_validator<F>(mut self, validator: F) -> Self
278    where
279        F: for<'a> Fn(
280                &'a AuthContext,
281                PluginPasswordValidationInput,
282            ) -> PluginPasswordValidatorFuture<'a>
283            + Send
284            + Sync
285            + 'static,
286    {
287        self.password_validators.push(PluginPasswordValidator {
288            handler: Arc::new(validator),
289        });
290        self
291    }
292
293    pub fn with_state<T>(mut self, state: T) -> Self
294    where
295        T: Any + Send + Sync + 'static,
296    {
297        self.state = Some(Arc::new(state));
298        self
299    }
300
301    pub fn state<T>(&self) -> Option<Arc<T>>
302    where
303        T: Any + Send + Sync + 'static,
304    {
305        self.state
306            .as_ref()
307            .and_then(|state| Arc::clone(state).downcast::<T>().ok())
308    }
309
310    pub fn with_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
311    where
312        F: Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, RustAuthError>
313            + Send
314            + Sync
315            + 'static,
316    {
317        self.middlewares.push(PluginMiddleware {
318            path: path.into(),
319            handler: Arc::new(middleware),
320        });
321        self
322    }
323
324    pub fn with_async_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
325    where
326        F: for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a>
327            + Send
328            + Sync
329            + 'static,
330    {
331        self.async_middlewares.push(PluginAsyncMiddleware {
332            path: path.into(),
333            handler: Arc::new(middleware),
334        });
335        self
336    }
337
338    pub fn with_on_request<F>(mut self, hook: F) -> Self
339    where
340        F: Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, RustAuthError>
341            + Send
342            + Sync
343            + 'static,
344    {
345        self.on_request = Some(Arc::new(hook));
346        self
347    }
348
349    pub fn with_on_response<F>(mut self, hook: F) -> Self
350    where
351        F: Fn(
352                &AuthContext,
353                &PluginRequest,
354                PluginResponse,
355            ) -> Result<PluginResponse, RustAuthError>
356            + Send
357            + Sync
358            + 'static,
359    {
360        self.on_response = Some(Arc::new(hook));
361        self
362    }
363
364    /// Async hook run during async response finalization after session hydration
365    /// and before synchronous `on_response` hooks.
366    pub fn with_on_response_async<F>(mut self, hook: F) -> Self
367    where
368        F: for<'a> Fn(
369                &'a AuthContext,
370                &'a PluginRequest,
371                &'a PluginResponse,
372            ) -> PluginOnResponseAsyncFuture<'a>
373            + Send
374            + Sync
375            + 'static,
376    {
377        self.on_response_async = Some(Arc::new(hook));
378        self
379    }
380}
381
382impl fmt::Debug for AuthPlugin {
383    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
384        formatter
385            .debug_struct("AuthPlugin")
386            .field("id", &self.id)
387            .field("version", &self.version)
388            .field("options", &self.options)
389            .field("endpoints", &self.endpoints.len())
390            .field("middlewares", &self.middlewares)
391            .field("async_middlewares", &self.async_middlewares)
392            .field("on_request", &self.on_request.as_ref().map(|_| "<hook>"))
393            .field("on_response", &self.on_response.as_ref().map(|_| "<hook>"))
394            .field(
395                "on_response_async",
396                &self.on_response_async.as_ref().map(|_| "<hook>"),
397            )
398            .field("init", &self.init.as_ref().map(|_| "<init>"))
399            .field("schema", &self.schema)
400            .field("rate_limit", &self.rate_limit)
401            .field("hooks", &self.hooks)
402            .field("error_codes", &self.error_codes)
403            .field("database_hooks", &self.database_hooks)
404            .field("migrations", &self.migrations)
405            .field("social_providers", &debug_social_providers(self))
406            .field("password_validators", &self.password_validators)
407            .field("state", &self.state.as_ref().map(|_| "<state>"))
408            .finish()
409    }
410}
411
412#[cfg(feature = "oauth")]
413fn debug_social_providers(plugin: &AuthPlugin) -> Vec<&str> {
414    plugin
415        .social_providers
416        .iter()
417        .map(|provider| provider.id())
418        .collect()
419}
420
421#[cfg(not(feature = "oauth"))]
422fn debug_social_providers(_plugin: &AuthPlugin) -> Vec<&'static str> {
423    Vec::new()
424}
425
426#[derive(Clone)]
427pub struct PluginMiddleware {
428    pub path: String,
429    pub handler: PluginMiddlewareHandler,
430}
431
432impl fmt::Debug for PluginMiddleware {
433    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
434        formatter
435            .debug_struct("PluginMiddleware")
436            .field("path", &self.path)
437            .field("handler", &"<middleware>")
438            .finish()
439    }
440}
441
442#[derive(Clone)]
443pub struct PluginAsyncMiddleware {
444    pub path: String,
445    pub handler: PluginAsyncMiddlewareHandler,
446}
447
448impl fmt::Debug for PluginAsyncMiddleware {
449    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
450        formatter
451            .debug_struct("PluginAsyncMiddleware")
452            .field("path", &self.path)
453            .field("handler", &"<async middleware>")
454            .finish()
455    }
456}
457
458pub enum PluginRequestAction {
459    Continue(PluginRequest),
460    Respond(PluginResponse),
461}