1use std::any::Any;
4use std::future::Future;
5use std::pin::Pin;
6
7mod db;
8mod endpoint;
9mod error;
10mod hooks;
11mod init;
12mod password;
13mod rate_limit;
14mod schema;
15
16pub use db::{
17 PluginDatabaseAfterHookHandler, PluginDatabaseAfterInput, PluginDatabaseBeforeAction,
18 PluginDatabaseBeforeHookHandler, PluginDatabaseBeforeInput, PluginDatabaseHook,
19 PluginDatabaseHookContext, PluginDatabaseOperation, PluginMigration, PluginMigrationBody,
20 PluginMigrationStep,
21};
22pub use endpoint::PluginEndpoint;
23pub use error::PluginErrorCode;
24pub use hooks::{
25 async_after_hook_handler, async_before_hook_handler, PluginAfterHook, PluginAfterHookAction,
26 PluginAfterHookFuture, PluginAfterHookHandler, PluginAsyncAfterHook,
27 PluginAsyncAfterHookHandler, PluginAsyncBeforeHook, PluginAsyncBeforeHookHandler,
28 PluginBeforeHook, PluginBeforeHookAction, PluginBeforeHookFuture, PluginBeforeHookHandler,
29 PluginEndpointHooks, PluginHookMatcher,
30};
31pub use init::{PluginInitHandler, PluginInitOutput};
32pub use password::{
33 PluginPasswordValidationInput, PluginPasswordValidationRejection, PluginPasswordValidator,
34 PluginPasswordValidatorFuture, PluginPasswordValidatorHandler,
35};
36pub use rate_limit::PluginRateLimitRule;
37pub use schema::PluginSchemaContribution;
38
39use crate::api::{ApiRequest, ApiResponse, AsyncAuthEndpoint, Body};
40use crate::context::AuthContext;
41use crate::error::RustAuthError;
42#[cfg(feature = "oauth")]
43use rustauth_oauth::oauth2::SocialOAuthProvider;
44use serde_json::Value;
45use std::fmt;
46use std::sync::Arc;
47
48pub type PluginBody = Body;
50pub type PluginRequest = ApiRequest;
52pub type PluginResponse = ApiResponse;
54pub type PluginMiddlewareFuture<'a> =
55 Pin<Box<dyn Future<Output = Result<Option<PluginResponse>, RustAuthError>> + Send + 'a>>;
56pub type PluginOnRequest = Arc<
57 dyn Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, RustAuthError> + Send + Sync,
58>;
59pub type PluginOnResponse = Arc<
60 dyn Fn(&AuthContext, &PluginRequest, PluginResponse) -> Result<PluginResponse, RustAuthError>
61 + Send
62 + Sync,
63>;
64pub type PluginOnResponseAsyncFuture<'a> =
65 Pin<Box<dyn Future<Output = Result<(), RustAuthError>> + Send + 'a>>;
66pub type PluginOnResponseAsync = Arc<
67 dyn for<'a> Fn(
68 &'a AuthContext,
69 &'a PluginRequest,
70 &'a PluginResponse,
71 ) -> PluginOnResponseAsyncFuture<'a>
72 + Send
73 + Sync,
74>;
75pub type PluginMiddlewareHandler = Arc<
76 dyn Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, RustAuthError>
77 + Send
78 + Sync,
79>;
80pub type PluginAsyncMiddlewareHandler = Arc<
81 dyn for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a> + Send + Sync,
82>;
83
84#[derive(Clone)]
85pub struct AuthPlugin {
86 pub id: String,
87 pub version: Option<String>,
88 pub options: Option<Value>,
89 pub endpoints: Vec<AsyncAuthEndpoint>,
90 pub middlewares: Vec<PluginMiddleware>,
91 pub async_middlewares: Vec<PluginAsyncMiddleware>,
92 pub on_request: Option<PluginOnRequest>,
93 pub on_response: Option<PluginOnResponse>,
94 pub on_response_async: Option<PluginOnResponseAsync>,
95 pub init: Option<PluginInitHandler>,
96 pub schema: Vec<PluginSchemaContribution>,
97 pub rate_limit: Vec<PluginRateLimitRule>,
98 pub hooks: PluginEndpointHooks,
99 pub error_codes: Vec<PluginErrorCode>,
100 pub database_hooks: Vec<PluginDatabaseHook>,
101 pub migrations: Vec<PluginMigration>,
102 #[cfg(feature = "oauth")]
103 pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
104 pub password_validators: Vec<PluginPasswordValidator>,
105 pub state: Option<Arc<dyn Any + Send + Sync>>,
106}
107
108impl AuthPlugin {
109 pub fn new(id: impl Into<String>) -> Self {
110 Self {
111 id: id.into(),
112 version: None,
113 options: None,
114 endpoints: Vec::new(),
115 middlewares: Vec::new(),
116 async_middlewares: Vec::new(),
117 on_request: None,
118 on_response: None,
119 on_response_async: None,
120 init: None,
121 schema: Vec::new(),
122 rate_limit: Vec::new(),
123 hooks: PluginEndpointHooks::default(),
124 error_codes: Vec::new(),
125 database_hooks: Vec::new(),
126 migrations: Vec::new(),
127 #[cfg(feature = "oauth")]
128 social_providers: Vec::new(),
129 password_validators: Vec::new(),
130 state: None,
131 }
132 }
133
134 pub fn with_version(mut self, version: impl Into<String>) -> Self {
135 self.version = Some(version.into());
136 self
137 }
138
139 pub fn with_options(mut self, options: Value) -> Self {
140 self.options = Some(options);
141 self
142 }
143
144 pub fn with_endpoint(mut self, endpoint: AsyncAuthEndpoint) -> Self {
145 self.endpoints.push(endpoint);
146 self
147 }
148
149 pub fn with_init<F>(mut self, init: F) -> Self
150 where
151 F: Fn(&AuthContext) -> Result<PluginInitOutput, RustAuthError> + Send + Sync + 'static,
152 {
153 self.init = Some(Arc::new(init));
154 self
155 }
156
157 pub fn with_schema(mut self, contribution: PluginSchemaContribution) -> Self {
158 self.schema.push(contribution);
159 self
160 }
161
162 pub fn with_rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
163 self.rate_limit.push(rule);
164 self
165 }
166
167 pub fn with_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
168 where
169 F: Fn(&AuthContext, PluginRequest) -> Result<PluginBeforeHookAction, RustAuthError>
170 + Send
171 + Sync
172 + 'static,
173 {
174 self.hooks.before.push(PluginBeforeHook {
175 matcher: PluginHookMatcher::path(path),
176 handler: Arc::new(hook),
177 });
178 self
179 }
180
181 pub fn with_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
182 where
183 F: Fn(
184 &AuthContext,
185 &PluginRequest,
186 PluginResponse,
187 ) -> Result<PluginAfterHookAction, RustAuthError>
188 + Send
189 + Sync
190 + 'static,
191 {
192 self.hooks.after.push(PluginAfterHook {
193 matcher: PluginHookMatcher::path(path),
194 handler: Arc::new(hook),
195 });
196 self
197 }
198
199 pub fn with_async_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
200 where
201 F: for<'a> Fn(&'a AuthContext, PluginRequest) -> PluginBeforeHookFuture<'a>
202 + Send
203 + Sync
204 + 'static,
205 {
206 self.hooks.async_before.push(PluginAsyncBeforeHook {
207 matcher: PluginHookMatcher::path(path),
208 handler: Arc::new(hook),
209 });
210 self
211 }
212
213 pub fn with_async_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
214 where
215 F: for<'a> Fn(
216 &'a AuthContext,
217 &'a PluginRequest,
218 PluginResponse,
219 ) -> PluginAfterHookFuture<'a>
220 + Send
221 + Sync
222 + 'static,
223 {
224 self.hooks.async_after.push(PluginAsyncAfterHook {
225 matcher: PluginHookMatcher::path(path),
226 handler: Arc::new(hook),
227 });
228 self
229 }
230
231 pub fn with_async_after_handler<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
233 where
234 F: for<'a> Fn(AuthContext, &'a PluginRequest, PluginResponse) -> Fut
235 + Send
236 + Sync
237 + Clone
238 + 'static,
239 for<'a> Fut: Future<Output = Result<PluginAfterHookAction, RustAuthError>> + Send + 'a,
240 {
241 self.with_async_after_hook(path, hooks::async_after_hook_handler(handler))
242 }
243
244 pub fn with_async_before_handler<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
246 where
247 F: Fn(AuthContext, PluginRequest) -> Fut + Send + Sync + Clone + 'static,
248 Fut: Future<Output = Result<PluginBeforeHookAction, RustAuthError>> + Send + 'static,
249 {
250 self.with_async_before_hook(path, hooks::async_before_hook_handler(handler))
251 }
252
253 pub fn with_error_code(mut self, error_code: PluginErrorCode) -> Self {
254 self.error_codes.push(error_code);
255 self
256 }
257
258 pub fn with_database_hook(mut self, hook: PluginDatabaseHook) -> Self {
259 self.database_hooks.push(hook);
260 self
261 }
262
263 pub fn with_migration(mut self, migration: PluginMigration) -> Self {
264 self.migrations.push(migration);
265 self
266 }
267
268 #[cfg(feature = "oauth")]
269 pub fn with_social_provider(
270 mut self,
271 provider: impl Into<Arc<dyn SocialOAuthProvider>>,
272 ) -> Self {
273 self.social_providers.push(provider.into());
274 self
275 }
276
277 pub fn with_password_validator<F>(mut self, validator: F) -> Self
278 where
279 F: for<'a> Fn(
280 &'a AuthContext,
281 PluginPasswordValidationInput,
282 ) -> PluginPasswordValidatorFuture<'a>
283 + Send
284 + Sync
285 + 'static,
286 {
287 self.password_validators.push(PluginPasswordValidator {
288 handler: Arc::new(validator),
289 });
290 self
291 }
292
293 pub fn with_state<T>(mut self, state: T) -> Self
294 where
295 T: Any + Send + Sync + 'static,
296 {
297 self.state = Some(Arc::new(state));
298 self
299 }
300
301 pub fn state<T>(&self) -> Option<Arc<T>>
302 where
303 T: Any + Send + Sync + 'static,
304 {
305 self.state
306 .as_ref()
307 .and_then(|state| Arc::clone(state).downcast::<T>().ok())
308 }
309
310 pub fn with_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
311 where
312 F: Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, RustAuthError>
313 + Send
314 + Sync
315 + 'static,
316 {
317 self.middlewares.push(PluginMiddleware {
318 path: path.into(),
319 handler: Arc::new(middleware),
320 });
321 self
322 }
323
324 pub fn with_async_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
325 where
326 F: for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a>
327 + Send
328 + Sync
329 + 'static,
330 {
331 self.async_middlewares.push(PluginAsyncMiddleware {
332 path: path.into(),
333 handler: Arc::new(middleware),
334 });
335 self
336 }
337
338 pub fn with_on_request<F>(mut self, hook: F) -> Self
339 where
340 F: Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, RustAuthError>
341 + Send
342 + Sync
343 + 'static,
344 {
345 self.on_request = Some(Arc::new(hook));
346 self
347 }
348
349 pub fn with_on_response<F>(mut self, hook: F) -> Self
350 where
351 F: Fn(
352 &AuthContext,
353 &PluginRequest,
354 PluginResponse,
355 ) -> Result<PluginResponse, RustAuthError>
356 + Send
357 + Sync
358 + 'static,
359 {
360 self.on_response = Some(Arc::new(hook));
361 self
362 }
363
364 pub fn with_on_response_async<F>(mut self, hook: F) -> Self
367 where
368 F: for<'a> Fn(
369 &'a AuthContext,
370 &'a PluginRequest,
371 &'a PluginResponse,
372 ) -> PluginOnResponseAsyncFuture<'a>
373 + Send
374 + Sync
375 + 'static,
376 {
377 self.on_response_async = Some(Arc::new(hook));
378 self
379 }
380}
381
382impl fmt::Debug for AuthPlugin {
383 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
384 formatter
385 .debug_struct("AuthPlugin")
386 .field("id", &self.id)
387 .field("version", &self.version)
388 .field("options", &self.options)
389 .field("endpoints", &self.endpoints.len())
390 .field("middlewares", &self.middlewares)
391 .field("async_middlewares", &self.async_middlewares)
392 .field("on_request", &self.on_request.as_ref().map(|_| "<hook>"))
393 .field("on_response", &self.on_response.as_ref().map(|_| "<hook>"))
394 .field(
395 "on_response_async",
396 &self.on_response_async.as_ref().map(|_| "<hook>"),
397 )
398 .field("init", &self.init.as_ref().map(|_| "<init>"))
399 .field("schema", &self.schema)
400 .field("rate_limit", &self.rate_limit)
401 .field("hooks", &self.hooks)
402 .field("error_codes", &self.error_codes)
403 .field("database_hooks", &self.database_hooks)
404 .field("migrations", &self.migrations)
405 .field("social_providers", &debug_social_providers(self))
406 .field("password_validators", &self.password_validators)
407 .field("state", &self.state.as_ref().map(|_| "<state>"))
408 .finish()
409 }
410}
411
412#[cfg(feature = "oauth")]
413fn debug_social_providers(plugin: &AuthPlugin) -> Vec<&str> {
414 plugin
415 .social_providers
416 .iter()
417 .map(|provider| provider.id())
418 .collect()
419}
420
421#[cfg(not(feature = "oauth"))]
422fn debug_social_providers(_plugin: &AuthPlugin) -> Vec<&'static str> {
423 Vec::new()
424}
425
426#[derive(Clone)]
427pub struct PluginMiddleware {
428 pub path: String,
429 pub handler: PluginMiddlewareHandler,
430}
431
432impl fmt::Debug for PluginMiddleware {
433 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
434 formatter
435 .debug_struct("PluginMiddleware")
436 .field("path", &self.path)
437 .field("handler", &"<middleware>")
438 .finish()
439 }
440}
441
442#[derive(Clone)]
443pub struct PluginAsyncMiddleware {
444 pub path: String,
445 pub handler: PluginAsyncMiddlewareHandler,
446}
447
448impl fmt::Debug for PluginAsyncMiddleware {
449 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
450 formatter
451 .debug_struct("PluginAsyncMiddleware")
452 .field("path", &self.path)
453 .field("handler", &"<async middleware>")
454 .finish()
455 }
456}
457
458pub enum PluginRequestAction {
459 Continue(PluginRequest),
460 Respond(PluginResponse),
461}